diff options
Diffstat (limited to 'lib')
| -rw-r--r-- | lib/anubis.go | 50 | ||||
| -rw-r--r-- | lib/anubis_test.go | 104 | ||||
| -rw-r--r-- | lib/http.go | 3 |
3 files changed, 124 insertions, 33 deletions
diff --git a/lib/anubis.go b/lib/anubis.go index 6e40f95..83e04dd 100644 --- a/lib/anubis.go +++ b/lib/anubis.go @@ -67,6 +67,10 @@ type Options struct { Policy *policy.ParsedConfig ServeRobotsTXT bool PrivateKey ed25519.PrivateKey + + CookieDomain string + CookieName string + CookiePartitioned bool } func LoadPoliciesOrDefault(fname string, defaultDifficulty int) (*policy.ParsedConfig, error) { @@ -108,6 +112,7 @@ func New(opts Options) (*Server, error) { priv: opts.PrivateKey, pub: opts.PrivateKey.Public().(ed25519.PublicKey), policy: opts.Policy, + opts: opts, DNSBLCache: decaymap.New[string, dnsbl.DroneBLResponse](), } @@ -145,6 +150,7 @@ type Server struct { priv ed25519.PrivateKey pub ed25519.PublicKey policy *policy.ParsedConfig + opts Options DNSBLCache *decaymap.Impl[string, dnsbl.DroneBLResponse] ChallengeDifficulty int } @@ -217,7 +223,7 @@ func (s *Server) MaybeReverseProxy(w http.ResponseWriter, r *http.Request) { s.next.ServeHTTP(w, r) return case config.RuleDeny: - ClearCookie(w) + s.ClearCookie(w) lg.Info("explicit deny") if rule == nil { lg.Error("rule is nil, cannot calculate checksum") @@ -236,7 +242,7 @@ func (s *Server) MaybeReverseProxy(w http.ResponseWriter, r *http.Request) { case config.RuleChallenge: lg.Debug("challenge requested") default: - ClearCookie(w) + s.ClearCookie(w) templ.Handler(web.Base("Oh noes!", web.ErrorPage("Other internal server error (contact the admin)")), templ.WithStatus(http.StatusInternalServerError)).ServeHTTP(w, r) return } @@ -244,21 +250,21 @@ func (s *Server) MaybeReverseProxy(w http.ResponseWriter, r *http.Request) { ckie, err := r.Cookie(anubis.CookieName) if err != nil { lg.Debug("cookie not found", "path", r.URL.Path) - ClearCookie(w) + s.ClearCookie(w) s.RenderIndex(w, r) return } if err := ckie.Valid(); err != nil { lg.Debug("cookie is invalid", "err", err) - ClearCookie(w) + s.ClearCookie(w) s.RenderIndex(w, r) return } if time.Now().After(ckie.Expires) && !ckie.Expires.IsZero() { lg.Debug("cookie expired", "path", r.URL.Path) - ClearCookie(w) + s.ClearCookie(w) s.RenderIndex(w, r) return } @@ -269,7 +275,7 @@ func (s *Server) MaybeReverseProxy(w http.ResponseWriter, r *http.Request) { if err != nil || !token.Valid { lg.Debug("invalid token", "path", r.URL.Path, "err", err) - ClearCookie(w) + s.ClearCookie(w) s.RenderIndex(w, r) return } @@ -284,7 +290,7 @@ func (s *Server) MaybeReverseProxy(w http.ResponseWriter, r *http.Request) { claims, ok := token.Claims.(jwt.MapClaims) if !ok { lg.Debug("invalid token claims type", "path", r.URL.Path) - ClearCookie(w) + s.ClearCookie(w) s.RenderIndex(w, r) return } @@ -292,7 +298,7 @@ func (s *Server) MaybeReverseProxy(w http.ResponseWriter, r *http.Request) { if claims["challenge"] != challenge { lg.Debug("invalid challenge", "path", r.URL.Path) - ClearCookie(w) + s.ClearCookie(w) s.RenderIndex(w, r) return } @@ -309,7 +315,7 @@ func (s *Server) MaybeReverseProxy(w http.ResponseWriter, r *http.Request) { if subtle.ConstantTimeCompare([]byte(claims["response"].(string)), []byte(calculated)) != 1 { lg.Debug("invalid response", "path", r.URL.Path) failedValidations.Inc() - ClearCookie(w) + s.ClearCookie(w) s.RenderIndex(w, r) return } @@ -372,7 +378,7 @@ func (s *Server) PassChallenge(w http.ResponseWriter, r *http.Request) { nonceStr := r.FormValue("nonce") if nonceStr == "" { - ClearCookie(w) + s.ClearCookie(w) lg.Debug("no nonce") templ.Handler(web.Base("Oh noes!", web.ErrorPage("missing nonce")), templ.WithStatus(http.StatusInternalServerError)).ServeHTTP(w, r) return @@ -380,7 +386,7 @@ func (s *Server) PassChallenge(w http.ResponseWriter, r *http.Request) { elapsedTimeStr := r.FormValue("elapsedTime") if elapsedTimeStr == "" { - ClearCookie(w) + s.ClearCookie(w) lg.Debug("no elapsedTime") templ.Handler(web.Base("Oh noes!", web.ErrorPage("missing elapsedTime")), templ.WithStatus(http.StatusInternalServerError)).ServeHTTP(w, r) return @@ -388,7 +394,7 @@ func (s *Server) PassChallenge(w http.ResponseWriter, r *http.Request) { elapsedTime, err := strconv.ParseFloat(elapsedTimeStr, 64) if err != nil { - ClearCookie(w) + s.ClearCookie(w) lg.Debug("elapsedTime doesn't parse", "err", err) templ.Handler(web.Base("Oh noes!", web.ErrorPage("invalid elapsedTime")), templ.WithStatus(http.StatusInternalServerError)).ServeHTTP(w, r) return @@ -404,7 +410,7 @@ func (s *Server) PassChallenge(w http.ResponseWriter, r *http.Request) { nonce, err := strconv.Atoi(nonceStr) if err != nil { - ClearCookie(w) + s.ClearCookie(w) lg.Debug("nonce doesn't parse", "err", err) templ.Handler(web.Base("Oh noes!", web.ErrorPage("invalid nonce")), templ.WithStatus(http.StatusInternalServerError)).ServeHTTP(w, r) return @@ -414,7 +420,7 @@ func (s *Server) PassChallenge(w http.ResponseWriter, r *http.Request) { calculated := internal.SHA256sum(calcString) if subtle.ConstantTimeCompare([]byte(response), []byte(calculated)) != 1 { - ClearCookie(w) + s.ClearCookie(w) lg.Debug("hash does not match", "got", response, "want", calculated) templ.Handler(web.Base("Oh noes!", web.ErrorPage("invalid response")), templ.WithStatus(http.StatusForbidden)).ServeHTTP(w, r) failedValidations.Inc() @@ -423,7 +429,7 @@ func (s *Server) PassChallenge(w http.ResponseWriter, r *http.Request) { // compare the leading zeroes if !strings.HasPrefix(response, strings.Repeat("0", s.ChallengeDifficulty)) { - ClearCookie(w) + s.ClearCookie(w) lg.Debug("difficulty check failed", "response", response, "difficulty", s.ChallengeDifficulty) templ.Handler(web.Base("Oh noes!", web.ErrorPage("invalid response")), templ.WithStatus(http.StatusForbidden)).ServeHTTP(w, r) failedValidations.Inc() @@ -442,17 +448,19 @@ func (s *Server) PassChallenge(w http.ResponseWriter, r *http.Request) { tokenString, err := token.SignedString(s.priv) if err != nil { lg.Error("failed to sign JWT", "err", err) - ClearCookie(w) + s.ClearCookie(w) templ.Handler(web.Base("Oh noes!", web.ErrorPage("failed to sign JWT")), templ.WithStatus(http.StatusInternalServerError)).ServeHTTP(w, r) return } http.SetCookie(w, &http.Cookie{ - Name: anubis.CookieName, - Value: tokenString, - Expires: time.Now().Add(24 * 7 * time.Hour), - SameSite: http.SameSiteLaxMode, - Path: "/", + Name: anubis.CookieName, + Value: tokenString, + Expires: time.Now().Add(24 * 7 * time.Hour), + SameSite: http.SameSiteLaxMode, + Domain: s.opts.CookieDomain, + Partitioned: s.opts.CookiePartitioned, + Path: "/", }) challengesValidated.Inc() diff --git a/lib/anubis_test.go b/lib/anubis_test.go index 0498c13..90d2cdf 100644 --- a/lib/anubis_test.go +++ b/lib/anubis_test.go @@ -1,15 +1,18 @@ package lib import ( + "encoding/json" "fmt" "net/http" "net/http/httptest" "testing" "github.com/TecharoHQ/anubis" + "github.com/TecharoHQ/anubis/internal" + "github.com/TecharoHQ/anubis/lib/policy" ) -func spawnAnubis(t *testing.T, h http.Handler) string { +func loadPolicies(t *testing.T, fname string) *policy.ParsedConfig { t.Helper() policy, err := LoadPoliciesOrDefault("", anubis.DefaultDifficulty) @@ -17,23 +20,102 @@ func spawnAnubis(t *testing.T, h http.Handler) string { t.Fatal(err) } - s, err := New(Options{ - Next: h, - Policy: policy, - ServeRobotsTXT: true, - }) + return policy +} + +func spawnAnubis(t *testing.T, opts Options) *Server { + t.Helper() + + s, err := New(opts) if err != nil { t.Fatalf("can't construct libanubis.Server: %v", err) } - ts := httptest.NewServer(s) - t.Log(ts.URL) + return s +} + +func TestCookieSettings(t *testing.T) { + pol := loadPolicies(t, "") + pol.DefaultDifficulty = 0 - t.Cleanup(func() { - ts.Close() + srv := spawnAnubis(t, Options{ + Next: http.NewServeMux(), + Policy: pol, + + CookieDomain: "local.cetacean.club", + CookiePartitioned: true, + CookieName: t.Name(), }) - return ts.URL + ts := httptest.NewServer(internal.DefaultXRealIP("127.0.0.1", srv)) + defer ts.Close() + + cli := &http.Client{ + CheckRedirect: func(req *http.Request, via []*http.Request) error { + return http.ErrUseLastResponse + }, + } + + resp, err := cli.Post(ts.URL+"/.within.website/x/cmd/anubis/api/make-challenge", "", nil) + if err != nil { + t.Fatalf("can't request challenge: %v", err) + } + defer resp.Body.Close() + + var chall = struct { + Challenge string `json:"challenge"` + }{} + if err := json.NewDecoder(resp.Body).Decode(&chall); err != nil { + t.Fatalf("can't read challenge response body: %v", err) + } + + nonce := 0 + elapsedTime := 420 + redir := "/" + calcString := fmt.Sprintf("%s%d", chall.Challenge, nonce) + calculated := internal.SHA256sum(calcString) + + req, err := http.NewRequest(http.MethodGet, ts.URL+"/.within.website/x/cmd/anubis/api/pass-challenge", nil) + if err != nil { + t.Fatalf("can't make request: %v", err) + } + + q := req.URL.Query() + q.Set("response", calculated) + q.Set("nonce", fmt.Sprint(nonce)) + q.Set("redir", redir) + q.Set("elapsedTime", fmt.Sprint(elapsedTime)) + req.URL.RawQuery = q.Encode() + + resp, err = cli.Do(req) + if err != nil { + t.Fatalf("can't do challenge passing") + } + + if resp.StatusCode != http.StatusFound { + t.Errorf("wanted %d, got: %d", http.StatusFound, resp.StatusCode) + } + + var ckie *http.Cookie + for _, cookie := range resp.Cookies() { + t.Logf("%#v", cookie) + if cookie.Name == anubis.CookieName { + ckie = cookie + break + } + } + + if ckie.Domain != "local.cetacean.club" { + t.Errorf("cookie domain is wrong, wanted local.cetacean.club, got: %s", ckie.Domain) + } + + if ckie.Partitioned != srv.opts.CookiePartitioned { + t.Errorf("wanted partitioned flag %v, got: %v", srv.opts.CookiePartitioned, ckie.Partitioned) + } + + if ckie == nil { + t.Errorf("Cookie %q not found", anubis.CookieName) + } } func TestCheckDefaultDifficultyMatchesPolicy(t *testing.T) { diff --git a/lib/http.go b/lib/http.go index 1284523..2f32b6d 100644 --- a/lib/http.go +++ b/lib/http.go @@ -7,13 +7,14 @@ import ( "github.com/TecharoHQ/anubis" ) -func ClearCookie(w http.ResponseWriter) { +func (s *Server) ClearCookie(w http.ResponseWriter) { http.SetCookie(w, &http.Cookie{ Name: anubis.CookieName, Value: "", Expires: time.Now().Add(-1 * time.Hour), MaxAge: -1, SameSite: http.SameSiteLaxMode, + Domain: s.opts.CookieDomain, }) } |
