diff options
| author | Xe Iaso <me@xeiaso.net> | 2024-01-28 13:21:49 -0500 |
|---|---|---|
| committer | Xe Iaso <me@xeiaso.net> | 2024-01-28 13:24:36 -0500 |
| commit | 2f5df3bf784fc77abd1336301bdb38d3bd318387 (patch) | |
| tree | bdd8108f03b05c3c05c8d0251c7b6ac3706b56fc | |
| parent | 57bd9082dd4f01fab353db4485819a2f87416245 (diff) | |
| download | xesite-2f5df3bf784fc77abd1336301bdb38d3bd318387.tar.xz xesite-2f5df3bf784fc77abd1336301bdb38d3bd318387.zip | |
internal: add validation for referers and accept-encoding
Signed-off-by: Xe Iaso <me@xeiaso.net>
| -rw-r--r-- | internal/accept_encoding.go | 89 | ||||
| -rw-r--r-- | internal/accept_encoding_test.go | 109 | ||||
| -rw-r--r-- | internal/referer.go | 8 |
3 files changed, 204 insertions, 2 deletions
diff --git a/internal/accept_encoding.go b/internal/accept_encoding.go index 9eca7ed..395cf24 100644 --- a/internal/accept_encoding.go +++ b/internal/accept_encoding.go @@ -3,21 +3,108 @@ package internal import ( "expvar" "net/http" + "strconv" + "strings" "tailscale.com/metrics" ) var ( acceptEncodings = &metrics.LabelMap{Label: "encoding"} + + validEncodings = []string{ + "gzip", + "x-gzip", + "deflate", + "br", + "identity", + "snappy", + "bzip2", + "lzma", + "zstd", + } ) func init() { expvar.Publish("gauge_xesite_accept_encoding", acceptEncodings) } +func inValidEncodings(enc string) bool { + for _, validEnc := range validEncodings { + if enc == validEnc { + return true + } + } + return false +} + func AcceptEncodingMiddleware(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - acceptEncodings.Add(r.Header.Get("Accept-Encoding"), 1) + for _, enc := range ParseAcceptEncoding(r.Header.Get("Accept-Encoding")) { + if !inValidEncodings(enc.Encoding) { + continue + } + acceptEncodings.Add(enc.Encoding, 1) + } + next.ServeHTTP(w, r) }) } + +type EncodingQ struct { + Encoding string + Q float64 +} + +func ParseAcceptEncoding(acptEnc string) []EncodingQ { + var eqs []EncodingQ + + encQStrs := strings.Split(acptEnc, ",") + for _, encQStr := range encQStrs { + trimedEncQStr := strings.Trim(encQStr, " ") + + encQ := strings.Split(trimedEncQStr, ";") + if len(encQ) == 1 { + eq := EncodingQ{encQ[0], 1} + eqs = append(eqs, eq) + } else { + qp := strings.Split(encQ[1], "=") + q, err := strconv.ParseFloat(qp[1], 64) + if err != nil { + panic(err) + } + eq := EncodingQ{encQ[0], q} + eqs = append(eqs, eq) + } + } + return eqs +} + +type LangQ struct { + Lang string + Q float64 +} + +func ParseAcceptLanguage(acptLang string) []LangQ { + var lqs []LangQ + + langQStrs := strings.Split(acptLang, ",") + for _, langQStr := range langQStrs { + trimedLangQStr := strings.Trim(langQStr, " ") + + langQ := strings.Split(trimedLangQStr, ";") + if len(langQ) == 1 { + lq := LangQ{langQ[0], 1} + lqs = append(lqs, lq) + } else { + qp := strings.Split(langQ[1], "=") + q, err := strconv.ParseFloat(qp[1], 64) + if err != nil { + panic(err) + } + lq := LangQ{langQ[0], q} + lqs = append(lqs, lq) + } + } + return lqs +} diff --git a/internal/accept_encoding_test.go b/internal/accept_encoding_test.go new file mode 100644 index 0000000..503579f --- /dev/null +++ b/internal/accept_encoding_test.go @@ -0,0 +1,109 @@ +package internal + +import "testing" + +func TestInValidEncodings(t *testing.T) { + tests := []struct { + enc string + ok bool + }{ + {"gzip", true}, + {"x-gzip", true}, + {"tacobell", false}, + } + + for _, test := range tests { + t.Run(test.enc, func(t *testing.T) { + ok := inValidEncodings(test.enc) + if ok != test.ok { + t.Errorf("ok = %t, want %t", ok, test.ok) + } + }) + } +} + +func TestParseAcceptLanguage(t *testing.T) { + acptLang := "en-US,en;q=0.9,ja-JP;q=0.8,ja;q=0.7" + lqs := ParseAcceptLanguage(acptLang) + if len(lqs) != 4 { + t.Errorf("len(lqs) = %d, want 4", len(lqs)) + } + if lqs[0].Lang != "en-US" { + t.Errorf("lqs[0].Lang = %s, want en-US", lqs[0].Lang) + } + if lqs[0].Q != 1 { + t.Errorf("lqs[0].Q = %f, want 1", lqs[0].Q) + } + if lqs[1].Lang != "en" { + t.Errorf("lqs[1].Lang = %s, want en", lqs[1].Lang) + } + if lqs[1].Q != 0.9 { + t.Errorf("lqs[1].Q = %f, want 0.9", lqs[1].Q) + } + if lqs[2].Lang != "ja-JP" { + t.Errorf("lqs[2].Lang = %s, want ja-JP", lqs[2].Lang) + } + if lqs[2].Q != 0.8 { + t.Errorf("lqs[2].Q = %f, want 0.8", lqs[2].Q) + } + if lqs[3].Lang != "ja" { + t.Errorf("lqs[3].Lang = %s, want ja", lqs[3].Lang) + } + if lqs[3].Q != 0.7 { + t.Errorf("lqs[3].Q = %f, want 0.7", lqs[3].Q) + } + + t.Run("invalid", func(t *testing.T) { + panicked := false + acptEnc := "taco;q=beer" + defer func() { + if r := recover(); r != nil { + panicked = true + } + if !panicked { + t.Errorf("did not panic") + } + }() + ParseAcceptLanguage(acptEnc) + }) +} + +func TestParseAcceptEncoding(t *testing.T) { + acptEnc := "gzip, deflate, br;q=0.9" + eqs := ParseAcceptEncoding(acptEnc) + if len(eqs) != 3 { + t.Errorf("len(eqs) = %d, want 3", len(eqs)) + } + if eqs[0].Encoding != "gzip" { + t.Errorf("eqs[0].Encoding = %s, want gzip", eqs[0].Encoding) + } + if eqs[0].Q != 1 { + t.Errorf("eqs[0].Q = %f, want 1", eqs[0].Q) + } + if eqs[1].Encoding != "deflate" { + t.Errorf("eqs[1].Encoding = %s, want deflate", eqs[1].Encoding) + } + if eqs[1].Q != 1 { + t.Errorf("eqs[1].Q = %f, want 1", eqs[1].Q) + } + if eqs[2].Encoding != "br" { + t.Errorf("eqs[2].Encoding = %s, want br", eqs[2].Encoding) + } + if eqs[2].Q != 1 { + t.Errorf("eqs[2].Q = %f, want 1", eqs[2].Q) + } + + t.Run("invalid", func(t *testing.T) { + panicked := false + acptEnc := "gzip, deflate, taco;q=beer" + defer func() { + if r := recover(); r != nil { + panicked = true + } + if !panicked { + t.Errorf("did not panic") + } + }() + ParseAcceptEncoding(acptEnc) + }) +} diff --git a/internal/referer.go b/internal/referer.go index ec8c2be..3a7d16f 100644 --- a/internal/referer.go +++ b/internal/referer.go @@ -3,6 +3,7 @@ package internal import ( "expvar" "net/http" + "net/url" "tailscale.com/metrics" ) @@ -17,7 +18,12 @@ func init() { func RefererMiddleware(next http.Handler) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { - referers.Add(r.Header.Get("Referer"), 1) + if referer := r.Header.Get("Referer"); referer != "" { + _, err := url.Parse(referer) + if err == nil { + referers.Add(referer, 1) + } + } next.ServeHTTP(w, r) } } |
