aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--cmd/mi/models/dao.go2
-rw-r--r--cmd/mi/models/switch.go6
-rw-r--r--cmd/mi/services/switchtracker/switchtracker.go3
-rw-r--r--cmd/mi/services/switchtracker/switchtracker_test.go124
-rw-r--r--cmd/mi/services/switchtracker/testdata/members.json10
5 files changed, 141 insertions, 4 deletions
diff --git a/cmd/mi/models/dao.go b/cmd/mi/models/dao.go
index 3241d5e..b00ed96 100644
--- a/cmd/mi/models/dao.go
+++ b/cmd/mi/models/dao.go
@@ -129,7 +129,7 @@ func (d *DAO) GetSwitch(ctx context.Context, id string) (*Switch, error) {
var sw Switch
if err := d.db.WithContext(ctx).
Joins("Member").
- Where("id = ?", id).
+ Where("switches.id = ?", id).
First(&sw).Error; err != nil {
return nil, err
}
diff --git a/cmd/mi/models/switch.go b/cmd/mi/models/switch.go
index ec85a1c..5a16dd6 100644
--- a/cmd/mi/models/switch.go
+++ b/cmd/mi/models/switch.go
@@ -17,7 +17,11 @@ type Switch struct {
}
// AsProto converts a Switch to its protobuf representation.
-func (s Switch) AsProto() *pb.Switch {
+func (s *Switch) AsProto() *pb.Switch {
+ if s == nil {
+ return nil
+ }
+
var endedAt string
if s.EndedAt != nil {
diff --git a/cmd/mi/services/switchtracker/switchtracker.go b/cmd/mi/services/switchtracker/switchtracker.go
index a7c1631..e7b2c5c 100644
--- a/cmd/mi/services/switchtracker/switchtracker.go
+++ b/cmd/mi/services/switchtracker/switchtracker.go
@@ -13,7 +13,6 @@ import (
)
type SwitchTracker struct {
- db *gorm.DB
dao *models.DAO
}
@@ -65,7 +64,7 @@ func (s *SwitchTracker) Switch(ctx context.Context, req *pb.SwitchReq) (*pb.Swit
slog.Error("can't switch front", "req", req, "err", err)
switch {
case errors.Is(err, models.ErrCantSwitchToYourself):
- twirp.InvalidArgumentError("member_name", "cannot switch to yourself").
+ return nil, twirp.InvalidArgumentError("member_name", "cannot switch to yourself").
WithMeta("member_name", req.GetMemberName())
case errors.Is(err, gorm.ErrRecordNotFound):
return nil, twirp.NotFoundError("can't find current switch")
diff --git a/cmd/mi/services/switchtracker/switchtracker_test.go b/cmd/mi/services/switchtracker/switchtracker_test.go
new file mode 100644
index 0000000..9def86d
--- /dev/null
+++ b/cmd/mi/services/switchtracker/switchtracker_test.go
@@ -0,0 +1,124 @@
+package switchtracker_test
+
+import (
+ "context"
+ "crypto/rand"
+ _ "embed"
+ "encoding/json"
+ "path/filepath"
+ "strings"
+ "testing"
+
+ "github.com/oklog/ulid/v2"
+ "within.website/x/cmd/mi/models"
+ "within.website/x/cmd/mi/services/switchtracker"
+ pb "within.website/x/proto/mi"
+)
+
+var (
+ //go:embed testdata/members.json
+ membersJSON []byte
+)
+
+func TestSwitch(t *testing.T) {
+ ctx, cancel := context.WithCancel(context.Background())
+ defer cancel()
+
+ dir := t.TempDir()
+ dao, err := models.New(filepath.Join(dir, "test.db"))
+ if err != nil {
+ t.Fatalf("failed to create dao: %v", err)
+ }
+
+ st := switchtracker.New(dao)
+
+ // Import members and create root switch
+
+ var members []*models.Member
+ if err := json.Unmarshal(membersJSON, &members); err != nil {
+ t.Fatalf("failed to unmarshal members: %v", err)
+ }
+
+ for _, m := range members {
+ if err := dao.DB().Create(m).Error; err != nil {
+ t.Fatalf("failed to create member: %v", err)
+ }
+ }
+
+ if err := dao.DB().Create(&models.Switch{
+ ID: ulid.MustNew(ulid.Now(), rand.Reader).String(),
+ MemberID: members[0].ID,
+ }).Error; err != nil {
+ t.Fatalf("failed to create root switch: %v", err)
+ }
+
+ resp, err := st.Members(ctx, nil)
+ if err != nil {
+ t.Errorf("failed to get members: %v", err)
+ }
+
+ if len(resp.Members) != len(members) {
+ t.Errorf("expected %d members, got %d", len(members), len(resp.Members))
+ }
+
+ front, err := st.WhoIsFront(ctx, nil)
+ if err != nil {
+ t.Errorf("failed to get front: %v", err)
+ }
+
+ if front.Member.Name != members[0].Name {
+ t.Errorf("expected front to be %s, got %s", members[0].Name, front.Member.Name)
+ }
+
+ _, err = st.Switch(ctx, &pb.SwitchReq{MemberName: members[1].Name})
+ if err != nil {
+ t.Errorf("failed to switch front: %v", err)
+ }
+
+ t.Log("trying to switch to current front")
+ front, err = st.WhoIsFront(ctx, nil)
+ if err != nil {
+ t.Errorf("failed to get front: %v", err)
+ }
+
+ if front.Member.Name != members[1].Name {
+ t.Errorf("expected front to be %s, got %s", members[1].Name, front.Member.Name)
+ }
+
+ front, err = st.WhoIsFront(ctx, nil)
+ if err != nil {
+ t.Errorf("failed to get front: %v", err)
+ }
+
+ _, err = st.Switch(ctx, &pb.SwitchReq{MemberName: front.Member.Name})
+ if err == nil {
+ t.Errorf("expected error, got nil")
+ }
+
+ if !strings.HasSuffix(err.Error(), "cannot switch to yourself") {
+ t.Errorf("expected error to be 'cannot switch to yourself', got %v", err)
+ }
+
+ switches, err := st.ListSwitches(ctx, &pb.ListSwitchesReq{
+ Count: 10,
+ })
+ if err != nil {
+ t.Errorf("failed to list switches: %v", err)
+ }
+
+ for _, s := range switches.Switches {
+ s := s
+ t.Run("get switch "+s.GetSwitch().GetId(), func(t *testing.T) {
+ fc, err := st.GetSwitch(ctx, &pb.GetSwitchReq{
+ Id: s.GetSwitch().GetId(),
+ })
+ if err != nil {
+ t.Errorf("failed to get switch: %v", err)
+ }
+
+ if fc.GetSwitch().GetId() != s.GetSwitch().GetId() {
+ t.Errorf("expected switch ID to be %s, got %s", s.GetSwitch().GetId(), fc.GetSwitch().GetId())
+ }
+ })
+ }
+}
diff --git a/cmd/mi/services/switchtracker/testdata/members.json b/cmd/mi/services/switchtracker/testdata/members.json
new file mode 100644
index 0000000..0b48bc3
--- /dev/null
+++ b/cmd/mi/services/switchtracker/testdata/members.json
@@ -0,0 +1,10 @@
+[
+ {
+ "ID": 1,
+ "Name": "Jessie"
+ },
+ {
+ "ID": 2,
+ "Name": "Sephie"
+ }
+] \ No newline at end of file