aboutsummaryrefslogtreecommitdiff
path: root/cmd
diff options
context:
space:
mode:
authorsoopyc <me@soopy.moe>2025-03-21 22:58:05 +0800
committerGitHub <noreply@github.com>2025-03-21 10:58:05 -0400
commit1c00431098247d696dbe4ff06b2bc5e036230e8c (patch)
tree6a972767107a7fd86cd0ed8c9adeffe1494fd5de /cmd
parentd93adbc11194303427b2f8f8d7db849fb67d6c1a (diff)
downloadanubis-1c00431098247d696dbe4ff06b2bc5e036230e8c.tar.xz
anubis-1c00431098247d696dbe4ff06b2bc5e036230e8c.zip
general unix domain sockets support (#45)
* feat: allow binding to unix domain sockets this is useful when the user does not want to expose more tcp ports than needed. also simplifes configuration in some situation, like with nixos modules as the socket paths can be automatically configured. docs updated with additional configuration flags. Signed-off-by: Cassie Cheung <me@soopy.moe> * feat: graceful shutdown and cleanup on signal this is needed to clean up left-over unix sockets, else on the next boot listener panics with `address already in use`. Co-authored-by: cat <cat@gensokyo.uk> Signed-off-by: Cassie Cheung <me@soopy.moe> * feat: support unix socket upstream targets adds support for proxying unix socket upstreams, essentially allowing anubis to run without listening on tcp sockets at all*. *for metrics, neither prometheus and victoriametrics supports scraping from unix sockets. if metrics are desired, tcp sockets are still needed. Co-authored-by: cat <cat@gensokyo.uk> Signed-off-by: Cassie Cheung <me@soopy.moe> * docs: add changelog entry --------- Signed-off-by: Cassie Cheung <me@soopy.moe> Co-authored-by: cat <cat@gensokyo.uk>
Diffstat (limited to 'cmd')
-rw-r--r--cmd/anubis/main.go132
1 files changed, 123 insertions, 9 deletions
diff --git a/cmd/anubis/main.go b/cmd/anubis/main.go
index 52b38d8..f2f7255 100644
--- a/cmd/anubis/main.go
+++ b/cmd/anubis/main.go
@@ -1,6 +1,7 @@
package main
import (
+ "context"
"crypto/ed25519"
"crypto/rand"
"crypto/sha256"
@@ -15,12 +16,16 @@ import (
"log/slog"
"math"
mrand "math/rand"
+ "net"
"net/http"
"net/http/httputil"
"net/url"
"os"
+ "os/signal"
"strconv"
"strings"
+ "sync"
+ "syscall"
"time"
"github.com/TecharoHQ/anubis"
@@ -37,9 +42,12 @@ import (
)
var (
- bind = flag.String("bind", ":8923", "TCP port to bind HTTP to")
+ bind = flag.String("bind", ":8923", "network address to bind HTTP to")
+ bindNetwork = flag.String("bind-network", "tcp", "network family to bind HTTP to, e.g. unix, tcp")
challengeDifficulty = flag.Int("difficulty", 4, "difficulty of the challenge")
- metricsBind = flag.String("metrics-bind", ":9090", "TCP port to bind metrics to")
+ metricsBind = flag.String("metrics-bind", ":9090", "network address to bind metrics to")
+ metricsBindNetwork = flag.String("metrics-bind-network", "tcp", "network family for the metrics server to bind to")
+ socketMode = flag.String("socket-mode", "0770", "socket mode (permissions) for unix domain sockets.")
robotsTxt = flag.Bool("serve-robots-txt", false, "serve a robots.txt file that disallows all robots")
policyFname = flag.String("policy-fname", "", "full path to anubis policy document (defaults to a sensible built-in policy)")
slogLevel = flag.String("slog-level", "INFO", "logging level (see https://pkg.go.dev/log/slog#hdr-Levels)")
@@ -101,6 +109,40 @@ func doHealthCheck() error {
return nil
}
+func setupListener(network string, address string) (net.Listener, string) {
+ formattedAddress := ""
+ switch network {
+ case "unix":
+ formattedAddress = "unix:" + address
+ case "tcp":
+ formattedAddress = "http://localhost" + address
+ default:
+ formattedAddress = fmt.Sprintf(`(%s) %s`, network, address)
+ }
+
+ listener, err := net.Listen(network, address)
+ if err != nil {
+ log.Fatal(fmt.Errorf("failed to bind to %s: %w", formattedAddress, err))
+ }
+
+ // additional permission handling for unix sockets
+ if network == "unix" {
+ mode, err := strconv.ParseUint(*socketMode, 8, 0)
+ if err != nil {
+ listener.Close()
+ log.Fatal(fmt.Errorf("could not parse socket mode %s: %w", *socketMode, err))
+ }
+
+ err = os.Chmod(address, os.FileMode(mode))
+ if err != nil {
+ listener.Close()
+ log.Fatal(fmt.Errorf("could not change socket mode: %w", err))
+ }
+ }
+
+ return listener, formattedAddress
+}
+
func main() {
flagenv.Parse()
flag.Parse()
@@ -155,20 +197,59 @@ func main() {
})
}
+ wg := new(sync.WaitGroup)
+ // install signal handler
+ ctx, stop := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM)
+ defer stop()
+
if *metricsBind != "" {
- go metricsServer()
+ wg.Add(1)
+ go metricsServer(ctx, wg.Done)
}
mux.HandleFunc("/", s.maybeReverseProxy)
- slog.Info("listening", "url", "http://localhost"+*bind, "difficulty", *challengeDifficulty, "serveRobotsTXT", *robotsTxt, "target", *target, "version", anubis.Version)
- log.Fatal(http.ListenAndServe(*bind, mux))
+ srv := http.Server{Handler: mux}
+ listener, url := setupListener(*bindNetwork, *bind)
+ slog.Info("listening", "url", url, "difficulty", *challengeDifficulty, "serveRobotsTXT", *robotsTxt, "target", *target, "version", anubis.Version)
+
+ go func() {
+ <-ctx.Done()
+ c, cancel := context.WithTimeout(context.Background(), 5*time.Second)
+ defer cancel()
+ if err := srv.Shutdown(c); err != nil {
+ log.Printf("cannot shut down: %v", err)
+ }
+ }()
+
+ if err := srv.Serve(listener); err != http.ErrServerClosed {
+ log.Fatal(err)
+ }
+ wg.Wait()
}
-func metricsServer() {
- http.DefaultServeMux.Handle("/metrics", promhttp.Handler())
- slog.Debug("listening for metrics", "url", "http://localhost"+*metricsBind)
- log.Fatal(http.ListenAndServe(*metricsBind, nil))
+func metricsServer(ctx context.Context, done func()) {
+ defer done()
+
+ mux := http.NewServeMux()
+ mux.Handle("/metrics", promhttp.Handler())
+
+ srv := http.Server{Handler: mux}
+ listener, url := setupListener(*metricsBindNetwork, *metricsBind)
+ slog.Debug("listening for metrics", "url", url)
+
+ go func() {
+ <-ctx.Done()
+ c, cancel := context.WithTimeout(context.Background(), 5*time.Second)
+ defer cancel()
+ if err := srv.Shutdown(c); err != nil {
+ log.Printf("cannot shut down: %v", err)
+ }
+ }()
+
+ if err := srv.Serve(listener); err != http.ErrServerClosed {
+ log.Fatal(err)
+ }
}
func sha256sum(text string) (string, error) {
@@ -207,7 +288,24 @@ func New(target, policyFname string) (*Server, error) {
return nil, fmt.Errorf("failed to generate ed25519 key: %w", err)
}
+ transport := http.DefaultTransport.(*http.Transport).Clone()
+
+ // https://github.com/oauth2-proxy/oauth2-proxy/blob/4e2100a2879ef06aea1411790327019c1a09217c/pkg/upstream/http.go#L124
+ if u.Scheme == "unix" {
+ // clean path up so we don't use the socket path in proxied requests
+ addr := u.Path
+ u.Path = ""
+ // tell transport how to dial unix sockets
+ transport.DialContext = func(ctx context.Context, _, _ string) (net.Conn, error) {
+ dialer := net.Dialer{}
+ return dialer.DialContext(ctx, "unix", addr)
+ }
+ // tell transport how to handle the unix url scheme
+ transport.RegisterProtocol("unix", unixRoundTripper{Transport: transport})
+ }
+
rp := httputil.NewSingleHostReverseProxy(u)
+ rp.Transport = transport
var fin io.ReadCloser
@@ -240,6 +338,22 @@ func New(target, policyFname string) (*Server, error) {
}, nil
}
+// https://github.com/oauth2-proxy/oauth2-proxy/blob/master/pkg/upstream/http.go#L124
+type unixRoundTripper struct {
+ Transport *http.Transport
+}
+
+// set bare minimum stuff
+func (t unixRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
+ req = req.Clone(req.Context())
+ if req.Host == "" {
+ req.Host = "localhost"
+ }
+ req.URL.Host = req.Host // proxy error: no Host in request URL
+ req.URL.Scheme = "http" // make http.Transport happy and avoid an infinite recursion
+ return t.Transport.RoundTrip(req)
+}
+
type Server struct {
rp *httputil.ReverseProxy
priv ed25519.PrivateKey