diff options
| -rw-r--r-- | cmd/mi/models/dao.go | 2 | ||||
| -rw-r--r-- | cmd/mi/models/switch.go | 6 | ||||
| -rw-r--r-- | cmd/mi/services/switchtracker/switchtracker.go | 3 | ||||
| -rw-r--r-- | cmd/mi/services/switchtracker/switchtracker_test.go | 124 | ||||
| -rw-r--r-- | cmd/mi/services/switchtracker/testdata/members.json | 10 |
5 files changed, 141 insertions, 4 deletions
diff --git a/cmd/mi/models/dao.go b/cmd/mi/models/dao.go index 3241d5e..b00ed96 100644 --- a/cmd/mi/models/dao.go +++ b/cmd/mi/models/dao.go @@ -129,7 +129,7 @@ func (d *DAO) GetSwitch(ctx context.Context, id string) (*Switch, error) { var sw Switch if err := d.db.WithContext(ctx). Joins("Member"). - Where("id = ?", id). + Where("switches.id = ?", id). First(&sw).Error; err != nil { return nil, err } diff --git a/cmd/mi/models/switch.go b/cmd/mi/models/switch.go index ec85a1c..5a16dd6 100644 --- a/cmd/mi/models/switch.go +++ b/cmd/mi/models/switch.go @@ -17,7 +17,11 @@ type Switch struct { } // AsProto converts a Switch to its protobuf representation. -func (s Switch) AsProto() *pb.Switch { +func (s *Switch) AsProto() *pb.Switch { + if s == nil { + return nil + } + var endedAt string if s.EndedAt != nil { diff --git a/cmd/mi/services/switchtracker/switchtracker.go b/cmd/mi/services/switchtracker/switchtracker.go index a7c1631..e7b2c5c 100644 --- a/cmd/mi/services/switchtracker/switchtracker.go +++ b/cmd/mi/services/switchtracker/switchtracker.go @@ -13,7 +13,6 @@ import ( ) type SwitchTracker struct { - db *gorm.DB dao *models.DAO } @@ -65,7 +64,7 @@ func (s *SwitchTracker) Switch(ctx context.Context, req *pb.SwitchReq) (*pb.Swit slog.Error("can't switch front", "req", req, "err", err) switch { case errors.Is(err, models.ErrCantSwitchToYourself): - twirp.InvalidArgumentError("member_name", "cannot switch to yourself"). + return nil, twirp.InvalidArgumentError("member_name", "cannot switch to yourself"). WithMeta("member_name", req.GetMemberName()) case errors.Is(err, gorm.ErrRecordNotFound): return nil, twirp.NotFoundError("can't find current switch") diff --git a/cmd/mi/services/switchtracker/switchtracker_test.go b/cmd/mi/services/switchtracker/switchtracker_test.go new file mode 100644 index 0000000..9def86d --- /dev/null +++ b/cmd/mi/services/switchtracker/switchtracker_test.go @@ -0,0 +1,124 @@ +package switchtracker_test + +import ( + "context" + "crypto/rand" + _ "embed" + "encoding/json" + "path/filepath" + "strings" + "testing" + + "github.com/oklog/ulid/v2" + "within.website/x/cmd/mi/models" + "within.website/x/cmd/mi/services/switchtracker" + pb "within.website/x/proto/mi" +) + +var ( + //go:embed testdata/members.json + membersJSON []byte +) + +func TestSwitch(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + dir := t.TempDir() + dao, err := models.New(filepath.Join(dir, "test.db")) + if err != nil { + t.Fatalf("failed to create dao: %v", err) + } + + st := switchtracker.New(dao) + + // Import members and create root switch + + var members []*models.Member + if err := json.Unmarshal(membersJSON, &members); err != nil { + t.Fatalf("failed to unmarshal members: %v", err) + } + + for _, m := range members { + if err := dao.DB().Create(m).Error; err != nil { + t.Fatalf("failed to create member: %v", err) + } + } + + if err := dao.DB().Create(&models.Switch{ + ID: ulid.MustNew(ulid.Now(), rand.Reader).String(), + MemberID: members[0].ID, + }).Error; err != nil { + t.Fatalf("failed to create root switch: %v", err) + } + + resp, err := st.Members(ctx, nil) + if err != nil { + t.Errorf("failed to get members: %v", err) + } + + if len(resp.Members) != len(members) { + t.Errorf("expected %d members, got %d", len(members), len(resp.Members)) + } + + front, err := st.WhoIsFront(ctx, nil) + if err != nil { + t.Errorf("failed to get front: %v", err) + } + + if front.Member.Name != members[0].Name { + t.Errorf("expected front to be %s, got %s", members[0].Name, front.Member.Name) + } + + _, err = st.Switch(ctx, &pb.SwitchReq{MemberName: members[1].Name}) + if err != nil { + t.Errorf("failed to switch front: %v", err) + } + + t.Log("trying to switch to current front") + front, err = st.WhoIsFront(ctx, nil) + if err != nil { + t.Errorf("failed to get front: %v", err) + } + + if front.Member.Name != members[1].Name { + t.Errorf("expected front to be %s, got %s", members[1].Name, front.Member.Name) + } + + front, err = st.WhoIsFront(ctx, nil) + if err != nil { + t.Errorf("failed to get front: %v", err) + } + + _, err = st.Switch(ctx, &pb.SwitchReq{MemberName: front.Member.Name}) + if err == nil { + t.Errorf("expected error, got nil") + } + + if !strings.HasSuffix(err.Error(), "cannot switch to yourself") { + t.Errorf("expected error to be 'cannot switch to yourself', got %v", err) + } + + switches, err := st.ListSwitches(ctx, &pb.ListSwitchesReq{ + Count: 10, + }) + if err != nil { + t.Errorf("failed to list switches: %v", err) + } + + for _, s := range switches.Switches { + s := s + t.Run("get switch "+s.GetSwitch().GetId(), func(t *testing.T) { + fc, err := st.GetSwitch(ctx, &pb.GetSwitchReq{ + Id: s.GetSwitch().GetId(), + }) + if err != nil { + t.Errorf("failed to get switch: %v", err) + } + + if fc.GetSwitch().GetId() != s.GetSwitch().GetId() { + t.Errorf("expected switch ID to be %s, got %s", s.GetSwitch().GetId(), fc.GetSwitch().GetId()) + } + }) + } +} diff --git a/cmd/mi/services/switchtracker/testdata/members.json b/cmd/mi/services/switchtracker/testdata/members.json new file mode 100644 index 0000000..0b48bc3 --- /dev/null +++ b/cmd/mi/services/switchtracker/testdata/members.json @@ -0,0 +1,10 @@ +[ + { + "ID": 1, + "Name": "Jessie" + }, + { + "ID": 2, + "Name": "Sephie" + } +]
\ No newline at end of file |
