From 03df9f24b6ec6a8d1020cd8a1e09ead21ec6b294 Mon Sep 17 00:00:00 2001 From: Victor Fusco <1221933+vfusco@users.noreply.github.com> Date: Tue, 14 Apr 2026 18:37:18 -0300 Subject: [PATCH 01/17] feat(service): add HTTP server helpers and error response writer --- pkg/service/http_server.go | 191 +++++++++++++++++++++++++ pkg/service/http_server_test.go | 239 ++++++++++++++++++++++++++++++++ 2 files changed, 430 insertions(+) create mode 100644 pkg/service/http_server.go create mode 100644 pkg/service/http_server_test.go diff --git a/pkg/service/http_server.go b/pkg/service/http_server.go new file mode 100644 index 000000000..ef2fd2cb8 --- /dev/null +++ b/pkg/service/http_server.go @@ -0,0 +1,191 @@ +// (c) Cartesi and individual authors (see AUTHORS) +// SPDX-License-Identifier: Apache-2.0 (see LICENSE) + +package service + +import ( + "context" + "fmt" + "log/slog" + "net" + "net/http" + "time" +) + +// HTTPServerOptions bundles the hardening knobs applied to an [http.Server]. +// +// Every field is required. Callers typically start from one of the package-level +// preset constructors ([DefaultInspectOptions], [DefaultTelemetryOptions], +// [DefaultJSONRPCOptions]) and mutate fields on the returned value. +type HTTPServerOptions struct { + ReadHeaderTimeout time.Duration + ReadTimeout time.Duration + WriteTimeout time.Duration + IdleTimeout time.Duration + MaxHeaderBytes int +} + +// DefaultInspectOptions returns a fresh [HTTPServerOptions] preset for the +// inspect HTTP surface. WriteTimeout is 600s (10 minutes) and serves as a +// process-health backstop only — it prevents leaked goroutines from holding +// a connection forever. The actual per-request deadline is set structurally +// in Inspector.ServeHTTP as InspectMaxDeadline + inspectResponseHeadroom. +// +// Each call returns an independent value; callers may mutate it freely +// without affecting other callers. +// +//nolint:mnd // These timers are the canonical source of these numbers. +func DefaultInspectOptions() HTTPServerOptions { + return HTTPServerOptions{ + ReadHeaderTimeout: 10 * time.Second, + ReadTimeout: 30 * time.Second, + WriteTimeout: 600 * time.Second, + IdleTimeout: 60 * time.Second, + MaxHeaderBytes: 64 * 1024, + } +} + +// DefaultTelemetryOptions returns a fresh [HTTPServerOptions] preset for the +// telemetry HTTP surface. Each call returns an independent value; callers +// may mutate it freely without affecting other callers. +// +//nolint:mnd // These timers are the canonical source of these numbers. +func DefaultTelemetryOptions() HTTPServerOptions { + return HTTPServerOptions{ + ReadHeaderTimeout: 5 * time.Second, + ReadTimeout: 10 * time.Second, + WriteTimeout: 10 * time.Second, + IdleTimeout: 60 * time.Second, + MaxHeaderBytes: 16 * 1024, + } +} + +// DefaultJSONRPCOptions returns a fresh [HTTPServerOptions] preset for the +// JSON-RPC HTTP surface. Each call returns an independent value; callers +// may mutate it freely without affecting other callers. +// +//nolint:mnd // These timers are the canonical source of these numbers. +func DefaultJSONRPCOptions() HTTPServerOptions { + return HTTPServerOptions{ + ReadHeaderTimeout: 10 * time.Second, + ReadTimeout: 30 * time.Second, + WriteTimeout: 30 * time.Second, + IdleTimeout: 60 * time.Second, + MaxHeaderBytes: 64 * 1024, + } +} + +// NewHTTPServer builds an [http.Server] from opts with an ErrorLog routed +// through logger. +func NewHTTPServer( + addr string, + handler http.Handler, + opts HTTPServerOptions, + logger *slog.Logger, +) *http.Server { + return &http.Server{ + Addr: addr, + Handler: handler, + ReadHeaderTimeout: opts.ReadHeaderTimeout, + ReadTimeout: opts.ReadTimeout, + WriteTimeout: opts.WriteTimeout, + IdleTimeout: opts.IdleTimeout, + MaxHeaderBytes: opts.MaxHeaderBytes, + ErrorLog: slog.NewLogLogger(logger.Handler(), slog.LevelError), + } +} + +// StartupBindWarning logs a warning when addr binds to an address that +// exposes the HTTP service beyond the local host. Specifically: +// +// - Unspecified address (0.0.0.0, ::, or an empty host like ":10012") — +// WARN: reachable on every interface. +// - Private RFC1918/RFC4193 address (192.168/16, 10/8, 172.16/12, ULA +// fc00::/7) or IPv6 link-local (fe80::/10) — WARN: reachable by other +// hosts on the same LAN/VPC, which is the most common multi-tenant +// foot-gun. +// +// Loopback (127.0.0.0/8, ::1) and public/hostname binds stay silent: the +// former is the recommended posture, the latter is assumed to be a +// deliberate choice the operator has already reasoned about. +// +// If [net.SplitHostPort] cannot parse addr, this logs at INFO so operator +// typos like "127.0.0.1.10012" (missing colon) are visible — the helper +// made no statement, and the subsequent ListenAndServe will fail loudly. +func StartupBindWarning(logger *slog.Logger, serviceName, addr string) { + host, _, err := net.SplitHostPort(addr) + if err != nil { + logger.Info( + "StartupBindWarning could not parse address; skipping bind-exposure check", + "service", serviceName, + "addr", addr, + "err", err, + ) + return + } + var ip net.IP + if host != "" { + ip = net.ParseIP(host) + } + switch { + case host == "" || (ip != nil && ip.IsUnspecified()): + logger.Warn( + "HTTP service bound to all interfaces; restrict access via firewall or reverse proxy", + "service", serviceName, + "addr", addr, + ) + case ip != nil && (ip.IsPrivate() || ip.IsLinkLocalUnicast()): + logger.Warn( + "HTTP service bound to a private/link-local address; restrict access via firewall or reverse proxy", + "service", serviceName, + "addr", addr, + ) + } + // Loopback, public literal, or non-IP hostname: stay silent. +} + +// ctxKeyRequestID is the context key under which [RequestIDMiddleware] stores +// the validated or generated X-Request-ID value. +type ctxKeyRequestID struct{} + +// RequestIDFromContext returns the request id stored by [RequestIDMiddleware], +// or the empty string if none is present. +func RequestIDFromContext(ctx context.Context) string { + if ctx == nil { + return "" + } + if v, ok := ctx.Value(ctxKeyRequestID{}).(string); ok { + return v + } + return "" +} + +// WriteInternalError writes a generic 500 response and logs the detailed +// error. The response body is exactly: +// +// Internal server error (request_id=)\n +// +// where comes from [RequestIDFromContext] (empty if no middleware has +// populated it). The caller's err is never echoed onto the wire — it only +// appears in the structured log under the "err" key. The log message is a +// fixed string ("http internal error") so operators can grep reliably on +// the structured field rather than a per-call-site prefix. +func WriteInternalError( + ctx context.Context, + w http.ResponseWriter, + logger *slog.Logger, + err error, +) { + reqID := RequestIDFromContext(ctx) + logger.Error("http internal error", + "err", err, + "request_id", reqID, + ) + w.Header().Set("Content-Type", "text/plain; charset=utf-8") + w.Header().Set("X-Content-Type-Options", "nosniff") + w.WriteHeader(http.StatusInternalServerError) + // gosec G705 is a false positive here: the response is text/plain with + // nosniff, and reqID is validated upstream by RequestIDMiddleware (requestIDPattern), + // so no taint can reach a browser as HTML. + _, _ = fmt.Fprintf(w, "Internal server error (request_id=%s)\n", reqID) //nolint:gosec // G705 +} diff --git a/pkg/service/http_server_test.go b/pkg/service/http_server_test.go new file mode 100644 index 000000000..64a5a0fff --- /dev/null +++ b/pkg/service/http_server_test.go @@ -0,0 +1,239 @@ +// (c) Cartesi and individual authors (see AUTHORS) +// SPDX-License-Identifier: Apache-2.0 (see LICENSE) + +package service + +import ( + "bytes" + "context" + "errors" + "io" + "log/slog" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +func discardLogger() *slog.Logger { + return slog.New(slog.NewTextHandler(io.Discard, nil)) +} + +// captureLogger returns a logger whose output is written to buf. Level is set +// to debug so every call is recorded. +func captureLogger(buf *bytes.Buffer) *slog.Logger { + return slog.New(slog.NewTextHandler(buf, &slog.HandlerOptions{Level: slog.LevelDebug})) +} + +func TestHTTPServerOptions_PresetValues(t *testing.T) { + inspect := DefaultInspectOptions() + require.Equal(t, 10*time.Second, inspect.ReadHeaderTimeout) + require.Equal(t, 30*time.Second, inspect.ReadTimeout) + require.Equal(t, 600*time.Second, inspect.WriteTimeout) + require.Equal(t, 60*time.Second, inspect.IdleTimeout) + require.Equal(t, 64*1024, inspect.MaxHeaderBytes) + + telemetry := DefaultTelemetryOptions() + require.Equal(t, 5*time.Second, telemetry.ReadHeaderTimeout) + require.Equal(t, 10*time.Second, telemetry.ReadTimeout) + require.Equal(t, 10*time.Second, telemetry.WriteTimeout) + require.Equal(t, 60*time.Second, telemetry.IdleTimeout) + require.Equal(t, 16*1024, telemetry.MaxHeaderBytes) + + jsonrpc := DefaultJSONRPCOptions() + require.Equal(t, 10*time.Second, jsonrpc.ReadHeaderTimeout) + require.Equal(t, 30*time.Second, jsonrpc.ReadTimeout) + require.Equal(t, 30*time.Second, jsonrpc.WriteTimeout) + require.Equal(t, 60*time.Second, jsonrpc.IdleTimeout) + require.Equal(t, 64*1024, jsonrpc.MaxHeaderBytes) +} + +// TestDefaultInspectOptions_ReturnsFreshCopy pins that the preset constructor +// hands every caller an independent value. Mutating one returned struct must +// not affect the next call's result. +func TestDefaultInspectOptions_ReturnsFreshCopy(t *testing.T) { + first := DefaultInspectOptions() + first.WriteTimeout = 0 + require.Equal(t, time.Duration(0), first.WriteTimeout) + second := DefaultInspectOptions() + require.Equal(t, 600*time.Second, second.WriteTimeout) +} + +// TestInspectWriteTimeoutExceedsMachineDeadline pins that the backstop +// WriteTimeout is well above any realistic machine deadline. +func TestInspectWriteTimeoutExceedsMachineDeadline(t *testing.T) { + require.Greater(t, DefaultInspectOptions().WriteTimeout, 180*time.Second) +} + +func TestNewHTTPServer_AppliesOptions(t *testing.T) { + handler := http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) {}) + opts := DefaultInspectOptions() + srv := NewHTTPServer(":0", handler, opts, discardLogger()) + + require.Equal(t, ":0", srv.Addr) + require.NotNil(t, srv.Handler) + require.Equal(t, opts.ReadHeaderTimeout, srv.ReadHeaderTimeout) + require.Equal(t, opts.ReadTimeout, srv.ReadTimeout) + require.Equal(t, opts.WriteTimeout, srv.WriteTimeout) + require.Equal(t, opts.IdleTimeout, srv.IdleTimeout) + require.Equal(t, opts.MaxHeaderBytes, srv.MaxHeaderBytes) +} + +func TestNewHTTPServer_ErrorLogWiring(t *testing.T) { + var buf bytes.Buffer + srv := NewHTTPServer(":0", http.NotFoundHandler(), DefaultInspectOptions(), captureLogger(&buf)) + require.NotNil(t, srv.ErrorLog) + // Write through the ErrorLog and check it lands in the captured output. + srv.ErrorLog.Print("test-error-log-line") + require.Contains(t, buf.String(), "test-error-log-line") +} + +func TestStartupBindWarning_UnspecifiedV4(t *testing.T) { + var buf bytes.Buffer + StartupBindWarning(captureLogger(&buf), "test", "0.0.0.0:10012") + require.Contains(t, buf.String(), "bound to all interfaces") + require.Contains(t, buf.String(), "0.0.0.0:10012") +} + +func TestStartupBindWarning_UnspecifiedV6(t *testing.T) { + var buf bytes.Buffer + StartupBindWarning(captureLogger(&buf), "test", "[::]:10012") + require.Contains(t, buf.String(), "bound to all interfaces") +} + +func TestStartupBindWarning_EmptyHost(t *testing.T) { + var buf bytes.Buffer + StartupBindWarning(captureLogger(&buf), "test", ":10012") + require.Contains(t, buf.String(), "bound to all interfaces") +} + +func TestStartupBindWarning_Localhost(t *testing.T) { + var buf bytes.Buffer + StartupBindWarning(captureLogger(&buf), "test", "127.0.0.1:10012") + require.Empty(t, buf.String()) +} + +func TestStartupBindWarning_LocalhostV6(t *testing.T) { + var buf bytes.Buffer + StartupBindWarning(captureLogger(&buf), "test", "[::1]:10012") + require.Empty(t, buf.String()) +} + +func TestStartupBindWarning_Hostname(t *testing.T) { + var buf bytes.Buffer + StartupBindWarning(captureLogger(&buf), "test", "myhost:10012") + require.Empty(t, buf.String()) +} + +func TestStartupBindWarning_PrivateV4_192(t *testing.T) { + var buf bytes.Buffer + StartupBindWarning(captureLogger(&buf), "test", "192.168.1.10:10012") + require.Contains(t, buf.String(), "level=WARN") + require.Contains(t, buf.String(), "private/link-local") + require.Contains(t, buf.String(), "192.168.1.10:10012") +} + +func TestStartupBindWarning_PrivateV4_10(t *testing.T) { + var buf bytes.Buffer + StartupBindWarning(captureLogger(&buf), "test", "10.0.0.5:10012") + require.Contains(t, buf.String(), "level=WARN") + require.Contains(t, buf.String(), "private/link-local") + require.Contains(t, buf.String(), "10.0.0.5:10012") +} + +func TestStartupBindWarning_PrivateV4_172(t *testing.T) { + var buf bytes.Buffer + StartupBindWarning(captureLogger(&buf), "test", "172.20.3.4:10012") + require.Contains(t, buf.String(), "level=WARN") + require.Contains(t, buf.String(), "private/link-local") +} + +func TestStartupBindWarning_PrivateV6_ULA(t *testing.T) { + var buf bytes.Buffer + StartupBindWarning(captureLogger(&buf), "test", "[fc00::1]:10012") + require.Contains(t, buf.String(), "level=WARN") + require.Contains(t, buf.String(), "private/link-local") +} + +func TestStartupBindWarning_LinkLocalV6(t *testing.T) { + var buf bytes.Buffer + // Using a bare link-local literal: net.SplitHostPort accepts zone ids + // but net.ParseIP does not, so we use the zone-less form here. + StartupBindWarning(captureLogger(&buf), "test", "[fe80::1]:10012") + require.Contains(t, buf.String(), "level=WARN") + require.Contains(t, buf.String(), "private/link-local") +} + +func TestStartupBindWarning_MalformedAddr(t *testing.T) { + var buf bytes.Buffer + // SplitHostPort fails: we log at INFO so the operator sees the helper + // bailed on their (likely mistyped) address rather than silently + // making no statement. + StartupBindWarning(captureLogger(&buf), "test", "not-a-valid-addr") + out := buf.String() + require.Contains(t, out, "level=INFO") + require.Contains(t, out, "could not parse address") + require.Contains(t, out, "not-a-valid-addr") + require.NotContains(t, out, "level=WARN") +} + +func TestWriteInternalError_BodyFormat(t *testing.T) { + rr := httptest.NewRecorder() + WriteInternalError(context.Background(), rr, discardLogger(), errors.New("inner")) + + require.Equal(t, http.StatusInternalServerError, rr.Code) + require.Equal(t, "text/plain; charset=utf-8", rr.Header().Get("Content-Type")) + require.Equal(t, "Internal server error (request_id=)\n", rr.Body.String()) +} + +func TestWriteInternalError_BodyFormatWithID(t *testing.T) { + rr := httptest.NewRecorder() + ctx := context.WithValue(context.Background(), ctxKeyRequestID{}, "abc-123") + WriteInternalError(ctx, rr, discardLogger(), errors.New("inner")) + + require.Equal(t, "Internal server error (request_id=abc-123)\n", rr.Body.String()) +} + +func TestWriteInternalError_NeverLeaksErrText(t *testing.T) { + rr := httptest.NewRecorder() + WriteInternalError(context.Background(), rr, discardLogger(), errors.New("secret-detail-12345")) + + require.NotContains(t, rr.Body.String(), "secret-detail-12345") +} + +func TestWriteInternalError_LogsDetail(t *testing.T) { + var buf bytes.Buffer + rr := httptest.NewRecorder() + ctx := context.WithValue(context.Background(), ctxKeyRequestID{}, "log-id-9") + WriteInternalError(ctx, rr, captureLogger(&buf), errors.New("secret-detail-12345")) + + logged := buf.String() + require.Contains(t, logged, "http internal error") + require.Contains(t, logged, "secret-detail-12345") + require.Contains(t, logged, "log-id-9") +} + +func TestWriteInternalError_StatusAndContentType(t *testing.T) { + rr := httptest.NewRecorder() + WriteInternalError(context.Background(), rr, discardLogger(), errors.New("x")) + + require.Equal(t, http.StatusInternalServerError, rr.Code) + require.Equal(t, "text/plain; charset=utf-8", rr.Header().Get("Content-Type")) + // Defense-in-depth: match http.Error's behaviour so content-sniffing + // clients cannot reinterpret the text/plain body as HTML. + require.Equal(t, "nosniff", rr.Header().Get("X-Content-Type-Options")) + // Sanity: body starts with the canonical prefix. + require.True(t, strings.HasPrefix(rr.Body.String(), "Internal server error (request_id=")) +} + +func TestRequestIDFromContext_Empty(t *testing.T) { + require.Equal(t, "", RequestIDFromContext(context.Background())) +} + +func TestRequestIDFromContext_Set(t *testing.T) { + ctx := context.WithValue(context.Background(), ctxKeyRequestID{}, "xyz") + require.Equal(t, "xyz", RequestIDFromContext(ctx)) +} From 24df2afb7a063f42668d9e654bd680d677be222c Mon Sep 17 00:00:00 2001 From: Victor Fusco <1221933+vfusco@users.noreply.github.com> Date: Tue, 14 Apr 2026 18:55:05 -0300 Subject: [PATCH 02/17] feat(service): add HTTP middleware for recover and request-id --- go.mod | 2 +- pkg/service/http_middleware.go | 162 ++++++++++++ pkg/service/http_middleware_test.go | 396 ++++++++++++++++++++++++++++ 3 files changed, 559 insertions(+), 1 deletion(-) create mode 100644 pkg/service/http_middleware.go create mode 100644 pkg/service/http_middleware_test.go diff --git a/go.mod b/go.mod index 4948c424f..643857e04 100644 --- a/go.mod +++ b/go.mod @@ -19,6 +19,7 @@ require ( github.com/deepmap/oapi-codegen/v2 v2.2.0 github.com/go-jet/jet/v2 v2.14.1 github.com/golang-migrate/migrate/v4 v4.19.1 + github.com/google/uuid v1.6.0 github.com/hashicorp/go-retryablehttp v0.7.8 github.com/jackc/pgx/v5 v5.8.0 github.com/lmittmann/tint v1.1.3 @@ -66,7 +67,6 @@ require ( github.com/go-openapi/swag v0.23.0 // indirect github.com/go-sql-driver/mysql v1.9.3 // indirect github.com/go-viper/mapstructure/v2 v2.5.0 // indirect - github.com/google/uuid v1.6.0 // indirect github.com/gorilla/websocket v1.5.3 // indirect github.com/hashicorp/go-cleanhttp v0.5.2 // indirect github.com/holiman/uint256 v1.3.2 // indirect diff --git a/pkg/service/http_middleware.go b/pkg/service/http_middleware.go new file mode 100644 index 000000000..a4f5275bb --- /dev/null +++ b/pkg/service/http_middleware.go @@ -0,0 +1,162 @@ +// (c) Cartesi and individual authors (see AUTHORS) +// SPDX-License-Identifier: Apache-2.0 (see LICENSE) + +package service + +import ( + "bufio" + "context" + "fmt" + "log/slog" + "net" + "net/http" + "regexp" + "runtime/debug" + + "github.com/google/uuid" +) + +const ( + requestIDHeader = "X-Request-ID" +) + +// requestIDPattern defines the accepted charset and length for +// upstream-supplied request ids. Anything outside this set is treated as +// untrusted and a fresh UUID is generated instead. +// +// The charset is deliberately chosen to cover the ID formats emitted by +// common reverse proxies and tracing systems while remaining safe to log +// and echo in response headers: +// +// - envoy uses `.` and `:` +// - AWS ALB and X-Ray use `-` and `=` +// - GCP Cloud Trace uses `/` +// - `+` appears in some base64-style correlation IDs +// +// Explicitly excluded: `\r`, `\n`, space, `<`, `>`, `"`, `'`, backtick, +// and all other control characters — these would enable log-injection or +// header-splitting if echoed back verbatim. The `{1,128}` quantifier is +// the single source of truth for the length bound. +var requestIDPattern = regexp.MustCompile(`^[A-Za-z0-9._:=/+-]{1,128}$`) + +// RequestIDMiddleware reads the X-Request-ID request header and trusts it +// only when it matches ^[A-Za-z0-9._:=/+-]{1,128}$. Otherwise a fresh +// UUIDv4 is generated. The chosen id is placed on the request context +// under ctxKeyRequestID{} and echoed on the response as X-Request-ID. +func RequestIDMiddleware(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + id := r.Header.Get(requestIDHeader) + if !requestIDPattern.MatchString(id) { + id = uuid.NewString() + } + w.Header().Set(requestIDHeader, id) + ctx := context.WithValue(r.Context(), ctxKeyRequestID{}, id) + next.ServeHTTP(w, r.WithContext(ctx)) + }) +} + +// responseWriterTap wraps an http.ResponseWriter and records whether any +// header or body bytes have been sent. It implements Unwrap so that helpers +// such as http.MaxBytesReader and http.ResponseController can walk the +// wrapper chain to reach the underlying *response and access its internal +// interfaces (for example, to force a connection close after 413). +// +// Without Unwrap, oversized POST bodies still return 413 but the underlying +// connection is not force-closed and clients may see protocol confusion on a +// subsequent pipelined request. Do not remove Unwrap. +type responseWriterTap struct { + http.ResponseWriter + wroteHeader bool +} + +func (t *responseWriterTap) WriteHeader(code int) { + t.wroteHeader = true + t.ResponseWriter.WriteHeader(code) +} + +func (t *responseWriterTap) Write(b []byte) (int, error) { + t.wroteHeader = true + return t.ResponseWriter.Write(b) +} + +func (t *responseWriterTap) Unwrap() http.ResponseWriter { + return t.ResponseWriter +} + +// Flush forwards to the wrapped writer's Flusher implementation if it +// supports streaming. Because responseWriterTap embeds the +// http.ResponseWriter interface (not a concrete type), Go's method +// promotion only surfaces Header/Write/WriteHeader — optional interfaces +// like http.Flusher are not promoted and a direct +// w.(http.Flusher).Flush() on the tap would fail the type assertion even +// when the underlying writer supports it. Forward explicitly so handlers +// can stream through RecoverMiddleware. +// +// http.NewResponseController(w).Flush() also works for the same purpose +// because it walks Unwrap() — use that path when you can't rely on a +// type assertion. +func (t *responseWriterTap) Flush() { + if f, ok := t.ResponseWriter.(http.Flusher); ok { + f.Flush() + } +} + +// Hijack forwards to the wrapped writer's Hijacker implementation if any. +// Required for any future websocket or connection-takeover use case; +// today no inspect or jsonrpc handler hijacks, but having the forwarding +// in place avoids a silent failure the moment one is added. Returns +// http.ErrNotSupported when the underlying writer does not implement +// http.Hijacker (for example, httptest.ResponseRecorder). +func (t *responseWriterTap) Hijack() (net.Conn, *bufio.ReadWriter, error) { + if h, ok := t.ResponseWriter.(http.Hijacker); ok { + return h.Hijack() + } + return nil, nil, http.ErrNotSupported +} + +// RecoverMiddleware catches panics from downstream handlers. If nothing has +// been written yet it emits a generic 500 via WriteInternalError. If bytes +// are already on the wire it re-panics with http.ErrAbortHandler to drop the +// connection silently — stitching a 500 onto a partial 200 would produce a +// corrupt response and a "superfluous WriteHeader" warning. +// +// Panics whose value is http.ErrAbortHandler are re-panicked unchanged so +// that the stdlib server preserves its special (non-logged) semantics. +// +// When RecoverMiddleware wraps RequestIDMiddleware (Recover outermost), the +// deferred recovery cannot read the request id from r.Context() because r is +// the pre-wrap request from this closure — downstream middleware creates a +// new *http.Request via WithContext and passes it to its next.ServeHTTP, but +// the outer r is untouched. The request id is instead read from the tap's +// response header, which RequestIDMiddleware populates on the shared +// ResponseWriter before calling next, so it is always set by the time a +// downstream panic reaches the defer. +func RecoverMiddleware(logger *slog.Logger) func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + tap := &responseWriterTap{ResponseWriter: w} + defer func() { + rec := recover() + if rec == nil { + return + } + if rec == http.ErrAbortHandler { + panic(rec) + } + reqID := tap.Header().Get(requestIDHeader) + logger.Error("http handler panic", + "panic", rec, + "stack", string(debug.Stack()), + "request_id", reqID, + ) + if !tap.wroteHeader { + ctx := context.WithValue(r.Context(), ctxKeyRequestID{}, reqID) + WriteInternalError(ctx, tap, logger, fmt.Errorf("panic in handler: %v", rec)) + return + } + panic(http.ErrAbortHandler) + }() + next.ServeHTTP(tap, r) + }) + } +} diff --git a/pkg/service/http_middleware_test.go b/pkg/service/http_middleware_test.go new file mode 100644 index 000000000..0225c2e1b --- /dev/null +++ b/pkg/service/http_middleware_test.go @@ -0,0 +1,396 @@ +// (c) Cartesi and individual authors (see AUTHORS) +// SPDX-License-Identifier: Apache-2.0 (see LICENSE) + +package service + +import ( + "bytes" + "errors" + "io" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/google/uuid" + "github.com/stretchr/testify/require" +) + +// runMiddleware wires a middleware in front of a handler and records the +// response via httptest.NewRecorder. +func runMiddleware(t *testing.T, mw func(http.Handler) http.Handler, h http.Handler, req *http.Request) *httptest.ResponseRecorder { + t.Helper() + rr := httptest.NewRecorder() + mw(h).ServeHTTP(rr, req) + return rr +} + +// ----------------------------------------------------------------------------- +// RequestIDMiddleware +// ----------------------------------------------------------------------------- + +func TestRequestIDMiddleware_GeneratesWhenMissing(t *testing.T) { + var capturedID string + h := http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) { + capturedID = RequestIDFromContext(r.Context()) + }) + + req := httptest.NewRequest(http.MethodGet, "/", nil) + rr := runMiddleware(t, RequestIDMiddleware, h, req) + + require.NotEmpty(t, capturedID) + require.Equal(t, capturedID, rr.Header().Get("X-Request-ID")) + + _, err := uuid.Parse(capturedID) + require.NoError(t, err, "generated id should parse as UUID") +} + +func TestRequestIDMiddleware_AcceptsValid(t *testing.T) { + // Pin the full accepted charset. Each entry represents a real-world + // upstream format we must preserve end-to-end for correlation: + // - "abc_123-xyz" — legacy underscore/hyphen + // - "abc.def.123" — envoy-style dotted id + // - "a:b:c" — envoy host:port:id + // - "trace=1-2-3" — AWS X-Ray style with '=' + // - "projects/foo/traces/bar" — GCP Cloud Trace path + // - "trace+span" — base64-ish '+' + cases := []string{ + "abc_123-xyz", + "abc.def.123", + "a:b:c", + "trace=1-2-3", + "projects/foo/traces/bar", + "trace+span", + } + for _, valid := range cases { + t.Run(valid, func(t *testing.T) { + var capturedID string + h := http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) { + capturedID = RequestIDFromContext(r.Context()) + }) + + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.Header.Set("X-Request-ID", valid) + rr := runMiddleware(t, RequestIDMiddleware, h, req) + + require.Equal(t, valid, capturedID) + require.Equal(t, valid, rr.Header().Get("X-Request-ID")) + }) + } +} + +func TestRequestIDMiddleware_RejectsTooLong(t *testing.T) { + long := strings.Repeat("a", 129) + var capturedID string + h := http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) { + capturedID = RequestIDFromContext(r.Context()) + }) + + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.Header.Set("X-Request-ID", long) + runMiddleware(t, RequestIDMiddleware, h, req) + + require.NotEqual(t, long, capturedID) + _, err := uuid.Parse(capturedID) + require.NoError(t, err, "regenerated id should parse as UUID") +} + +func TestRequestIDMiddleware_RejectsBadChars(t *testing.T) { + // Each case must be regenerated as a fresh UUID. Keep the charset + // exclusion list tight: anything that could enable log-injection, + // header-splitting, or HTML/JS smuggling when echoed back in logs or + // on the X-Request-ID response header. + cases := map[string]string{ + "semicolon": "foo;bar", + "space": "id with space", + "newline": "id\nnewline", + "carriage": "id\rcr", + "tab": "foo\tbar", + "angle": "