diff options
| author | Xe Iaso <me@xeiaso.net> | 2025-04-19 16:09:19 -0400 |
|---|---|---|
| committer | Xe Iaso <me@xeiaso.net> | 2025-04-19 16:09:19 -0400 |
| commit | b5533db3c24657875901863126f46adae488e703 (patch) | |
| tree | ae1b7dbdfb8ea1f292995481f531b6dc628de264 /cmd | |
| parent | b2bc32ad7a215e257616bde979a6febcef1332e1 (diff) | |
| download | x-b5533db3c24657875901863126f46adae488e703.tar.xz x-b5533db3c24657875901863126f46adae488e703.zip | |
cmd/relayd: automagically reload TLS certificates, JA3N/JA4 fingerprints
Signed-off-by: Xe Iaso <me@xeiaso.net>
Diffstat (limited to 'cmd')
| -rw-r--r-- | cmd/relayd/fingerprint.go | 339 | ||||
| -rw-r--r-- | cmd/relayd/main.go | 126 |
2 files changed, 457 insertions, 8 deletions
diff --git a/cmd/relayd/fingerprint.go b/cmd/relayd/fingerprint.go new file mode 100644 index 0000000..64322a9 --- /dev/null +++ b/cmd/relayd/fingerprint.go @@ -0,0 +1,339 @@ +package main + +import ( + "context" + "crypto/md5" + "crypto/sha256" + "crypto/tls" + "encoding/hex" + "fmt" + "net" + "net/http" + "slices" + "strconv" + "strings" + "sync/atomic" +) + +func applyTLSFingerprinter(server *http.Server) { + server.TLSConfig = server.TLSConfig.Clone() + + getCertificate := server.TLSConfig.GetCertificate + if getCertificate == nil { + server.TLSConfig.GetCertificate = func(clientHello *tls.ClientHelloInfo) (*tls.Certificate, error) { + ja3n, ja4 := buildTLSFingerprint(clientHello) + ptr := clientHello.Context().Value(tlsFingerprintKey{}) + if fpPtr, ok := ptr.(*TLSFingerprint); ok && ptr != nil && fpPtr != nil { + fpPtr.ja3n.Store(&ja3n) + fpPtr.ja4.Store(&ja4) + } + + return nil, nil + } + } else { + server.TLSConfig.GetCertificate = func(clientHello *tls.ClientHelloInfo) (*tls.Certificate, error) { + ja3n, ja4 := buildTLSFingerprint(clientHello) + ptr := clientHello.Context().Value(tlsFingerprintKey{}) + if fpPtr, ok := ptr.(*TLSFingerprint); ok && ptr != nil && fpPtr != nil { + fpPtr.ja3n.Store(&ja3n) + fpPtr.ja4.Store(&ja4) + } + + return getCertificate(clientHello) + } + } + server.ConnContext = func(ctx context.Context, c net.Conn) context.Context { + return context.WithValue(ctx, tlsFingerprintKey{}, &TLSFingerprint{}) + } +} + +type tlsFingerprintKey struct{} +type TLSFingerprint struct { + ja3n atomic.Pointer[TLSFingerprintJA3N] + ja4 atomic.Pointer[TLSFingerprintJA4] +} + +type TLSFingerprintJA3N [md5.Size]byte + +func (f TLSFingerprintJA3N) String() string { + return hex.EncodeToString(f[:]) +} + +type TLSFingerprintJA4 struct { + A [10]byte + B [6]byte + C [6]byte +} + +func (f TLSFingerprintJA4) String() string { + return strings.Join([]string{ + string(f.A[:]), + hex.EncodeToString(f.B[:]), + hex.EncodeToString(f.C[:]), + }, "_") +} + +func (f *TLSFingerprint) JA3N() *TLSFingerprintJA3N { + return f.ja3n.Load() +} + +func (f *TLSFingerprint) JA4() *TLSFingerprintJA4 { + return f.ja4.Load() +} + +const greaseMask = 0x0F0F +const greaseValue = 0x0a0a + +// TLS extension numbers +const ( + extensionServerName uint16 = 0 + extensionStatusRequest uint16 = 5 + extensionSupportedCurves uint16 = 10 // supported_groups in TLS 1.3, see RFC 8446, Section 4.2.7 + extensionSupportedPoints uint16 = 11 + extensionSignatureAlgorithms uint16 = 13 + extensionALPN uint16 = 16 + extensionSCT uint16 = 18 + extensionExtendedMasterSecret uint16 = 23 + extensionSessionTicket uint16 = 35 + extensionPreSharedKey uint16 = 41 + extensionEarlyData uint16 = 42 + extensionSupportedVersions uint16 = 43 + extensionCookie uint16 = 44 + extensionPSKModes uint16 = 45 + extensionCertificateAuthorities uint16 = 47 + extensionSignatureAlgorithmsCert uint16 = 50 + extensionKeyShare uint16 = 51 + extensionQUICTransportParameters uint16 = 57 + extensionRenegotiationInfo uint16 = 0xff01 + extensionECHOuterExtensions uint16 = 0xfd00 + extensionEncryptedClientHello uint16 = 0xfe0d +) + +func tlsFingerprintJA3(hello *tls.ClientHelloInfo, sortExtensions bool) []byte { + buf := make([]byte, 0, 256) + + { + var sslVersion uint16 + var hasGrease bool + for _, v := range hello.SupportedVersions { + if v&greaseMask != greaseValue { + if v > sslVersion { + sslVersion = v + } + } else { + hasGrease = true + } + } + + // maximum TLS 1.2 as specified on JA3, as TLS 1.3 is put in SupportedVersions + if slices.Contains(hello.Extensions, extensionSupportedVersions) && hasGrease && sslVersion > tls.VersionTLS12 { + sslVersion = tls.VersionTLS12 + } + + buf = strconv.AppendUint(buf, uint64(sslVersion), 10) + buf = append(buf, ',') + } + + n := 0 + for _, cipher := range hello.CipherSuites { + //if !slices.Contains(greaseValues[:], cipher) { + if cipher&greaseMask != greaseValue { + buf = strconv.AppendUint(buf, uint64(cipher), 10) + buf = append(buf, '-') + n = 1 + } + } + + buf = buf[:len(buf)-n] + buf = append(buf, ',') + n = 0 + + extensions := hello.Extensions + if sortExtensions { + extensions = slices.Clone(extensions) + slices.Sort(extensions) + } + + for _, extension := range extensions { + if extension&greaseMask != greaseValue { + buf = strconv.AppendUint(buf, uint64(extension), 10) + buf = append(buf, '-') + n = 1 + } + } + + buf = buf[:len(buf)-n] + buf = append(buf, ',') + n = 0 + + for _, curve := range hello.SupportedCurves { + if curve&greaseMask != greaseValue { + buf = strconv.AppendUint(buf, uint64(curve), 10) + buf = append(buf, '-') + n = 1 + } + } + + buf = buf[:len(buf)-n] + buf = append(buf, ',') + n = 0 + + for _, point := range hello.SupportedPoints { + buf = strconv.AppendUint(buf, uint64(point), 10) + buf = append(buf, '-') + n = 1 + } + + buf = buf[:len(buf)-n] + + sum := md5.Sum(buf) + return sum[:] +} + +func tlsFingerprintJA4(hello *tls.ClientHelloInfo) (ja4 TLSFingerprintJA4) { + buf := make([]byte, 0, 10) + + // TODO: t = TLS, q = QUIC + buf = append(buf, 't') + + { + var sslVersion uint16 + for _, v := range hello.SupportedVersions { + if v&greaseMask != greaseValue { + if v > sslVersion { + sslVersion = v + } + } + } + + switch sslVersion { + case tls.VersionSSL30: + buf = append(buf, 's', '3') + case tls.VersionTLS10: + buf = append(buf, '1', '0') + case tls.VersionTLS11: + buf = append(buf, '1', '1') + case tls.VersionTLS12: + buf = append(buf, '1', '2') + case tls.VersionTLS13: + buf = append(buf, '1', '3') + default: + sslVersion -= 0x0201 + buf = strconv.AppendUint(buf, uint64(sslVersion>>8), 10) + buf = strconv.AppendUint(buf, uint64(sslVersion&0xff), 10) + } + + } + + if slices.Contains(hello.Extensions, extensionServerName) && hello.ServerName != "" { + buf = append(buf, 'd') + } else { + buf = append(buf, 'i') + } + + ciphers := make([]uint16, 0, len(hello.CipherSuites)) + for _, cipher := range hello.CipherSuites { + if cipher&greaseMask != greaseValue { + ciphers = append(ciphers, cipher) + } + } + + extensionCount := 0 + extensions := make([]uint16, 0, len(hello.Extensions)) + for _, extension := range hello.Extensions { + if extension&greaseMask != greaseValue { + extensionCount++ + if extension != extensionALPN && extension != extensionServerName { + extensions = append(extensions, extension) + } + } + } + + schemes := make([]tls.SignatureScheme, 0, len(hello.SignatureSchemes)) + + for _, scheme := range hello.SignatureSchemes { + if scheme&greaseMask != greaseValue { + schemes = append(schemes, scheme) + } + } + + //TODO: maybe little endian + slices.Sort(ciphers) + slices.Sort(extensions) + //slices.Sort(schemes) + + if len(ciphers) < 10 { + buf = append(buf, '0') + buf = strconv.AppendUint(buf, uint64(len(ciphers)), 10) + } else if len(ciphers) > 99 { + buf = append(buf, '9', '9') + } else { + buf = strconv.AppendUint(buf, uint64(len(ciphers)), 10) + } + + if extensionCount < 10 { + buf = append(buf, '0') + buf = strconv.AppendUint(buf, uint64(extensionCount), 10) + } else if extensionCount > 99 { + buf = append(buf, '9', '9') + } else { + buf = strconv.AppendUint(buf, uint64(extensionCount), 10) + } + + if len(hello.SupportedProtos) > 0 && len(hello.SupportedProtos[0]) > 1 { + buf = append(buf, hello.SupportedProtos[0][0], hello.SupportedProtos[0][len(hello.SupportedProtos[0])-1]) + } else { + buf = append(buf, '0', '0') + } + + copy(ja4.A[:], buf) + + ja4.B = ja4SHA256(uint16SliceToHex(ciphers)) + + extBuf := uint16SliceToHex(extensions) + + if len(schemes) > 0 { + extBuf = append(extBuf, '_') + extBuf = append(extBuf, uint16SliceToHex(schemes)...) + } + + ja4.C = ja4SHA256(extBuf) + + return ja4 +} + +func uint16SliceToHex[T ~uint16](in []T) (out []byte) { + if len(in) == 0 { + return out + } + out = slices.Grow(out, hex.EncodedLen(len(in)*2)+len(in)) + + for _, n := range in { + out = append(out, fmt.Sprintf("%04x", uint16(n))...) + out = append(out, ',') + } + out = out[:len(out)-1] + + return out +} + +func ja4SHA256(buf []byte) [6]byte { + if len(buf) == 0 { + return [6]byte{0, 0, 0, 0, 0, 0} + } + sum := sha256.Sum256(buf) + + return [6]byte(sum[:6]) +} + +func buildTLSFingerprint(hello *tls.ClientHelloInfo) (ja3n TLSFingerprintJA3N, ja4 TLSFingerprintJA4) { + return TLSFingerprintJA3N(tlsFingerprintJA3(hello, true)), tlsFingerprintJA4(hello) +} + +func GetTLSFingerprint(r *http.Request) *TLSFingerprint { + ptr := r.Context().Value(tlsFingerprintKey{}) + if fpPtr, ok := ptr.(*TLSFingerprint); ok && ptr != nil && fpPtr != nil { + return fpPtr + } + return nil +} diff --git a/cmd/relayd/main.go b/cmd/relayd/main.go index 18d94cc..4ee685a 100644 --- a/cmd/relayd/main.go +++ b/cmd/relayd/main.go @@ -1,14 +1,20 @@ package main import ( + "crypto/tls" "flag" + "fmt" "log" "log/slog" + "net" "net/http" "net/http/httputil" "net/url" "os" + "os/signal" "path/filepath" + "sync" + "syscall" "time" "within.website/x/internal" @@ -68,12 +74,116 @@ func main() { log.Fatal(err) } - log.Fatal( - http.ListenAndServeTLS( - *bind, - cert, - key, - httputil.NewSingleHostReverseProxy(u), - ), - ) + kpr, err := NewKeypairReloader(cert, key) + if err != nil { + log.Fatal(err) + } + + h := httputil.NewSingleHostReverseProxy(u) + oldDirector := h.Director + + h.Director = func(req *http.Request) { + oldDirector(req) + + host, _, _ := net.SplitHostPort(req.RemoteAddr) + if host != "" { + req.Header.Set("X-Real-Ip", host) + } + + fp := GetTLSFingerprint(req) + if fp != nil { + req.Header.Set("JA3N-Fingerprint", fp.JA3N().String()) + req.Header.Set("JA4-Fingerprint", fp.JA4().String()) + } + } + + srv := &http.Server{ + Addr: *bind, + Handler: h, + TLSConfig: &tls.Config{ + GetCertificate: kpr.GetCertificate, + }, + } + + applyTLSFingerprinter(srv) + + log.Fatal(srv.ListenAndServeTLS("", "")) +} + +type keypairReloader struct { + certMu sync.RWMutex + cert *tls.Certificate + certPath string + keyPath string + modTime time.Time +} + +func NewKeypairReloader(certPath, keyPath string) (*keypairReloader, error) { + result := &keypairReloader{ + certPath: certPath, + keyPath: keyPath, + } + cert, err := tls.LoadX509KeyPair(certPath, keyPath) + if err != nil { + return nil, err + } + result.cert = &cert + + st, err := os.Stat(certPath) + if err != nil { + return nil, err + } + result.modTime = st.ModTime() + + go func() { + c := make(chan os.Signal, 1) + signal.Notify(c, syscall.SIGHUP) + for range c { + slog.Info("got SIGHUP") + if err := result.maybeReload(); err != nil { + slog.Error("can't load tls cert", "err", err) + } + } + }() + return result, nil +} + +func (kpr *keypairReloader) maybeReload() error { + slog.Info("loading new keypair", "cert", kpr.certPath, "key", kpr.keyPath) + newCert, err := tls.LoadX509KeyPair(kpr.certPath, kpr.keyPath) + if err != nil { + return err + } + + st, err := os.Stat(kpr.certPath) + if err != nil { + return err + } + + kpr.certMu.Lock() + defer kpr.certMu.Unlock() + kpr.cert = &newCert + kpr.modTime = st.ModTime() + + return nil +} + +func (kpr *keypairReloader) GetCertificate(*tls.ClientHelloInfo) (*tls.Certificate, error) { + kpr.certMu.RLock() + defer kpr.certMu.RUnlock() + + st, err := os.Stat(kpr.certPath) + if err != nil { + return nil, fmt.Errorf("internal error: stat(%q): %q", kpr.certPath, err) + } + + if st.ModTime().After(kpr.modTime) { + kpr.certMu.RUnlock() + if err := kpr.maybeReload(); err != nil { + return nil, fmt.Errorf("can't reload cert: %w", err) + } + kpr.certMu.RLock() + } + + return kpr.cert, nil } |
