diff options
Diffstat (limited to 'lib')
| -rw-r--r-- | lib/anubis.go | 519 | ||||
| -rw-r--r-- | lib/checkresult.go | 25 | ||||
| -rw-r--r-- | lib/http.go | 34 | ||||
| -rw-r--r-- | lib/policy/bot.go | 32 | ||||
| -rw-r--r-- | lib/policy/config/config.go | 162 | ||||
| -rw-r--r-- | lib/policy/config/config_test.go | 248 | ||||
| -rw-r--r-- | lib/policy/config/testdata/bad/badregexes.json | 14 | ||||
| -rw-r--r-- | lib/policy/config/testdata/bad/invalid.json | 5 | ||||
| -rw-r--r-- | lib/policy/config/testdata/bad/nobots.json | 1 | ||||
| -rw-r--r-- | lib/policy/config/testdata/good/allow_everyone.json | 12 | ||||
| -rw-r--r-- | lib/policy/config/testdata/good/challengemozilla.json | 9 | ||||
| -rw-r--r-- | lib/policy/config/testdata/good/everything_blocked.json | 10 | ||||
| -rw-r--r-- | lib/policy/policy.go | 122 | ||||
| -rw-r--r-- | lib/policy/policy_test.go | 68 | ||||
| -rw-r--r-- | lib/random.go | 9 |
15 files changed, 1270 insertions, 0 deletions
diff --git a/lib/anubis.go b/lib/anubis.go new file mode 100644 index 0000000..5953a48 --- /dev/null +++ b/lib/anubis.go @@ -0,0 +1,519 @@ +package lib + +import ( + "crypto/ed25519" + "crypto/rand" + "crypto/sha256" + "crypto/subtle" + "encoding/json" + "fmt" + "io" + "log" + "log/slog" + "math" + "net" + "net/http" + "os" + "strconv" + "strings" + "time" + + "github.com/a-h/templ" + "github.com/golang-jwt/jwt/v5" + "github.com/prometheus/client_golang/prometheus" + "github.com/prometheus/client_golang/prometheus/promauto" + + "github.com/TecharoHQ/anubis" + "github.com/TecharoHQ/anubis/data" + "github.com/TecharoHQ/anubis/decaymap" + "github.com/TecharoHQ/anubis/internal" + "github.com/TecharoHQ/anubis/internal/dnsbl" + "github.com/TecharoHQ/anubis/lib/policy" + "github.com/TecharoHQ/anubis/lib/policy/config" + "github.com/TecharoHQ/anubis/web" + "github.com/TecharoHQ/anubis/xess" +) + +var ( + challengesIssued = promauto.NewCounter(prometheus.CounterOpts{ + Name: "anubis_challenges_issued", + Help: "The total number of challenges issued", + }) + + challengesValidated = promauto.NewCounter(prometheus.CounterOpts{ + Name: "anubis_challenges_validated", + Help: "The total number of challenges validated", + }) + + droneBLHits = promauto.NewCounterVec(prometheus.CounterOpts{ + Name: "anubis_dronebl_hits", + Help: "The total number of hits from DroneBL", + }, []string{"status"}) + + failedValidations = promauto.NewCounter(prometheus.CounterOpts{ + Name: "anubis_failed_validations", + Help: "The total number of failed validations", + }) + + timeTaken = promauto.NewHistogram(prometheus.HistogramOpts{ + Name: "anubis_time_taken", + Help: "The time taken for a browser to generate a response (milliseconds)", + Buckets: prometheus.ExponentialBucketsRange(1, math.Pow(2, 18), 19), + }) +) + +type Options struct { + Next http.Handler + Policy *policy.ParsedConfig + ServeRobotsTXT bool +} + +func LoadPoliciesOrDefault(fname string, defaultDifficulty int) (*policy.ParsedConfig, error) { + var fin io.ReadCloser + var err error + + if fname != "" { + fin, err = os.Open(fname) + if err != nil { + return nil, fmt.Errorf("can't parse policy file %s: %w", fname, err) + } + } else { + fname = "(data)/botPolicies.json" + fin, err = data.BotPolicies.Open("botPolicies.json") + if err != nil { + return nil, fmt.Errorf("[unexpected] can't parse builtin policy file %s: %w", fname, err) + } + } + + defer fin.Close() + + policy, err := policy.ParseConfig(fin, fname, defaultDifficulty) + + return policy, err +} + +func New(opts Options) (*Server, error) { + pub, priv, err := ed25519.GenerateKey(rand.Reader) + if err != nil { + return nil, fmt.Errorf("failed to generate ed25519 key: %w", err) + } + + if err != nil { + return nil, err // parseConfig sets a fancy error for us + } + + result := &Server{ + next: opts.Next, + priv: priv, + pub: pub, + policy: opts.Policy, + DNSBLCache: decaymap.New[string, dnsbl.DroneBLResponse](), + } + + mux := http.NewServeMux() + xess.Mount(mux) + + mux.Handle(anubis.StaticPath, internal.UnchangingCache(http.StripPrefix(anubis.StaticPath, http.FileServerFS(web.Static)))) + + if opts.ServeRobotsTXT { + mux.HandleFunc("/robots.txt", func(w http.ResponseWriter, r *http.Request) { + http.ServeFileFS(w, r, web.Static, "static/robots.txt") + }) + + mux.HandleFunc("/.well-known/robots.txt", func(w http.ResponseWriter, r *http.Request) { + http.ServeFileFS(w, r, web.Static, "static/robots.txt") + }) + } + + // mux.HandleFunc("GET /.within.website/x/cmd/anubis/static/js/main.mjs", serveMainJSWithBestEncoding) + + mux.HandleFunc("POST /.within.website/x/cmd/anubis/api/make-challenge", result.MakeChallenge) + mux.HandleFunc("GET /.within.website/x/cmd/anubis/api/pass-challenge", result.PassChallenge) + mux.HandleFunc("GET /.within.website/x/cmd/anubis/api/test-error", result.TestError) + + mux.HandleFunc("/", result.MaybeReverseProxy) + + result.mux = mux + + return result, nil +} + +type Server struct { + mux *http.ServeMux + next http.Handler + priv ed25519.PrivateKey + pub ed25519.PublicKey + policy *policy.ParsedConfig + DNSBLCache *decaymap.Impl[string, dnsbl.DroneBLResponse] + ChallengeDifficulty int +} + +func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { + s.mux.ServeHTTP(w, r) +} + +func (s *Server) challengeFor(r *http.Request, difficulty int) string { + fp := sha256.Sum256(s.priv.Seed()) + + data := fmt.Sprintf( + "Accept-Language=%s,X-Real-IP=%s,User-Agent=%s,WeekTime=%s,Fingerprint=%x,Difficulty=%d", + r.Header.Get("Accept-Language"), + r.Header.Get("X-Real-Ip"), + r.UserAgent(), + time.Now().UTC().Round(24*7*time.Hour).Format(time.RFC3339), + fp, + difficulty, + ) + return internal.SHA256sum(data) +} + +func (s *Server) MaybeReverseProxy(w http.ResponseWriter, r *http.Request) { + lg := slog.With( + "user_agent", r.UserAgent(), + "accept_language", r.Header.Get("Accept-Language"), + "priority", r.Header.Get("Priority"), + "x-forwarded-for", + r.Header.Get("X-Forwarded-For"), + "x-real-ip", r.Header.Get("X-Real-Ip"), + ) + + cr, rule, err := s.check(r) + if err != nil { + lg.Error("check failed", "err", err) + templ.Handler(web.Base("Oh noes!", web.ErrorPage("Internal Server Error: administrator has misconfigured Anubis. Please contact the administrator and ask them to look for the logs around \"maybeReverseProxy\"")), templ.WithStatus(http.StatusInternalServerError)).ServeHTTP(w, r) + return + } + + r.Header.Add("X-Anubis-Rule", cr.Name) + r.Header.Add("X-Anubis-Action", string(cr.Rule)) + lg = lg.With("check_result", cr) + policy.PolicyApplications.WithLabelValues(cr.Name, string(cr.Rule)).Add(1) + + ip := r.Header.Get("X-Real-Ip") + + if s.policy.DNSBL && ip != "" { + resp, ok := s.DNSBLCache.Get(ip) + if !ok { + lg.Debug("looking up ip in dnsbl") + resp, err := dnsbl.Lookup(ip) + if err != nil { + lg.Error("can't look up ip in dnsbl", "err", err) + } + s.DNSBLCache.Set(ip, resp, 24*time.Hour) + droneBLHits.WithLabelValues(resp.String()).Inc() + } + + if resp != dnsbl.AllGood { + lg.Info("DNSBL hit", "status", resp.String()) + templ.Handler(web.Base("Oh noes!", web.ErrorPage(fmt.Sprintf("DroneBL reported an entry: %s, see https://dronebl.org/lookup?ip=%s", resp.String(), ip))), templ.WithStatus(http.StatusOK)).ServeHTTP(w, r) + return + } + } + + switch cr.Rule { + case config.RuleAllow: + lg.Debug("allowing traffic to origin (explicit)") + s.next.ServeHTTP(w, r) + return + case config.RuleDeny: + ClearCookie(w) + lg.Info("explicit deny") + if rule == nil { + lg.Error("rule is nil, cannot calculate checksum") + templ.Handler(web.Base("Oh noes!", web.ErrorPage("Other internal server error (contact the admin)")), templ.WithStatus(http.StatusInternalServerError)).ServeHTTP(w, r) + return + } + hash, err := rule.Hash() + if err != nil { + lg.Error("can't calculate checksum of rule", "err", err) + templ.Handler(web.Base("Oh noes!", web.ErrorPage("Other internal server error (contact the admin)")), templ.WithStatus(http.StatusInternalServerError)).ServeHTTP(w, r) + return + } + lg.Debug("rule hash", "hash", hash) + templ.Handler(web.Base("Oh noes!", web.ErrorPage(fmt.Sprintf("Access Denied: error code %s", hash))), templ.WithStatus(http.StatusOK)).ServeHTTP(w, r) + return + case config.RuleChallenge: + lg.Debug("challenge requested") + default: + ClearCookie(w) + templ.Handler(web.Base("Oh noes!", web.ErrorPage("Other internal server error (contact the admin)")), templ.WithStatus(http.StatusInternalServerError)).ServeHTTP(w, r) + return + } + + ckie, err := r.Cookie(anubis.CookieName) + if err != nil { + lg.Debug("cookie not found", "path", r.URL.Path) + ClearCookie(w) + s.RenderIndex(w, r) + return + } + + if err := ckie.Valid(); err != nil { + lg.Debug("cookie is invalid", "err", err) + ClearCookie(w) + s.RenderIndex(w, r) + return + } + + if time.Now().After(ckie.Expires) && !ckie.Expires.IsZero() { + lg.Debug("cookie expired", "path", r.URL.Path) + ClearCookie(w) + s.RenderIndex(w, r) + return + } + + token, err := jwt.ParseWithClaims(ckie.Value, jwt.MapClaims{}, func(token *jwt.Token) (interface{}, error) { + return s.pub, nil + }, jwt.WithExpirationRequired(), jwt.WithStrictDecoding()) + + if err != nil || !token.Valid { + lg.Debug("invalid token", "path", r.URL.Path, "err", err) + ClearCookie(w) + s.RenderIndex(w, r) + return + } + + if randomJitter() { + r.Header.Add("X-Anubis-Status", "PASS-BRIEF") + lg.Debug("cookie is not enrolled into secondary screening") + s.next.ServeHTTP(w, r) + return + } + + claims, ok := token.Claims.(jwt.MapClaims) + if !ok { + lg.Debug("invalid token claims type", "path", r.URL.Path) + ClearCookie(w) + s.RenderIndex(w, r) + return + } + challenge := s.challengeFor(r, rule.Challenge.Difficulty) + + if claims["challenge"] != challenge { + lg.Debug("invalid challenge", "path", r.URL.Path) + ClearCookie(w) + s.RenderIndex(w, r) + return + } + + var nonce int + + if v, ok := claims["nonce"].(float64); ok { + nonce = int(v) + } + + calcString := fmt.Sprintf("%s%d", challenge, nonce) + calculated := internal.SHA256sum(calcString) + + if subtle.ConstantTimeCompare([]byte(claims["response"].(string)), []byte(calculated)) != 1 { + lg.Debug("invalid response", "path", r.URL.Path) + failedValidations.Inc() + ClearCookie(w) + s.RenderIndex(w, r) + return + } + + slog.Debug("all checks passed") + r.Header.Add("X-Anubis-Status", "PASS-FULL") + s.next.ServeHTTP(w, r) +} + +func (s *Server) RenderIndex(w http.ResponseWriter, r *http.Request) { + templ.Handler( + web.Base("Making sure you're not a bot!", web.Index()), + ).ServeHTTP(w, r) +} + +func (s *Server) MakeChallenge(w http.ResponseWriter, r *http.Request) { + lg := slog.With("user_agent", r.UserAgent(), "accept_language", r.Header.Get("Accept-Language"), "priority", r.Header.Get("Priority"), "x-forwarded-for", r.Header.Get("X-Forwarded-For"), "x-real-ip", r.Header.Get("X-Real-Ip")) + + cr, rule, err := s.check(r) + if err != nil { + lg.Error("check failed", "err", err) + w.WriteHeader(http.StatusInternalServerError) + json.NewEncoder(w).Encode(struct { + Error string `json:"error"` + }{ + Error: "Internal Server Error: administrator has misconfigured Anubis. Please contact the administrator and ask them to look for the logs around \"makeChallenge\"", + }) + return + } + lg = lg.With("check_result", cr) + challenge := s.challengeFor(r, rule.Challenge.Difficulty) + + json.NewEncoder(w).Encode(struct { + Challenge string `json:"challenge"` + Rules *config.ChallengeRules `json:"rules"` + }{ + Challenge: challenge, + Rules: rule.Challenge, + }) + lg.Debug("made challenge", "challenge", challenge, "rules", rule.Challenge, "cr", cr) + challengesIssued.Inc() +} + +func (s *Server) PassChallenge(w http.ResponseWriter, r *http.Request) { + lg := slog.With( + "user_agent", r.UserAgent(), + "accept_language", r.Header.Get("Accept-Language"), + "priority", r.Header.Get("Priority"), + "x-forwarded-for", r.Header.Get("X-Forwarded-For"), + "x-real-ip", r.Header.Get("X-Real-Ip"), + ) + + cr, rule, err := s.check(r) + if err != nil { + lg.Error("check failed", "err", err) + templ.Handler(web.Base("Oh noes!", web.ErrorPage("Internal Server Error: administrator has misconfigured Anubis. Please contact the administrator and ask them to look for the logs around \"passChallenge\".")), templ.WithStatus(http.StatusInternalServerError)).ServeHTTP(w, r) + return + } + lg = lg.With("check_result", cr) + + nonceStr := r.FormValue("nonce") + if nonceStr == "" { + ClearCookie(w) + lg.Debug("no nonce") + templ.Handler(web.Base("Oh noes!", web.ErrorPage("missing nonce")), templ.WithStatus(http.StatusInternalServerError)).ServeHTTP(w, r) + return + } + + elapsedTimeStr := r.FormValue("elapsedTime") + if elapsedTimeStr == "" { + ClearCookie(w) + lg.Debug("no elapsedTime") + templ.Handler(web.Base("Oh noes!", web.ErrorPage("missing elapsedTime")), templ.WithStatus(http.StatusInternalServerError)).ServeHTTP(w, r) + return + } + + elapsedTime, err := strconv.ParseFloat(elapsedTimeStr, 64) + if err != nil { + ClearCookie(w) + lg.Debug("elapsedTime doesn't parse", "err", err) + templ.Handler(web.Base("Oh noes!", web.ErrorPage("invalid elapsedTime")), templ.WithStatus(http.StatusInternalServerError)).ServeHTTP(w, r) + return + } + + lg.Info("challenge took", "elapsedTime", elapsedTime) + timeTaken.Observe(elapsedTime) + + response := r.FormValue("response") + redir := r.FormValue("redir") + + challenge := s.challengeFor(r, rule.Challenge.Difficulty) + + nonce, err := strconv.Atoi(nonceStr) + if err != nil { + ClearCookie(w) + lg.Debug("nonce doesn't parse", "err", err) + templ.Handler(web.Base("Oh noes!", web.ErrorPage("invalid nonce")), templ.WithStatus(http.StatusInternalServerError)).ServeHTTP(w, r) + return + } + + calcString := fmt.Sprintf("%s%d", challenge, nonce) + calculated := internal.SHA256sum(calcString) + + if subtle.ConstantTimeCompare([]byte(response), []byte(calculated)) != 1 { + ClearCookie(w) + lg.Debug("hash does not match", "got", response, "want", calculated) + templ.Handler(web.Base("Oh noes!", web.ErrorPage("invalid response")), templ.WithStatus(http.StatusForbidden)).ServeHTTP(w, r) + failedValidations.Inc() + return + } + + // compare the leading zeroes + if !strings.HasPrefix(response, strings.Repeat("0", s.ChallengeDifficulty)) { + ClearCookie(w) + lg.Debug("difficulty check failed", "response", response, "difficulty", s.ChallengeDifficulty) + templ.Handler(web.Base("Oh noes!", web.ErrorPage("invalid response")), templ.WithStatus(http.StatusForbidden)).ServeHTTP(w, r) + failedValidations.Inc() + return + } + + // generate JWT cookie + token := jwt.NewWithClaims(jwt.SigningMethodEdDSA, jwt.MapClaims{ + "challenge": challenge, + "nonce": nonce, + "response": response, + "iat": time.Now().Unix(), + "nbf": time.Now().Add(-1 * time.Minute).Unix(), + "exp": time.Now().Add(24 * 7 * time.Hour).Unix(), + }) + tokenString, err := token.SignedString(s.priv) + if err != nil { + lg.Error("failed to sign JWT", "err", err) + ClearCookie(w) + templ.Handler(web.Base("Oh noes!", web.ErrorPage("failed to sign JWT")), templ.WithStatus(http.StatusInternalServerError)).ServeHTTP(w, r) + return + } + + http.SetCookie(w, &http.Cookie{ + Name: anubis.CookieName, + Value: tokenString, + Expires: time.Now().Add(24 * 7 * time.Hour), + SameSite: http.SameSiteLaxMode, + Path: "/", + }) + + challengesValidated.Inc() + lg.Debug("challenge passed, redirecting to app") + http.Redirect(w, r, redir, http.StatusFound) +} + +func (s *Server) TestError(w http.ResponseWriter, r *http.Request) { + err := r.FormValue("err") + templ.Handler(web.Base("Oh noes!", web.ErrorPage(err)), templ.WithStatus(http.StatusInternalServerError)).ServeHTTP(w, r) +} + +// Check evaluates the list of rules, and returns the result +func (s *Server) check(r *http.Request) (CheckResult, *policy.Bot, error) { + host := r.Header.Get("X-Real-Ip") + if host == "" { + return decaymap.Zilch[CheckResult](), nil, fmt.Errorf("[misconfiguration] X-Real-Ip header is not set") + } + + addr := net.ParseIP(host) + if addr == nil { + return decaymap.Zilch[CheckResult](), nil, fmt.Errorf("[misconfiguration] %q is not an IP address", host) + } + + for _, b := range s.policy.Bots { + if b.UserAgent != nil { + if b.UserAgent.MatchString(r.UserAgent()) && s.checkRemoteAddress(b, addr) { + return cr("bot/"+b.Name, b.Action), &b, nil + } + } + + if b.Path != nil { + if b.Path.MatchString(r.URL.Path) && s.checkRemoteAddress(b, addr) { + return cr("bot/"+b.Name, b.Action), &b, nil + } + } + + if b.Ranger != nil { + if s.checkRemoteAddress(b, addr) { + return cr("bot/"+b.Name, b.Action), &b, nil + } + } + } + + return cr("default/allow", config.RuleAllow), &policy.Bot{ + Challenge: &config.ChallengeRules{ + Difficulty: anubis.DefaultDifficulty, + ReportAs: anubis.DefaultDifficulty, + Algorithm: config.AlgorithmFast, + }, + }, nil +} + +func (s *Server) checkRemoteAddress(b policy.Bot, addr net.IP) bool { + if b.Ranger == nil { + return true + } + + ok, err := b.Ranger.Contains(addr) + if err != nil { + log.Panicf("[unexpected] something very funky is going on, %q does not have a calculable network number: %v", addr.String(), err) + } + + return ok +} diff --git a/lib/checkresult.go b/lib/checkresult.go new file mode 100644 index 0000000..3803df2 --- /dev/null +++ b/lib/checkresult.go @@ -0,0 +1,25 @@ +package lib + +import ( + "log/slog" + + "github.com/TecharoHQ/anubis/lib/policy/config" +) + +type CheckResult struct { + Name string + Rule config.Rule +} + +func (cr CheckResult) LogValue() slog.Value { + return slog.GroupValue( + slog.String("name", cr.Name), + slog.String("rule", string(cr.Rule))) +} + +func cr(name string, rule config.Rule) CheckResult { + return CheckResult{ + Name: name, + Rule: rule, + } +} diff --git a/lib/http.go b/lib/http.go new file mode 100644 index 0000000..1284523 --- /dev/null +++ b/lib/http.go @@ -0,0 +1,34 @@ +package lib + +import ( + "net/http" + "time" + + "github.com/TecharoHQ/anubis" +) + +func ClearCookie(w http.ResponseWriter) { + http.SetCookie(w, &http.Cookie{ + Name: anubis.CookieName, + Value: "", + Expires: time.Now().Add(-1 * time.Hour), + MaxAge: -1, + SameSite: http.SameSiteLaxMode, + }) +} + +// https://github.com/oauth2-proxy/oauth2-proxy/blob/master/pkg/upstream/http.go#L124 +type UnixRoundTripper struct { + Transport *http.Transport +} + +// set bare minimum stuff +func (t UnixRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { + req = req.Clone(req.Context()) + if req.Host == "" { + req.Host = "localhost" + } + req.URL.Host = req.Host // proxy error: no Host in request URL + req.URL.Scheme = "http" // make http.Transport happy and avoid an infinite recursion + return t.Transport.RoundTrip(req) +} diff --git a/lib/policy/bot.go b/lib/policy/bot.go new file mode 100644 index 0000000..d9ca135 --- /dev/null +++ b/lib/policy/bot.go @@ -0,0 +1,32 @@ +package policy + +import ( + "fmt" + "regexp" + + "github.com/TecharoHQ/anubis/internal" + "github.com/TecharoHQ/anubis/lib/policy/config" + "github.com/yl2chen/cidranger" +) + +type Bot struct { + Name string + UserAgent *regexp.Regexp + Path *regexp.Regexp + Action config.Rule `json:"action"` + Challenge *config.ChallengeRules + Ranger cidranger.Ranger +} + +func (b Bot) Hash() (string, error) { + var pathRex string + if b.Path != nil { + pathRex = b.Path.String() + } + var userAgentRex string + if b.UserAgent != nil { + userAgentRex = b.UserAgent.String() + } + + return internal.SHA256sum(fmt.Sprintf("%s::%s::%s", b.Name, pathRex, userAgentRex)), nil +} diff --git a/lib/policy/config/config.go b/lib/policy/config/config.go new file mode 100644 index 0000000..67eddbf --- /dev/null +++ b/lib/policy/config/config.go @@ -0,0 +1,162 @@ +package config + +import ( + "errors" + "fmt" + "net" + "regexp" +) + +var ( + ErrNoBotRulesDefined = errors.New("config: must define at least one (1) bot rule") + ErrBotMustHaveName = errors.New("config.Bot: must set name") + ErrBotMustHaveUserAgentOrPath = errors.New("config.Bot: must set either user_agent_regex, path_regex, or remote_addresses") + ErrBotMustHaveUserAgentOrPathNotBoth = errors.New("config.Bot: must set either user_agent_regex, path_regex, and not both") + ErrUnknownAction = errors.New("config.Bot: unknown action") + ErrInvalidUserAgentRegex = errors.New("config.Bot: invalid user agent regex") + ErrInvalidPathRegex = errors.New("config.Bot: invalid path regex") + ErrInvalidCIDR = errors.New("config.Bot: invalid CIDR") +) + +type Rule string + +const ( + RuleUnknown Rule = "" + RuleAllow Rule = "ALLOW" + RuleDeny Rule = "DENY" + RuleChallenge Rule = "CHALLENGE" +) + +type Algorithm string + +const ( + AlgorithmUnknown Algorithm = "" + AlgorithmFast Algorithm = "fast" + AlgorithmSlow Algorithm = "slow" +) + +type BotConfig struct { + Name string `json:"name"` + UserAgentRegex *string `json:"user_agent_regex"` + PathRegex *string `json:"path_regex"` + Action Rule `json:"action"` + RemoteAddr []string `json:"remote_addresses"` + Challenge *ChallengeRules `json:"challenge,omitempty"` +} + +func (b BotConfig) Valid() error { + var errs []error + + if b.Name == "" { + errs = append(errs, ErrBotMustHaveName) + } + + if b.UserAgentRegex == nil && b.PathRegex == nil && (b.RemoteAddr == nil || len(b.RemoteAddr) == 0) { + errs = append(errs, ErrBotMustHaveUserAgentOrPath) + } + + if b.UserAgentRegex != nil && b.PathRegex != nil { + errs = append(errs, ErrBotMustHaveUserAgentOrPathNotBoth) + } + + if b.UserAgentRegex != nil { + if _, err := regexp.Compile(*b.UserAgentRegex); err != nil { + errs = append(errs, ErrInvalidUserAgentRegex, err) + } + } + + if b.PathRegex != nil { + if _, err := regexp.Compile(*b.PathRegex); err != nil { + errs = append(errs, ErrInvalidPathRegex, err) + } + } + + if b.RemoteAddr != nil && len(b.RemoteAddr) > 0 { + for _, cidr := range b.RemoteAddr { + if _, _, err := net.ParseCIDR(cidr); err != nil { + errs = append(errs, ErrInvalidCIDR, err) + } + } + } + + switch b.Action { + case RuleAllow, RuleChallenge, RuleDeny: + // okay + default: + errs = append(errs, fmt.Errorf("%w: %q", ErrUnknownAction, b.Action)) + } + + if b.Action == RuleChallenge && b.Challenge != nil { + if err := b.Challenge.Valid(); err != nil { + errs = append(errs, err) + } + } + + if len(errs) != 0 { + return fmt.Errorf("config: bot entry for %q is not valid:\n%w", b.Name, errors.Join(errs...)) + } + + return nil +} + +type ChallengeRules struct { + Difficulty int `json:"difficulty"` + ReportAs int `json:"report_as"` + Algorithm Algorithm `json:"algorithm"` +} + +var ( + ErrChallengeRuleHasWrongAlgorithm = errors.New("config.Bot.ChallengeRules: algorithm is invalid") + ErrChallengeDifficultyTooLow = errors.New("config.Bot.ChallengeRules: difficulty is too low (must be >= 1)") + ErrChallengeDifficultyTooHigh = errors.New("config.Bot.ChallengeRules: difficulty is too high (must be <= 64)") +) + +func (cr ChallengeRules) Valid() error { + var errs []error + + if cr.Difficulty < 1 { + errs = append(errs, fmt.Errorf("%w, got: %d", ErrChallengeDifficultyTooLow, cr.Difficulty)) + } + + if cr.Difficulty > 64 { + errs = append(errs, fmt.Errorf("%w, got: %d", ErrChallengeDifficultyTooHigh, cr.Difficulty)) + } + + switch cr.Algorithm { + case AlgorithmFast, AlgorithmSlow, AlgorithmUnknown: + // do nothing, it's all good + default: + errs = append(errs, fmt.Errorf("%w: %q", ErrChallengeRuleHasWrongAlgorithm, cr.Algorithm)) + } + + if len(errs) != 0 { + return fmt.Errorf("config: challenge rules entry is not valid:\n%w", errors.Join(errs...)) + } + + return nil +} + +type Config struct { + Bots []BotConfig `json:"bots"` + DNSBL bool `json:"dnsbl"` +} + +func (c Config) Valid() error { + var errs []error + + if len(c.Bots) == 0 { + errs = append(errs, ErrNoBotRulesDefined) + } + + for _, b := range c.Bots { + if err := b.Valid(); err != nil { + errs = append(errs, err) + } + } + + if len(errs) != 0 { + return fmt.Errorf("config is not valid:\n%w", errors.Join(errs...)) + } + + return nil +} diff --git a/lib/policy/config/config_test.go b/lib/policy/config/config_test.go new file mode 100644 index 0000000..a169087 --- /dev/null +++ b/lib/policy/config/config_test.go @@ -0,0 +1,248 @@ +package config + +import ( + "encoding/json" + "errors" + "os" + "path/filepath" + "testing" +) + +func p[V any](v V) *V { return &v } + +func TestBotValid(t *testing.T) { + var tests = []struct { + name string + bot BotConfig + err error + }{ + { + name: "simple user agent", + bot: BotConfig{ + Name: "mozilla-ua", + Action: RuleChallenge, + UserAgentRegex: p("Mozilla"), + }, + err: nil, + }, + { + name: "simple path", + bot: BotConfig{ + Name: "well-known-path", + Action: RuleAllow, + PathRegex: p("^/.well-known/.*$"), + }, + err: nil, + }, + { + name: "no rule name", + bot: BotConfig{ + Action: RuleChallenge, + UserAgentRegex: p("Mozilla"), + }, + err: ErrBotMustHaveName, + }, + { + name: "no rule matcher", + bot: BotConfig{ + Name: "broken-rule", + Action: RuleAllow, + }, + err: ErrBotMustHaveUserAgentOrPath, + }, + { + name: "both user-agent and path", + bot: BotConfig{ + Name: "path-and-user-agent", + Action: RuleDeny, + UserAgentRegex: p("Mozilla"), + PathRegex: p("^/.secret-place/.*$"), + }, + err: ErrBotMustHaveUserAgentOrPathNotBoth, + }, + { + name: "unknown action", + bot: BotConfig{ + Name: "Unknown action", + Action: RuleUnknown, + UserAgentRegex: p("Mozilla"), + }, + err: ErrUnknownAction, + }, + { + name: "invalid user agent regex", + bot: BotConfig{ + Name: "mozilla-ua", + Action: RuleChallenge, + UserAgentRegex: p("a(b"), + }, + err: ErrInvalidUserAgentRegex, + }, + { + name: "invalid path regex", + bot: BotConfig{ + Name: "mozilla-ua", + Action: RuleChallenge, + PathRegex: p("a(b"), + }, + err: ErrInvalidPathRegex, + }, + { + name: "challenge difficulty too low", + bot: BotConfig{ + Name: "mozilla-ua", + Action: RuleChallenge, + PathRegex: p("Mozilla"), + Challenge: &ChallengeRules{ + Difficulty: 0, + ReportAs: 4, + Algorithm: "fast", + }, + }, + err: ErrChallengeDifficultyTooLow, + }, + { + name: "challenge difficulty too high", + bot: BotConfig{ + Name: "mozilla-ua", + Action: RuleChallenge, + PathRegex: p("Mozilla"), + Challenge: &ChallengeRules{ + Difficulty: 420, + ReportAs: 4, + Algorithm: "fast", + }, + }, + err: ErrChallengeDifficultyTooHigh, + }, + { + name: "challenge wrong algorithm", + bot: BotConfig{ + Name: "mozilla-ua", + Action: RuleChallenge, + PathRegex: p("Mozilla"), + Challenge: &ChallengeRules{ + Difficulty: 420, + ReportAs: 4, + Algorithm: "high quality rips", + }, + }, + err: ErrChallengeRuleHasWrongAlgorithm, + }, + { + name: "invalid cidr range", + bot: BotConfig{ + Name: "mozilla-ua", + Action: RuleAllow, + RemoteAddr: []string{"0.0.0.0/33"}, + }, + err: ErrInvalidCIDR, + }, + { + name: "only filter by IP range", + bot: BotConfig{ + Name: "mozilla-ua", + Action: RuleAllow, + RemoteAddr: []string{"0.0.0.0/0"}, + }, + err: nil, + }, + { + name: "filter by user agent and IP range", + bot: BotConfig{ + Name: "mozilla-ua", + Action: RuleAllow, + UserAgentRegex: p("Mozilla"), + RemoteAddr: []string{"0.0.0.0/0"}, + }, + err: nil, + }, + { + name: "filter by path and IP range", + bot: BotConfig{ + Name: "mozilla-ua", + Action: RuleAllow, + PathRegex: p("^.*$"), + RemoteAddr: []string{"0.0.0.0/0"}, + }, + err: nil, + }, + } + + for _, cs := range tests { + cs := cs + t.Run(cs.name, func(t *testing.T) { + err := cs.bot.Valid() + if err == nil && cs.err == nil { + return + } + + if err == nil && cs.err != nil { + t.Errorf("didn't get an error, but wanted: %v", cs.err) + } + + if !errors.Is(err, cs.err) { + t.Logf("got wrong error from Valid()") + t.Logf("wanted: %v", cs.err) + t.Logf("got: %v", err) + t.Errorf("got invalid error from check") + } + }) + } +} + +func TestConfigValidKnownGood(t *testing.T) { + finfos, err := os.ReadDir("testdata/good") + if err != nil { + t.Fatal(err) + } + + for _, st := range finfos { + st := st + t.Run(st.Name(), func(t *testing.T) { + fin, err := os.Open(filepath.Join("testdata", "good", st.Name())) + if err != nil { + t.Fatal(err) + } + defer fin.Close() + + var c Config + if err := json.NewDecoder(fin).Decode(&c); err != nil { + t.Fatalf("can't decode file: %v", err) + } + + if err := c.Valid(); err != nil { + t.Fatal(err) + } + }) |
