aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--mastodon/robocadey/gpt2/.gitignore1
-rwxr-xr-xmastodon/robocadey/gpt2/main.py41
-rw-r--r--mastodon/robocadey/main.go76
3 files changed, 52 insertions, 66 deletions
diff --git a/mastodon/robocadey/gpt2/.gitignore b/mastodon/robocadey/gpt2/.gitignore
new file mode 100644
index 0000000..92b189f
--- /dev/null
+++ b/mastodon/robocadey/gpt2/.gitignore
@@ -0,0 +1 @@
+checkpoint
diff --git a/mastodon/robocadey/gpt2/main.py b/mastodon/robocadey/gpt2/main.py
index 82d4e5e..750f338 100755
--- a/mastodon/robocadey/gpt2/main.py
+++ b/mastodon/robocadey/gpt2/main.py
@@ -7,36 +7,27 @@ import socket
import sys
from datetime import datetime
-sockpath = "/xe/gpt2/checkpoint/server.sock"
-
sess = gpt2.start_tf_sess()
gpt2.load_gpt2(sess, run_name='run1')
-if os.path.exists(sockpath):
- os.remove(sockpath)
-
-sock = socket.socket(socket.AF_UNIX)
-sock.bind(sockpath)
+SYSTEMD_FIRST_SOCKET_FD = 3
+sock = socket.fromfd(SYSTEMD_FIRST_SOCKET_FD, socket.AF_UNIX, socket.SOCK_STREAM)
-print("Listening on", sockpath)
sock.listen(1)
while True:
connection, client_address = sock.accept()
- try:
- print("generating shitpost")
- result = gpt2.generate(sess,
- length=512,
- temperature=0.8,
- nsamples=1,
- batch_size=1,
- return_as_list=True,
- top_p=0.9,
- )[0].split("\n")[1:][:-1]
- print("shitpost generated")
- connection.send(json.dumps(result).encode())
- finally:
- connection.close()
-
-server.close()
-os.remove(sockpath)
+ print("generating shitpost")
+ result = gpt2.generate(sess,
+ length=512,
+ temperature=0.8,
+ nsamples=1,
+ batch_size=1,
+ return_as_list=True,
+ top_p=0.9,
+ )[0].split("\n")[1:][:-1]
+ print("shitpost generated")
+ connection.send(json.dumps(result).encode())
+ connection.close()
+
+sock.close()
diff --git a/mastodon/robocadey/main.go b/mastodon/robocadey/main.go
index 805cd5d..09e81ca 100644
--- a/mastodon/robocadey/main.go
+++ b/mastodon/robocadey/main.go
@@ -4,15 +4,13 @@ import (
"context"
"encoding/json"
"flag"
- "fmt"
"math/rand"
- "os"
+ "net"
"time"
"github.com/McKael/madon/v2"
"within.website/ln"
"within.website/x/internal"
- "within.website/x/markov"
)
var (
@@ -20,44 +18,46 @@ var (
appID = flag.String("app-id", "", "oauth2 app id")
appSecret = flag.String("app-secret", "", "oauth2 app secret")
token = flag.String("token", "", "oauth2 token")
- state = flag.String("state", "./robocadey.gob", "state file")
- readFrom = flag.String("read-from", "", "if set, read from this JSON file")
+ sockPath = flag.String("gpt2-sock", "/run/robocadey-gpt2.sock", "path to unix socket for robocadey-gpt2")
)
var scopes = []string{"read", "write", "follow"}
-func main() {
- internal.HandleStartup()
- ctx, cancel := context.WithCancel(context.Background())
- defer cancel()
+func getShitposts(sockPath string) ([]string, error) {
+ var conn net.Conn
+ var err error
+ if sockPath != "" {
+ conn, err = net.Dial("unix", sockPath)
+ } else {
+ conn, err = net.Dial("tcp", "[::1]:9999")
+ }
- if *readFrom != "" {
- os.Remove(*state)
- fin, err := os.Open(*readFrom)
- if err != nil {
- ln.FatalErr(ctx, err)
- }
- defer fin.Close()
+ if err != nil {
+ return nil, err
+ }
+ defer conn.Close()
+ var result []string
+ err = json.NewDecoder(conn).Decode(&result)
+ if err != nil {
+ return nil, err
+ }
- var lines []string
- c := markov.NewChain(3)
- err = json.NewDecoder(fin).Decode(&lines)
- if err != nil {
- ln.FatalErr(ctx, err)
- }
+ return result, nil
+}
- for _, line := range lines {
- c.Write(line)
- }
+func getShitpost(ctx context.Context) string {
+ shitposts, err := getShitposts(*sockPath)
+ if err != nil {
+ ln.FatalErr(ctx, err)
+ }
- err = c.Save(*state)
- if err != nil {
- ln.FatalErr(ctx, err)
- }
+ return shitposts[rand.Intn(len(shitposts))]
+}
- fmt.Println("data imported successfully")
- return
- }
+func main() {
+ internal.HandleStartup()
+ ctx, cancel := context.WithCancel(context.Background())
+ defer cancel()
c, err := madon.RestoreApp("furry boost bot", *instance, *appID, *appSecret, &madon.UserToken{AccessToken: *token})
if err != nil {
@@ -65,16 +65,10 @@ func main() {
}
_ = c
- chain := markov.NewChain(3)
- err = chain.Load(*state)
- if err != nil {
- ln.FatalErr(ctx, err)
- }
-
rand.Seed(time.Now().UnixMicro())
if _, err := c.PostStatus(madon.PostStatusParams{
- Text: chain.Generate(150),
+ Text: getShitpost(ctx),
}); err != nil {
ln.FatalErr(ctx, err)
}
@@ -102,7 +96,7 @@ func main() {
case <-t:
if _, err := c.PostStatus(madon.PostStatusParams{
- Text: chain.Generate(150),
+ Text: getShitpost(ctx),
}); err != nil {
ln.FatalErr(ctx, err)
}
@@ -120,7 +114,7 @@ func main() {
"target": n.Account.Acct,
})
if _, err := c.PostStatus(madon.PostStatusParams{
- Text: "@" + n.Account.Acct + " " + chain.Generate(150),
+ Text: "@" + n.Account.Acct + " " + getShitpost(ctx),
InReplyTo: n.Status.ID,
}); err != nil {
ln.FatalErr(ctx, err)