diff options
| -rw-r--r-- | mastodon/robocadey/gpt2/.gitignore | 1 | ||||
| -rwxr-xr-x | mastodon/robocadey/gpt2/main.py | 41 | ||||
| -rw-r--r-- | mastodon/robocadey/main.go | 76 |
3 files changed, 52 insertions, 66 deletions
diff --git a/mastodon/robocadey/gpt2/.gitignore b/mastodon/robocadey/gpt2/.gitignore new file mode 100644 index 0000000..92b189f --- /dev/null +++ b/mastodon/robocadey/gpt2/.gitignore @@ -0,0 +1 @@ +checkpoint diff --git a/mastodon/robocadey/gpt2/main.py b/mastodon/robocadey/gpt2/main.py index 82d4e5e..750f338 100755 --- a/mastodon/robocadey/gpt2/main.py +++ b/mastodon/robocadey/gpt2/main.py @@ -7,36 +7,27 @@ import socket import sys from datetime import datetime -sockpath = "/xe/gpt2/checkpoint/server.sock" - sess = gpt2.start_tf_sess() gpt2.load_gpt2(sess, run_name='run1') -if os.path.exists(sockpath): - os.remove(sockpath) - -sock = socket.socket(socket.AF_UNIX) -sock.bind(sockpath) +SYSTEMD_FIRST_SOCKET_FD = 3 +sock = socket.fromfd(SYSTEMD_FIRST_SOCKET_FD, socket.AF_UNIX, socket.SOCK_STREAM) -print("Listening on", sockpath) sock.listen(1) while True: connection, client_address = sock.accept() - try: - print("generating shitpost") - result = gpt2.generate(sess, - length=512, - temperature=0.8, - nsamples=1, - batch_size=1, - return_as_list=True, - top_p=0.9, - )[0].split("\n")[1:][:-1] - print("shitpost generated") - connection.send(json.dumps(result).encode()) - finally: - connection.close() - -server.close() -os.remove(sockpath) + print("generating shitpost") + result = gpt2.generate(sess, + length=512, + temperature=0.8, + nsamples=1, + batch_size=1, + return_as_list=True, + top_p=0.9, + )[0].split("\n")[1:][:-1] + print("shitpost generated") + connection.send(json.dumps(result).encode()) + connection.close() + +sock.close() diff --git a/mastodon/robocadey/main.go b/mastodon/robocadey/main.go index 805cd5d..09e81ca 100644 --- a/mastodon/robocadey/main.go +++ b/mastodon/robocadey/main.go @@ -4,15 +4,13 @@ import ( "context" "encoding/json" "flag" - "fmt" "math/rand" - "os" + "net" "time" "github.com/McKael/madon/v2" "within.website/ln" "within.website/x/internal" - "within.website/x/markov" ) var ( @@ -20,44 +18,46 @@ var ( appID = flag.String("app-id", "", "oauth2 app id") appSecret = flag.String("app-secret", "", "oauth2 app secret") token = flag.String("token", "", "oauth2 token") - state = flag.String("state", "./robocadey.gob", "state file") - readFrom = flag.String("read-from", "", "if set, read from this JSON file") + sockPath = flag.String("gpt2-sock", "/run/robocadey-gpt2.sock", "path to unix socket for robocadey-gpt2") ) var scopes = []string{"read", "write", "follow"} -func main() { - internal.HandleStartup() - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() +func getShitposts(sockPath string) ([]string, error) { + var conn net.Conn + var err error + if sockPath != "" { + conn, err = net.Dial("unix", sockPath) + } else { + conn, err = net.Dial("tcp", "[::1]:9999") + } - if *readFrom != "" { - os.Remove(*state) - fin, err := os.Open(*readFrom) - if err != nil { - ln.FatalErr(ctx, err) - } - defer fin.Close() + if err != nil { + return nil, err + } + defer conn.Close() + var result []string + err = json.NewDecoder(conn).Decode(&result) + if err != nil { + return nil, err + } - var lines []string - c := markov.NewChain(3) - err = json.NewDecoder(fin).Decode(&lines) - if err != nil { - ln.FatalErr(ctx, err) - } + return result, nil +} - for _, line := range lines { - c.Write(line) - } +func getShitpost(ctx context.Context) string { + shitposts, err := getShitposts(*sockPath) + if err != nil { + ln.FatalErr(ctx, err) + } - err = c.Save(*state) - if err != nil { - ln.FatalErr(ctx, err) - } + return shitposts[rand.Intn(len(shitposts))] +} - fmt.Println("data imported successfully") - return - } +func main() { + internal.HandleStartup() + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() c, err := madon.RestoreApp("furry boost bot", *instance, *appID, *appSecret, &madon.UserToken{AccessToken: *token}) if err != nil { @@ -65,16 +65,10 @@ func main() { } _ = c - chain := markov.NewChain(3) - err = chain.Load(*state) - if err != nil { - ln.FatalErr(ctx, err) - } - rand.Seed(time.Now().UnixMicro()) if _, err := c.PostStatus(madon.PostStatusParams{ - Text: chain.Generate(150), + Text: getShitpost(ctx), }); err != nil { ln.FatalErr(ctx, err) } @@ -102,7 +96,7 @@ func main() { case <-t: if _, err := c.PostStatus(madon.PostStatusParams{ - Text: chain.Generate(150), + Text: getShitpost(ctx), }); err != nil { ln.FatalErr(ctx, err) } @@ -120,7 +114,7 @@ func main() { "target": n.Account.Acct, }) if _, err := c.PostStatus(madon.PostStatusParams{ - Text: "@" + n.Account.Acct + " " + chain.Generate(150), + Text: "@" + n.Account.Acct + " " + getShitpost(ctx), InReplyTo: n.Status.ID, }); err != nil { ln.FatalErr(ctx, err) |
