aboutsummaryrefslogtreecommitdiff
path: root/lib/policy
diff options
context:
space:
mode:
authorXe Iaso <me@xeiaso.net>2025-04-22 07:49:41 -0400
committerGitHub <noreply@github.com>2025-04-22 07:49:41 -0400
commit84b28760b3b54c7d26ad40a1e7343d6de242ad9b (patch)
tree3104e239da1ac3cedc43ee07d58bb6ad5cc9f12c /lib/policy
parent9b7bf8ee06fd84f3aa3298f483d8bca69c3372d4 (diff)
downloadanubis-84b28760b3b54c7d26ad40a1e7343d6de242ad9b.tar.xz
anubis-84b28760b3b54c7d26ad40a1e7343d6de242ad9b.zip
feat(lib): use Checker type instead of ad-hoc logic (#318)
This makes each check into its own type that has encapsulated check logic, meaning that it's easier to add new checker implementations in the future. Signed-off-by: Xe Iaso <me@xeiaso.net>
Diffstat (limited to 'lib/policy')
-rw-r--r--lib/policy/bot.go35
-rw-r--r--lib/policy/checker.go201
-rw-r--r--lib/policy/checker_test.go200
-rw-r--r--lib/policy/checkresult.go18
-rw-r--r--lib/policy/policy.go57
5 files changed, 445 insertions, 66 deletions
diff --git a/lib/policy/bot.go b/lib/policy/bot.go
index e656d9a..3a43655 100644
--- a/lib/policy/bot.go
+++ b/lib/policy/bot.go
@@ -2,45 +2,18 @@ package policy
import (
"fmt"
- "regexp"
- "strings"
"github.com/TecharoHQ/anubis/internal"
"github.com/TecharoHQ/anubis/lib/policy/config"
- "github.com/yl2chen/cidranger"
)
type Bot struct {
Name string
- UserAgent *regexp.Regexp
- Path *regexp.Regexp
- Headers map[string]*regexp.Regexp
- Action config.Rule `json:"action"`
+ Action config.Rule
Challenge *config.ChallengeRules
- Ranger cidranger.Ranger
+ Rules Checker
}
-func (b Bot) Hash() (string, error) {
- var pathRex string
- if b.Path != nil {
- pathRex = b.Path.String()
- }
- var userAgentRex string
- if b.UserAgent != nil {
- userAgentRex = b.UserAgent.String()
- }
- var headersRex string
- if len(b.Headers) > 0 {
- var sb strings.Builder
- sb.Grow(len(b.Headers) * 64)
-
- for name, expr := range b.Headers {
- sb.WriteString(name)
- sb.WriteString(expr.String())
- }
-
- headersRex = sb.String()
- }
-
- return internal.SHA256sum(fmt.Sprintf("%s::%s::%s::%s", b.Name, pathRex, userAgentRex, headersRex)), nil
+func (b Bot) Hash() string {
+ return internal.SHA256sum(fmt.Sprintf("%s::%s", b.Name, b.Rules.Hash()))
}
diff --git a/lib/policy/checker.go b/lib/policy/checker.go
new file mode 100644
index 0000000..ad98ced
--- /dev/null
+++ b/lib/policy/checker.go
@@ -0,0 +1,201 @@
+package policy
+
+import (
+ "errors"
+ "fmt"
+ "net"
+ "net/http"
+ "regexp"
+ "strings"
+
+ "github.com/TecharoHQ/anubis/internal"
+ "github.com/yl2chen/cidranger"
+)
+
+var (
+ ErrMisconfiguration = errors.New("[unexpected] policy: administrator misconfiguration")
+)
+
+type Checker interface {
+ Check(*http.Request) (bool, error)
+ Hash() string
+}
+
+type CheckerList []Checker
+
+func (cl CheckerList) Check(r *http.Request) (bool, error) {
+ for _, c := range cl {
+ ok, err := c.Check(r)
+ if err != nil {
+ return ok, err
+ }
+ if ok {
+ return ok, nil
+ }
+ }
+
+ return false, nil
+}
+
+func (cl CheckerList) Hash() string {
+ var sb strings.Builder
+
+ for _, c := range cl {
+ fmt.Fprintln(&sb, c.Hash())
+ }
+
+ return internal.SHA256sum(sb.String())
+}
+
+type RemoteAddrChecker struct {
+ ranger cidranger.Ranger
+ hash string
+}
+
+func NewRemoteAddrChecker(cidrs []string) (Checker, error) {
+ ranger := cidranger.NewPCTrieRanger()
+ var sb strings.Builder
+
+ for _, cidr := range cidrs {
+ _, rng, err := net.ParseCIDR(cidr)
+ if err != nil {
+ return nil, fmt.Errorf("%w: range %s not parsing: %w", ErrMisconfiguration, cidr, err)
+ }
+
+ ranger.Insert(cidranger.NewBasicRangerEntry(*rng))
+ fmt.Fprintln(&sb, cidr)
+ }
+
+ return &RemoteAddrChecker{
+ ranger: ranger,
+ hash: internal.SHA256sum(sb.String()),
+ }, nil
+}
+
+func (rac *RemoteAddrChecker) Check(r *http.Request) (bool, error) {
+ host := r.Header.Get("X-Real-Ip")
+ if host == "" {
+ return false, fmt.Errorf("%w: header X-Real-Ip is not set", ErrMisconfiguration)
+ }
+
+ addr := net.ParseIP(host)
+ if addr == nil {
+ return false, fmt.Errorf("%w: %s is not an IP address", ErrMisconfiguration, host)
+ }
+
+ ok, err := rac.ranger.Contains(addr)
+ if err != nil {
+ return false, err
+ }
+
+ if ok {
+ return true, nil
+ }
+
+ return false, nil
+}
+
+func (rac *RemoteAddrChecker) Hash() string {
+ return rac.hash
+}
+
+type HeaderMatchesChecker struct {
+ header string
+ regexp *regexp.Regexp
+ hash string
+}
+
+func NewUserAgentChecker(rexStr string) (Checker, error) {
+ return NewHeaderMatchesChecker("User-Agent", rexStr)
+}
+
+func NewHeaderMatchesChecker(header, rexStr string) (Checker, error) {
+ rex, err := regexp.Compile(rexStr)
+ if err != nil {
+ return nil, fmt.Errorf("%w: regex %s failed parse: %w", ErrMisconfiguration, rexStr, err)
+ }
+ return &HeaderMatchesChecker{header, rex, internal.SHA256sum(header + ": " + rexStr)}, nil
+}
+
+func (hmc *HeaderMatchesChecker) Check(r *http.Request) (bool, error) {
+ if hmc.regexp.MatchString(r.Header.Get(hmc.header)) {
+ return true, nil
+ }
+
+ return false, nil
+}
+
+func (hmc *HeaderMatchesChecker) Hash() string {
+ return hmc.hash
+}
+
+type PathChecker struct {
+ regexp *regexp.Regexp
+ hash string
+}
+
+func NewPathChecker(rexStr string) (Checker, error) {
+ rex, err := regexp.Compile(rexStr)
+ if err != nil {
+ return nil, fmt.Errorf("%w: regex %s failed parse: %w", ErrMisconfiguration, rexStr, err)
+ }
+ return &PathChecker{rex, internal.SHA256sum(rexStr)}, nil
+}
+
+func (pc *PathChecker) Check(r *http.Request) (bool, error) {
+ if pc.regexp.MatchString(r.URL.Path) {
+ return true, nil
+ }
+
+ return false, nil
+}
+
+func (pc *PathChecker) Hash() string {
+ return pc.hash
+}
+
+func NewHeaderExistsChecker(key string) Checker {
+ return headerExistsChecker{key}
+}
+
+type headerExistsChecker struct {
+ header string
+}
+
+func (hec headerExistsChecker) Check(r *http.Request) (bool, error) {
+ if r.Header.Get(hec.header) != "" {
+ return true, nil
+ }
+
+ return false, nil
+}
+
+func (hec headerExistsChecker) Hash() string {
+ return internal.SHA256sum(hec.header)
+}
+
+func NewHeadersChecker(headermap map[string]string) (Checker, error) {
+ var result CheckerList
+ var errs []error
+
+ for key, rexStr := range headermap {
+ if rexStr == ".*" {
+ result = append(result, headerExistsChecker{key})
+ continue
+ }
+
+ rex, err := regexp.Compile(rexStr)
+ if err != nil {
+ errs = append(errs, fmt.Errorf("while compiling header %s regex %s: %w", key, rexStr, err))
+ continue
+ }
+
+ result = append(result, &HeaderMatchesChecker{key, rex, internal.SHA256sum(key + ": " + rexStr)})
+ }
+
+ if len(errs) != 0 {
+ return nil, errors.Join(errs...)
+ }
+
+ return result, nil
+}
diff --git a/lib/policy/checker_test.go b/lib/policy/checker_test.go
new file mode 100644
index 0000000..6739509
--- /dev/null
+++ b/lib/policy/checker_test.go
@@ -0,0 +1,200 @@
+package policy
+
+import (
+ "errors"
+ "net/http"
+ "testing"
+)
+
+func TestRemoteAddrChecker(t *testing.T) {
+ for _, tt := range []struct {
+ name string
+ cidrs []string
+ ip string
+ ok bool
+ err error
+ }{
+ {
+ name: "match_ipv4",
+ cidrs: []string{"0.0.0.0/0"},
+ ip: "1.1.1.1",
+ ok: true,
+ err: nil,
+ },
+ {
+ name: "match_ipv6",
+ cidrs: []string{"::/0"},
+ ip: "cafe:babe::",
+ ok: true,
+ err: nil,
+ },
+ {
+ name: "not_match_ipv4",
+ cidrs: []string{"1.1.1.1/32"},
+ ip: "1.1.1.2",
+ ok: false,
+ err: nil,
+ },
+ {
+ name: "not_match_ipv6",
+ cidrs: []string{"cafe:babe::/128"},
+ ip: "cafe:babe:4::/128",
+ ok: false,
+ err: nil,
+ },
+ {
+ name: "no_ip_set",
+ cidrs: []string{"::/0"},
+ ok: false,
+ err: ErrMisconfiguration,
+ },
+ {
+ name: "invalid_ip",
+ cidrs: []string{"::/0"},
+ ip: "According to all natural laws of aviation",
+ ok: false,
+ err: ErrMisconfiguration,
+ },
+ } {
+ t.Run(tt.name, func(t *testing.T) {
+ rac, err := NewRemoteAddrChecker(tt.cidrs)
+ if err != nil && !errors.Is(err, tt.err) {
+ t.Fatalf("creating RemoteAddrChecker failed: %v", err)
+ }
+
+ r, err := http.NewRequest(http.MethodGet, "/", nil)
+ if err != nil {
+ t.Fatalf("can't make request: %v", err)
+ }
+
+ if tt.ip != "" {
+ r.Header.Add("X-Real-Ip", tt.ip)
+ }
+
+ ok, err := rac.Check(r)
+
+ if tt.ok != ok {
+ t.Errorf("ok: %v, wanted: %v", ok, tt.ok)
+ }
+
+ if err != nil && tt.err != nil && !errors.Is(err, tt.err) {
+ t.Errorf("err: %v, wanted: %v", err, tt.err)
+ }
+ })
+ }
+}
+
+func TestHeaderMatchesChecker(t *testing.T) {
+ for _, tt := range []struct {
+ name string
+ header string
+ rexStr string
+ reqHeaderKey string
+ reqHeaderValue string
+ ok bool
+ err error
+ }{
+ {
+ name: "match",
+ header: "Cf-Worker",
+ rexStr: ".*",
+ reqHeaderKey: "Cf-Worker",
+ reqHeaderValue: "true",
+ ok: true,
+ err: nil,
+ },
+ {
+ name: "not_match",
+ header: "Cf-Worker",
+ rexStr: "false",
+ reqHeaderKey: "Cf-Worker",
+ reqHeaderValue: "true",
+ ok: false,
+ err: nil,
+ },
+ {
+ name: "not_present",
+ header: "Cf-Worker",
+ rexStr: "foobar",
+ reqHeaderKey: "Something-Else",
+ reqHeaderValue: "true",
+ ok: false,
+ err: nil,
+ },
+ {
+ name: "invalid_regex",
+ rexStr: "a(b",
+ err: ErrMisconfiguration,
+ },
+ } {
+ t.Run(tt.name, func(t *testing.T) {
+ hmc, err := NewHeaderMatchesChecker(tt.header, tt.rexStr)
+ if err != nil && !errors.Is(err, tt.err) {
+ t.Fatalf("creating HeaderMatchesChecker failed")
+ }
+
+ if tt.err != nil && hmc == nil {
+ return
+ }
+
+ r, err := http.NewRequest(http.MethodGet, "/", nil)
+ if err != nil {
+ t.Fatalf("can't make request: %v", err)
+ }
+
+ r.Header.Set(tt.reqHeaderKey, tt.reqHeaderValue)
+
+ ok, err := hmc.Check(r)
+
+ if tt.ok != ok {
+ t.Errorf("ok: %v, wanted: %v", ok, tt.ok)
+ }
+
+ if err != nil && tt.err != nil && !errors.Is(err, tt.err) {
+ t.Errorf("err: %v, wanted: %v", err, tt.err)
+ }
+ })
+ }
+}
+
+func TestHeaderExistsChecker(t *testing.T) {
+ for _, tt := range []struct {
+ name string
+ header string
+ reqHeader string
+ ok bool
+ }{
+ {
+ name: "match",
+ header: "Authorization",
+ reqHeader: "Authorization",
+ ok: true,
+ },
+ {
+ name: "not_match",
+ header: "Authorization",
+ reqHeader: "Authentication",
+ },
+ } {
+ t.Run(tt.name, func(t *testing.T) {
+ hec := headerExistsChecker{tt.header}
+
+ r, err := http.NewRequest(http.MethodGet, "/", nil)
+ if err != nil {
+ t.Fatalf("can't make request: %v", err)
+ }
+
+ r.Header.Set(tt.reqHeader, "hunter2")
+
+ ok, err := hec.Check(r)
+
+ if tt.ok != ok {
+ t.Errorf("ok: %v, wanted: %v", ok, tt.ok)
+ }
+
+ if err != nil {
+ t.Errorf("err: %v", err)
+ }
+ })
+ }
+}
diff --git a/lib/policy/checkresult.go b/lib/policy/checkresult.go
new file mode 100644
index 0000000..c84f326
--- /dev/null
+++ b/lib/policy/checkresult.go
@@ -0,0 +1,18 @@
+package policy
+
+import (
+ "log/slog"
+
+ "github.com/TecharoHQ/anubis/lib/policy/config"
+)
+
+type CheckResult struct {
+ Name string
+ Rule config.Rule
+}
+
+func (cr CheckResult) LogValue() slog.Value {
+ return slog.GroupValue(
+ slog.String("name", cr.Name),
+ slog.String("rule", string(cr.Rule)))
+}
diff --git a/lib/policy/policy.go b/lib/policy/policy.go
index 2d610c8..5923f16 100644
--- a/lib/policy/policy.go
+++ b/lib/policy/policy.go
@@ -4,12 +4,9 @@ import (
"errors"
"fmt"
"io"
- "net"
- "regexp"
"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promauto"
- "github.com/yl2chen/cidranger"
"k8s.io/apimachinery/pkg/util/yaml"
"github.com/TecharoHQ/anubis/lib/policy/config"
@@ -58,57 +55,45 @@ func ParseConfig(fin io.Reader, fname string, defaultDifficulty int) (*ParsedCon
}
parsedBot := Bot{
- Name: b.Name,
- Action: b.Action,
- Headers: map[string]*regexp.Regexp{},
+ Name: b.Name,
+ Action: b.Action,
}
- if len(b.RemoteAddr) > 0 {
- parsedBot.Ranger = cidranger.NewPCTrieRanger()
-
- for _, cidr := range b.RemoteAddr {
- _, rng, err := net.ParseCIDR(cidr)
- if err != nil {
- return nil, fmt.Errorf("[unexpected] range %s not parsing: %w", cidr, err)
- }
+ cl := CheckerList{}
- parsedBot.Ranger.Insert(cidranger.NewBasicRangerEntry(*rng))
+ if len(b.RemoteAddr) > 0 {
+ c, err := NewRemoteAddrChecker(b.RemoteAddr)
+ if err != nil {
+ validationErrs = append(validationErrs, fmt.Errorf("while processing rule %s remote addr set: %w", b.Name, err))
+ } else {
+ cl = append(cl, c)
}
}
if b.UserAgentRegex != nil {
- userAgent, err := regexp.Compile(*b.UserAgentRegex)
+ c, err := NewUserAgentChecker(*b.UserAgentRegex)
if err != nil {
- validationErrs = append(validationErrs, fmt.Errorf("while compiling user agent regexp: %w", err))
- continue
+ validationErrs = append(validationErrs, fmt.Errorf("while processing rule %s user agent regex: %w", b.Name, err))
} else {
- parsedBot.UserAgent = userAgent
+ cl = append(cl, c)
}
}
if b.PathRegex != nil {
- path, err := regexp.Compile(*b.PathRegex)
+ c, err := NewPathChecker(*b.PathRegex)
if err != nil {
- validationErrs = append(validationErrs, fmt.Errorf("while compiling path regexp: %w", err))
- continue
+ validationErrs = append(validationErrs, fmt.Errorf("while processing rule %s path regex: %w", b.Name, err))
} else {
- parsedBot.Path = path
+ cl = append(cl, c)
}
}
if len(b.HeadersRegex) > 0 {
- for name, expr := range b.HeadersRegex {
- if name == "" {
- continue
- }
-
- header, err := regexp.Compile(expr)
- if err != nil {
- validationErrs = append(validationErrs, fmt.Errorf("while compiling header regexp: %w", err))
- continue
- } else {
- parsedBot.Headers[name] = header
- }
+ c, err := NewHeadersChecker(b.HeadersRegex)
+ if err != nil {
+ validationErrs = append(validationErrs, fmt.Errorf("while processing rule %s headers regex map: %w", b.Name, err))
+ } else {
+ cl = append(cl, c)
}
}
@@ -125,6 +110,8 @@ func ParseConfig(fin io.Reader, fname string, defaultDifficulty int) (*ParsedCon
}
}
+ parsedBot.Rules = cl
+
result.Bots = append(result.Bots, parsedBot)
}