diff options
| author | Xe Iaso <me@xeiaso.net> | 2023-06-22 07:45:51 -0400 |
|---|---|---|
| committer | Xe Iaso <me@xeiaso.net> | 2023-06-22 07:45:51 -0400 |
| commit | f8251f93d4a500e6a5b82bdb045cd5e0c51a7140 (patch) | |
| tree | e6e100cd3c47624bbdb1ef913aedbf65ae69dc71 /cmd | |
| parent | 19fd40f5d6294d9154181b2340a91dbb9524e582 (diff) | |
| download | x-f8251f93d4a500e6a5b82bdb045cd5e0c51a7140.tar.xz x-f8251f93d4a500e6a5b82bdb045cd5e0c51a7140.zip | |
cmd/marabot: try some things to make the SQLite inserts fail less
Signed-off-by: Xe Iaso <me@xeiaso.net>
Diffstat (limited to 'cmd')
| -rw-r--r-- | cmd/marabot/.gitignore | 1 | ||||
| -rw-r--r-- | cmd/marabot/discord.go | 212 | ||||
| -rw-r--r-- | cmd/marabot/main.go | 81 | ||||
| -rw-r--r-- | cmd/marabot/revolt.go | 2 | ||||
| -rw-r--r-- | cmd/marabot/schema.sql | 2 |
5 files changed, 256 insertions, 42 deletions
diff --git a/cmd/marabot/.gitignore b/cmd/marabot/.gitignore index 9cd792a..657f4e1 100644 --- a/cmd/marabot/.gitignore +++ b/cmd/marabot/.gitignore @@ -2,3 +2,4 @@ *.db-shm *.db-wal .marabot.db-litestream +*.csv diff --git a/cmd/marabot/discord.go b/cmd/marabot/discord.go index 3d8e4a1..30bee83 100644 --- a/cmd/marabot/discord.go +++ b/cmd/marabot/discord.go @@ -59,7 +59,10 @@ func (mr *MaraRevolt) importDiscordData(ctx context.Context, db *sql.DB, dg *dis ln.Error(ctx, err, ln.Action("inserting emoji")) continue } - mr.attachmentPreprocess.Add([3]string{eURL, "emoji", ""}, len(eURL)) + + if err := mr.archiveAttachment(ctx, tx, eURL, "emoji", ""); err != nil { + return err + } } rows, err := tx.QueryContext(ctx, "SELECT url, message_id FROM discord_attachments WHERE url NOT IN ( SELECT url FROM s3_uploads )") @@ -71,7 +74,9 @@ func (mr *MaraRevolt) importDiscordData(ctx context.Context, db *sql.DB, dg *dis continue } - mr.attachmentPreprocess.Add([3]string{url, "attachments", messageID}, len(url)) + if err := mr.archiveAttachment(ctx, tx, url, "attachments", messageID); err != nil { + return err + } } } @@ -83,6 +88,9 @@ func (mr *MaraRevolt) importDiscordData(ctx context.Context, db *sql.DB, dg *dis } func (mr *MaraRevolt) DiscordMessageDelete(s *discordgo.Session, m *discordgo.MessageDelete) { + mr.lock.Lock() + defer mr.lock.Unlock() + ctx := opname.With(context.Background(), "marabot.discord-message-delete") tx, err := mr.db.Begin() @@ -128,28 +136,159 @@ func (mr *MaraRevolt) DiscordMessageDelete(s *discordgo.Session, m *discordgo.Me } func (mr *MaraRevolt) DiscordMessageEdit(s *discordgo.Session, m *discordgo.MessageUpdate) { + mr.lock.Lock() + defer mr.lock.Unlock() + if _, err := mr.db.Exec("UPDATE discord_messages SET content = ?, edited_at = ? WHERE id = ?", m.Content, time.Now().Format(time.RFC3339), m.ID); err != nil { ln.Error(context.Background(), err) } } func (mr *MaraRevolt) DiscordMessageCreate(s *discordgo.Session, m *discordgo.MessageCreate) { + mr.lock.Lock() + defer mr.lock.Unlock() + ctx, cancel := context.WithCancel(context.Background()) defer cancel() ctx = opname.With(ctx, "marabot.discordMessageCreate") - if err := mr.discordMessageCreate(ctx, s, m); err != nil { - ln.Error(context.Background(), err) + tx, err := mr.db.Begin() + if err != nil { + ln.Error(ctx, err) + return + } + defer tx.Rollback() + + if err := mr.discordMessageCreate(ctx, tx, s, m.Message); err != nil { + ln.Error(context.Background(), err, ln.F{ + "channel_id": m.ChannelID, + "message_id": m.ID, + }) + s.MessageReactionAdd(m.ChannelID, m.ID, "🔥") + } + + if err := tx.Commit(); err != nil { + ln.Error(context.Background(), err, ln.F{ + "channel_id": m.ChannelID, + "message_id": m.ID, + }) } } -func (mr *MaraRevolt) discordMessageCreate(ctx context.Context, s *discordgo.Session, m *discordgo.MessageCreate) error { +func (mr *MaraRevolt) DiscordReactionAdd(s *discordgo.Session, mra *discordgo.MessageReactionAdd) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + ctx = opname.With(ctx, "marabot.discord-reaction-add") + + if mra.Emoji.Name == "💾" { + go mr.backfillDiscordChannel(s, mra.ChannelID, mra.MessageID) + return + } + + m, err := s.ChannelMessage(mra.ChannelID, mra.MessageID) + if err != nil { + ln.Error(ctx, err) + return + } + defer s.MessageReactionRemove(m.ChannelID, m.ID, "🔥", "@me") + tx, err := mr.db.Begin() if err != nil { - return err + ln.Error(ctx, err) + return } defer tx.Rollback() + if err := mr.discordMessageCreate(ctx, tx, s, m); err != nil { + ln.Error(context.Background(), err, ln.F{ + "channel_id": m.ChannelID, + "message_id": m.ID, + }) + s.MessageReactionAdd(m.ChannelID, m.ID, "ðŸ˜") + } + + if err := tx.Commit(); err != nil { + ln.Error(ctx, err, ln.F{ + "channel_id": m.ChannelID, + "message_id": m.ID, + }) + } +} + +func (mr *MaraRevolt) doesDiscordMessageExist(ctx context.Context, tx *sql.Tx, messageID string) (bool, error) { + var count int + if err := tx.QueryRowContext(ctx, "SELECT COUNT(*) FROM discord_messages WHERE id = ?", messageID).Scan(&count); err != nil { + return false, err + } + + if count > 0 { + return true, nil + } + + return false, nil +} + +func (mr *MaraRevolt) backfillDiscordChannel(s *discordgo.Session, channelID, messageID string) { + curr := messageID + ctx := opname.With(context.Background(), "marabot.backfillDiscordChannel") + + ln.Log(ctx, ln.Action("archiving channel from message"), ln.F{"channel_id": channelID, "message_id": messageID}) + + tx, err := mr.db.Begin() + if err != nil { + ln.Error(ctx, err) + return + } + defer tx.Rollback() + + t := time.NewTicker(30 * time.Second) + defer t.Stop() + + done := false + + for range t.C { + ln.Log(ctx, ln.Action("fetching batch of messages"), ln.F{"curr": curr}) + msgs, err := s.ChannelMessages(channelID, 100, "", curr, "") + if err != nil { + ln.Error(ctx, err) + s.ChannelMessageSend(channelID, fmt.Sprintf("error getting messages past %s: %v", curr, err)) + break + } + + for _, msg := range msgs { + found, err := mr.doesDiscordMessageExist(ctx, tx, msg.ID) + if err != nil { + ln.Error(ctx, err, ln.F{"message_id": msg.ID}) + continue + } + + if found { + ln.Log(ctx, ln.Action("stopping archival")) + done = true + } + + if err := mr.discordMessageCreate(ctx, tx, s, msg); err != nil { + ln.Error(ctx, err, ln.F{"message_id": msg.ID}) + continue + } + + curr = msg.ID + } + + if done { + break + } + } + + if err := tx.Commit(); err != nil { + ln.Error(ctx, err, ln.F{ + "channel_id": channelID, + "message_id": messageID, + }) + } +} + +func (mr *MaraRevolt) discordMessageCreate(ctx context.Context, tx *sql.Tx, s *discordgo.Session, m *discordgo.Message) error { if _, err := tx.ExecContext(ctx, `INSERT INTO discord_users (id, username, avatar_url, accent_color) VALUES (?, ?, ?, ?) ON CONFLICT(id) @@ -157,56 +296,61 @@ DO UPDATE SET username = EXCLUDED.username, avatar_url = EXCLUDED.avatar_url, ac return err } - mr.attachmentPreprocess.Add([3]string{m.Author.AvatarURL(""), "avatars", ""}, len(m.Author.Avatar)) + if err := mr.archiveAttachment(ctx, tx, m.Author.AvatarURL(""), "avatars", ""); err != nil { + return err + } - if _, err := tx.ExecContext(ctx, `INSERT INTO discord_messages (id, guild_id, channel_id, author_id, content, created_at, edited_at, webhook_id) VALUES (?, ?, ?, ?, ?, ?, ?, ?)`, m.ID, m.GuildID, m.ChannelID, m.Author.ID, m.Content, m.Timestamp.Format(time.RFC3339), m.EditedTimestamp, m.WebhookID); err != nil { + if _, err := tx.ExecContext(ctx, `INSERT INTO discord_messages (id, guild_id, channel_id, author_id, content, created_at, edited_at, webhook_id) VALUES (?, ?, ?, ?, ?, ?, ?, ?) ON CONFLICT(id) DO UPDATE SET content = EXCLUDED.content`, m.ID, m.GuildID, m.ChannelID, m.Author.ID, m.Content, m.Timestamp.Format(time.RFC3339), m.EditedTimestamp, m.WebhookID); err != nil { return err } if m.WebhookID != "" { - if _, err := tx.ExecContext(ctx, "INSERT INTO discord_webhook_message_info (id, name, avatar_url) VALUES (?, ?, ?)", m.ID, m.Author.Username, m.Author.AvatarURL("")); err != nil { + if _, err := tx.ExecContext(ctx, "INSERT INTO discord_webhook_message_info (id, name, avatar_url) VALUES (?, ?, ?) ON CONFLICT DO NOTHING", m.ID, m.Author.Username, m.Author.AvatarURL("")); err != nil { return err } } for _, att := range m.Attachments { - if _, err := tx.ExecContext(ctx, `INSERT INTO discord_attachments (id, message_id, url, proxy_url, filename, content_type, width, height, size) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)`, att.ID, m.ID, att.URL, att.ProxyURL, att.Filename, att.ContentType, att.Width, att.Height, att.Size); err != nil { + if _, err := tx.ExecContext(ctx, `INSERT INTO discord_attachments (id, message_id, url, proxy_url, filename, content_type, width, height, size) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?) ON CONFLICT DO NOTHING`, att.ID, m.ID, att.URL, att.ProxyURL, att.Filename, att.ContentType, att.Width, att.Height, att.Size); err != nil { return err } - mr.attachmentPreprocess.Add([3]string{att.URL, "attachments", m.ID}, len(att.URL)) - } - - for _, emb := range m.Embeds { - if emb.Image == nil { - continue - } - if _, err := tx.ExecContext(ctx, `INSERT INTO discord_attachments (id, message_id, url, proxy_url, filename, content_type, width, height, size) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)`, uuid.NewString(), m.ID, emb.Image.URL, emb.Image.ProxyURL, filepath.Base(emb.Image.URL), "", emb.Image.Width, emb.Image.Height, 0); err != nil { + if err := mr.archiveAttachment(ctx, tx, att.URL, "attachments", m.ID); err != nil { return err } - mr.attachmentPreprocess.Add([3]string{emb.Image.URL, "attachments", m.ID}, len(emb.Image.URL)) - } + for _, emb := range m.Embeds { + if emb.Image == nil { + continue + } + if _, err := tx.ExecContext(ctx, `INSERT INTO discord_attachments (id, message_id, url, proxy_url, filename, content_type, width, height, size) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?) ON CONFLICT DO NOTHING`, uuid.NewString(), m.ID, emb.Image.URL, emb.Image.ProxyURL, filepath.Base(emb.Image.URL), "", emb.Image.Width, emb.Image.Height, 0); err != nil { + return err + } - ch, err := s.Channel(m.ChannelID) - if err != nil { - return err - } + if err := mr.archiveAttachment(ctx, tx, emb.Image.URL, "attachments", m.ID); err != nil { + return err + } + } - if _, err := tx.ExecContext(ctx, "INSERT INTO discord_channels (id, guild_id, name, topic, nsfw) VALUES (?, ?, ?, ?, ?) ON CONFLICT(id) DO UPDATE SET name = EXCLUDED.name, topic = EXCLUDED.topic, nsfw = EXCLUDED.nsfw", ch.ID, ch.GuildID, ch.Name, ch.Topic, ch.NSFW); err != nil { - return err - } + ch, err := s.Channel(m.ChannelID) + if err != nil { + return err + } - for _, emoji := range m.GetCustomEmojis() { - eURL := fmt.Sprintf("https://cdn.discordapp.com/emojis/%s?size=240&quality=lossless", emoji.ID) - if _, err := tx.ExecContext(ctx, "INSERT INTO discord_emoji (id, guild_id, name, url) VALUES (?, ?, ?, ?) ON CONFLICT(id) DO UPDATE SET name = EXCLUDED.name, url = EXCLUDED.url", emoji.ID, furryholeDiscord, emoji.Name, eURL); err != nil { + if _, err := tx.ExecContext(ctx, "INSERT INTO discord_channels (id, guild_id, name, topic, nsfw) VALUES (?, ?, ?, ?, ?) ON CONFLICT(id) DO UPDATE SET name = EXCLUDED.name, topic = EXCLUDED.topic, nsfw = EXCLUDED.nsfw", ch.ID, ch.GuildID, ch.Name, ch.Topic, ch.NSFW); err != nil { return err } - mr.attachmentPreprocess.Add([3]string{eURL, "emoji", ""}, len(eURL)) - } - if err := tx.Commit(); err != nil { - return err + for _, emoji := range m.GetCustomEmojis() { + eURL := fmt.Sprintf("https://cdn.discordapp.com/emojis/%s?size=240&quality=lossless", emoji.ID) + if _, err := tx.ExecContext(ctx, "INSERT INTO discord_emoji (id, guild_id, name, url) VALUES (?, ?, ?, ?) ON CONFLICT(id) DO UPDATE SET name = EXCLUDED.name, url = EXCLUDED.url", emoji.ID, furryholeDiscord, emoji.Name, eURL); err != nil { + return err + } + mr.attachmentPreprocess.Add([3]string{eURL, "emoji", ""}, len(eURL)) + if err := mr.archiveAttachment(ctx, tx, eURL, "emoji", m.ID); err != nil { + return err + } + } } return nil diff --git a/cmd/marabot/main.go b/cmd/marabot/main.go index 8c16d74..b1cf4cd 100644 --- a/cmd/marabot/main.go +++ b/cmd/marabot/main.go @@ -5,6 +5,7 @@ import ( "context" "crypto/sha512" "database/sql" + "database/sql/driver" _ "embed" "flag" "fmt" @@ -16,13 +17,12 @@ import ( "syscall" "time" - _ "modernc.org/sqlite" - "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/aws/session" "github.com/aws/aws-sdk-go/service/s3" "github.com/aws/aws-sdk-go/service/s3/s3manager" "github.com/bwmarrin/discordgo" + "github.com/tailscale/sqlite" "tailscale.com/hostinfo" "within.website/ln" "within.website/ln/opname" @@ -54,6 +54,19 @@ var ( dbSchema string ) +func openDB(fname string) (*sql.DB, error) { + db := sql.OpenDB(sqlite.Connector("file:"+fname, func(ctx context.Context, conn driver.ConnPrepareContext) error { + return sqlite.ExecScript(conn.(sqlite.SQLConn), dbSchema) + }, nil)) + + err := db.Ping() + if err != nil { + return nil, err + } + + return db, nil +} + func main() { internal.HandleStartup() @@ -64,16 +77,12 @@ func main() { ln.Log(ctx, ln.Action("starting up")) - db, err := sql.Open("sqlite", *dbFile) + db, err := openDB(*dbFile) if err != nil { ln.FatalErr(ctx, err, ln.Action("opening sqlite database")) } defer db.Close() - if _, err := db.ExecContext(ctx, dbSchema); err != nil { - ln.FatalErr(ctx, err, ln.Action("running database schema")) - } - ircmsgs := make(chan string, 10) // Init a new client. @@ -116,6 +125,7 @@ func main() { dg.AddHandler(mr.DiscordMessageCreate) dg.AddHandler(mr.DiscordMessageDelete) dg.AddHandler(mr.DiscordMessageEdit) + dg.AddHandler(mr.DiscordReactionAdd) if err := dg.Open(); err != nil { ln.FatalErr(ctx, err, ln.Action("opening discord client")) @@ -217,6 +227,63 @@ func (mr *MaraRevolt) preprocessLinks(ctx context.Context, data [][3]string) { } } +func (mr *MaraRevolt) archiveAttachment(ctx context.Context, tx *sql.Tx, link, kind, messageID string) error { + att, err := hashURL(link, kind) + if err != nil { + ln.Error(ctx, err, ln.F{"link": link, "kind": kind}) + + if werr, ok := err.(*web.Error); ok { + if werr.GotStatus == http.StatusNotFound { + tx.ExecContext(ctx, "DELETE FROM discord_users WHERE avatar_url = ?", link) + tx.ExecContext(ctx, "DELETE FROM discord_attachments WHERE url = ?", link) + tx.ExecContext(ctx, "DELETE FROM discord_emoji WHERE url = ?", link) + tx.ExecContext(ctx, "DELETE FROM revolt_attachments WHERE url = ?", link) + tx.ExecContext(ctx, "DELETE FROM revolt_users WHERE avatar_url = ?", link) + tx.ExecContext(ctx, "DELETE FROM revolt_emoji WHERE url = ?", link) + } else { + return err + } + } + } + + att.MessageID = aws.String(messageID) + + key := filepath.Join(att.Kind, att.ID) + + f := ln.F{"kind": att.Kind, "id": att.ID, "url": att.URL, "content_type": att.ContentType} + + var count int + if err := tx.QueryRowContext(ctx, "SELECT COUNT(*) FROM s3_uploads WHERE id = ?", att.ID).Scan(&count); err != nil { + ln.Error(ctx, err, f) + return err + } + + f["count"] = count + + if count != 0 { + return nil + } + + if _, err := mr.uploader.UploadWithContext(ctx, &s3manager.UploadInput{ + Bucket: aws.String(*awsS3Bucket), + Key: aws.String(key), + ContentType: aws.String(att.ContentType), + Body: bytes.NewBuffer(att.Data), + Metadata: map[string]*string{ + "Original-URL": aws.String(att.URL), + "Message-ID": att.MessageID, + }, + }); err != nil { + return err + } + + if _, err := tx.ExecContext(ctx, "INSERT INTO s3_uploads(id, url, kind, content_type, created_at, message_id) VALUES (?, ?, ?, ?, ?, ?)", att.ID, att.URL, att.Kind, att.ContentType, att.CreatedAt, att.MessageID); err != nil { + ln.Error(ctx, err, ln.Action("saving upload information to DB"), f) + } + + return nil +} + func hashURL(itemURL, kind string) (*Attachment, error) { resp, err := http.Get(itemURL) if err != nil { diff --git a/cmd/marabot/revolt.go b/cmd/marabot/revolt.go index 5abea81..3f870cf 100644 --- a/cmd/marabot/revolt.go +++ b/cmd/marabot/revolt.go @@ -5,6 +5,7 @@ import ( "database/sql" "fmt" "strings" + "sync" "time" "github.com/aws/aws-sdk-go/service/s3/s3iface" @@ -23,6 +24,7 @@ type MaraRevolt struct { attachmentUpload *bundler.Bundler[*Attachment] uploader *s3manager.Uploader s3 s3iface.S3API + lock sync.Mutex revolt.NullHandler } diff --git a/cmd/marabot/schema.sql b/cmd/marabot/schema.sql index 5c5f045..9709a4a 100644 --- a/cmd/marabot/schema.sql +++ b/cmd/marabot/schema.sql @@ -96,7 +96,7 @@ CREATE INDEX IF NOT EXISTS discord_emoji_url ON discord_emoji(url); CREATE TABLE IF NOT EXISTS irc_messages ( - id SERIAL PRIMARY KEY, + id SERIAL PRIMARY KEY AUTOINCREMENT, nick TEXT NOT NULL, user TEXT NOT NULL, host TEXT NOT NULL, |
