diff options
| author | Xe Iaso <me@xeiaso.net> | 2025-04-22 07:49:41 -0400 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2025-04-22 07:49:41 -0400 |
| commit | 84b28760b3b54c7d26ad40a1e7343d6de242ad9b (patch) | |
| tree | 3104e239da1ac3cedc43ee07d58bb6ad5cc9f12c /lib/policy | |
| parent | 9b7bf8ee06fd84f3aa3298f483d8bca69c3372d4 (diff) | |
| download | anubis-84b28760b3b54c7d26ad40a1e7343d6de242ad9b.tar.xz anubis-84b28760b3b54c7d26ad40a1e7343d6de242ad9b.zip | |
feat(lib): use Checker type instead of ad-hoc logic (#318)
This makes each check into its own type that has encapsulated check
logic, meaning that it's easier to add new checker implementations in
the future.
Signed-off-by: Xe Iaso <me@xeiaso.net>
Diffstat (limited to 'lib/policy')
| -rw-r--r-- | lib/policy/bot.go | 35 | ||||
| -rw-r--r-- | lib/policy/checker.go | 201 | ||||
| -rw-r--r-- | lib/policy/checker_test.go | 200 | ||||
| -rw-r--r-- | lib/policy/checkresult.go | 18 | ||||
| -rw-r--r-- | lib/policy/policy.go | 57 |
5 files changed, 445 insertions, 66 deletions
diff --git a/lib/policy/bot.go b/lib/policy/bot.go index e656d9a..3a43655 100644 --- a/lib/policy/bot.go +++ b/lib/policy/bot.go @@ -2,45 +2,18 @@ package policy import ( "fmt" - "regexp" - "strings" "github.com/TecharoHQ/anubis/internal" "github.com/TecharoHQ/anubis/lib/policy/config" - "github.com/yl2chen/cidranger" ) type Bot struct { Name string - UserAgent *regexp.Regexp - Path *regexp.Regexp - Headers map[string]*regexp.Regexp - Action config.Rule `json:"action"` + Action config.Rule Challenge *config.ChallengeRules - Ranger cidranger.Ranger + Rules Checker } -func (b Bot) Hash() (string, error) { - var pathRex string - if b.Path != nil { - pathRex = b.Path.String() - } - var userAgentRex string - if b.UserAgent != nil { - userAgentRex = b.UserAgent.String() - } - var headersRex string - if len(b.Headers) > 0 { - var sb strings.Builder - sb.Grow(len(b.Headers) * 64) - - for name, expr := range b.Headers { - sb.WriteString(name) - sb.WriteString(expr.String()) - } - - headersRex = sb.String() - } - - return internal.SHA256sum(fmt.Sprintf("%s::%s::%s::%s", b.Name, pathRex, userAgentRex, headersRex)), nil +func (b Bot) Hash() string { + return internal.SHA256sum(fmt.Sprintf("%s::%s", b.Name, b.Rules.Hash())) } diff --git a/lib/policy/checker.go b/lib/policy/checker.go new file mode 100644 index 0000000..ad98ced --- /dev/null +++ b/lib/policy/checker.go @@ -0,0 +1,201 @@ +package policy + +import ( + "errors" + "fmt" + "net" + "net/http" + "regexp" + "strings" + + "github.com/TecharoHQ/anubis/internal" + "github.com/yl2chen/cidranger" +) + +var ( + ErrMisconfiguration = errors.New("[unexpected] policy: administrator misconfiguration") +) + +type Checker interface { + Check(*http.Request) (bool, error) + Hash() string +} + +type CheckerList []Checker + +func (cl CheckerList) Check(r *http.Request) (bool, error) { + for _, c := range cl { + ok, err := c.Check(r) + if err != nil { + return ok, err + } + if ok { + return ok, nil + } + } + + return false, nil +} + +func (cl CheckerList) Hash() string { + var sb strings.Builder + + for _, c := range cl { + fmt.Fprintln(&sb, c.Hash()) + } + + return internal.SHA256sum(sb.String()) +} + +type RemoteAddrChecker struct { + ranger cidranger.Ranger + hash string +} + +func NewRemoteAddrChecker(cidrs []string) (Checker, error) { + ranger := cidranger.NewPCTrieRanger() + var sb strings.Builder + + for _, cidr := range cidrs { + _, rng, err := net.ParseCIDR(cidr) + if err != nil { + return nil, fmt.Errorf("%w: range %s not parsing: %w", ErrMisconfiguration, cidr, err) + } + + ranger.Insert(cidranger.NewBasicRangerEntry(*rng)) + fmt.Fprintln(&sb, cidr) + } + + return &RemoteAddrChecker{ + ranger: ranger, + hash: internal.SHA256sum(sb.String()), + }, nil +} + +func (rac *RemoteAddrChecker) Check(r *http.Request) (bool, error) { + host := r.Header.Get("X-Real-Ip") + if host == "" { + return false, fmt.Errorf("%w: header X-Real-Ip is not set", ErrMisconfiguration) + } + + addr := net.ParseIP(host) + if addr == nil { + return false, fmt.Errorf("%w: %s is not an IP address", ErrMisconfiguration, host) + } + + ok, err := rac.ranger.Contains(addr) + if err != nil { + return false, err + } + + if ok { + return true, nil + } + + return false, nil +} + +func (rac *RemoteAddrChecker) Hash() string { + return rac.hash +} + +type HeaderMatchesChecker struct { + header string + regexp *regexp.Regexp + hash string +} + +func NewUserAgentChecker(rexStr string) (Checker, error) { + return NewHeaderMatchesChecker("User-Agent", rexStr) +} + +func NewHeaderMatchesChecker(header, rexStr string) (Checker, error) { + rex, err := regexp.Compile(rexStr) + if err != nil { + return nil, fmt.Errorf("%w: regex %s failed parse: %w", ErrMisconfiguration, rexStr, err) + } + return &HeaderMatchesChecker{header, rex, internal.SHA256sum(header + ": " + rexStr)}, nil +} + +func (hmc *HeaderMatchesChecker) Check(r *http.Request) (bool, error) { + if hmc.regexp.MatchString(r.Header.Get(hmc.header)) { + return true, nil + } + + return false, nil +} + +func (hmc *HeaderMatchesChecker) Hash() string { + return hmc.hash +} + +type PathChecker struct { + regexp *regexp.Regexp + hash string +} + +func NewPathChecker(rexStr string) (Checker, error) { + rex, err := regexp.Compile(rexStr) + if err != nil { + return nil, fmt.Errorf("%w: regex %s failed parse: %w", ErrMisconfiguration, rexStr, err) + } + return &PathChecker{rex, internal.SHA256sum(rexStr)}, nil +} + +func (pc *PathChecker) Check(r *http.Request) (bool, error) { + if pc.regexp.MatchString(r.URL.Path) { + return true, nil + } + + return false, nil +} + +func (pc *PathChecker) Hash() string { + return pc.hash +} + +func NewHeaderExistsChecker(key string) Checker { + return headerExistsChecker{key} +} + +type headerExistsChecker struct { + header string +} + +func (hec headerExistsChecker) Check(r *http.Request) (bool, error) { + if r.Header.Get(hec.header) != "" { + return true, nil + } + + return false, nil +} + +func (hec headerExistsChecker) Hash() string { + return internal.SHA256sum(hec.header) +} + +func NewHeadersChecker(headermap map[string]string) (Checker, error) { + var result CheckerList + var errs []error + + for key, rexStr := range headermap { + if rexStr == ".*" { + result = append(result, headerExistsChecker{key}) + continue + } + + rex, err := regexp.Compile(rexStr) + if err != nil { + errs = append(errs, fmt.Errorf("while compiling header %s regex %s: %w", key, rexStr, err)) + continue + } + + result = append(result, &HeaderMatchesChecker{key, rex, internal.SHA256sum(key + ": " + rexStr)}) + } + + if len(errs) != 0 { + return nil, errors.Join(errs...) + } + + return result, nil +} diff --git a/lib/policy/checker_test.go b/lib/policy/checker_test.go new file mode 100644 index 0000000..6739509 --- /dev/null +++ b/lib/policy/checker_test.go @@ -0,0 +1,200 @@ +package policy + +import ( + "errors" + "net/http" + "testing" +) + +func TestRemoteAddrChecker(t *testing.T) { + for _, tt := range []struct { + name string + cidrs []string + ip string + ok bool + err error + }{ + { + name: "match_ipv4", + cidrs: []string{"0.0.0.0/0"}, + ip: "1.1.1.1", + ok: true, + err: nil, + }, + { + name: "match_ipv6", + cidrs: []string{"::/0"}, + ip: "cafe:babe::", + ok: true, + err: nil, + }, + { + name: "not_match_ipv4", + cidrs: []string{"1.1.1.1/32"}, + ip: "1.1.1.2", + ok: false, + err: nil, + }, + { + name: "not_match_ipv6", + cidrs: []string{"cafe:babe::/128"}, + ip: "cafe:babe:4::/128", + ok: false, + err: nil, + }, + { + name: "no_ip_set", + cidrs: []string{"::/0"}, + ok: false, + err: ErrMisconfiguration, + }, + { + name: "invalid_ip", + cidrs: []string{"::/0"}, + ip: "According to all natural laws of aviation", + ok: false, + err: ErrMisconfiguration, + }, + } { + t.Run(tt.name, func(t *testing.T) { + rac, err := NewRemoteAddrChecker(tt.cidrs) + if err != nil && !errors.Is(err, tt.err) { + t.Fatalf("creating RemoteAddrChecker failed: %v", err) + } + + r, err := http.NewRequest(http.MethodGet, "/", nil) + if err != nil { + t.Fatalf("can't make request: %v", err) + } + + if tt.ip != "" { + r.Header.Add("X-Real-Ip", tt.ip) + } + + ok, err := rac.Check(r) + + if tt.ok != ok { + t.Errorf("ok: %v, wanted: %v", ok, tt.ok) + } + + if err != nil && tt.err != nil && !errors.Is(err, tt.err) { + t.Errorf("err: %v, wanted: %v", err, tt.err) + } + }) + } +} + +func TestHeaderMatchesChecker(t *testing.T) { + for _, tt := range []struct { + name string + header string + rexStr string + reqHeaderKey string + reqHeaderValue string + ok bool + err error + }{ + { + name: "match", + header: "Cf-Worker", + rexStr: ".*", + reqHeaderKey: "Cf-Worker", + reqHeaderValue: "true", + ok: true, + err: nil, + }, + { + name: "not_match", + header: "Cf-Worker", + rexStr: "false", + reqHeaderKey: "Cf-Worker", + reqHeaderValue: "true", + ok: false, + err: nil, + }, + { + name: "not_present", + header: "Cf-Worker", + rexStr: "foobar", + reqHeaderKey: "Something-Else", + reqHeaderValue: "true", + ok: false, + err: nil, + }, + { + name: "invalid_regex", + rexStr: "a(b", + err: ErrMisconfiguration, + }, + } { + t.Run(tt.name, func(t *testing.T) { + hmc, err := NewHeaderMatchesChecker(tt.header, tt.rexStr) + if err != nil && !errors.Is(err, tt.err) { + t.Fatalf("creating HeaderMatchesChecker failed") + } + + if tt.err != nil && hmc == nil { + return + } + + r, err := http.NewRequest(http.MethodGet, "/", nil) + if err != nil { + t.Fatalf("can't make request: %v", err) + } + + r.Header.Set(tt.reqHeaderKey, tt.reqHeaderValue) + + ok, err := hmc.Check(r) + + if tt.ok != ok { + t.Errorf("ok: %v, wanted: %v", ok, tt.ok) + } + + if err != nil && tt.err != nil && !errors.Is(err, tt.err) { + t.Errorf("err: %v, wanted: %v", err, tt.err) + } + }) + } +} + +func TestHeaderExistsChecker(t *testing.T) { + for _, tt := range []struct { + name string + header string + reqHeader string + ok bool + }{ + { + name: "match", + header: "Authorization", + reqHeader: "Authorization", + ok: true, + }, + { + name: "not_match", + header: "Authorization", + reqHeader: "Authentication", + }, + } { + t.Run(tt.name, func(t *testing.T) { + hec := headerExistsChecker{tt.header} + + r, err := http.NewRequest(http.MethodGet, "/", nil) + if err != nil { + t.Fatalf("can't make request: %v", err) + } + + r.Header.Set(tt.reqHeader, "hunter2") + + ok, err := hec.Check(r) + + if tt.ok != ok { + t.Errorf("ok: %v, wanted: %v", ok, tt.ok) + } + + if err != nil { + t.Errorf("err: %v", err) + } + }) + } +} diff --git a/lib/policy/checkresult.go b/lib/policy/checkresult.go new file mode 100644 index 0000000..c84f326 --- /dev/null +++ b/lib/policy/checkresult.go @@ -0,0 +1,18 @@ +package policy + +import ( + "log/slog" + + "github.com/TecharoHQ/anubis/lib/policy/config" +) + +type CheckResult struct { + Name string + Rule config.Rule +} + +func (cr CheckResult) LogValue() slog.Value { + return slog.GroupValue( + slog.String("name", cr.Name), + slog.String("rule", string(cr.Rule))) +} diff --git a/lib/policy/policy.go b/lib/policy/policy.go index 2d610c8..5923f16 100644 --- a/lib/policy/policy.go +++ b/lib/policy/policy.go @@ -4,12 +4,9 @@ import ( "errors" "fmt" "io" - "net" - "regexp" "github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus/promauto" - "github.com/yl2chen/cidranger" "k8s.io/apimachinery/pkg/util/yaml" "github.com/TecharoHQ/anubis/lib/policy/config" @@ -58,57 +55,45 @@ func ParseConfig(fin io.Reader, fname string, defaultDifficulty int) (*ParsedCon } parsedBot := Bot{ - Name: b.Name, - Action: b.Action, - Headers: map[string]*regexp.Regexp{}, + Name: b.Name, + Action: b.Action, } - if len(b.RemoteAddr) > 0 { - parsedBot.Ranger = cidranger.NewPCTrieRanger() - - for _, cidr := range b.RemoteAddr { - _, rng, err := net.ParseCIDR(cidr) - if err != nil { - return nil, fmt.Errorf("[unexpected] range %s not parsing: %w", cidr, err) - } + cl := CheckerList{} - parsedBot.Ranger.Insert(cidranger.NewBasicRangerEntry(*rng)) + if len(b.RemoteAddr) > 0 { + c, err := NewRemoteAddrChecker(b.RemoteAddr) + if err != nil { + validationErrs = append(validationErrs, fmt.Errorf("while processing rule %s remote addr set: %w", b.Name, err)) + } else { + cl = append(cl, c) } } if b.UserAgentRegex != nil { - userAgent, err := regexp.Compile(*b.UserAgentRegex) + c, err := NewUserAgentChecker(*b.UserAgentRegex) if err != nil { - validationErrs = append(validationErrs, fmt.Errorf("while compiling user agent regexp: %w", err)) - continue + validationErrs = append(validationErrs, fmt.Errorf("while processing rule %s user agent regex: %w", b.Name, err)) } else { - parsedBot.UserAgent = userAgent + cl = append(cl, c) } } if b.PathRegex != nil { - path, err := regexp.Compile(*b.PathRegex) + c, err := NewPathChecker(*b.PathRegex) if err != nil { - validationErrs = append(validationErrs, fmt.Errorf("while compiling path regexp: %w", err)) - continue + validationErrs = append(validationErrs, fmt.Errorf("while processing rule %s path regex: %w", b.Name, err)) } else { - parsedBot.Path = path + cl = append(cl, c) } } if len(b.HeadersRegex) > 0 { - for name, expr := range b.HeadersRegex { - if name == "" { - continue - } - - header, err := regexp.Compile(expr) - if err != nil { - validationErrs = append(validationErrs, fmt.Errorf("while compiling header regexp: %w", err)) - continue - } else { - parsedBot.Headers[name] = header - } + c, err := NewHeadersChecker(b.HeadersRegex) + if err != nil { + validationErrs = append(validationErrs, fmt.Errorf("while processing rule %s headers regex map: %w", b.Name, err)) + } else { + cl = append(cl, c) } } @@ -125,6 +110,8 @@ func ParseConfig(fin io.Reader, fname string, defaultDifficulty int) (*ParsedCon } } + parsedBot.Rules = cl + result.Bots = append(result.Bots, parsedBot) } |
