aboutsummaryrefslogtreecommitdiff
path: root/web/mastodon/websocket.go
blob: c867e6ae8fc976aeed3afbc0667d32369c674d01 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
package mastodon

import (
	"context"
	"encoding/json"
	"log/slog"
	"net/url"
	"time"

	"nhooyr.io/websocket"
)

// WSSubscribeRequest is a websocket instruction to subscribe to a streaming feed.
type WSSubscribeRequest struct {
	Type    string `json:"type"` // should be "subscribe" or "unsubscribe"
	Stream  string `json:"stream"`
	Hashtag string `json:"hashtag,omitempty"`
}

// WSMessage is a websocket message. Whenever you get something from the streaming service, it will fit into this box.
type WSMessage struct {
	Stream  []string `json:"stream"`
	Event   string   `json:"event"`
	Payload string   `json:"payload"` // json string
}

// StreamMessages is a low-level message streaming facility.
func (c *Client) StreamMessages(ctx context.Context, subreq ...WSSubscribeRequest) (chan WSMessage, error) {
	result := make(chan WSMessage, 10)

	u, err := c.server.Parse("/api/v1/streaming")
	if err != nil {
		return nil, err
	}

	switch u.Scheme {
	case "http":
		u.Scheme = "ws"
	case "https":
		u.Scheme = "wss"
	}

	q := u.Query()
	q.Set("access_token", c.token)
	u.RawQuery = q.Encode()

	go func(ctx context.Context) {
		for {
			select {
			case <-ctx.Done():
				return
			default:
			}

			if err := doWebsocket(ctx, u, result, subreq); err != nil {
				slog.Error("websocket error, retrying", "err", err)
			}
			time.Sleep(time.Minute)
		}
	}(ctx)

	return result, nil
}

func doWebsocket(ctx context.Context, u *url.URL, result chan WSMessage, subreq []WSSubscribeRequest) error {
	conn, _, err := websocket.Dial(ctx, u.String(), &websocket.DialOptions{})
	if err != nil {
		return err
	}
	defer conn.Close(websocket.StatusNormalClosure, "doWebsocket function returned")

	for _, sub := range subreq {
		data, err := json.Marshal(sub)
		if err != nil {
			return err
		}
		err = conn.Write(ctx, websocket.MessageText, data)
		if err != nil {
			return err
		}
	}

	for {
		select {
		case <-ctx.Done():
			return ctx.Err()

		default:
		}

		msgType, data, err := conn.Read(ctx)
		if err != nil {
			return err
		}

		if msgType != websocket.MessageText {
			slog.Debug("got non-text message from mastodon", "data", data)
			continue
		}

		var msg WSMessage
		err = json.Unmarshal(data, &msg)
		if err != nil {
			return err
		}

		result <- msg
	}
}