aboutsummaryrefslogtreecommitdiff
path: root/cmd
diff options
context:
space:
mode:
Diffstat (limited to 'cmd')
-rw-r--r--cmd/anubis/main.go132
1 files changed, 123 insertions, 9 deletions
diff --git a/cmd/anubis/main.go b/cmd/anubis/main.go
index 52b38d8..f2f7255 100644
--- a/cmd/anubis/main.go
+++ b/cmd/anubis/main.go
@@ -1,6 +1,7 @@
package main
import (
+ "context"
"crypto/ed25519"
"crypto/rand"
"crypto/sha256"
@@ -15,12 +16,16 @@ import (
"log/slog"
"math"
mrand "math/rand"
+ "net"
"net/http"
"net/http/httputil"
"net/url"
"os"
+ "os/signal"
"strconv"
"strings"
+ "sync"
+ "syscall"
"time"
"github.com/TecharoHQ/anubis"
@@ -37,9 +42,12 @@ import (
)
var (
- bind = flag.String("bind", ":8923", "TCP port to bind HTTP to")
+ bind = flag.String("bind", ":8923", "network address to bind HTTP to")
+ bindNetwork = flag.String("bind-network", "tcp", "network family to bind HTTP to, e.g. unix, tcp")
challengeDifficulty = flag.Int("difficulty", 4, "difficulty of the challenge")
- metricsBind = flag.String("metrics-bind", ":9090", "TCP port to bind metrics to")
+ metricsBind = flag.String("metrics-bind", ":9090", "network address to bind metrics to")
+ metricsBindNetwork = flag.String("metrics-bind-network", "tcp", "network family for the metrics server to bind to")
+ socketMode = flag.String("socket-mode", "0770", "socket mode (permissions) for unix domain sockets.")
robotsTxt = flag.Bool("serve-robots-txt", false, "serve a robots.txt file that disallows all robots")
policyFname = flag.String("policy-fname", "", "full path to anubis policy document (defaults to a sensible built-in policy)")
slogLevel = flag.String("slog-level", "INFO", "logging level (see https://pkg.go.dev/log/slog#hdr-Levels)")
@@ -101,6 +109,40 @@ func doHealthCheck() error {
return nil
}
+func setupListener(network string, address string) (net.Listener, string) {
+ formattedAddress := ""
+ switch network {
+ case "unix":
+ formattedAddress = "unix:" + address
+ case "tcp":
+ formattedAddress = "http://localhost" + address
+ default:
+ formattedAddress = fmt.Sprintf(`(%s) %s`, network, address)
+ }
+
+ listener, err := net.Listen(network, address)
+ if err != nil {
+ log.Fatal(fmt.Errorf("failed to bind to %s: %w", formattedAddress, err))
+ }
+
+ // additional permission handling for unix sockets
+ if network == "unix" {
+ mode, err := strconv.ParseUint(*socketMode, 8, 0)
+ if err != nil {
+ listener.Close()
+ log.Fatal(fmt.Errorf("could not parse socket mode %s: %w", *socketMode, err))
+ }
+
+ err = os.Chmod(address, os.FileMode(mode))
+ if err != nil {
+ listener.Close()
+ log.Fatal(fmt.Errorf("could not change socket mode: %w", err))
+ }
+ }
+
+ return listener, formattedAddress
+}
+
func main() {
flagenv.Parse()
flag.Parse()
@@ -155,20 +197,59 @@ func main() {
})
}
+ wg := new(sync.WaitGroup)
+ // install signal handler
+ ctx, stop := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM)
+ defer stop()
+
if *metricsBind != "" {
- go metricsServer()
+ wg.Add(1)
+ go metricsServer(ctx, wg.Done)
}
mux.HandleFunc("/", s.maybeReverseProxy)
- slog.Info("listening", "url", "http://localhost"+*bind, "difficulty", *challengeDifficulty, "serveRobotsTXT", *robotsTxt, "target", *target, "version", anubis.Version)
- log.Fatal(http.ListenAndServe(*bind, mux))
+ srv := http.Server{Handler: mux}
+ listener, url := setupListener(*bindNetwork, *bind)
+ slog.Info("listening", "url", url, "difficulty", *challengeDifficulty, "serveRobotsTXT", *robotsTxt, "target", *target, "version", anubis.Version)
+
+ go func() {
+ <-ctx.Done()
+ c, cancel := context.WithTimeout(context.Background(), 5*time.Second)
+ defer cancel()
+ if err := srv.Shutdown(c); err != nil {
+ log.Printf("cannot shut down: %v", err)
+ }
+ }()
+
+ if err := srv.Serve(listener); err != http.ErrServerClosed {
+ log.Fatal(err)
+ }
+ wg.Wait()
}
-func metricsServer() {
- http.DefaultServeMux.Handle("/metrics", promhttp.Handler())
- slog.Debug("listening for metrics", "url", "http://localhost"+*metricsBind)
- log.Fatal(http.ListenAndServe(*metricsBind, nil))
+func metricsServer(ctx context.Context, done func()) {
+ defer done()
+
+ mux := http.NewServeMux()
+ mux.Handle("/metrics", promhttp.Handler())
+
+ srv := http.Server{Handler: mux}
+ listener, url := setupListener(*metricsBindNetwork, *metricsBind)
+ slog.Debug("listening for metrics", "url", url)
+
+ go func() {
+ <-ctx.Done()
+ c, cancel := context.WithTimeout(context.Background(), 5*time.Second)
+ defer cancel()
+ if err := srv.Shutdown(c); err != nil {
+ log.Printf("cannot shut down: %v", err)
+ }
+ }()
+
+ if err := srv.Serve(listener); err != http.ErrServerClosed {
+ log.Fatal(err)
+ }
}
func sha256sum(text string) (string, error) {
@@ -207,7 +288,24 @@ func New(target, policyFname string) (*Server, error) {
return nil, fmt.Errorf("failed to generate ed25519 key: %w", err)
}
+ transport := http.DefaultTransport.(*http.Transport).Clone()
+
+ // https://github.com/oauth2-proxy/oauth2-proxy/blob/4e2100a2879ef06aea1411790327019c1a09217c/pkg/upstream/http.go#L124
+ if u.Scheme == "unix" {
+ // clean path up so we don't use the socket path in proxied requests
+ addr := u.Path
+ u.Path = ""
+ // tell transport how to dial unix sockets
+ transport.DialContext = func(ctx context.Context, _, _ string) (net.Conn, error) {
+ dialer := net.Dialer{}
+ return dialer.DialContext(ctx, "unix", addr)
+ }
+ // tell transport how to handle the unix url scheme
+ transport.RegisterProtocol("unix", unixRoundTripper{Transport: transport})
+ }
+
rp := httputil.NewSingleHostReverseProxy(u)
+ rp.Transport = transport
var fin io.ReadCloser
@@ -240,6 +338,22 @@ func New(target, policyFname string) (*Server, error) {
}, nil
}
+// https://github.com/oauth2-proxy/oauth2-proxy/blob/master/pkg/upstream/http.go#L124
+type unixRoundTripper struct {
+ Transport *http.Transport
+}
+
+// set bare minimum stuff
+func (t unixRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
+ req = req.Clone(req.Context())
+ if req.Host == "" {
+ req.Host = "localhost"
+ }
+ req.URL.Host = req.Host // proxy error: no Host in request URL
+ req.URL.Scheme = "http" // make http.Transport happy and avoid an infinite recursion
+ return t.Transport.RoundTrip(req)
+}
+
type Server struct {
rp *httputil.ReverseProxy
priv ed25519.PrivateKey