diff options
Diffstat (limited to 'cmd')
| -rw-r--r-- | cmd/anubis/main.go | 132 |
1 files changed, 123 insertions, 9 deletions
diff --git a/cmd/anubis/main.go b/cmd/anubis/main.go index 52b38d8..f2f7255 100644 --- a/cmd/anubis/main.go +++ b/cmd/anubis/main.go @@ -1,6 +1,7 @@ package main import ( + "context" "crypto/ed25519" "crypto/rand" "crypto/sha256" @@ -15,12 +16,16 @@ import ( "log/slog" "math" mrand "math/rand" + "net" "net/http" "net/http/httputil" "net/url" "os" + "os/signal" "strconv" "strings" + "sync" + "syscall" "time" "github.com/TecharoHQ/anubis" @@ -37,9 +42,12 @@ import ( ) var ( - bind = flag.String("bind", ":8923", "TCP port to bind HTTP to") + bind = flag.String("bind", ":8923", "network address to bind HTTP to") + bindNetwork = flag.String("bind-network", "tcp", "network family to bind HTTP to, e.g. unix, tcp") challengeDifficulty = flag.Int("difficulty", 4, "difficulty of the challenge") - metricsBind = flag.String("metrics-bind", ":9090", "TCP port to bind metrics to") + metricsBind = flag.String("metrics-bind", ":9090", "network address to bind metrics to") + metricsBindNetwork = flag.String("metrics-bind-network", "tcp", "network family for the metrics server to bind to") + socketMode = flag.String("socket-mode", "0770", "socket mode (permissions) for unix domain sockets.") robotsTxt = flag.Bool("serve-robots-txt", false, "serve a robots.txt file that disallows all robots") policyFname = flag.String("policy-fname", "", "full path to anubis policy document (defaults to a sensible built-in policy)") slogLevel = flag.String("slog-level", "INFO", "logging level (see https://pkg.go.dev/log/slog#hdr-Levels)") @@ -101,6 +109,40 @@ func doHealthCheck() error { return nil } +func setupListener(network string, address string) (net.Listener, string) { + formattedAddress := "" + switch network { + case "unix": + formattedAddress = "unix:" + address + case "tcp": + formattedAddress = "http://localhost" + address + default: + formattedAddress = fmt.Sprintf(`(%s) %s`, network, address) + } + + listener, err := net.Listen(network, address) + if err != nil { + log.Fatal(fmt.Errorf("failed to bind to %s: %w", formattedAddress, err)) + } + + // additional permission handling for unix sockets + if network == "unix" { + mode, err := strconv.ParseUint(*socketMode, 8, 0) + if err != nil { + listener.Close() + log.Fatal(fmt.Errorf("could not parse socket mode %s: %w", *socketMode, err)) + } + + err = os.Chmod(address, os.FileMode(mode)) + if err != nil { + listener.Close() + log.Fatal(fmt.Errorf("could not change socket mode: %w", err)) + } + } + + return listener, formattedAddress +} + func main() { flagenv.Parse() flag.Parse() @@ -155,20 +197,59 @@ func main() { }) } + wg := new(sync.WaitGroup) + // install signal handler + ctx, stop := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM) + defer stop() + if *metricsBind != "" { - go metricsServer() + wg.Add(1) + go metricsServer(ctx, wg.Done) } mux.HandleFunc("/", s.maybeReverseProxy) - slog.Info("listening", "url", "http://localhost"+*bind, "difficulty", *challengeDifficulty, "serveRobotsTXT", *robotsTxt, "target", *target, "version", anubis.Version) - log.Fatal(http.ListenAndServe(*bind, mux)) + srv := http.Server{Handler: mux} + listener, url := setupListener(*bindNetwork, *bind) + slog.Info("listening", "url", url, "difficulty", *challengeDifficulty, "serveRobotsTXT", *robotsTxt, "target", *target, "version", anubis.Version) + + go func() { + <-ctx.Done() + c, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + if err := srv.Shutdown(c); err != nil { + log.Printf("cannot shut down: %v", err) + } + }() + + if err := srv.Serve(listener); err != http.ErrServerClosed { + log.Fatal(err) + } + wg.Wait() } -func metricsServer() { - http.DefaultServeMux.Handle("/metrics", promhttp.Handler()) - slog.Debug("listening for metrics", "url", "http://localhost"+*metricsBind) - log.Fatal(http.ListenAndServe(*metricsBind, nil)) +func metricsServer(ctx context.Context, done func()) { + defer done() + + mux := http.NewServeMux() + mux.Handle("/metrics", promhttp.Handler()) + + srv := http.Server{Handler: mux} + listener, url := setupListener(*metricsBindNetwork, *metricsBind) + slog.Debug("listening for metrics", "url", url) + + go func() { + <-ctx.Done() + c, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + if err := srv.Shutdown(c); err != nil { + log.Printf("cannot shut down: %v", err) + } + }() + + if err := srv.Serve(listener); err != http.ErrServerClosed { + log.Fatal(err) + } } func sha256sum(text string) (string, error) { @@ -207,7 +288,24 @@ func New(target, policyFname string) (*Server, error) { return nil, fmt.Errorf("failed to generate ed25519 key: %w", err) } + transport := http.DefaultTransport.(*http.Transport).Clone() + + // https://github.com/oauth2-proxy/oauth2-proxy/blob/4e2100a2879ef06aea1411790327019c1a09217c/pkg/upstream/http.go#L124 + if u.Scheme == "unix" { + // clean path up so we don't use the socket path in proxied requests + addr := u.Path + u.Path = "" + // tell transport how to dial unix sockets + transport.DialContext = func(ctx context.Context, _, _ string) (net.Conn, error) { + dialer := net.Dialer{} + return dialer.DialContext(ctx, "unix", addr) + } + // tell transport how to handle the unix url scheme + transport.RegisterProtocol("unix", unixRoundTripper{Transport: transport}) + } + rp := httputil.NewSingleHostReverseProxy(u) + rp.Transport = transport var fin io.ReadCloser @@ -240,6 +338,22 @@ func New(target, policyFname string) (*Server, error) { }, nil } +// https://github.com/oauth2-proxy/oauth2-proxy/blob/master/pkg/upstream/http.go#L124 +type unixRoundTripper struct { + Transport *http.Transport +} + +// set bare minimum stuff +func (t unixRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { + req = req.Clone(req.Context()) + if req.Host == "" { + req.Host = "localhost" + } + req.URL.Host = req.Host // proxy error: no Host in request URL + req.URL.Scheme = "http" // make http.Transport happy and avoid an infinite recursion + return t.Transport.RoundTrip(req) +} + type Server struct { rp *httputil.ReverseProxy priv ed25519.PrivateKey |
