diff --git a/docs/http-posture.md b/docs/http-posture.md new file mode 100644 index 000000000..ae9e601f5 --- /dev/null +++ b/docs/http-posture.md @@ -0,0 +1,364 @@ +# HTTP Deployment Posture + +**Audience:** operators running the Cartesi rollups-node in production. +**Applies to:** releases that include the HTTP hardening package (2.0.0-alpha.12 and later). + +This document describes the HTTP-facing surfaces of the node, how they are +protected in-process, and the deployment posture operators are expected to +provide around them. The node assumes a **trusted network boundary** — the +in-process controls are a defense-in-depth layer, not a substitute for +operator-side network policy. + +## The three HTTP surfaces + +| Surface | Default address | Purpose | Per-request cost | +| --- | --- | --- | --- | +| **Telemetry** (`/livez`, `/readyz`) | `:10000` | Orchestrator health checks | Trivial — a boolean check and a short response | +| **JSON-RPC API** (`/rpc`) | `:10011` | Read-only query interface | Up to 1 MiB body, DB queries, list responses up to 10000 items | +| **Inspect** (`/inspect/{dapp}`) | `:10012` | Machine state query without advancing | Up to 2 MiB body, Cartesi Machine fork + execution | + +Telemetry is cheap by design — orchestrators (Kubernetes, Docker, +systemd health checks) hammer it intentionally. JSON-RPC and inspect +are the expensive surfaces and are the targets of the hardening package. + +## Bind defaults + +All three services bind to `:PORT` (all interfaces) by default. This is +required for Docker container port publishing and for typical reverse-proxy +front-ending patterns where the node listens on an in-container interface +the proxy can reach. + +**On startup, the node logs a warning** for every HTTP service that binds +to an unspecified address (`:PORT`, `0.0.0.0:PORT`, or `[::]:PORT`): + +```text +WRN HTTP service bound to all interfaces; restrict access via firewall or reverse proxy + service=inspect addr=:10012 +``` + +This warning is expected under Docker and Compose. In bare-metal +deployments without a reverse proxy, consider overriding the addresses to +loopback only: + +```bash +CARTESI_INSPECT_ADDRESS=127.0.0.1:10012 +CARTESI_JSONRPC_API_ADDRESS=127.0.0.1:10011 +CARTESI_TELEMETRY_ADDRESS=127.0.0.1:10000 +``` + +## Recommended deployment posture + +The node's HTTP surfaces are designed to sit **behind a reverse proxy** +(nginx, Caddy, Traefik, Envoy, or a cloud API gateway). The proxy is +expected to provide: + +- **TLS termination.** The node speaks plain HTTP only. +- **Per-IP rate limiting.** The node has no IP-awareness; the proxy's + `limit_req`/`limit_conn`-equivalent primitives own that responsibility. +- **Authentication and authorization.** The node's HTTP endpoints are + unauthenticated; do not expose them to untrusted clients without a + proxy that enforces auth. +- **Connection and request caps at the network layer.** The in-process + admission control (described below) is a second line of defense, not + a replacement for proxy-level limits. + +For internal-only deployments without external exposure: + +- Bind to loopback or a private-network interface. +- Use a firewall (`iptables`, `nftables`, cloud security groups) to + restrict source addresses. +- Treat `inspect` as a sensitive execution surface — it can run machine + code on demand and should stay internal whenever possible. + +**Browser exposure.** CORS is **disabled by default** on both JSON-RPC +and inspect. No `Access-Control-Allow-Origin` header is emitted unless +the operator explicitly configures an origin allowlist via +`CARTESI_JSONRPC_CORS_ALLOWED_ORIGINS` or +`CARTESI_INSPECT_CORS_ALLOWED_ORIGINS`. When configured, only +exact-match origins are reflected — wildcard (`*`) is never used. + +For production browser exposure, prefer handling CORS at the reverse +proxy (nginx, Caddy, Envoy) rather than in the node. + +Note that `http://localhost:3000` and `http://127.0.0.1:3000` are +**distinct origins** in the browser's same-origin policy. If you bind +the node to `127.0.0.1` but your frontend dev server runs at +`http://localhost:3000`, you must allowlist the origin the browser +sees (e.g. `http://localhost:3000`), not the bind address. + +CORS configuration is read at startup and is not reloadable. Changing +allowed origins requires a process restart. + +## Timeout baselines + +Each HTTP surface uses a named preset from `pkg/service` (exported as +`DefaultInspectOptions`, `DefaultTelemetryOptions`, `DefaultJSONRPCOptions`). +All five fields are set on every server. + +| Field | Inspect | Telemetry | JSON-RPC | +| --- | --- | --- | --- | +| `ReadHeaderTimeout` | 10s | 5s | 10s | +| `ReadTimeout` | 30s | 10s | 30s | +| `WriteTimeout` | 600s | 10s | 30s | +| `IdleTimeout` | 60s | 60s | 60s | +| `MaxHeaderBytes` | 64 KiB | 16 KiB | 64 KiB | + +**Inspect `WriteTimeout` = 600s** is a process-health backstop. The actual +per-request deadline is set structurally in `Inspector.ServeHTTP` as +`InspectMaxDeadline + 30s` (the 30-second headroom covers response +serialization). The HTTP `WriteTimeout` prevents leaked goroutines from +holding a connection forever but never participates in normal request +lifecycle. + +### Per-request deadline enforcement + +Each inspect request receives a `context.WithTimeout` set to the +application's `InspectMaxDeadline + 30s` (the 30 seconds cover JSON +response serialization and wire delivery). This deadline is set in +`Inspector.ServeHTTP` after resolving the application, before invoking +the Cartesi Machine. The machine manager sets a nested +`context.WithTimeout(ctx, InspectMaxDeadline)` for the machine execution +itself; Go context nesting means the shorter deadline always wins. + +This structural approach eliminates the need for coordinating +`WriteTimeout` with `InspectMaxDeadline` — operators configure per-app +deadlines without worrying about HTTP-layer timeouts. + +## Admission control + +Both expensive HTTP surfaces (JSON-RPC and inspect) have an **in-process +concurrency gate** that fails fast when the number of in-flight requests +exceeds a configured limit. This bounds goroutine count, per-request +memory, and backend contention even under a flood. + +### Configuration + +| Env var | Default | Scope | +| --- | --- | --- | +| `CARTESI_INSPECT_MAX_INFLIGHT` | `64` | Inspect service only | +| `CARTESI_JSONRPC_MAX_INFLIGHT` | `64` | JSON-RPC service only | + +Each surface has its **own independent budget**. A flood on one +surface cannot starve the other. Telemetry is never gated — its +per-request cost is too low to justify admission. + +A value of `0` disables admission on that surface. Backpressure then +falls back to: + +- Inspect: the per-application Cartesi Machine semaphore (blocking, not + fail-fast; deeper in the request path). +- JSON-RPC: the PostgreSQL connection pool (blocking). + +### Rejection semantics + +When admission rejects a request: + +```text +HTTP/1.1 503 Service Unavailable +Content-Type: text/plain; charset=utf-8 +X-Content-Type-Options: nosniff +Retry-After: 1-3 (jittered) + +service at capacity (request_id=) +``` + +- **Silent by default** — the rejection is not logged per-request + (logging every rejection during a flood would amplify the flood). +- The `Retry-After` header carries a jittered value in `[1, 3]` seconds + to desynchronize retrying clients and prevent thundering-herd pile-ups. + Clients and proxies may or may not honor it. +- Operators observe saturation via the monotonic `Rejected()` counter + exposed on the `SemaphoreAdmission` instance. Wiring into a metrics + backend is out of scope for this release; see the rollups-node bug + taxonomy for when that becomes urgent. + +Note that the 503 response is `text/plain`, not a JSON-RPC error +envelope. JSON-RPC 2.0 client SDKs will treat this as a transport-level +error, not a protocol-level one. This is deliberate: admission rejection +happens before the request reaches the JSON-RPC handler, so a transport +error is the correct signal. + +### Tuning guidance + +The defaults (64 in-flight per surface) are conservative and fit a +single-node deployment on typical hardware. Consider raising the limit +when: + +- The reverse proxy is already doing per-IP rate limiting and the + admission gate is hitting the ceiling on legitimate traffic. +- DB / machine capacity can sustain more concurrent work and the + `Rejected()` counter is non-trivial under normal load. + +Consider lowering when: + +- Memory pressure from buffered request bodies is observable. +- Concurrent machine forks are slowing legitimate inspects to the point + of cascading timeouts. + +## Request bodies + +| Surface | Cap | Enforcement | +| --- | --- | --- | +| Telemetry | — | No request body | +| JSON-RPC | **1 MiB** | `http.MaxBytesReader` in the handler | +| Inspect | **2 MiB** | `http.MaxBytesReader` in the handler (matches the Cartesi Machine CMIO RX buffer) | + +Oversized bodies are rejected with: + +```text +HTTP/1.1 413 Request Entity Too Large +Content-Type: text/plain; charset=utf-8 + +Payload too large +``` + +The connection is then force-closed by the stdlib, so clients cannot +pipeline additional requests on the same connection. This behavior +depends on the internal `responseWriterTap.Unwrap()` cooperating with +`http.MaxBytesReader`; see the hardening v3 plan for the design note. + +**Worst-case body buffer memory under saturation.** +Each admitted request pins its body buffer for the full request lifetime +(up to `InspectMaxDeadline + 30s` for inspect (typically ~210s with the default 180s deadline), 30s for JSON-RPC). At default concurrency this +means `CARTESI_INSPECT_MAX_INFLIGHT × 2 MiB = 128 MiB` for inspect and +`CARTESI_JSONRPC_MAX_INFLIGHT × 1 MiB = 64 MiB` for JSON-RPC. Operators +should size process RAM headroom accordingly, on top of machine state, +database connections, and other working memory. + +## PostgreSQL pool sizing + +The node uses `pgxpool` for database access. The pool's default maximum +connection count is `max(4, runtime.NumCPU())` unless overridden in the +connection URL. On a typical 4-core container this means 4 connections +shared across every service. + +**Rule of thumb:** set `pool_max_conns` to at least +`CARTESI_JSONRPC_MAX_INFLIGHT + steady-state writer services`. In +standalone mode the writer services (EVM Reader, Advancer, Validator, +Claimer, PRT, plus overhead) account for roughly 6-8 connections, so a +conservative floor is `64 + 8 = 72`. Override via the connection URL: + +```bash +CARTESI_DATABASE_CONNECTION="postgres://user:pass@db:5432/rollups?pool_max_conns=72&pool_max_conn_lifetime=30m" +``` + +**Fail-fast vs. fail-slow.** Admission rejects are visible and +immediate (503 with `Retry-After`). Exceeding the pool capacity is +invisible and slow — handlers that passed admission block inside +`pgxpool.Acquire` while holding admission permits, causing latency +degradation with no obvious signal to the operator. Sizing the pool to +match the admission limit prevents this silent backpressure. A future +release may emit a startup warning when the sum of admission limits +exceeds the configured pool size. + +Note that the inspect surface also queries the database (one +`GetApplication` call per request) and shares the same underlying pool, +so both hardened HTTP surfaces compete for connections. + +## Request IDs + +Every response from inspect, JSON-RPC, and telemetry includes an +`X-Request-ID` header. The middleware enforces: + +- **Validation:** upstream `X-Request-ID` is trusted only if it matches + `^[A-Za-z0-9._:=/+-]{1,128}$`. Values outside that charset or longer + than 128 characters are **discarded** and a fresh UUIDv4 is generated + in their place. This prevents log injection via `\n` / `\r` and caps + header cardinality. +- **Generation:** when no upstream ID is present (or the upstream value + was rejected), the node generates a UUIDv4 via `github.com/google/uuid`. +- **Propagation:** the chosen ID is placed on the request context and + echoed on the response `X-Request-ID` header. Error log lines from the + handler include the ID as a structured field. + +This lets operators correlate a single request across: + +- The reverse proxy's access log (if it assigns IDs upstream). +- The rollups-node's structured log output. +- The client's error response (see next section). + +## Internal errors + +When a handler hits an unexpected error that it can't express as a +domain-level status code, it responds with: + +```text +HTTP/1.1 500 Internal Server Error +Content-Type: text/plain; charset=utf-8 + +Internal server error (request_id=) +``` + +The original Go `error` value is **never** written to the response body. +Its full content — message, wrapped chain, stack if available — is +logged at ERR level with the request ID as a structured field. Operators +triage a user's 500 report by grepping logs for the request ID they +reported. + +As with admission 503s, the 500 response is `text/plain`, not a JSON-RPC +error envelope. JSON-RPC client SDKs will surface these as transport +errors, which is the correct signal for a server-side fault. + +## Panic recovery + +Every HTTP handler chain is wrapped in a panic-recovery middleware: + +- **If the handler panics before any byte has been written**, the + middleware catches the panic, logs the value and stack trace at ERR + level with the request ID, and writes a generic 500 (same format as + "Internal errors" above). +- **If the handler panics after bytes have been flushed**, the + middleware cannot safely write a 500 without producing a corrupt + response (stitching a `500 Internal Server Error` onto a started 200 + would lie to the client and trigger Go's "superfluous WriteHeader" + warning). Instead, the middleware re-panics with `http.ErrAbortHandler` + — the stdlib's documented sentinel for "abort this connection + silently". The client observes a truncated response and connection + drop, which is the honest signal. +- **Panics whose value is already `http.ErrAbortHandler` are re-panicked + unchanged.** This preserves the stdlib contract for handlers that + intentionally use the sentinel to abort without logging. + +## Known limitations + +1. **Two-layer admission on inspect.** Inspect requests pass through two + independent concurrency gates: the HTTP-global admission gate (default + 64 in-flight, configured via `CARTESI_INSPECT_MAX_INFLIGHT`) and a + per-application machine semaphore (`MaxConcurrentInspects`, default + 10). Both gates are fail-fast (`TryAcquire`): when either is full the + request is rejected immediately with 503 and the caller's admission + permit is released. This means one saturated application does **not** + starve others — its excess requests fail at the per-app gate and free + HTTP-global capacity for other apps. Operators should be aware that + both layers return 503; the HTTP-global gate includes a `Retry-After` + header while the per-app gate does not. + +2. **Single-replica assumption.** The admission budget (default 64 per + surface) is per-process. Multi-replica deployments behind a load + balancer have an effective budget of `replicas × CARTESI_*_MAX_INFLIGHT`. + The default 64 is sized for a single-node deployment on typical + hardware; operators should not assume it represents a global limit. + +## Non-goals + +The following are explicitly **out of scope** for the HTTP hardening +package. If you need them, add them at the reverse proxy or via +follow-up work. + +- TLS termination inside the node. +- Authentication / authorization on HTTP endpoints. +- Per-IP rate limiting. +- Per-application fairness inside the admission gate (the gate is + global per HTTP surface). +- Global cross-service admission (inspect and JSON-RPC have independent + budgets). +- Admission on telemetry. +- Wiring `SemaphoreAdmission.Rejected()` into a metrics backend. +- Flipping bind defaults to loopback. +- Exposing `net/http/pprof`. + +## Related documentation + +- `docs/config.md` — generated reference for every `CARTESI_*` + environment variable, including `CARTESI_INSPECT_MAX_INFLIGHT` and + `CARTESI_JSONRPC_MAX_INFLIGHT`. diff --git a/go.mod b/go.mod index 4948c424f..643857e04 100644 --- a/go.mod +++ b/go.mod @@ -19,6 +19,7 @@ require ( github.com/deepmap/oapi-codegen/v2 v2.2.0 github.com/go-jet/jet/v2 v2.14.1 github.com/golang-migrate/migrate/v4 v4.19.1 + github.com/google/uuid v1.6.0 github.com/hashicorp/go-retryablehttp v0.7.8 github.com/jackc/pgx/v5 v5.8.0 github.com/lmittmann/tint v1.1.3 @@ -66,7 +67,6 @@ require ( github.com/go-openapi/swag v0.23.0 // indirect github.com/go-sql-driver/mysql v1.9.3 // indirect github.com/go-viper/mapstructure/v2 v2.5.0 // indirect - github.com/google/uuid v1.6.0 // indirect github.com/gorilla/websocket v1.5.3 // indirect github.com/hashicorp/go-cleanhttp v0.5.2 // indirect github.com/holiman/uint256 v1.3.2 // indirect diff --git a/internal/advancer/service.go b/internal/advancer/service.go index 8db09aa5f..3c6d4d7c2 100644 --- a/internal/advancer/service.go +++ b/internal/advancer/service.go @@ -20,7 +20,7 @@ import ( // httpShutdownTimeout is how long to wait for in-flight inspect HTTP requests // to drain before forcibly closing the server during shutdown. -const httpShutdownTimeout = 10 * time.Second //nolint: mnd +const httpShutdownTimeout = 10 * time.Second // Service is the main advancer service that processes inputs through Cartesi machines type Service struct { @@ -30,8 +30,6 @@ type Service struct { repository AdvancerRepository machineManager manager.MachineProvider inspector *inspect.Inspector - HTTPServer *http.Server - HTTPServerFunc func() error // cleanedUp ensures HTTP server shutdown and machine manager close run // exactly once, even when Stop() is called multiple times (by the child's @@ -83,13 +81,23 @@ func Create(ctx context.Context, c *CreateInfo) (*Service, error) { // Initialize the inspect service if enabled if c.Config.FeatureInspectEnabled { - s.inspector, s.HTTPServer, s.HTTPServerFunc = inspect.NewInspector( - c.Repository, - manager, - c.Config.InspectAddress, - c.LogLevel, - c.LogColor, - ) + var admission *service.SemaphoreAdmission + if c.Config.InspectMaxInflight > 0 { + admission = service.NewSemaphoreAdmission(c.Config.InspectMaxInflight) + } + inspector, err := inspect.NewInspector(inspect.CreateInfo{ + Repository: c.Repository, + Machines: manager, + Address: c.Config.InspectAddress, + LogLevel: c.LogLevel, + LogPretty: c.LogColor, + Admission: admission, + CORSAllowedOrigins: c.Config.InspectCorsAllowedOrigins, + }) + if err != nil { + return nil, fmt.Errorf("failed to create inspect service: %w", err) + } + s.inspector = inspector } s.snapshotsDir = c.Config.SnapshotsDir @@ -137,11 +145,11 @@ func (s *Service) Stop(b bool) []error { // resources would not see IsStopping() == true. s.SetStopping() var errs []error - if s.HTTPServer != nil { + if s.inspector != nil { s.Logger.Info("Shutting down inspect HTTP server") shutdownCtx, cancel := context.WithTimeout(context.Background(), httpShutdownTimeout) defer cancel() - if err := s.HTTPServer.Shutdown(shutdownCtx); err != nil { + if err := s.inspector.Shutdown(shutdownCtx); err != nil { errs = append(errs, fmt.Errorf("failed to shutdown inspect HTTP server: %w", err)) } } @@ -154,9 +162,9 @@ func (s *Service) Stop(b bool) []error { return errs } func (s *Service) Serve() error { - if s.inspector != nil && s.HTTPServerFunc != nil { + if s.inspector != nil { go func() { - if err := s.HTTPServerFunc(); err != nil && !errors.Is(err, http.ErrServerClosed) { + if err := s.inspector.Serve(); err != nil && !errors.Is(err, http.ErrServerClosed) { s.Logger.Error("Inspect HTTP server failed — shutting down", "error", err) s.Cancel() } diff --git a/internal/config/generate/Config.toml b/internal/config/generate/Config.toml index 0fcc1fd8a..442bdfb10 100644 --- a/internal/config/generate/Config.toml +++ b/internal/config/generate/Config.toml @@ -445,6 +445,34 @@ description = """ HTTP address for the JSON-RPC API.""" used-by = ["jsonrpc", "node"] +[http.CARTESI_JSONRPC_MAX_INFLIGHT] +default = "64" +go-type = "uint64" +description = """ +Maximum number of concurrent in-flight JSON-RPC requests. +Requests beyond this limit receive HTTP 503 Service Unavailable +with Retry-After: 1. Set to 0 to disable HTTP-level admission +control.""" +used-by = ["jsonrpc", "node"] + +[http.CARTESI_JSONRPC_CORS_ALLOWED_ORIGINS] +default = "" +go-type = "string" +description = """ +Comma-separated list of allowed browser origins for the JSON-RPC API. +If empty, CORS is disabled. Origins are lowercased and validated at +startup. Example: "http://localhost:3000,https://app.example.com".""" +used-by = ["jsonrpc", "node"] + +[http.CARTESI_INSPECT_CORS_ALLOWED_ORIGINS] +default = "" +go-type = "string" +description = """ +Comma-separated list of allowed browser origins for inspect. +If empty, CORS is disabled. Origins are lowercased and validated at +startup. Example: "http://localhost:3000,https://app.example.com".""" +used-by = ["advancer", "node"] + [http.CARTESI_INSPECT_ADDRESS] default = ":10012" go-type = "string" @@ -452,6 +480,17 @@ description = """ HTTP address for inspect.""" used-by = ["advancer", "node"] +[http.CARTESI_INSPECT_MAX_INFLIGHT] +default = "64" +go-type = "uint64" +description = """ +Maximum number of concurrent in-flight HTTP inspect requests. +Requests beyond this limit receive HTTP 503 Service Unavailable +with Retry-After: 1. Set to 0 to disable HTTP-level admission +control (backpressure then falls back to the per-application +machine semaphore).""" +used-by = ["advancer", "node"] + [http.CARTESI_JSONRPC_API_URL] default = "http://localhost:10011/rpc" go-type = "string" diff --git a/internal/config/generate/code.go b/internal/config/generate/code.go index f32591338..f8f8ed97c 100644 --- a/internal/config/generate/code.go +++ b/internal/config/generate/code.go @@ -59,6 +59,11 @@ var funcMap = template.FuncMap{ "mapstructure": func(s string) string { return "`mapstructure:\"" + s + "\"`" }, + // hasEmptyStringDefault checks if a variable has default="" and is a string type. + // These vars need viper.IsSet instead of s!="" because empty string is a valid value. + "hasEmptyStringDefault": func(env Env) bool { + return env.Default != nil && *env.Default == "" && env.GoType == "string" + }, // isUsedBy checks if a variable is used by a specific service "isUsedBy": func(env Env, service string) bool { return slices.Contains(env.UsedBy, service) @@ -223,7 +228,11 @@ func Get{{ toFieldName .Name }}() ({{ .GoType }}, error) { s = strings.TrimSpace(string(contents)) } {{- end }} + {{- if hasEmptyStringDefault . }} + if viper.IsSet({{ toConstName .Name }}) { + {{- else }} if s != "" { + {{- end }} v, err := {{ toGoFunc .GoType }}(s) if err != nil { return v, fmt.Errorf("failed to parse %s: %w", {{ toConstName .Name }}, err) diff --git a/internal/config/generated.go b/internal/config/generated.go index 8e06ecd05..defadaaed 100644 --- a/internal/config/generated.go +++ b/internal/config/generated.go @@ -49,9 +49,13 @@ const ( CLAIMER_TELEMETRY_ADDRESS = "CARTESI_CLAIMER_TELEMETRY_ADDRESS" EVM_READER_TELEMETRY_ADDRESS = "CARTESI_EVM_READER_TELEMETRY_ADDRESS" INSPECT_ADDRESS = "CARTESI_INSPECT_ADDRESS" + INSPECT_CORS_ALLOWED_ORIGINS = "CARTESI_INSPECT_CORS_ALLOWED_ORIGINS" + INSPECT_MAX_INFLIGHT = "CARTESI_INSPECT_MAX_INFLIGHT" INSPECT_URL = "CARTESI_INSPECT_URL" JSONRPC_API_ADDRESS = "CARTESI_JSONRPC_API_ADDRESS" JSONRPC_API_URL = "CARTESI_JSONRPC_API_URL" + JSONRPC_CORS_ALLOWED_ORIGINS = "CARTESI_JSONRPC_CORS_ALLOWED_ORIGINS" + JSONRPC_MAX_INFLIGHT = "CARTESI_JSONRPC_MAX_INFLIGHT" JSONRPC_TELEMETRY_ADDRESS = "CARTESI_JSONRPC_TELEMETRY_ADDRESS" NODE_TELEMETRY_ADDRESS = "CARTESI_NODE_TELEMETRY_ADDRESS" PRT_TELEMETRY_ADDRESS = "CARTESI_PRT_TELEMETRY_ADDRESS" @@ -151,12 +155,20 @@ func SetDefaults() { viper.SetDefault(INSPECT_ADDRESS, ":10012") + viper.SetDefault(INSPECT_CORS_ALLOWED_ORIGINS, "") + + viper.SetDefault(INSPECT_MAX_INFLIGHT, "64") + viper.SetDefault(INSPECT_URL, "http://localhost:10012") viper.SetDefault(JSONRPC_API_ADDRESS, ":10011") viper.SetDefault(JSONRPC_API_URL, "http://localhost:10011/rpc") + viper.SetDefault(JSONRPC_CORS_ALLOWED_ORIGINS, "") + + viper.SetDefault(JSONRPC_MAX_INFLIGHT, "64") + viper.SetDefault(JSONRPC_TELEMETRY_ADDRESS, ":10005") viper.SetDefault(NODE_TELEMETRY_ADDRESS, ":10000") @@ -239,6 +251,18 @@ type AdvancerConfig struct { // HTTP address for inspect. InspectAddress string `mapstructure:"CARTESI_INSPECT_ADDRESS"` + // Comma-separated list of allowed browser origins for inspect. + // If empty, CORS is disabled. Origins are lowercased and validated at + // startup. Example: "http://localhost:3000,https://app.example.com". + InspectCorsAllowedOrigins string `mapstructure:"CARTESI_INSPECT_CORS_ALLOWED_ORIGINS"` + + // Maximum number of concurrent in-flight HTTP inspect requests. + // Requests beyond this limit receive HTTP 503 Service Unavailable + // with Retry-After: 1. Set to 0 to disable HTTP-level admission + // control (backpressure then falls back to the per-application + // machine semaphore). + InspectMaxInflight uint64 `mapstructure:"CARTESI_INSPECT_MAX_INFLIGHT"` + // If set to true, the node will add colors to its log output. LogColor bool `mapstructure:"CARTESI_LOG_COLOR"` @@ -312,6 +336,20 @@ func LoadAdvancerConfig() (*AdvancerConfig, error) { return nil, fmt.Errorf("CARTESI_INSPECT_ADDRESS is required for the advancer service: %w", err) } + cfg.InspectCorsAllowedOrigins, err = GetInspectCorsAllowedOrigins() + if err != nil && err != ErrNotDefined { + return nil, fmt.Errorf("failed to get CARTESI_INSPECT_CORS_ALLOWED_ORIGINS: %w", err) + } else if err == ErrNotDefined { + return nil, fmt.Errorf("CARTESI_INSPECT_CORS_ALLOWED_ORIGINS is required for the advancer service: %w", err) + } + + cfg.InspectMaxInflight, err = GetInspectMaxInflight() + if err != nil && err != ErrNotDefined { + return nil, fmt.Errorf("failed to get CARTESI_INSPECT_MAX_INFLIGHT: %w", err) + } else if err == ErrNotDefined { + return nil, fmt.Errorf("CARTESI_INSPECT_MAX_INFLIGHT is required for the advancer service: %w", err) + } + cfg.LogColor, err = GetLogColor() if err != nil && err != ErrNotDefined { return nil, fmt.Errorf("failed to get CARTESI_LOG_COLOR: %w", err) @@ -759,6 +797,17 @@ type JsonrpcConfig struct { // HTTP address for the JSON-RPC API. JsonrpcApiAddress string `mapstructure:"CARTESI_JSONRPC_API_ADDRESS"` + // Comma-separated list of allowed browser origins for the JSON-RPC API. + // If empty, CORS is disabled. Origins are lowercased and validated at + // startup. Example: "http://localhost:3000,https://app.example.com". + JsonrpcCorsAllowedOrigins string `mapstructure:"CARTESI_JSONRPC_CORS_ALLOWED_ORIGINS"` + + // Maximum number of concurrent in-flight JSON-RPC requests. + // Requests beyond this limit receive HTTP 503 Service Unavailable + // with Retry-After: 1. Set to 0 to disable HTTP-level admission + // control. + JsonrpcMaxInflight uint64 `mapstructure:"CARTESI_JSONRPC_MAX_INFLIGHT"` + // HTTP address for JSON-RPC's telemetry service. JsonrpcTelemetryAddress string `mapstructure:"CARTESI_JSONRPC_TELEMETRY_ADDRESS"` @@ -800,6 +849,20 @@ func LoadJsonrpcConfig() (*JsonrpcConfig, error) { return nil, fmt.Errorf("CARTESI_JSONRPC_API_ADDRESS is required for the jsonrpc service: %w", err) } + cfg.JsonrpcCorsAllowedOrigins, err = GetJsonrpcCorsAllowedOrigins() + if err != nil && err != ErrNotDefined { + return nil, fmt.Errorf("failed to get CARTESI_JSONRPC_CORS_ALLOWED_ORIGINS: %w", err) + } else if err == ErrNotDefined { + return nil, fmt.Errorf("CARTESI_JSONRPC_CORS_ALLOWED_ORIGINS is required for the jsonrpc service: %w", err) + } + + cfg.JsonrpcMaxInflight, err = GetJsonrpcMaxInflight() + if err != nil && err != ErrNotDefined { + return nil, fmt.Errorf("failed to get CARTESI_JSONRPC_MAX_INFLIGHT: %w", err) + } else if err == ErrNotDefined { + return nil, fmt.Errorf("CARTESI_JSONRPC_MAX_INFLIGHT is required for the jsonrpc service: %w", err) + } + cfg.JsonrpcTelemetryAddress, err = GetJsonrpcTelemetryAddress() if err != nil && err != ErrNotDefined { return nil, fmt.Errorf("failed to get CARTESI_JSONRPC_TELEMETRY_ADDRESS: %w", err) @@ -880,9 +943,32 @@ type NodeConfig struct { // HTTP address for inspect. InspectAddress string `mapstructure:"CARTESI_INSPECT_ADDRESS"` + // Comma-separated list of allowed browser origins for inspect. + // If empty, CORS is disabled. Origins are lowercased and validated at + // startup. Example: "http://localhost:3000,https://app.example.com". + InspectCorsAllowedOrigins string `mapstructure:"CARTESI_INSPECT_CORS_ALLOWED_ORIGINS"` + + // Maximum number of concurrent in-flight HTTP inspect requests. + // Requests beyond this limit receive HTTP 503 Service Unavailable + // with Retry-After: 1. Set to 0 to disable HTTP-level admission + // control (backpressure then falls back to the per-application + // machine semaphore). + InspectMaxInflight uint64 `mapstructure:"CARTESI_INSPECT_MAX_INFLIGHT"` + // HTTP address for the JSON-RPC API. JsonrpcApiAddress string `mapstructure:"CARTESI_JSONRPC_API_ADDRESS"` + // Comma-separated list of allowed browser origins for the JSON-RPC API. + // If empty, CORS is disabled. Origins are lowercased and validated at + // startup. Example: "http://localhost:3000,https://app.example.com". + JsonrpcCorsAllowedOrigins string `mapstructure:"CARTESI_JSONRPC_CORS_ALLOWED_ORIGINS"` + + // Maximum number of concurrent in-flight JSON-RPC requests. + // Requests beyond this limit receive HTTP 503 Service Unavailable + // with Retry-After: 1. Set to 0 to disable HTTP-level admission + // control. + JsonrpcMaxInflight uint64 `mapstructure:"CARTESI_JSONRPC_MAX_INFLIGHT"` + // HTTP address for Node's telemetry service. NodeTelemetryAddress string `mapstructure:"CARTESI_NODE_TELEMETRY_ADDRESS"` @@ -1038,6 +1124,20 @@ func LoadNodeConfig() (*NodeConfig, error) { return nil, fmt.Errorf("CARTESI_INSPECT_ADDRESS is required for the node service: %w", err) } + cfg.InspectCorsAllowedOrigins, err = GetInspectCorsAllowedOrigins() + if err != nil && err != ErrNotDefined { + return nil, fmt.Errorf("failed to get CARTESI_INSPECT_CORS_ALLOWED_ORIGINS: %w", err) + } else if err == ErrNotDefined { + return nil, fmt.Errorf("CARTESI_INSPECT_CORS_ALLOWED_ORIGINS is required for the node service: %w", err) + } + + cfg.InspectMaxInflight, err = GetInspectMaxInflight() + if err != nil && err != ErrNotDefined { + return nil, fmt.Errorf("failed to get CARTESI_INSPECT_MAX_INFLIGHT: %w", err) + } else if err == ErrNotDefined { + return nil, fmt.Errorf("CARTESI_INSPECT_MAX_INFLIGHT is required for the node service: %w", err) + } + cfg.JsonrpcApiAddress, err = GetJsonrpcApiAddress() if err != nil && err != ErrNotDefined { return nil, fmt.Errorf("failed to get CARTESI_JSONRPC_API_ADDRESS: %w", err) @@ -1045,6 +1145,20 @@ func LoadNodeConfig() (*NodeConfig, error) { return nil, fmt.Errorf("CARTESI_JSONRPC_API_ADDRESS is required for the node service: %w", err) } + cfg.JsonrpcCorsAllowedOrigins, err = GetJsonrpcCorsAllowedOrigins() + if err != nil && err != ErrNotDefined { + return nil, fmt.Errorf("failed to get CARTESI_JSONRPC_CORS_ALLOWED_ORIGINS: %w", err) + } else if err == ErrNotDefined { + return nil, fmt.Errorf("CARTESI_JSONRPC_CORS_ALLOWED_ORIGINS is required for the node service: %w", err) + } + + cfg.JsonrpcMaxInflight, err = GetJsonrpcMaxInflight() + if err != nil && err != ErrNotDefined { + return nil, fmt.Errorf("failed to get CARTESI_JSONRPC_MAX_INFLIGHT: %w", err) + } else if err == ErrNotDefined { + return nil, fmt.Errorf("CARTESI_JSONRPC_MAX_INFLIGHT is required for the node service: %w", err) + } + cfg.NodeTelemetryAddress, err = GetNodeTelemetryAddress() if err != nil && err != ErrNotDefined { return nil, fmt.Errorf("failed to get CARTESI_NODE_TELEMETRY_ADDRESS: %w", err) @@ -1459,6 +1573,8 @@ func (c *NodeConfig) ToAdvancerConfig() *AdvancerConfig { FeatureInspectEnabled: c.FeatureInspectEnabled, FeatureMachineHashCheckEnabled: c.FeatureMachineHashCheckEnabled, InspectAddress: c.InspectAddress, + InspectCorsAllowedOrigins: c.InspectCorsAllowedOrigins, + InspectMaxInflight: c.InspectMaxInflight, LogColor: c.LogColor, LogLevel: c.LogLevel, JsonrpcMachineLogLevel: c.JsonrpcMachineLogLevel, @@ -1514,11 +1630,13 @@ func (c *NodeConfig) ToEvmreaderConfig() *EvmreaderConfig { // ToJsonrpcConfig converts a NodeConfig to a JsonrpcConfig. func (c *NodeConfig) ToJsonrpcConfig() *JsonrpcConfig { return &JsonrpcConfig{ - DatabaseConnection: c.DatabaseConnection, - JsonrpcApiAddress: c.JsonrpcApiAddress, - LogColor: c.LogColor, - LogLevel: c.LogLevel, - MaxStartupTime: c.MaxStartupTime, + DatabaseConnection: c.DatabaseConnection, + JsonrpcApiAddress: c.JsonrpcApiAddress, + JsonrpcCorsAllowedOrigins: c.JsonrpcCorsAllowedOrigins, + JsonrpcMaxInflight: c.JsonrpcMaxInflight, + LogColor: c.LogColor, + LogLevel: c.LogLevel, + MaxStartupTime: c.MaxStartupTime, } } @@ -1953,6 +2071,32 @@ func GetInspectAddress() (string, error) { return notDefinedstring(), fmt.Errorf("%s: %w", INSPECT_ADDRESS, ErrNotDefined) } +// GetInspectCorsAllowedOrigins returns the value for the environment variable CARTESI_INSPECT_CORS_ALLOWED_ORIGINS. +func GetInspectCorsAllowedOrigins() (string, error) { + s := viper.GetString(INSPECT_CORS_ALLOWED_ORIGINS) + if viper.IsSet(INSPECT_CORS_ALLOWED_ORIGINS) { + v, err := toString(s) + if err != nil { + return v, fmt.Errorf("failed to parse %s: %w", INSPECT_CORS_ALLOWED_ORIGINS, err) + } + return v, nil + } + return notDefinedstring(), fmt.Errorf("%s: %w", INSPECT_CORS_ALLOWED_ORIGINS, ErrNotDefined) +} + +// GetInspectMaxInflight returns the value for the environment variable CARTESI_INSPECT_MAX_INFLIGHT. +func GetInspectMaxInflight() (uint64, error) { + s := viper.GetString(INSPECT_MAX_INFLIGHT) + if s != "" { + v, err := toUint64(s) + if err != nil { + return v, fmt.Errorf("failed to parse %s: %w", INSPECT_MAX_INFLIGHT, err) + } + return v, nil + } + return notDefineduint64(), fmt.Errorf("%s: %w", INSPECT_MAX_INFLIGHT, ErrNotDefined) +} + // GetInspectUrl returns the value for the environment variable CARTESI_INSPECT_URL. func GetInspectUrl() (string, error) { s := viper.GetString(INSPECT_URL) @@ -1992,6 +2136,32 @@ func GetJsonrpcApiUrl() (string, error) { return notDefinedstring(), fmt.Errorf("%s: %w", JSONRPC_API_URL, ErrNotDefined) } +// GetJsonrpcCorsAllowedOrigins returns the value for the environment variable CARTESI_JSONRPC_CORS_ALLOWED_ORIGINS. +func GetJsonrpcCorsAllowedOrigins() (string, error) { + s := viper.GetString(JSONRPC_CORS_ALLOWED_ORIGINS) + if viper.IsSet(JSONRPC_CORS_ALLOWED_ORIGINS) { + v, err := toString(s) + if err != nil { + return v, fmt.Errorf("failed to parse %s: %w", JSONRPC_CORS_ALLOWED_ORIGINS, err) + } + return v, nil + } + return notDefinedstring(), fmt.Errorf("%s: %w", JSONRPC_CORS_ALLOWED_ORIGINS, ErrNotDefined) +} + +// GetJsonrpcMaxInflight returns the value for the environment variable CARTESI_JSONRPC_MAX_INFLIGHT. +func GetJsonrpcMaxInflight() (uint64, error) { + s := viper.GetString(JSONRPC_MAX_INFLIGHT) + if s != "" { + v, err := toUint64(s) + if err != nil { + return v, fmt.Errorf("failed to parse %s: %w", JSONRPC_MAX_INFLIGHT, err) + } + return v, nil + } + return notDefineduint64(), fmt.Errorf("%s: %w", JSONRPC_MAX_INFLIGHT, ErrNotDefined) +} + // GetJsonrpcTelemetryAddress returns the value for the environment variable CARTESI_JSONRPC_TELEMETRY_ADDRESS. func GetJsonrpcTelemetryAddress() (string, error) { s := viper.GetString(JSONRPC_TELEMETRY_ADDRESS) diff --git a/internal/inspect/hardening_test.go b/internal/inspect/hardening_test.go new file mode 100644 index 000000000..63e107465 --- /dev/null +++ b/internal/inspect/hardening_test.go @@ -0,0 +1,548 @@ +// (c) Cartesi and individual authors (see AUTHORS) +// SPDX-License-Identifier: Apache-2.0 (see LICENSE) + +package inspect + +import ( + "bytes" + "context" + "errors" + "fmt" + "io" + "log/slog" + "net" + "net/http" + "net/http/httptest" + "strconv" + "strings" + "testing" + "time" + + "github.com/cartesi/rollups-node/internal/manager" + . "github.com/cartesi/rollups-node/internal/model" + "github.com/cartesi/rollups-node/pkg/service" + + "github.com/stretchr/testify/require" +) + +// newInspectorForTest constructs an Inspector via NewInspector with a mock +// repository and machine set, but returns only the http.Handler — the real +// http.Server is never started. Tests exercise it with httptest. +// +// When machineErr is non-nil, the mock machine's Inspect call returns it, +// so tests can cover the internal-error path. +func newInspectorForTest(t *testing.T, machineErr error) (*Inspector, *Application) { + t.Helper() + + app := &Application{ + ID: 1, + IApplicationAddress: randomAddress(), + Name: "test-app", + ExecutionParameters: ExecutionParameters{ + InspectMaxDeadline: 10 * time.Second, + }, + } + repo := newMockRepository() + repo.apps = append(repo.apps, app) + + mm := newMockMachines() + mm.Map[1] = MockMachine{application: app} + + insp, err := NewInspector(CreateInfo{ + Repository: repo, + Machines: hardeningMachines{MachinesMock: mm, err: machineErr}, + Address: "127.0.0.1:0", + LogLevel: slog.LevelError, + LogPretty: false, + }) + require.NoError(t, err) + return insp, app +} + +// hardeningMachines wraps MachinesMock so an injected error propagates out +// of the MockMachine.Inspect call used by tests. The underlying MockMachine +// never returns an error on its own. +type hardeningMachines struct { + *MachinesMock + err error +} + +func (h hardeningMachines) GetMachine(appID int64) (manager.MachineInstance, bool) { + inst, ok := h.MachinesMock.GetMachine(appID) + if !ok { + return nil, false + } + if h.err != nil { + return &erroringMachine{inner: inst, err: h.err}, true + } + return inst, true +} + +type erroringMachine struct { + inner manager.MachineInstance + err error +} + +func (m *erroringMachine) Inspect(_ context.Context, _ []byte) (*InspectResult, error) { + if m.err == errPanicSentinel { + panic("boom-from-machine") + } + return nil, m.err +} + +// errPanicSentinel is a marker value understood by erroringMachine.Inspect: +// when passed as the machineErr to newInspectorForTest, the mock machine +// panics instead of returning an error, so tests can exercise the +// RecoverMiddleware branch without monkey-patching the inspect handler. +var errPanicSentinel = errors.New("sentinel: make machine panic") + +// Forward the rest to the inner machine so the type satisfies the interface +// without reimplementing the stubs. Tests only reach Inspect. +func (m *erroringMachine) Advance(ctx context.Context, input []byte, a, b uint64, c bool) (*AdvanceResult, error) { + return m.inner.Advance(ctx, input, a, b, c) +} +func (m *erroringMachine) Application() *Application { return m.inner.Application() } +func (m *erroringMachine) ProcessedInputs() uint64 { return m.inner.ProcessedInputs() } +func (m *erroringMachine) OutputsProof(ctx context.Context) (*OutputsProof, error) { + return m.inner.OutputsProof(ctx) +} +func (m *erroringMachine) Synchronize(ctx context.Context, repo manager.MachineRepository, batchSize uint64) error { + return m.inner.Synchronize(ctx, repo, batchSize) +} +func (m *erroringMachine) CreateSnapshot(ctx context.Context, processedInputs uint64, path string) error { + return m.inner.CreateSnapshot(ctx, processedInputs, path) +} +func (m *erroringMachine) Hash(ctx context.Context) ([32]byte, error) { return m.inner.Hash(ctx) } +func (m *erroringMachine) Close() error { return m.inner.Close() } + +// ----------------------------------------------------------------------------- + +func TestInspector_NewWithCreateInfo(t *testing.T) { + insp, _ := newInspectorForTest(t, nil) + // Package-internal access: the hardened http.Server is unexported and + // tests pin its fields directly rather than via a public accessor. + srv := insp.server + require.NotNil(t, srv) + + opts := service.DefaultInspectOptions() + require.Equal(t, opts.ReadHeaderTimeout, srv.ReadHeaderTimeout) + require.Equal(t, opts.ReadTimeout, srv.ReadTimeout) + require.Equal(t, opts.WriteTimeout, srv.WriteTimeout) + require.Equal(t, opts.IdleTimeout, srv.IdleTimeout) + require.Equal(t, opts.MaxHeaderBytes, srv.MaxHeaderBytes) +} + +func TestInspector_NewRejectsNilMachines(t *testing.T) { + _, err := NewInspector(CreateInfo{ + Repository: newMockRepository(), + Machines: nil, + Address: "127.0.0.1:0", + }) + require.ErrorIs(t, err, ErrInvalidMachines) +} + +func TestInspector_OversizedPayloadReturns413(t *testing.T) { + insp, app := newInspectorForTest(t, nil) + body := bytes.NewReader(make([]byte, maxPayloadSize+1)) + req := httptest.NewRequest(http.MethodPost, + fmt.Sprintf("/inspect/%s", app.Name), body) + rr := httptest.NewRecorder() + insp.ServeMux.ServeHTTP(rr, req) + + require.Equal(t, http.StatusRequestEntityTooLarge, rr.Code) + require.Contains(t, rr.Body.String(), "Payload too large") +} + +func TestInspector_ExactBoundaryAccepted(t *testing.T) { + insp, app := newInspectorForTest(t, nil) + body := bytes.NewReader(make([]byte, maxPayloadSize)) + req := httptest.NewRequest(http.MethodPost, + fmt.Sprintf("/inspect/%s", app.Name), body) + rr := httptest.NewRecorder() + insp.ServeMux.ServeHTTP(rr, req) + + require.Equal(t, http.StatusOK, rr.Code, "body at exact limit must be accepted") +} + +func TestInspector_GETReturns405WithAllowHeader(t *testing.T) { + insp, app := newInspectorForTest(t, nil) + req := httptest.NewRequest(http.MethodGet, fmt.Sprintf("/inspect/%s", app.Name), nil) + rr := httptest.NewRecorder() + insp.ServeMux.ServeHTTP(rr, req) + + require.Equal(t, http.StatusMethodNotAllowed, rr.Code) + require.Equal(t, http.MethodPost, rr.Header().Get("Allow")) +} + +func TestInspector_PUTReturns405WithAllowHeader(t *testing.T) { + insp, app := newInspectorForTest(t, nil) + req := httptest.NewRequest(http.MethodPut, fmt.Sprintf("/inspect/%s", app.Name), nil) + rr := httptest.NewRecorder() + insp.ServeMux.ServeHTTP(rr, req) + + require.Equal(t, http.StatusMethodNotAllowed, rr.Code) + require.Equal(t, http.MethodPost, rr.Header().Get("Allow")) +} + +func TestInspector_InternalErrorBodyIsGeneric(t *testing.T) { + secret := errors.New("secret-credentials-xyzzy") + insp, app := newInspectorForTest(t, secret) + + req := httptest.NewRequest(http.MethodPost, + fmt.Sprintf("/inspect/%s", app.Name), + strings.NewReader("hello")) + rr := httptest.NewRecorder() + insp.ServeMux.ServeHTTP(rr, req) + + require.Equal(t, http.StatusInternalServerError, rr.Code) + require.Contains(t, rr.Body.String(), "Internal server error (request_id=") + require.NotContains(t, rr.Body.String(), "secret-credentials-xyzzy", + "internal error body must never leak err.Error()") +} + +func TestInspector_InternalErrorIncludesRequestID(t *testing.T) { + insp, app := newInspectorForTest(t, errors.New("boom")) + + req := httptest.NewRequest(http.MethodPost, + fmt.Sprintf("/inspect/%s", app.Name), + strings.NewReader("hello")) + req.Header.Set("X-Request-ID", "pinned-id-42") + rr := httptest.NewRecorder() + insp.ServeMux.ServeHTTP(rr, req) + + require.Equal(t, http.StatusInternalServerError, rr.Code) + require.Contains(t, rr.Body.String(), "request_id=pinned-id-42") + require.Equal(t, "pinned-id-42", rr.Header().Get("X-Request-ID")) +} + +// TestInspector_ChainOrder_RecoverCoversRequestID pins the customer- +// visible consequences of the chain order built by +// [service.NewServiceHandler]: a panic from deep in the handler chain +// must turn into a clean 500 Internal Server Error (not a dropped TCP +// connection), and the response must still carry X-Request-ID on the way +// out so callers and log aggregators retain correlation. +// +// The precise semantics of how RecoverMiddleware reads the request id +// from the outer request (via the shared ResponseWriter header, not +// r.Context()) are a property of [service.RecoverMiddleware] itself and +// are documented and tested there. +func TestInspector_ChainOrder_RecoverCoversRequestID(t *testing.T) { + insp, app := newInspectorForTest(t, errPanicSentinel) + + req := httptest.NewRequest(http.MethodPost, + fmt.Sprintf("/inspect/%s", app.Name), + strings.NewReader("hello")) + req.Header.Set("X-Request-ID", "chain-order-test-id") + rr := httptest.NewRecorder() + + require.NotPanics(t, func() { + insp.ServeMux.ServeHTTP(rr, req) + }, "panic in handler must be caught by RecoverMiddleware, not propagate to the test") + + require.Equal(t, http.StatusInternalServerError, rr.Code, + "Recover must turn handler panics into a 500") + require.Equal(t, "chain-order-test-id", rr.Header().Get("X-Request-ID"), + "X-Request-ID must be echoed on the 500 response so clients can correlate") +} + +func TestInspector_HappyPathStillWorks(t *testing.T) { + insp, app := newInspectorForTest(t, nil) + + req := httptest.NewRequest(http.MethodPost, + fmt.Sprintf("/inspect/%s", app.Name), + strings.NewReader("hello")) + rr := httptest.NewRecorder() + insp.ServeMux.ServeHTTP(rr, req) + + require.Equal(t, http.StatusOK, rr.Code) + require.Contains(t, rr.Body.String(), `"status":"Accepted"`) +} + +func TestInspector_EmptyDappPathReturns404(t *testing.T) { + insp, _ := newInspectorForTest(t, nil) + // An empty dapp path value does not match the /inspect/{dapp} pattern, + // so this exercises the 404 branch. Ensure the request instead just + // does not match the inspect route. + req := httptest.NewRequest(http.MethodPost, "/inspect/", strings.NewReader("x")) + rr := httptest.NewRecorder() + insp.ServeMux.ServeHTTP(rr, req) + require.Equal(t, http.StatusNotFound, rr.Code) +} + +// TestInspector_RealServer_PayloadTooLarge runs the oversized-body path +// through a real httptest.Server so the full middleware chain and +// MaxBytesReader wire up correctly end-to-end. This is the keystone test +// for the responseWriterTap.Unwrap() behaviour; if the tap ever loses +// Unwrap, MaxBytesReader silently stops enforcing connection-close and +// this test is the first thing to notice. +func TestInspector_RealServer_PayloadTooLarge(t *testing.T) { + insp, app := newInspectorForTest(t, nil) + + srv := httptest.NewServer(insp.ServeMux) + defer srv.Close() + + body := bytes.NewReader(make([]byte, maxPayloadSize+1)) + resp, err := http.Post( + fmt.Sprintf("%s/inspect/%s", srv.URL, app.Name), + "application/octet-stream", + body) + require.NoError(t, err) + defer resp.Body.Close() + + require.Equal(t, http.StatusRequestEntityTooLarge, resp.StatusCode) + respBody, _ := io.ReadAll(resp.Body) + require.Contains(t, string(respBody), "Payload too large") +} + +// ----------------------------------------------------------------------------- +// Admission control integration +// ----------------------------------------------------------------------------- + +// newInspectorWithAdmission is a variant of newInspectorForTest that also +// accepts a *service.SemaphoreAdmission to exercise the admission +// middleware in the full handler chain. +func newInspectorWithAdmission(t *testing.T, admission *service.SemaphoreAdmission) (*Inspector, *Application) { + t.Helper() + + app := &Application{ + ID: 1, + IApplicationAddress: randomAddress(), + Name: "test-app", + ExecutionParameters: ExecutionParameters{ + InspectMaxDeadline: 10 * time.Second, + }, + } + repo := newMockRepository() + repo.apps = append(repo.apps, app) + + mm := newMockMachines() + mm.Map[1] = MockMachine{application: app} + + insp, err := NewInspector(CreateInfo{ + Repository: repo, + Machines: mm, + Address: "127.0.0.1:0", + LogLevel: slog.LevelError, + LogPretty: false, + Admission: admission, + }) + require.NoError(t, err) + return insp, app +} + +func TestInspector_AdmissionAccessor(t *testing.T) { + admission := service.NewSemaphoreAdmission(1) + insp, _ := newInspectorWithAdmission(t, admission) + require.Same(t, admission, insp.Admission(), + "Admission() must return the instance passed via CreateInfo") + + inspNil, _ := newInspectorWithAdmission(t, nil) + require.Nil(t, inspNil.Admission(), + "Admission() must return nil when admission control is disabled") +} + +func TestInspector_AdmissionRejectsWhenExhausted(t *testing.T) { + // Pre-fill a single-permit admission so every subsequent request + // bounces with 503 regardless of payload shape. + admission := service.NewSemaphoreAdmission(1) + admission.TryAcquire() // pre-fill the single permit + insp, app := newInspectorWithAdmission(t, admission) + + req := httptest.NewRequest(http.MethodPost, + fmt.Sprintf("/inspect/%s", app.Name), + strings.NewReader("hello")) + rr := httptest.NewRecorder() + insp.ServeMux.ServeHTTP(rr, req) + + require.Equal(t, http.StatusServiceUnavailable, rr.Code) + retryAfter, err := strconv.Atoi(rr.Header().Get("Retry-After")) + require.NoError(t, err) + require.GreaterOrEqual(t, retryAfter, 1) + require.LessOrEqual(t, retryAfter, 3) + require.Contains(t, rr.Body.String(), "service at capacity") + require.Equal(t, "nosniff", rr.Header().Get("X-Content-Type-Options")) + require.Equal(t, uint64(1), admission.Rejected()) +} + +func TestInspector_AdmissionDisabledWhenNil(t *testing.T) { + // nil Admission should disable the gate entirely. Any request + // should reach the handler. + insp, app := newInspectorWithAdmission(t, nil) + + req := httptest.NewRequest(http.MethodPost, + fmt.Sprintf("/inspect/%s", app.Name), + strings.NewReader("hello")) + rr := httptest.NewRecorder() + insp.ServeMux.ServeHTTP(rr, req) + + require.Equal(t, http.StatusOK, rr.Code) +} + +func TestInspector_AdmissionPermitReleasedAfterRequest(t *testing.T) { + admission := service.NewSemaphoreAdmission(1) + insp, app := newInspectorWithAdmission(t, admission) + + for range 5 { + req := httptest.NewRequest(http.MethodPost, + fmt.Sprintf("/inspect/%s", app.Name), + strings.NewReader("hello")) + rr := httptest.NewRecorder() + insp.ServeMux.ServeHTTP(rr, req) + require.Equal(t, http.StatusOK, rr.Code, "sequential requests must always succeed at limit=1") + } + require.Equal(t, uint64(0), admission.Rejected()) +} + +// ----------------------------------------------------------------------------- +// Per-app capacity (ErrInspectAtCapacity) +// ----------------------------------------------------------------------------- + +func TestInspector_PerAppCapacityReturns503(t *testing.T) { + // Inject ErrInspectAtCapacity as the machine-level error to simulate a + // saturated per-app semaphore. The handler should return 503 with a + // body that says "Application inspect at capacity". + insp, app := newInspectorForTest(t, manager.ErrInspectAtCapacity) + + req := httptest.NewRequest(http.MethodPost, + fmt.Sprintf("/inspect/%s", app.Name), + strings.NewReader("hello")) + rr := httptest.NewRecorder() + insp.ServeMux.ServeHTTP(rr, req) + + require.Equal(t, http.StatusServiceUnavailable, rr.Code) + require.Contains(t, rr.Body.String(), "Application inspect at capacity") +} + +// ----------------------------------------------------------------------------- +// CORS integration +// ----------------------------------------------------------------------------- + +func newInspectorWithCORS(t *testing.T, origins string) (*Inspector, *Application) { + t.Helper() + + app := &Application{ + ID: 1, + IApplicationAddress: randomAddress(), + Name: "test-app", + } + repo := newMockRepository() + repo.apps = append(repo.apps, app) + + mm := newMockMachines() + mm.Map[1] = MockMachine{application: app} + + insp, err := NewInspector(CreateInfo{ + Repository: repo, + Machines: mm, + Address: "127.0.0.1:0", + LogLevel: slog.LevelError, + LogPretty: false, + CORSAllowedOrigins: origins, + }) + require.NoError(t, err) + return insp, app +} + +func TestInspector_CORSDisabledByDefault(t *testing.T) { + insp, app := newInspectorWithCORS(t, "") + + req := httptest.NewRequest(http.MethodPost, + fmt.Sprintf("/inspect/%s", app.Name), + strings.NewReader("hello")) + req.Header.Set("Origin", "http://evil.com") + rr := httptest.NewRecorder() + insp.ServeMux.ServeHTTP(rr, req) + + require.Empty(t, rr.Header().Get("Access-Control-Allow-Origin")) + require.Empty(t, rr.Header().Get("Vary")) +} + +func TestInspector_CORSAllowedOriginEchoed(t *testing.T) { + insp, app := newInspectorWithCORS(t, "http://trusted.example.com") + + req := httptest.NewRequest(http.MethodPost, + fmt.Sprintf("/inspect/%s", app.Name), + strings.NewReader("hello")) + req.Header.Set("Origin", "http://trusted.example.com") + rr := httptest.NewRecorder() + insp.ServeMux.ServeHTTP(rr, req) + + require.Equal(t, "http://trusted.example.com", rr.Header().Get("Access-Control-Allow-Origin")) + require.Contains(t, rr.Header().Values("Vary"), "Origin") +} + +func TestInspector_CORSDisallowedOriginNoGrant(t *testing.T) { + insp, app := newInspectorWithCORS(t, "http://trusted.example.com") + + req := httptest.NewRequest(http.MethodPost, + fmt.Sprintf("/inspect/%s", app.Name), + strings.NewReader("hello")) + req.Header.Set("Origin", "http://evil.com") + rr := httptest.NewRecorder() + insp.ServeMux.ServeHTTP(rr, req) + + require.Empty(t, rr.Header().Get("Access-Control-Allow-Origin")) + require.Contains(t, rr.Header().Values("Vary"), "Origin") +} + +func TestInspector_CORSPreflightShortCircuits(t *testing.T) { + insp, app := newInspectorWithCORS(t, "http://trusted.example.com") + + req := httptest.NewRequest(http.MethodOptions, + fmt.Sprintf("/inspect/%s", app.Name), nil) + req.Header.Set("Origin", "http://trusted.example.com") + req.Header.Set("Access-Control-Request-Method", "POST") + rr := httptest.NewRecorder() + insp.ServeMux.ServeHTTP(rr, req) + + require.Equal(t, http.StatusNoContent, rr.Code) + require.Equal(t, "http://trusted.example.com", rr.Header().Get("Access-Control-Allow-Origin")) + require.Equal(t, "3600", rr.Header().Get("Access-Control-Max-Age")) +} + +// ----------------------------------------------------------------------------- + +// TestInspector_ServeReturnsNilOnGracefulShutdown verifies the new Serve() +// method swallows http.ErrServerClosed and returns nil, matching the +// contract the advancer relies on. +func TestInspector_ServeReturnsNilOnGracefulShutdown(t *testing.T) { + insp, _ := newInspectorForTest(t, nil) + + // Pre-bind a listener on an OS-assigned port so the test knows the + // real address (http.Server never rewrites Addr after binding to :0). + // Inject it via the listen hook so Serve() uses it. + listener, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + addr := listener.Addr().String() + insp.listen = func(string, string) (net.Listener, error) { return listener, nil } + + serveErr := make(chan error, 1) + go func() { serveErr <- insp.Serve() }() + + // Wait until the server is actually accepting connections before + // shutting down. Any HTTP response (including 404) confirms the + // accept loop is live. + deadline := time.Now().Add(2 * time.Second) + ready := false + for time.Now().Before(deadline) { + resp, err := http.Get("http://" + addr) + if err == nil { + resp.Body.Close() + ready = true + break + } + time.Sleep(5 * time.Millisecond) + } + require.True(t, ready, "server did not start listening in time") + + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + require.NoError(t, insp.Shutdown(ctx)) + + select { + case err := <-serveErr: + require.NoError(t, err, "Serve() must return nil on graceful shutdown") + case <-time.After(2 * time.Second): + t.Fatal("Serve did not return after Shutdown") + } +} diff --git a/internal/inspect/inspect.go b/internal/inspect/inspect.go index 660b54792..994c0fb77 100644 --- a/internal/inspect/inspect.go +++ b/internal/inspect/inspect.go @@ -10,13 +10,13 @@ import ( "fmt" "io" "log/slog" + "net" "net/http" - "os" + "sync" "time" "github.com/cartesi/rollups-node/internal/manager" . "github.com/cartesi/rollups-node/internal/model" - "github.com/cartesi/rollups-node/internal/services" "github.com/cartesi/rollups-node/pkg/service" "github.com/ethereum/go-ethereum/common/hexutil" ) @@ -27,6 +27,10 @@ import ( // by the machine anyway, so there is no reason to read them into memory. const maxPayloadSize = 1 << 21 // 2 MiB +// inspectResponseHeadroom is the time budget reserved for HTTP response +// serialization after the Cartesi Machine's inspect deadline fires. +const inspectResponseHeadroom = 30 * time.Second + var ( ErrInvalidMachines = errors.New("machines must not be nil") ErrNoApp = errors.New("no application") @@ -43,9 +47,17 @@ type InspectRepository interface { type Inspector struct { IInspectMachines - repository InspectRepository - Logger *slog.Logger - ServeMux *http.ServeMux + repository InspectRepository + Logger *slog.Logger + ServeMux *http.ServeMux + server *http.Server + admission *service.SemaphoreAdmission + deadlineWarnedMu sync.Mutex + deadlineWarned map[int64]struct{} + // listen opens the HTTP listener. It defaults to net.Listen and is + // overridden in tests so Serve() can be exercised against a pre-bound + // listener whose actual address is known to the test. + listen func(network, address string) (net.Listener, error) } type ReportResponse struct { @@ -59,51 +71,88 @@ type InspectResponse struct { ProcessedInputs uint64 `json:"processed_input_count"` } -func NewInspector( - repo InspectRepository, - machines IInspectMachines, - address string, - logLevel slog.Level, - logPretty bool, -) (*Inspector, *http.Server, func() error) { - logger := service.NewLogger(slog.Level(logLevel), logPretty) - logger = logger.With("service", "inspect") +// CreateInfo bundles the parameters for [NewInspector]. +type CreateInfo struct { + Repository InspectRepository + Machines IInspectMachines + Address string + LogLevel slog.Level + LogPretty bool + // Admission is an optional HTTP-level concurrency gate. A nil value + // disables admission control; the middleware chain treats nil as a + // pass-through so wiring is uniform regardless of configuration. + Admission *service.SemaphoreAdmission + // CORSAllowedOrigins is the raw comma-separated origin allowlist. + // Empty disables CORS entirely. + CORSAllowedOrigins string +} + +// NewInspector constructs an [Inspector] and its backing HTTP server +// with the node's canonical hardening chain applied via +// [service.NewServiceHandler]. See that helper for the middleware order +// and rationale. +// +// Use [Inspector.Serve] to run the HTTP server and [Inspector.Shutdown] +// to stop it gracefully. +func NewInspector(c CreateInfo) (*Inspector, error) { + if c.Machines == nil { + return nil, ErrInvalidMachines + } + + logger := service.NewLogger(c.LogLevel, c.LogPretty).With("service", "inspect") inspector := &Inspector{ - IInspectMachines: machines, - repository: repo, + IInspectMachines: c.Machines, + repository: c.Repository, Logger: logger, + deadlineWarned: make(map[int64]struct{}), ServeMux: http.NewServeMux(), + admission: c.Admission, } - inspector.ServeMux.Handle("/inspect/{dapp}", services.CorsMiddleware(http.Handler(inspector))) - - server := &http.Server{ - Addr: address, - Handler: inspector.ServeMux, - ErrorLog: slog.NewLogLogger(inspector.Logger.Handler(), slog.LevelError), - } - - return inspector, server, func() error { - maxRetries := 3 // FIXME: should go to config - retryInterval := 5 * time.Second // FIXME: should go to config - inspector.Logger.Info("Create", "LogLevel", logLevel, "pid", os.Getpid()) - inspector.Logger.Info("Listening", "address", address) - var err error = nil - for retry := 0; retry <= maxRetries; retry++ { - switch err = server.ListenAndServe(); err { - case http.ErrServerClosed: - return nil - default: - inspector.Logger.Error("http", - "error", err, - "try", retry+1, - "maxRetries", maxRetries, - "error", err) - } - time.Sleep(retryInterval) - } + handler := service.NewServiceHandler(inspector, service.HandlerOptions{ + Logger: logger, + Admission: c.Admission, + CORS: service.ParseCORSConfig(logger, c.CORSAllowedOrigins, + []string{"POST", "OPTIONS"}, []string{"Content-Type"}), + }) + inspector.ServeMux.Handle("/inspect/{dapp}", handler) + + inspector.server = service.NewHTTPServer(c.Address, inspector.ServeMux, service.DefaultInspectOptions(), logger) + inspector.listen = net.Listen + service.StartupBindWarning(logger, "inspect", c.Address) + + return inspector, nil +} + +// Serve opens the HTTP listener and runs the server. Returns nil on +// graceful shutdown. +func (inspect *Inspector) Serve() error { + listener, err := inspect.listen("tcp", inspect.server.Addr) + if err != nil { return err } + inspect.Logger.Info("Listening", "address", listener.Addr().String()) + if err := inspect.server.Serve(listener); !errors.Is(err, http.ErrServerClosed) { + return err + } + return nil +} + +// Shutdown gracefully stops the inspect HTTP server, waiting for in-flight +// requests to complete or ctx to expire. Callers should not access the +// underlying *http.Server directly; exposing only Shutdown keeps the API +// surface minimal and prevents misuse (e.g. reaching for ListenAndServe, +// SetKeepAlivesEnabled, or mutating Handler after construction). +func (inspect *Inspector) Shutdown(ctx context.Context) error { + return inspect.server.Shutdown(ctx) +} + +// Admission returns the concurrency gate used by the inspect HTTP surface, +// or nil when admission control is disabled. This accessor gives the +// advancer (or a future metrics hook) a path to reach the inspect +// admission counters without threading the controller separately. +func (inspect *Inspector) Admission() *service.SemaphoreAdmission { + return inspect.admission } func (inspect *Inspector) ServeHTTP(w http.ResponseWriter, r *http.Request) { @@ -124,46 +173,68 @@ func (inspect *Inspector) ServeHTTP(w http.ResponseWriter, r *http.Request) { } dapp = r.PathValue("dapp") - if r.Method == "POST" { - // Limit the request body to the machine's CMIO RX buffer size. - // Payloads larger than this are rejected by the machine, so reading - // them into memory would only waste resources. We read maxPayloadSize+1 - // bytes so we can distinguish "exactly at limit" from "over limit". - limitedReader := io.LimitReader(r.Body, maxPayloadSize+1) - payload, err = io.ReadAll(limitedReader) - if err != nil { - inspect.Logger.Info("Bad request", "err", err) - http.Error(w, err.Error(), http.StatusBadRequest) - return - } - if int64(len(payload)) > maxPayloadSize { + if r.Method != http.MethodPost { + inspect.Logger.Info("HTTP method not allowed", "application", dapp, "method", r.Method) + w.Header().Set("Allow", http.MethodPost) + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + return + } + + // Cap the request body at the machine's CMIO RX buffer size. MaxBytesReader + // both enforces the limit and signals the server to close the connection + // on over-limit so clients can't pipeline further requests on it. + r.Body = http.MaxBytesReader(w, r.Body, maxPayloadSize) + payload, err = io.ReadAll(r.Body) + if err != nil { + var maxErr *http.MaxBytesError + if errors.As(err, &maxErr) { inspect.Logger.Info("Payload too large", - "size", len(payload), - "limit", maxPayloadSize) + "limit", maxPayloadSize, + "application", dapp) http.Error(w, "Payload too large", http.StatusRequestEntityTooLarge) return } - } else { - inspect.Logger.Info("HTTP method not supported", "application", dapp) - http.Error(w, "HTTP method not supported", http.StatusNotFound) + inspect.Logger.Info("Bad request", "err", err, "application", dapp) + http.Error(w, "Bad request", http.StatusBadRequest) return } inspect.Logger.Info("Got new inspect request", "application", dapp) - result, err := inspect.process(r.Context(), dapp, payload) - if err != nil { - if errors.Is(err, ErrMachineNotReady) { - inspect.Logger.Warn("Machine not ready", "application", dapp, "err", err) + + app, machine, resolveErr := inspect.resolveApp(r.Context(), dapp) + if resolveErr != nil { + if errors.Is(resolveErr, ErrMachineNotReady) { + inspect.Logger.Warn("Machine not ready", "application", dapp, "err", resolveErr) http.Error(w, "Machine not ready", http.StatusServiceUnavailable) return } - if errors.Is(err, ErrNoApp) { - inspect.Logger.Error("Application not found", "application", dapp, "err", err) + if errors.Is(resolveErr, ErrNoApp) { + inspect.Logger.Info("Application not found", "application", dapp, "err", resolveErr) http.Error(w, "Application not found", http.StatusNotFound) return } - inspect.Logger.Error("Internal server error", "err", err) - http.Error(w, err.Error(), http.StatusInternalServerError) + service.WriteInternalError(r.Context(), w, inspect.Logger, + fmt.Errorf("inspect resolve failed: %w", resolveErr)) + return + } + + deadline := app.ExecutionParameters.InspectMaxDeadline + inspectResponseHeadroom + if inspect.server != nil && deadline > inspect.server.WriteTimeout { + inspect.warnDeadlineExceedsWriteTimeout(app, deadline) + } + ctx, cancel := context.WithTimeout(r.Context(), deadline) + defer cancel() + + result, err := machine.Inspect(ctx, payload) + if err != nil { + if errors.Is(err, manager.ErrInspectAtCapacity) { + inspect.Logger.Info("Application inspect at capacity", + "application", dapp) + http.Error(w, "Application inspect at capacity", http.StatusServiceUnavailable) + return + } + service.WriteInternalError(ctx, w, inspect.Logger, + fmt.Errorf("inspect processing failed: %w", err)) return } @@ -190,11 +261,15 @@ func (inspect *Inspector) ServeHTTP(w http.ResponseWriter, r *http.Request) { } w.Header().Set("Content-Type", "application/json") - err = json.NewEncoder(w).Encode(response) - if err != nil { - inspect.Logger.Error("Internal server error", - "err", err) - http.Error(w, err.Error(), http.StatusInternalServerError) + if err := json.NewEncoder(w).Encode(response); err != nil { + // Headers are already flushed; we can only log. Writing a 500 via + // WriteInternalError here would produce "superfluous WriteHeader" + // warnings and garble the response. + inspect.Logger.Error("failed to encode inspect response", + "err", err, + "application", dapp, + "request_id", service.RequestIDFromContext(r.Context()), + ) return } inspect.Logger.Info("Request executed", @@ -202,29 +277,37 @@ func (inspect *Inspector) ServeHTTP(w http.ResponseWriter, r *http.Request) { "application", dapp) } -// process sends an inspect request to the machine -func (inspect *Inspector) process( +func (inspect *Inspector) warnDeadlineExceedsWriteTimeout(app *Application, deadline time.Duration) { + inspect.deadlineWarnedMu.Lock() + defer inspect.deadlineWarnedMu.Unlock() + if _, seen := inspect.deadlineWarned[app.ID]; seen { + return + } + inspect.deadlineWarned[app.ID] = struct{}{} + inspect.Logger.Warn( + "application inspect deadline exceeds HTTP WriteTimeout; response may be truncated", + "application", app.Name, + "inspect_max_deadline", app.ExecutionParameters.InspectMaxDeadline, + "response_headroom", inspectResponseHeadroom, + "effective_deadline", deadline, + "http_write_timeout", inspect.server.WriteTimeout, + ) +} + +func (inspect *Inspector) resolveApp( ctx context.Context, nameOrAddress string, - query []byte) (*InspectResult, error) { - +) (*Application, manager.MachineInstance, error) { app, err := inspect.repository.GetApplication(ctx, nameOrAddress) if app == nil { if err != nil { - return nil, fmt.Errorf("%w %s", err, nameOrAddress) + return nil, nil, fmt.Errorf("%w %s", err, nameOrAddress) } - return nil, fmt.Errorf("%w %s", ErrNoApp, nameOrAddress) + return nil, nil, fmt.Errorf("%w %s", ErrNoApp, nameOrAddress) } - // Asserts that the app has an associated machine. machine, exists := inspect.GetMachine(app.ID) if !exists { - return nil, fmt.Errorf("%w %s", ErrMachineNotReady, nameOrAddress) + return nil, nil, fmt.Errorf("%w %s", ErrMachineNotReady, nameOrAddress) } - - res, err := machine.Inspect(ctx, query) - if err != nil { - return nil, err - } - - return res, nil + return app, machine, nil } diff --git a/internal/inspect/inspect_test.go b/internal/inspect/inspect_test.go index 386eb64b8..c2bec0055 100644 --- a/internal/inspect/inspect_test.go +++ b/internal/inspect/inspect_test.go @@ -11,12 +11,12 @@ import ( "fmt" "log/slog" "net/http" + "net/http/httptest" "testing" "time" "github.com/cartesi/rollups-node/internal/manager" . "github.com/cartesi/rollups-node/internal/model" - "github.com/cartesi/rollups-node/internal/services" "github.com/cartesi/rollups-node/pkg/service" "github.com/ethereum/go-ethereum/common" @@ -24,50 +24,21 @@ import ( "github.com/stretchr/testify/suite" ) -const TestTimeout = 5 * time.Second - func TestInspect(t *testing.T) { suite.Run(t, new(InspectSuite)) } type InspectSuite struct { suite.Suite - ServicePort int - ServiceAddr string -} - -func (s *InspectSuite) SetupSuite() { - s.ServicePort = 5555 -} - -func (s *InspectSuite) SetupTest() { - s.ServicePort++ - s.ServiceAddr = fmt.Sprintf("127.0.0.1:%v", s.ServicePort) } func (s *InspectSuite) TestPostOk() { inspect, app, payload := s.setup() - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() + srv := s.startServer(inspect) + defer srv.Close() - router := http.NewServeMux() - router.Handle("/inspect/{dapp}", inspect) - httpService := services.HttpService{Name: "http", Address: s.ServiceAddr, Handler: router} - - result := make(chan error, 1) - ready := make(chan struct{}, 1) - go func() { - result <- httpService.Start(ctx, ready, service.NewLogger(slog.LevelDebug, true)) - }() - - select { - case <-ready: - case <-time.After(TestTimeout): - s.FailNow("timed out waiting for HttpService to be ready") - } - - resp, err := http.Post(fmt.Sprintf("http://%v/inspect/%v", s.ServiceAddr, app.IApplicationAddress.Hex()), + resp, err := http.Post(fmt.Sprintf("%s/inspect/%s", srv.URL, app.IApplicationAddress.Hex()), "application/octet-stream", bytes.NewBuffer(payload.Bytes())) if err != nil { @@ -79,26 +50,10 @@ func (s *InspectSuite) TestPostOk() { func (s *InspectSuite) TestPostWithNameOk() { inspect, app, payload := s.setup() - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() + srv := s.startServer(inspect) + defer srv.Close() - router := http.NewServeMux() - router.Handle("/inspect/{dapp}", inspect) - httpService := services.HttpService{Name: "http", Address: s.ServiceAddr, Handler: router} - - result := make(chan error, 1) - ready := make(chan struct{}, 1) - go func() { - result <- httpService.Start(ctx, ready, service.NewLogger(slog.LevelDebug, true)) - }() - - select { - case <-ready: - case <-time.After(TestTimeout): - s.FailNow("timed out waiting for HttpService to be ready") - } - - resp, err := http.Post(fmt.Sprintf("http://%s/inspect/%s", s.ServiceAddr, app.Name), + resp, err := http.Post(fmt.Sprintf("%s/inspect/%s", srv.URL, app.Name), "application/octet-stream", bytes.NewBuffer(payload.Bytes())) if err != nil { @@ -110,32 +65,16 @@ func (s *InspectSuite) TestPostWithNameOk() { func (s *InspectSuite) TestPostNoApp() { inspect, _, payload := s.setup() - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() + srv := s.startServer(inspect) + defer srv.Close() - router := http.NewServeMux() - router.Handle("/inspect/{dapp}", inspect) - httpService := services.HttpService{Name: "http", Address: s.ServiceAddr, Handler: router} - - result := make(chan error, 1) - ready := make(chan struct{}, 1) - go func() { - result <- httpService.Start(ctx, ready, service.NewLogger(slog.LevelDebug, true)) - }() - - select { - case <-ready: - case <-time.After(TestTimeout): - s.FailNow("timed out waiting for HttpService to be ready") - } - - resp, err := http.Post(fmt.Sprintf("http://%s/inspect/%s", s.ServiceAddr, "Aloha"), + resp, err := http.Post(fmt.Sprintf("%s/inspect/%s", srv.URL, "Aloha"), "application/octet-stream", bytes.NewBuffer(payload.Bytes())) s.Require().Nil(err) s.Equal(http.StatusNotFound, resp.StatusCode) - resp, err = http.Post(fmt.Sprintf("http://%s/inspect/%s", s.ServiceAddr, + resp, err = http.Post(fmt.Sprintf("%s/inspect/%s", srv.URL, "0x1000000000000000000000000000000000000000"), "application/octet-stream", bytes.NewBuffer(payload.Bytes())) @@ -144,9 +83,6 @@ func (s *InspectSuite) TestPostNoApp() { } func (s *InspectSuite) TestPostMachineNotReady() { - // App exists in the repository but has no machine in the machines map. - // This simulates the startup window where the advancer hasn't created - // the machine instance yet. Should return 503 Service Unavailable. app := &Application{ ID: 42, IApplicationAddress: randomAddress(), @@ -154,7 +90,7 @@ func (s *InspectSuite) TestPostMachineNotReady() { } repo := newMockRepository() repo.apps = append(repo.apps, app) - machines := newMockMachines() // no machine added for app ID 42 + machines := newMockMachines() inspect := &Inspector{ repository: repo, @@ -162,34 +98,17 @@ func (s *InspectSuite) TestPostMachineNotReady() { Logger: service.NewLogger(slog.LevelDebug, true), } - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - router := http.NewServeMux() - router.Handle("/inspect/{dapp}", inspect) - httpService := services.HttpService{Name: "http", Address: s.ServiceAddr, Handler: router} - - ready := make(chan struct{}, 1) - go func() { - _ = httpService.Start(ctx, ready, service.NewLogger(slog.LevelDebug, true)) - }() + srv := s.startServer(inspect) + defer srv.Close() - select { - case <-ready: - case <-time.After(TestTimeout): - s.FailNow("timed out waiting for HttpService to be ready") - } - - // Query by name - respByName, err := http.Post(fmt.Sprintf("http://%s/inspect/%s", s.ServiceAddr, app.Name), + respByName, err := http.Post(fmt.Sprintf("%s/inspect/%s", srv.URL, app.Name), "application/octet-stream", bytes.NewBuffer([]byte("hello"))) s.Require().Nil(err) defer respByName.Body.Close() s.Equal(http.StatusServiceUnavailable, respByName.StatusCode) - // Query by address - respByAddr, err := http.Post(fmt.Sprintf("http://%s/inspect/%s", s.ServiceAddr, app.IApplicationAddress.Hex()), + respByAddr, err := http.Post(fmt.Sprintf("%s/inspect/%s", srv.URL, app.IApplicationAddress.Hex()), "application/octet-stream", bytes.NewBuffer([]byte("hello"))) s.Require().Nil(err) @@ -199,17 +118,16 @@ func (s *InspectSuite) TestPostMachineNotReady() { func (s *InspectSuite) TestPostMaxPayloadSize() { inspect, app, _ := s.setup() - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - s.startServer(ctx, inspect) - // A payload exactly at the max size should be accepted. + srv := s.startServer(inspect) + defer srv.Close() + payload := make([]byte, maxPayloadSize) _, err := crand.Read(payload) s.Require().NoError(err) resp, err := http.Post( - fmt.Sprintf("http://%s/inspect/%s", s.ServiceAddr, app.IApplicationAddress.Hex()), + fmt.Sprintf("%s/inspect/%s", srv.URL, app.IApplicationAddress.Hex()), "application/octet-stream", bytes.NewReader(payload)) s.Require().NoError(err) @@ -225,17 +143,16 @@ func (s *InspectSuite) TestPostMaxPayloadSize() { func (s *InspectSuite) TestPostPayloadTooLarge() { inspect, app, _ := s.setup() - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - s.startServer(ctx, inspect) - // A payload one byte over the max size should be rejected. + srv := s.startServer(inspect) + defer srv.Close() + payload := make([]byte, maxPayloadSize+1) _, err := crand.Read(payload) s.Require().NoError(err) resp, err := http.Post( - fmt.Sprintf("http://%s/inspect/%s", s.ServiceAddr, app.IApplicationAddress.Hex()), + fmt.Sprintf("%s/inspect/%s", srv.URL, app.IApplicationAddress.Hex()), "application/octet-stream", bytes.NewReader(payload)) s.Require().NoError(err) @@ -243,21 +160,10 @@ func (s *InspectSuite) TestPostPayloadTooLarge() { s.Equal(http.StatusRequestEntityTooLarge, resp.StatusCode) } -func (s *InspectSuite) startServer(ctx context.Context, inspect *Inspector) { +func (s *InspectSuite) startServer(inspect *Inspector) *httptest.Server { router := http.NewServeMux() router.Handle("/inspect/{dapp}", inspect) - httpService := services.HttpService{Name: "http", Address: s.ServiceAddr, Handler: router} - - ready := make(chan struct{}, 1) - go func() { - _ = httpService.Start(ctx, ready, service.NewLogger(slog.LevelDebug, true)) - }() - - select { - case <-ready: - case <-time.After(TestTimeout): - s.FailNow("timed out waiting for HttpService to be ready") - } + return httptest.NewServer(router) } func (s *InspectSuite) setup() (*Inspector, *Application, common.Hash) { @@ -330,6 +236,7 @@ func (mock *MockMachine) Inspect( return &res, nil } +// Not used in inspect tests, but needed to satisfy the interface func (mock *MockMachine) Advance( _ context.Context, input []byte, @@ -337,7 +244,6 @@ func (mock *MockMachine) Advance( _ uint64, _ bool, ) (*AdvanceResult, error) { - // Not used in inspect tests, but needed to satisfy the interface return nil, nil } @@ -353,24 +259,23 @@ func (m *MockMachine) OutputsProof(ctx context.Context) (*OutputsProof, error) { return nil, nil } +// Not used in inspect tests, but needed to satisfy the interface func (mock *MockMachine) Synchronize(ctx context.Context, repo manager.MachineRepository, batchSize uint64) error { - // Not used in inspect tests, but needed to satisfy the interface return nil } +// Not used in inspect tests, but needed to satisfy the interface func (mock *MockMachine) CreateSnapshot(ctx context.Context, processedInputs uint64, path string) error { - // Not used in inspect tests, but needed to satisfy the interface return nil } // Retrieves the hash of the current machine state func (m *MockMachine) Hash(ctx context.Context) ([32]byte, error) { - // Not used in inspect tests, but needed to satisfy the interface return [32]byte{}, nil } +// Not used in inspect tests, but needed to satisfy the interface func (mock *MockMachine) Close() error { - // Not used in inspect tests, but needed to satisfy the interface return nil } @@ -380,6 +285,9 @@ func newMockMachine(id int64) *MockMachine { ID: id, IApplicationAddress: randomAddress(), Name: fmt.Sprintf("app-%v", id), + ExecutionParameters: ExecutionParameters{ + InspectMaxDeadline: 10 * time.Second, + }, }, } } diff --git a/internal/jsonrpc/jsonrpc.go b/internal/jsonrpc/jsonrpc.go index c766e90a8..7b73514fa 100644 --- a/internal/jsonrpc/jsonrpc.go +++ b/internal/jsonrpc/jsonrpc.go @@ -79,6 +79,11 @@ func (s *Service) handleRPC(w http.ResponseWriter, r *http.Request) { defer r.Body.Close() body, err := io.ReadAll(r.Body) if err != nil { + var maxErr *http.MaxBytesError + if errors.As(err, &maxErr) { + http.Error(w, "Payload too large", http.StatusRequestEntityTooLarge) + return + } http.Error(w, "Failed to read request body", http.StatusBadRequest) return } diff --git a/internal/jsonrpc/service.go b/internal/jsonrpc/service.go index 7f08532e4..f39ad2157 100644 --- a/internal/jsonrpc/service.go +++ b/internal/jsonrpc/service.go @@ -13,7 +13,6 @@ import ( "github.com/cartesi/rollups-node/internal/config" "github.com/cartesi/rollups-node/internal/repository" - "github.com/cartesi/rollups-node/internal/services" "github.com/cartesi/rollups-node/pkg/contracts/inputs" "github.com/cartesi/rollups-node/pkg/contracts/outputs" "github.com/cartesi/rollups-node/pkg/service" @@ -21,6 +20,8 @@ import ( "github.com/ethereum/go-ethereum/accounts/abi" ) +const jsonrpcShutdownTimeout = 5 * time.Second + // ----------------------------------------------------------------------------- // Service Implementation // ----------------------------------------------------------------------------- @@ -30,6 +31,7 @@ type Service struct { service.Service repository repository.Repository server *http.Server + admission *service.SemaphoreAdmission inputABI *abi.ABI outputABI *abi.ABI // listen opens the HTTP listener. It defaults to net.Listen and is @@ -74,15 +76,23 @@ func Create(ctx context.Context, c *CreateInfo) (*Service, error) { return nil, err } + if c.Config.JsonrpcMaxInflight > 0 { + s.admission = service.NewSemaphoreAdmission(c.Config.JsonrpcMaxInflight) + } + mux := http.NewServeMux() mux.HandleFunc("/rpc", s.handleRPC) - s.server = &http.Server{ - Addr: c.Config.JsonrpcApiAddress, - Handler: services.CorsMiddleware(mux), // FIXME: add proper cors config - WriteTimeout: 30 * time.Second, //nolint: mnd - ReadTimeout: 30 * time.Second, //nolint: mnd - ReadHeaderTimeout: 10 * time.Second, //nolint: mnd - } + + handler := service.NewServiceHandler(mux, service.HandlerOptions{ + Logger: s.Logger, + Admission: s.admission, + CORS: service.ParseCORSConfig(s.Logger, c.Config.JsonrpcCorsAllowedOrigins, + []string{"POST", "OPTIONS"}, []string{"Content-Type"}), + }) + + s.server = service.NewHTTPServer(c.Config.JsonrpcApiAddress, handler, service.DefaultJSONRPCOptions(), s.Logger) + service.StartupBindWarning(s.Logger, "jsonrpc", c.Config.JsonrpcApiAddress) + if s.listen == nil { s.listen = net.Listen } @@ -111,7 +121,7 @@ func (s *Service) Stop(_ bool) []error { s.SetStopping() var errs []error s.Logger.Info("Shutting down JSON-RPC HTTP server", "addr", s.server.Addr) - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) //nolint: mnd + ctx, cancel := context.WithTimeout(context.Background(), jsonrpcShutdownTimeout) defer cancel() if err := s.server.Shutdown(ctx); err != nil { errs = append(errs, err) diff --git a/internal/jsonrpc/service_test.go b/internal/jsonrpc/service_test.go new file mode 100644 index 000000000..e4ad62ba9 --- /dev/null +++ b/internal/jsonrpc/service_test.go @@ -0,0 +1,223 @@ +// (c) Cartesi and individual authors (see AUTHORS) +// SPDX-License-Identifier: Apache-2.0 (see LICENSE) + +package jsonrpc + +import ( + "bytes" + "net/http" + "net/http/httptest" + "strconv" + "testing" + + "github.com/cartesi/rollups-node/pkg/service" + + "github.com/stretchr/testify/require" +) + +// ensureSentinelRejects fires a request against s.server.Handler and +// expects the admission middleware to reject it with a 503. +func ensureSentinelRejects(t *testing.T, s *Service) { + t.Helper() + req := httptest.NewRequest(http.MethodPost, "/rpc", http.NoBody) + req.Header.Set("Content-Type", "application/json") + rr := httptest.NewRecorder() + s.server.Handler.ServeHTTP(rr, req) + + require.Equal(t, http.StatusServiceUnavailable, rr.Code) + retryAfter, err := strconv.Atoi(rr.Header().Get("Retry-After")) + require.NoError(t, err) + require.GreaterOrEqual(t, retryAfter, 1) + require.LessOrEqual(t, retryAfter, 3) + require.Contains(t, rr.Body.String(), "service at capacity") + require.Equal(t, "nosniff", rr.Header().Get("X-Content-Type-Options")) +} + +// TestJSONRPC_HardenedServerOptions checks that Create wires the jsonrpc +// HTTP server through NewHTTPServer with DefaultJSONRPCOptions. If a future +// refactor reintroduces an inline &http.Server{} literal or drops a field, +// this test catches it. +func TestJSONRPC_HardenedServerOptions(t *testing.T) { + s := newTestService(t, "jsonrpc-server-options") + require.NotNil(t, s.server) + + opts := service.DefaultJSONRPCOptions() + require.Equal(t, opts.ReadHeaderTimeout, s.server.ReadHeaderTimeout) + require.Equal(t, opts.ReadTimeout, s.server.ReadTimeout) + require.Equal(t, opts.WriteTimeout, s.server.WriteTimeout) + require.Equal(t, opts.IdleTimeout, s.server.IdleTimeout) + require.Equal(t, opts.MaxHeaderBytes, s.server.MaxHeaderBytes) + require.NotNil(t, s.server.ErrorLog) +} + +// TestJSONRPC_RequestIDPropagated verifies the middleware chain echoes a +// valid X-Request-ID back on the response. Runs directly against the +// Service.server handler so the whole chain (Recover -> RequestID -> CORS +// -> Admission -> mux) is exercised. +func TestJSONRPC_RequestIDPropagated(t *testing.T) { + s := newTestService(t, "jsonrpc-request-id") + + req := httptest.NewRequest(http.MethodPost, "/rpc", http.NoBody) + req.Header.Set("X-Request-ID", "pinned-xyz") + req.Header.Set("Content-Type", "application/json") + rr := httptest.NewRecorder() + s.server.Handler.ServeHTTP(rr, req) + + // handleRPC will reject the empty body as a bad request, but the + // middleware chain still runs and must echo the header. + require.Equal(t, "pinned-xyz", rr.Header().Get("X-Request-ID")) +} + +// ----------------------------------------------------------------------------- +// Oversized body handling +// ----------------------------------------------------------------------------- + +func TestJSONRPC_OversizedBodyReturns413(t *testing.T) { + s := newTestService(t, "jsonrpc-oversize") + + // Build a body that exceeds MAX_BODY_SIZE (1 MiB). + oversized := bytes.Repeat([]byte("x"), MAX_BODY_SIZE+1) + + req := httptest.NewRequest(http.MethodPost, "/rpc", bytes.NewReader(oversized)) + req.Header.Set("Content-Type", "application/json") + rr := httptest.NewRecorder() + s.server.Handler.ServeHTTP(rr, req) + + require.Equal(t, http.StatusRequestEntityTooLarge, rr.Code) + require.Contains(t, rr.Body.String(), "Payload too large") +} + +// ----------------------------------------------------------------------------- +// Admission control +// ----------------------------------------------------------------------------- + +func TestJSONRPC_AdmissionDisabledWhenZero(t *testing.T) { + // JsonrpcMaxInflight=0 must leave s.admission == nil and make the + // middleware a passthrough. We can't assert "no 503" directly + // without coordinating a slow handler; instead verify the field + // and confirm a basic request reaches handleRPC (which rejects + // an empty body with 400, not 503). + s := newTestServiceWithInflight(t, "jsonrpc-adm-zero", 0) + require.Nil(t, s.admission, "limit=0 must produce nil admission") + + req := httptest.NewRequest(http.MethodPost, "/rpc", http.NoBody) + req.Header.Set("Content-Type", "application/json") + rr := httptest.NewRecorder() + s.server.Handler.ServeHTTP(rr, req) + require.NotEqual(t, http.StatusServiceUnavailable, rr.Code, + "disabled admission must not return 503") +} + +func TestJSONRPC_AdmissionWiredWhenPositive(t *testing.T) { + // JsonrpcMaxInflight>0 must construct a non-nil SemaphoreAdmission + // with the matching limit. + s := newTestServiceWithInflight(t, "jsonrpc-adm-wired", 7) + require.NotNil(t, s.admission) + require.Equal(t, uint64(7), s.admission.Limit()) +} + +func TestJSONRPC_AdmissionRejectsWhenExhausted(t *testing.T) { + // Replace the post-Create admission with a pre-filled semaphore + // and rebuild just the middleware stack onto a fresh mux. This is + // the same wiring Create() does — inlined here to exercise the + // rejection path without a blocking RPC handler. + s := newTestServiceWithInflight(t, "jsonrpc-adm-reject", 1) + + // Swap the admission underlying the server handler for a + // pre-filled one to force rejection on every request. + s.admission = service.NewSemaphoreAdmission(1) + s.admission.TryAcquire() // pre-fill the single permit + s.server.Handler = rebuildHandlerWithAdmission(s) + + ensureSentinelRejects(t, s) + require.Equal(t, uint64(1), s.admission.Rejected()) +} + +func TestJSONRPC_AdmissionPermitReleasedAfterRequest(t *testing.T) { + // With limit=1 a sequential burst must all succeed: each request + // releases its permit on return. + s := newTestServiceWithInflight(t, "jsonrpc-adm-release", 1) + require.NotNil(t, s.admission) + + for range 5 { + req := httptest.NewRequest(http.MethodPost, "/rpc", http.NoBody) + req.Header.Set("Content-Type", "application/json") + rr := httptest.NewRecorder() + s.server.Handler.ServeHTTP(rr, req) + // handleRPC returns 400 on empty body; the key assertion is + // that we never see 503 because the permit is released each + // time the handler returns. + require.NotEqual(t, http.StatusServiceUnavailable, rr.Code) + } + require.Equal(t, uint64(0), s.admission.Rejected()) +} + +// rebuildHandlerWithAdmission rewraps the service's mux with only +// AdmissionMiddleware bound to the service's current admission field. +// Used by tests that swap the admission after construction to exercise +// the rejection path. The outer chain (Recover, RequestID, CORS) from +// Create() is deliberately omitted — admission is what these tests +// pin, and the other layers are inert for a plain POST without Origin. +func rebuildHandlerWithAdmission(s *Service) http.Handler { + mux := http.NewServeMux() + mux.HandleFunc("/rpc", s.handleRPC) + + var handler http.Handler = mux + handler = service.AdmissionMiddleware(s.admission)(handler) + return handler +} + +// ----------------------------------------------------------------------------- +// CORS integration +// ----------------------------------------------------------------------------- + +func TestJSONRPC_CORSDisabledByDefault(t *testing.T) { + s := newTestService(t, "jsonrpc-cors-off") + + req := httptest.NewRequest(http.MethodPost, "/rpc", http.NoBody) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Origin", "http://evil.com") + rr := httptest.NewRecorder() + s.server.Handler.ServeHTTP(rr, req) + + require.Empty(t, rr.Header().Get("Access-Control-Allow-Origin")) + require.Empty(t, rr.Header().Get("Vary")) +} + +func TestJSONRPC_CORSAllowedOriginEchoed(t *testing.T) { + s := newTestServiceWithCORS(t, "jsonrpc-cors-on", "http://trusted.example.com") + + req := httptest.NewRequest(http.MethodPost, "/rpc", http.NoBody) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Origin", "http://trusted.example.com") + rr := httptest.NewRecorder() + s.server.Handler.ServeHTTP(rr, req) + + require.Equal(t, "http://trusted.example.com", rr.Header().Get("Access-Control-Allow-Origin")) + require.Contains(t, rr.Header().Values("Vary"), "Origin") +} + +func TestJSONRPC_CORSDisallowedOriginNoGrant(t *testing.T) { + s := newTestServiceWithCORS(t, "jsonrpc-cors-reject", "http://trusted.example.com") + + req := httptest.NewRequest(http.MethodPost, "/rpc", http.NoBody) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Origin", "http://evil.com") + rr := httptest.NewRecorder() + s.server.Handler.ServeHTTP(rr, req) + + require.Empty(t, rr.Header().Get("Access-Control-Allow-Origin")) + require.Contains(t, rr.Header().Values("Vary"), "Origin") +} + +func TestJSONRPC_CORSNoOriginPassthrough(t *testing.T) { + s := newTestServiceWithCORS(t, "jsonrpc-cors-no-origin", "http://trusted.example.com") + + req := httptest.NewRequest(http.MethodPost, "/rpc", http.NoBody) + req.Header.Set("Content-Type", "application/json") + rr := httptest.NewRecorder() + s.server.Handler.ServeHTTP(rr, req) + + require.Empty(t, rr.Header().Get("Access-Control-Allow-Origin")) + require.NotEqual(t, http.StatusServiceUnavailable, rr.Code) +} diff --git a/internal/jsonrpc/util_test.go b/internal/jsonrpc/util_test.go index 629d56b51..e68660282 100644 --- a/internal/jsonrpc/util_test.go +++ b/internal/jsonrpc/util_test.go @@ -71,6 +71,21 @@ type testRPCResponse[T any] struct { } func newTestService(t *testing.T, name string) *Service { + return newTestServiceWithInflight(t, name, 0) +} + +// newTestServiceWithInflight is like newTestService but lets the caller +// set CARTESI_JSONRPC_MAX_INFLIGHT for admission-control tests. A value +// of 0 disables admission (the Create-time nil-admission path). +func newTestServiceWithInflight(t *testing.T, name string, maxInflight uint64) *Service { + return newTestServiceFull(t, name, maxInflight, "") +} + +func newTestServiceWithCORS(t *testing.T, name string, origins string) *Service { + return newTestServiceFull(t, name, 0, origins) +} + +func newTestServiceFull(t *testing.T, name string, maxInflight uint64, corsOrigins string) *Service { ctx := context.Background() dbTestEndpoint, err := db.GetTestDatabaseEndpoint() @@ -91,6 +106,10 @@ func newTestService(t *testing.T, name string) *Service { LogLevel: logLevel, LogColor: true, }, + Config: config.JsonrpcConfig{ + JsonrpcMaxInflight: maxInflight, + JsonrpcCorsAllowedOrigins: corsOrigins, + }, Repository: repo, } s, err := Create(ctx, &ci) diff --git a/internal/manager/instance.go b/internal/manager/instance.go index 1bb6ea4e5..56274e8a1 100644 --- a/internal/manager/instance.go +++ b/internal/manager/instance.go @@ -29,6 +29,7 @@ var ( ErrInvalidAdvanceTimeout = errors.New("advance timeout must not be negative") ErrInvalidInspectTimeout = errors.New("inspect timeout must not be negative") ErrInvalidConcurrentLimit = errors.New("maximum concurrent inspects must not be zero") + ErrInspectAtCapacity = errors.New("application inspect at capacity") ) // MachineInstanceImpl represents a running Cartesi machine for an application. @@ -389,10 +390,11 @@ func (m *MachineInstanceImpl) forkForInspect(ctx context.Context) (machine.Machi // Inspect queries the machine state without modifying it func (m *MachineInstanceImpl) Inspect(ctx context.Context, query []byte) (*InspectResult, error) { - // Limit concurrent inspects - err := m.inspectSemaphore.Acquire(ctx, 1) - if err != nil { - return nil, err + // Limit concurrent inspects. TryAcquire is non-blocking so that a + // saturated application fails fast and releases its HTTP admission + // permit, preventing one app from starving others on the same node. + if !m.inspectSemaphore.TryAcquire(1) { + return nil, ErrInspectAtCapacity } defer m.inspectSemaphore.Release(1) diff --git a/internal/manager/instance_test.go b/internal/manager/instance_test.go index 8a77931a7..995375e17 100644 --- a/internal/manager/instance_test.go +++ b/internal/manager/instance_test.go @@ -610,20 +610,19 @@ func (s *MachineInstanceSuite) TestInspect() { }) s.Run("Error", func() { - s.Run("Acquire", func() { + s.Run("AtCapacity", func() { require := s.Require() _, _, machine := s.setupInspect() - // Set semaphore to 0 to force acquisition failure + // Pre-fill all semaphore slots to simulate a saturated app machine.inspectSemaphore.TryAcquire(int64(machine.maxConcurrentInspects)) - ctx, cancel := context.WithTimeout(context.Background(), centisecond) - defer cancel() - - res, err := machine.Inspect(ctx, []byte{}) + // TryAcquire is non-blocking: the error is returned immediately, + // no context deadline required. + res, err := machine.Inspect(context.Background(), []byte{}) require.Error(err) require.Nil(res) - require.ErrorIs(err, context.DeadlineExceeded) + require.ErrorIs(err, ErrInspectAtCapacity) // Release the semaphore for cleanup machine.inspectSemaphore.Release(int64(machine.maxConcurrentInspects)) diff --git a/internal/services/http.go b/internal/services/http.go deleted file mode 100644 index 6b1d03298..000000000 --- a/internal/services/http.go +++ /dev/null @@ -1,79 +0,0 @@ -// (c) Cartesi and individual authors (see AUTHORS) -// SPDX-License-Identifier: Apache-2.0 (see LICENSE) - -package services - -import ( - "context" - "errors" - "log/slog" - "net" - "net/http" - "time" -) - -const DefaultServiceTimeout = 1 * time.Minute - -// FIXME: Simple CORS middleware. Improve this -func CorsMiddleware(next http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - // Set CORS headers - w.Header().Set("Access-Control-Allow-Origin", "*") // Allow all origins - w.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS") - w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization") - - // Handle preflight (OPTIONS) request - if r.Method == http.MethodOptions { - w.WriteHeader(http.StatusNoContent) - return - } - - // Proceed with the next handler if not preflight - next.ServeHTTP(w, r) - }) -} - -// Used for testing -type HttpService struct { - Name string - Address string - Handler http.Handler -} - -func (s *HttpService) String() string { - return s.Name -} - -func (s *HttpService) Start(ctx context.Context, ready chan<- struct{}, logger *slog.Logger) error { - server := http.Server{ - Addr: s.Address, - Handler: CorsMiddleware(s.Handler), - ErrorLog: slog.NewLogLogger(logger.Handler(), slog.LevelError), - } - - listener, err := net.Listen("tcp", s.Address) - if err != nil { - return err - } - - logger.Info("HTTP server started listening", "service", s, "port", listener.Addr()) - ready <- struct{}{} - - done := make(chan error, 1) - go func() { - err := server.Serve(listener) - if !errors.Is(err, http.ErrServerClosed) { - logger.Warn("Service exited with error", "service", s, "error", err) - } - done <- err - }() - - select { - case err = <-done: - return err - case <-ctx.Done(): - ctx, cancel := context.WithTimeout(context.Background(), DefaultServiceTimeout) - defer cancel() - return server.Shutdown(ctx) - } -} diff --git a/internal/services/http_test.go b/internal/services/http_test.go deleted file mode 100644 index 2b9329644..000000000 --- a/internal/services/http_test.go +++ /dev/null @@ -1,157 +0,0 @@ -// (c) Cartesi and individual authors (see AUTHORS) -// SPDX-License-Identifier: Apache-2.0 (see LICENSE) - -package services - -import ( - "context" - "fmt" - "io" - "log/slog" - "net/http" - "testing" - "time" - - "github.com/stretchr/testify/suite" -) - -type HttpServiceSuite struct { - suite.Suite - ServicePort int - ServiceAddr string -} - -func TestHttpService(t *testing.T) { - suite.Run(t, new(HttpServiceSuite)) -} - -func (s *HttpServiceSuite) SetupSuite() { - s.ServicePort = 5555 -} - -func (s *HttpServiceSuite) SetupTest() { - s.ServicePort++ - s.ServiceAddr = fmt.Sprintf("127.0.0.1:%v", s.ServicePort) -} - -func (s *HttpServiceSuite) TestItStopsWhenContextIsClosed() { - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - service := HttpService{Name: "http", Address: s.ServiceAddr, Handler: http.NewServeMux()} - - result := make(chan error, 1) - ready := make(chan struct{}, 1) - go func() { - result <- service.Start(ctx, ready, slog.Default()) - }() - - select { - case <-ready: - cancel() - case <-time.After(DefaultServiceTimeout): - s.FailNow("timed out waiting for HttpService to be ready") - } - - select { - case err := <-result: - s.Nil(err) - case <-time.After(DefaultServiceTimeout): - s.FailNow("timed out waiting for HttpService to stop") - } -} - -func (s *HttpServiceSuite) TestItRespondsToRequests() { - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - router := http.NewServeMux() - router.HandleFunc("/test", defaultHandler) - service := HttpService{Name: "http", Address: s.ServiceAddr, Handler: router} - - result := make(chan error, 1) - ready := make(chan struct{}, 1) - go func() { - result <- service.Start(ctx, ready, slog.Default()) - }() - - select { - case <-ready: - case <-time.After(DefaultServiceTimeout): - s.FailNow("timed out waiting for HttpService to be ready") - } - - resp, err := http.Get(fmt.Sprintf("http://%v/test", s.ServiceAddr)) - if err != nil { - s.FailNow(err.Error()) - } - s.assertResponse(resp) -} - -func (s *HttpServiceSuite) TestItRespondsOngoingRequestsAfterContextIsClosed() { - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - router := http.NewServeMux() - router.HandleFunc("/test", func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusOK) - // simulate a long-running request - <-time.After(100 * time.Millisecond) - fmt.Fprintf(w, "test") - }) - service := HttpService{Name: "http", Address: s.ServiceAddr, Handler: router} - - result := make(chan error, 1) - ready := make(chan struct{}, 1) - go func() { - result <- service.Start(ctx, ready, slog.Default()) - }() - - select { - case <-ready: - case <-time.After(DefaultServiceTimeout): - s.FailNow("timed out waiting for HttpService to be ready") - } - - clientResult := make(chan ClientResult, 1) - go func() { - resp, err := http.Get(fmt.Sprintf("http://%v/test", s.ServiceAddr)) - clientResult <- ClientResult{Response: resp, Error: err} - }() - - // wait a bit so server has enough time to start responding the request - <-time.After(200 * time.Millisecond) - cancel() - - select { - case res := <-clientResult: - s.Nil(res.Error) - s.assertResponse(res.Response) - err := <-result - s.Nil(err) - case <-result: - s.FailNow("HttpService closed before responding") - } -} - -type ClientResult struct { - Response *http.Response - Error error -} - -func defaultHandler(w http.ResponseWriter, req *http.Request) { - w.WriteHeader(http.StatusOK) - fmt.Fprint(w, "test") -} - -func (s *HttpServiceSuite) assertResponse(resp *http.Response) { - s.Equal(http.StatusOK, resp.StatusCode) - - defer resp.Body.Close() - - bytes, err := io.ReadAll(resp.Body) - if err != nil { - s.FailNow("failed to read response body. ", err) - } - s.Equal([]byte("test"), bytes) -} diff --git a/pkg/service/http_admission.go b/pkg/service/http_admission.go new file mode 100644 index 000000000..3d3203087 --- /dev/null +++ b/pkg/service/http_admission.go @@ -0,0 +1,112 @@ +// (c) Cartesi and individual authors (see AUTHORS) +// SPDX-License-Identifier: Apache-2.0 (see LICENSE) + +package service + +import ( + "fmt" + "math" + "math/rand/v2" + "net/http" + "strconv" + "sync/atomic" + + "golang.org/x/sync/semaphore" +) + +const ( + retryAfterMin = 1 + retryAfterMax = 3 +) + +// SemaphoreAdmission caps the number of concurrent HTTP requests that may +// be in flight on a given route. It is a front-door gate used by callers +// (inspect, jsonrpc) to fail fast under saturation before spending memory +// or backend resources on a request. +// +// Production callers disable admission by passing nil to +// [AdmissionMiddleware], which produces a no-op middleware so callers can +// wire admission unconditionally and disable it via configuration. +type SemaphoreAdmission struct { + sem *semaphore.Weighted + limit uint64 + rejected atomic.Uint64 +} + +// NewSemaphoreAdmission constructs a [SemaphoreAdmission] with the given +// capacity. Values above [math.MaxInt64] are clamped to MaxInt64 because +// the underlying semaphore primitive is int64-based; realistic operator +// inputs are orders of magnitude below this bound. +// +// Production callers disable admission by passing nil to +// [AdmissionMiddleware], not by constructing a zero-limit controller. +func NewSemaphoreAdmission(limit uint64) *SemaphoreAdmission { + sizedLimit := min(limit, uint64(math.MaxInt64)) + return &SemaphoreAdmission{ + sem: semaphore.NewWeighted(int64(sizedLimit)), //nolint:gosec // G115: clamped above + limit: sizedLimit, + } +} + +// TryAcquire attempts to reserve one in-flight slot without blocking. +// It returns true when a permit was obtained; false otherwise. Rejected +// attempts bump the monotonic [Rejected] counter. +func (a *SemaphoreAdmission) TryAcquire() bool { + if !a.sem.TryAcquire(1) { + a.rejected.Add(1) + return false + } + return true +} + +// Release returns a permit previously obtained via [TryAcquire]. It is +// an error to call Release without a matching successful TryAcquire. +func (a *SemaphoreAdmission) Release() { + a.sem.Release(1) +} + +// Rejected returns the total number of permit attempts that were +// rejected since construction. Monotonic. +func (a *SemaphoreAdmission) Rejected() uint64 { + return a.rejected.Load() +} + +// Limit returns the configured capacity the controller was constructed with. +func (a *SemaphoreAdmission) Limit() uint64 { + return a.limit +} + +// AdmissionMiddleware returns a middleware that gates requests through +// the supplied controller. When ac is nil the middleware is a no-op +// pass-through, so callers can wire it unconditionally and disable +// admission via configuration. +// +// On reject: 503 Service Unavailable, Retry-After: 1-3 (jittered), +// body "service at capacity (request_id=)\n", no per-request log +// (logging every rejection would amplify a flood). On admit: defers +// [SemaphoreAdmission.Release] so panics in downstream handlers still +// free the permit. +func AdmissionMiddleware(ac *SemaphoreAdmission) func(http.Handler) http.Handler { + if ac == nil { + return func(next http.Handler) http.Handler { return next } + } + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if !ac.TryAcquire() { + reqID := RequestIDFromContext(r.Context()) + jitter := retryAfterMin + rand.IntN(retryAfterMax-retryAfterMin+1) //nolint:gosec // G404: jitter, not security + w.Header().Set("Retry-After", strconv.Itoa(jitter)) + w.Header().Set("Content-Type", "text/plain; charset=utf-8") + w.Header().Set("X-Content-Type-Options", "nosniff") + w.WriteHeader(http.StatusServiceUnavailable) + // gosec G705 is a false positive: text/plain with nosniff, + // and reqID is validated upstream by RequestIDMiddleware + // against a safe charset, so no taint reaches a browser. + _, _ = fmt.Fprintf(w, "service at capacity (request_id=%s)\n", reqID) //nolint:gosec // G705 + return + } + defer ac.Release() + next.ServeHTTP(w, r) + }) + } +} diff --git a/pkg/service/http_admission_test.go b/pkg/service/http_admission_test.go new file mode 100644 index 000000000..8e97ef7ca --- /dev/null +++ b/pkg/service/http_admission_test.go @@ -0,0 +1,228 @@ +// (c) Cartesi and individual authors (see AUTHORS) +// SPDX-License-Identifier: Apache-2.0 (see LICENSE) + +package service + +import ( + "bytes" + "errors" + "net/http" + "net/http/httptest" + "strconv" + "sync" + "testing" + + "github.com/stretchr/testify/require" +) + +// ----------------------------------------------------------------------------- +// SemaphoreAdmission +// ----------------------------------------------------------------------------- + +func TestSemaphoreAdmission_BasicAcquireRelease(t *testing.T) { + a := NewSemaphoreAdmission(2) + + require.True(t, a.TryAcquire()) + require.True(t, a.TryAcquire()) + require.False(t, a.TryAcquire(), "third should be rejected at limit 2") + + a.Release() + require.True(t, a.TryAcquire(), "after one release a new permit should be available") +} + +func TestSemaphoreAdmission_RejectedCounter(t *testing.T) { + a := NewSemaphoreAdmission(1) + + require.True(t, a.TryAcquire()) + for range 4 { + require.False(t, a.TryAcquire()) + } + require.Equal(t, uint64(4), a.Rejected()) + + a.Release() + require.True(t, a.TryAcquire()) + require.Equal(t, uint64(4), a.Rejected(), "successful acquire must not bump counter") +} + +func TestSemaphoreAdmission_ZeroLimit(t *testing.T) { + // A zero-capacity semaphore naturally rejects every TryAcquire. + a := NewSemaphoreAdmission(0) + + for range 3 { + require.False(t, a.TryAcquire()) + } + require.Equal(t, uint64(3), a.Rejected()) +} + +// ----------------------------------------------------------------------------- +// AdmissionMiddleware +// ----------------------------------------------------------------------------- + +func TestAdmissionMiddleware_NilIsPassthrough(t *testing.T) { + handler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte("ok")) + }) + mw := AdmissionMiddleware(nil) + + rr := httptest.NewRecorder() + mw(handler).ServeHTTP(rr, httptest.NewRequest(http.MethodGet, "/", nil)) + + require.Equal(t, http.StatusOK, rr.Code) + require.Equal(t, "ok", rr.Body.String()) +} + +func TestAdmissionMiddleware_AdmitsBelowLimit(t *testing.T) { + ac := NewSemaphoreAdmission(2) + handler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + }) + mw := AdmissionMiddleware(ac) + + for range 5 { + rr := httptest.NewRecorder() + mw(handler).ServeHTTP(rr, httptest.NewRequest(http.MethodGet, "/", nil)) + require.Equal(t, http.StatusOK, rr.Code) + } + require.Equal(t, uint64(0), ac.Rejected()) +} + +func TestAdmissionMiddleware_RejectsAtLimit(t *testing.T) { + ac := NewSemaphoreAdmission(1) + + // A handler that signals when it has been entered (i.e. the permit + // has been acquired by AdmissionMiddleware) and then blocks until + // we release it. This gives the test a deterministic point to fire + // the second request without any time-based polling. + entered := make(chan struct{}) + block := make(chan struct{}) + done := make(chan struct{}) + handler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + close(entered) + <-block + w.WriteHeader(http.StatusOK) + }) + + wrapped := AdmissionMiddleware(ac)(handler) + + // Fire the first request in a goroutine; it will hold the permit + // until we close(block). + go func() { + rr := httptest.NewRecorder() + wrapped.ServeHTTP(rr, httptest.NewRequest(http.MethodGet, "/", nil)) + close(done) + }() + + <-entered // first request has the permit + + // Second request — must be rejected immediately. + rr := httptest.NewRecorder() + wrapped.ServeHTTP(rr, httptest.NewRequest(http.MethodGet, "/", nil)) + + require.Equal(t, http.StatusServiceUnavailable, rr.Code) + retryAfter, err := strconv.Atoi(rr.Header().Get("Retry-After")) + require.NoError(t, err) + require.GreaterOrEqual(t, retryAfter, 1) + require.LessOrEqual(t, retryAfter, 3) + require.Contains(t, rr.Body.String(), "service at capacity") + require.Equal(t, "nosniff", rr.Header().Get("X-Content-Type-Options")) + require.Equal(t, uint64(1), ac.Rejected()) + + close(block) + <-done +} + +func TestAdmissionMiddleware_ReleasesOnSuccess(t *testing.T) { + ac := NewSemaphoreAdmission(1) + handler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + }) + mw := AdmissionMiddleware(ac) + + for range 3 { + rr := httptest.NewRecorder() + mw(handler).ServeHTTP(rr, httptest.NewRequest(http.MethodGet, "/", nil)) + require.Equal(t, http.StatusOK, rr.Code) + } + // The permit should always be back after each sequential request. + require.Equal(t, uint64(0), ac.Rejected()) +} + +// TestAdmissionReleaseOnPanic pins the "permit is released even when a +// downstream handler panics" invariant. This must work under the real +// middleware chain — RecoverMiddleware wraps AdmissionMiddleware and +// catches the panic, but only after AdmissionMiddleware's deferred +// Release has run. +func TestAdmissionReleaseOnPanic(t *testing.T) { + ac := NewSemaphoreAdmission(2) + + panicking := http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) { + panic(errors.New("kaboom")) + }) + + // Chain: Recover(Admission(panicking)). Admission defers its + // Release(), then the handler panics, then the defer runs (releasing + // the permit), then Recover catches the panic and writes a 500. + var buf bytes.Buffer + chain := RecoverMiddleware(captureLogger(&buf))(AdmissionMiddleware(ac)(panicking)) + + req := httptest.NewRequest(http.MethodGet, "/", nil) + rr := httptest.NewRecorder() + chain.ServeHTTP(rr, req) + + require.Equal(t, http.StatusInternalServerError, rr.Code) + + // Both permits should be available: verify by acquiring both in a row. + require.True(t, ac.TryAcquire(), "permit 1 must be available after panic") + require.True(t, ac.TryAcquire(), "permit 2 must be available after panic") + ac.Release() + ac.Release() +} + +func TestAdmissionMiddleware_NoLogSpam(t *testing.T) { + ac := NewSemaphoreAdmission(1) + ac.TryAcquire() // pre-fill the single permit so all subsequent acquires fail + handler := http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) {}) + + // No logger is wired into AdmissionMiddleware itself, so rejection + // must produce no log output regardless of what wraps it. Verify + // that wrapping in RecoverMiddleware with a capture logger doesn't + // pick up any rejection logs. + var buf bytes.Buffer + chain := RecoverMiddleware(captureLogger(&buf))(AdmissionMiddleware(ac)(handler)) + + for range 10 { + rr := httptest.NewRecorder() + chain.ServeHTTP(rr, httptest.NewRequest(http.MethodGet, "/", nil)) + require.Equal(t, http.StatusServiceUnavailable, rr.Code) + } + require.Empty(t, buf.String(), "rejection must not log per-request") +} + +func TestAdmissionMiddleware_ConcurrentStress(t *testing.T) { + ac := NewSemaphoreAdmission(4) + + // Handler that yields briefly so concurrent requests have a chance + // to collide on the semaphore. + handler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + }) + mw := AdmissionMiddleware(ac) + wrapped := mw(handler) + + var wg sync.WaitGroup + for range 100 { + wg.Add(1) + go func() { + defer wg.Done() + rr := httptest.NewRecorder() + wrapped.ServeHTTP(rr, httptest.NewRequest(http.MethodGet, "/", nil)) + }() + } + wg.Wait() + + // After everything settles, all permits must be back. + for range 4 { + require.True(t, ac.TryAcquire()) + } +} diff --git a/pkg/service/http_cors.go b/pkg/service/http_cors.go new file mode 100644 index 000000000..3198640f0 --- /dev/null +++ b/pkg/service/http_cors.go @@ -0,0 +1,207 @@ +// (c) Cartesi and individual authors (see AUTHORS) +// SPDX-License-Identifier: Apache-2.0 (see LICENSE) + +package service + +import ( + "bufio" + "log/slog" + "net" + "net/http" + "strings" +) + +// CORSConfig holds the pre-validated, pre-joined CORS policy for a single +// HTTP surface (JSON-RPC or inspect). When the allowedOrigins map is empty +// the config is disabled and the middleware is a no-op pass-through. +type CORSConfig struct { + allowedOrigins map[string]struct{} // lowercased at construction + methods string // pre-joined for header value + headers string // pre-joined for header value + maxAge string // "3600" +} + +// Enabled reports whether the config has at least one valid allowed origin. +func (c CORSConfig) Enabled() bool { + return len(c.allowedOrigins) > 0 +} + +// ParseCORSConfig splits origins on comma, validates each entry, and returns +// a ready-to-use [CORSConfig]. Invalid entries are logged as warnings and +// skipped. If no valid origins remain, the returned config is disabled +// ([CORSConfig.Enabled] returns false). +// +// Validation rules per origin: +// - Empty/whitespace-only entries are silently skipped. +// - "null" is rejected (sandboxed iframe attack vector). +// - Entries with paths, query strings, or fragments are rejected. +// - A single trailing slash is stripped (bare origin normalization). +// - The entire origin is lowercased per RFC 6454. +func ParseCORSConfig(logger *slog.Logger, origins string, methods []string, headers []string) CORSConfig { + cfg := CORSConfig{ + allowedOrigins: make(map[string]struct{}), + methods: strings.Join(methods, ", "), + headers: strings.Join(headers, ", "), + maxAge: "3600", + } + + for raw := range strings.SplitSeq(origins, ",") { + entry := strings.TrimSpace(raw) + if entry == "" { + continue + } + + if strings.EqualFold(entry, "null") { + logger.Warn("CORS origin \"null\" rejected (sandboxed iframe attack vector)", + "origin", entry, + ) + continue + } + + entry = strings.TrimSuffix(entry, "/") + + if scheme, rest, ok := strings.Cut(entry, "://"); ok { + _ = scheme + if strings.ContainsAny(rest, "/?#") { + logger.Warn("CORS origin rejected: must not contain path, query, or fragment", + "origin", raw, + ) + continue + } + } else if strings.ContainsAny(entry, "/?#") { + logger.Warn("CORS origin rejected: must not contain path, query, or fragment", + "origin", raw, + ) + continue + } + + lower := strings.ToLower(entry) + cfg.allowedOrigins[lower] = struct{}{} + } + + if len(cfg.allowedOrigins) == 0 { + cfg.allowedOrigins = nil + } + return cfg +} + +// corsWriter wraps an [http.ResponseWriter] and injects CORS headers at +// WriteHeader time, after the handler has had a chance to set (or +// overwrite) response headers. This ensures CORS headers are authoritative +// regardless of what downstream handlers do — including setting their own +// Access-Control-Allow-Origin, which the wrapper strips and replaces. +type corsWriter struct { + http.ResponseWriter + origin string + cfg CORSConfig + written bool +} + +func (cw *corsWriter) setCORSHeaders() { + h := cw.Header() + h.Add("Vary", "Origin") + h.Del("Access-Control-Allow-Origin") + h.Set("Access-Control-Allow-Origin", cw.origin) + h.Set("Access-Control-Allow-Methods", cw.cfg.methods) + h.Set("Access-Control-Allow-Headers", cw.cfg.headers) +} + +func (cw *corsWriter) WriteHeader(code int) { + if !cw.written { + cw.written = true + cw.setCORSHeaders() + } + cw.ResponseWriter.WriteHeader(code) +} + +func (cw *corsWriter) Write(b []byte) (int, error) { + if !cw.written { + cw.written = true + cw.setCORSHeaders() + } + return cw.ResponseWriter.Write(b) +} + +func (cw *corsWriter) Unwrap() http.ResponseWriter { return cw.ResponseWriter } + +func (cw *corsWriter) Flush() { + if !cw.written { + cw.written = true + cw.setCORSHeaders() + } + if f, ok := cw.ResponseWriter.(http.Flusher); ok { + f.Flush() + } +} + +func (cw *corsWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { + if h, ok := cw.ResponseWriter.(http.Hijacker); ok { + return h.Hijack() + } + return nil, nil, http.ErrNotSupported +} + +// CORSMiddleware returns a middleware that enforces the given [CORSConfig]. +// When the config is disabled ([CORSConfig.Enabled] returns false), the +// returned middleware is a no-op identity wrapper — the same pattern used +// by [AdmissionMiddleware] when its controller is nil. +// +// Preflight requests (OPTIONS with Access-Control-Request-Method) from +// allowed origins are short-circuited with a 204 before downstream +// handlers (and therefore before admission control) run. +func CORSMiddleware(cfg CORSConfig) func(http.Handler) http.Handler { + if !cfg.Enabled() { + return func(next http.Handler) http.Handler { return next } + } + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + origin := r.Header.Get("Origin") + if origin == "" { + next.ServeHTTP(w, r) + return + } + + // Short-circuit ALL preflights before admission. Allowed + // origins get CORS grant headers; disallowed origins get + // Vary only. Either way, downstream handlers (and their + // admission permits) are never touched. + if r.Method == http.MethodOptions && r.Header.Get("Access-Control-Request-Method") != "" { + w.Header().Add("Vary", "Origin") + w.Header().Add("Vary", "Access-Control-Request-Method") + w.Header().Add("Vary", "Access-Control-Request-Headers") + lower := strings.ToLower(origin) + if _, ok := cfg.allowedOrigins[lower]; ok { + w.Header().Set("Access-Control-Allow-Origin", origin) + w.Header().Set("Access-Control-Allow-Methods", cfg.methods) + w.Header().Set("Access-Control-Allow-Headers", cfg.headers) + w.Header().Set("Access-Control-Max-Age", cfg.maxAge) + } + w.WriteHeader(http.StatusNoContent) + return + } + + lower := strings.ToLower(origin) + if _, ok := cfg.allowedOrigins[lower]; !ok { + // Disallowed origin: set Vary for caching correctness + // but emit no CORS grant headers. + w.Header().Add("Vary", "Origin") + next.ServeHTTP(w, r) + return + } + + // Non-preflight allowed origin: wrap the response writer so + // CORS headers (including Vary: Origin) are injected at + // WriteHeader time, after the handler has had a chance to + // set (and potentially overwrite) headers. + cw := &corsWriter{ResponseWriter: w, origin: origin, cfg: cfg} + next.ServeHTTP(cw, r) + + // If the handler returned without calling WriteHeader or + // Write (e.g., an empty 200), inject CORS headers now so + // they are present when the stdlib flushes. + if !cw.written { + cw.setCORSHeaders() + } + }) + } +} diff --git a/pkg/service/http_cors_test.go b/pkg/service/http_cors_test.go new file mode 100644 index 000000000..d4fbb2cb6 --- /dev/null +++ b/pkg/service/http_cors_test.go @@ -0,0 +1,359 @@ +// (c) Cartesi and individual authors (see AUTHORS) +// SPDX-License-Identifier: Apache-2.0 (see LICENSE) + +package service + +import ( + "bytes" + "io" + "net/http" + "net/http/httptest" + "strconv" + "strings" + "testing" + + "github.com/stretchr/testify/require" +) + +// --------------------------------------------------------------------------- +// ParseCORSConfig — constructor tests +// --------------------------------------------------------------------------- + +func TestParseCORSConfig_EmptyString(t *testing.T) { + cfg := ParseCORSConfig(discardLogger(), "", []string{"POST"}, []string{"Content-Type"}) + require.False(t, cfg.Enabled()) +} + +func TestParseCORSConfig_WhitespaceOnly(t *testing.T) { + cfg := ParseCORSConfig(discardLogger(), " , , ", []string{"POST"}, []string{"Content-Type"}) + require.False(t, cfg.Enabled()) +} + +func TestParseCORSConfig_NullRejected(t *testing.T) { + var buf bytes.Buffer + cfg := ParseCORSConfig(captureLogger(&buf), "null", []string{"POST"}, []string{"Content-Type"}) + require.False(t, cfg.Enabled()) + require.Contains(t, buf.String(), "null") + require.Contains(t, buf.String(), "level=WARN") +} + +func TestParseCORSConfig_NullAmongValid(t *testing.T) { + var buf bytes.Buffer + cfg := ParseCORSConfig(captureLogger(&buf), "http://ok.com,null", []string{"POST"}, []string{"Content-Type"}) + require.True(t, cfg.Enabled()) + _, ok := cfg.allowedOrigins["http://ok.com"] + require.True(t, ok) + _, ok = cfg.allowedOrigins["null"] + require.False(t, ok) +} + +func TestParseCORSConfig_TrailingSlashStripped(t *testing.T) { + cfg := ParseCORSConfig(discardLogger(), "http://example.com/", []string{"POST"}, []string{"Content-Type"}) + require.True(t, cfg.Enabled()) + _, ok := cfg.allowedOrigins["http://example.com"] + require.True(t, ok) +} + +func TestParseCORSConfig_PathRejected(t *testing.T) { + var buf bytes.Buffer + cfg := ParseCORSConfig(captureLogger(&buf), "http://example.com/foo", []string{"POST"}, []string{"Content-Type"}) + require.False(t, cfg.Enabled()) + require.Contains(t, buf.String(), "level=WARN") +} + +func TestParseCORSConfig_QueryRejected(t *testing.T) { + var buf bytes.Buffer + cfg := ParseCORSConfig(captureLogger(&buf), "http://example.com?x=1", []string{"POST"}, []string{"Content-Type"}) + require.False(t, cfg.Enabled()) + require.Contains(t, buf.String(), "level=WARN") +} + +func TestParseCORSConfig_FragmentRejected(t *testing.T) { + var buf bytes.Buffer + cfg := ParseCORSConfig(captureLogger(&buf), "http://example.com#f", []string{"POST"}, []string{"Content-Type"}) + require.False(t, cfg.Enabled()) + require.Contains(t, buf.String(), "level=WARN") +} + +func TestParseCORSConfig_Lowercased(t *testing.T) { + cfg := ParseCORSConfig(discardLogger(), "HTTP://EXAMPLE.COM", []string{"POST"}, []string{"Content-Type"}) + require.True(t, cfg.Enabled()) + _, ok := cfg.allowedOrigins["http://example.com"] + require.True(t, ok) +} + +func TestParseCORSConfig_ValidOrigins(t *testing.T) { + cfg := ParseCORSConfig( + discardLogger(), + "http://localhost:3000, https://app.example.com", + []string{"POST", "OPTIONS"}, + []string{"Content-Type"}, + ) + require.True(t, cfg.Enabled()) + _, ok := cfg.allowedOrigins["http://localhost:3000"] + require.True(t, ok) + _, ok = cfg.allowedOrigins["https://app.example.com"] + require.True(t, ok) +} + +// --------------------------------------------------------------------------- +// CORSMiddleware — runtime tests +// --------------------------------------------------------------------------- + +// echoHandler writes a 200 and records that it was called via the pointed bool. +func echoHandler(called *bool) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + *called = true + w.WriteHeader(http.StatusOK) + }) +} + +func enabledCORSConfig() CORSConfig { + return ParseCORSConfig( + discardLogger(), + "http://example.com, https://other.example.com", + []string{"POST", "OPTIONS"}, + []string{"Content-Type"}, + ) +} + +func TestCORSMiddleware_Disabled(t *testing.T) { + cfg := ParseCORSConfig(discardLogger(), "", []string{"POST"}, []string{"Content-Type"}) + var called bool + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.Header.Set("Origin", "http://example.com") + rr := runMiddleware(t, CORSMiddleware(cfg), echoHandler(&called), req) + + require.True(t, called) + require.Empty(t, rr.Header().Get("Access-Control-Allow-Origin")) + require.Empty(t, rr.Header().Get("Vary")) +} + +func TestCORSMiddleware_NoOriginHeader(t *testing.T) { + cfg := enabledCORSConfig() + var called bool + req := httptest.NewRequest(http.MethodGet, "/", nil) + rr := runMiddleware(t, CORSMiddleware(cfg), echoHandler(&called), req) + + require.True(t, called) + require.Empty(t, rr.Header().Get("Access-Control-Allow-Origin")) + require.Empty(t, rr.Header().Get("Vary")) +} + +func TestCORSMiddleware_AllowedOrigin(t *testing.T) { + cfg := enabledCORSConfig() + var called bool + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.Header.Set("Origin", "http://example.com") + rr := runMiddleware(t, CORSMiddleware(cfg), echoHandler(&called), req) + + require.True(t, called) + require.Equal(t, "http://example.com", rr.Header().Get("Access-Control-Allow-Origin")) + require.Contains(t, rr.Header().Values("Vary"), "Origin") + require.Equal(t, "POST, OPTIONS", rr.Header().Get("Access-Control-Allow-Methods")) + require.Equal(t, "Content-Type", rr.Header().Get("Access-Control-Allow-Headers")) +} + +func TestCORSMiddleware_DisallowedOrigin(t *testing.T) { + cfg := enabledCORSConfig() + var called bool + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.Header.Set("Origin", "http://evil.com") + rr := runMiddleware(t, CORSMiddleware(cfg), echoHandler(&called), req) + + require.True(t, called) + require.Empty(t, rr.Header().Get("Access-Control-Allow-Origin")) + require.Contains(t, rr.Header().Values("Vary"), "Origin") +} + +func TestCORSMiddleware_CaseInsensitive(t *testing.T) { + cfg := enabledCORSConfig() + var called bool + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.Header.Set("Origin", "HTTP://EXAMPLE.COM") + rr := runMiddleware(t, CORSMiddleware(cfg), echoHandler(&called), req) + + require.True(t, called) + require.Equal(t, "HTTP://EXAMPLE.COM", rr.Header().Get("Access-Control-Allow-Origin")) +} + +func TestCORSMiddleware_PreflightAllowed(t *testing.T) { + cfg := enabledCORSConfig() + var called bool + req := httptest.NewRequest(http.MethodOptions, "/", nil) + req.Header.Set("Origin", "http://example.com") + req.Header.Set("Access-Control-Request-Method", "POST") + rr := runMiddleware(t, CORSMiddleware(cfg), echoHandler(&called), req) + + require.False(t, called) + require.Equal(t, http.StatusNoContent, rr.Code) + require.Equal(t, "http://example.com", rr.Header().Get("Access-Control-Allow-Origin")) + require.Equal(t, "POST, OPTIONS", rr.Header().Get("Access-Control-Allow-Methods")) + require.Equal(t, "Content-Type", rr.Header().Get("Access-Control-Allow-Headers")) + require.Equal(t, "3600", rr.Header().Get("Access-Control-Max-Age")) + vary := rr.Header().Values("Vary") + require.Contains(t, vary, "Origin") + require.Contains(t, vary, "Access-Control-Request-Method") + require.Contains(t, vary, "Access-Control-Request-Headers") +} + +func TestCORSMiddleware_PreflightDisallowed(t *testing.T) { + cfg := enabledCORSConfig() + var called bool + req := httptest.NewRequest(http.MethodOptions, "/", nil) + req.Header.Set("Origin", "http://evil.com") + req.Header.Set("Access-Control-Request-Method", "POST") + rr := runMiddleware(t, CORSMiddleware(cfg), echoHandler(&called), req) + + require.False(t, called, "disallowed preflight must not reach downstream handler") + require.Equal(t, http.StatusNoContent, rr.Code) + require.Empty(t, rr.Header().Get("Access-Control-Allow-Origin")) + require.Contains(t, rr.Header().Values("Vary"), "Origin") + require.Empty(t, rr.Header().Get("Access-Control-Max-Age")) +} + +func TestCORSMiddleware_OptionsWithoutRequestMethod(t *testing.T) { + cfg := enabledCORSConfig() + var called bool + req := httptest.NewRequest(http.MethodOptions, "/", nil) + req.Header.Set("Origin", "http://example.com") + rr := runMiddleware(t, CORSMiddleware(cfg), echoHandler(&called), req) + + require.True(t, called) + require.Equal(t, "http://example.com", rr.Header().Get("Access-Control-Allow-Origin")) + require.Empty(t, rr.Header().Get("Access-Control-Max-Age")) +} + +func TestCORSMiddleware_CredentialsNeverSet(t *testing.T) { + cfg := enabledCORSConfig() + var called bool + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.Header.Set("Origin", "http://example.com") + rr := runMiddleware(t, CORSMiddleware(cfg), echoHandler(&called), req) + + require.True(t, called) + require.Empty(t, rr.Header().Get("Access-Control-Allow-Credentials")) +} + +func TestCORSMiddleware_UpstreamHeadersStripped(t *testing.T) { + cfg := enabledCORSConfig() + upstream := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Access-Control-Allow-Origin", "upstream") + w.WriteHeader(http.StatusOK) + }) + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.Header.Set("Origin", "http://example.com") + rr := runMiddleware(t, CORSMiddleware(cfg), upstream, req) + + require.Equal(t, "http://example.com", rr.Header().Get("Access-Control-Allow-Origin")) +} + +func TestCORSMiddleware_VaryUsesAdd(t *testing.T) { + cfg := enabledCORSConfig() + upstream := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Vary", "Accept-Encoding") + w.WriteHeader(http.StatusOK) + }) + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.Header.Set("Origin", "http://example.com") + rr := runMiddleware(t, CORSMiddleware(cfg), upstream, req) + + vary := rr.Header().Values("Vary") + require.Contains(t, vary, "Origin") + require.Contains(t, vary, "Accept-Encoding") +} + +func TestCORSMiddleware_ErrorResponseGetsCORSHeaders(t *testing.T) { + cfg := enabledCORSConfig() + errorHandler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusInternalServerError) + }) + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.Header.Set("Origin", "http://example.com") + rr := runMiddleware(t, CORSMiddleware(cfg), errorHandler, req) + + require.Equal(t, http.StatusInternalServerError, rr.Code) + require.Equal(t, "http://example.com", rr.Header().Get("Access-Control-Allow-Origin")) + require.Contains(t, rr.Header().Values("Vary"), "Origin") +} + +func TestCorsWriter_Unwrap(t *testing.T) { + inner := httptest.NewRecorder() + cw := &corsWriter{ResponseWriter: inner, origin: "http://example.com", cfg: enabledCORSConfig()} + + require.Same(t, http.ResponseWriter(inner), cw.Unwrap()) +} + +func TestCORSMiddleware_MaxBytesReaderUnwrapChain(t *testing.T) { + cfg := enabledCORSConfig() + + // Handler uses MaxBytesReader to enforce a 16-byte body limit. + // MaxBytesReader walks the Unwrap() chain to reach the real + // *http.response so it can force-close the connection after 413. + // Without Unwrap on corsWriter, the walk would stop at the + // embedded ResponseWriter interface and the close would silently + // fail. + maxBodyHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + r.Body = http.MaxBytesReader(w, r.Body, 16) + _, err := io.ReadAll(r.Body) + if err != nil { + http.Error(w, "request too large", http.StatusRequestEntityTooLarge) + return + } + w.WriteHeader(http.StatusOK) + }) + + oversizedBody := strings.Repeat("X", 1024) + req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(oversizedBody)) + req.Header.Set("Origin", "http://example.com") + rr := runMiddleware(t, CORSMiddleware(cfg), maxBodyHandler, req) + + require.Equal(t, http.StatusRequestEntityTooLarge, rr.Code) + require.Equal(t, "http://example.com", rr.Header().Get("Access-Control-Allow-Origin"), + "CORS headers must be present on 413 responses when corsWriter wraps the response writer") + require.Contains(t, rr.Header().Values("Vary"), "Origin") +} + +// TestCORSMiddleware_Admission503GetsCORSHeaders verifies that when admission +// control rejects a request with 503, the response still carries CORS headers. +// The middleware chain is CORS -> Admission -> handler. The corsWriter wraps +// the response writer before admission runs, so WriteHeader(503) triggers CORS +// header injection. A middleware reorder would silently break this invariant. +func TestCORSMiddleware_Admission503GetsCORSHeaders(t *testing.T) { + cfg := enabledCORSConfig() + + // Create an admission controller with limit 1 and immediately fill it + // so every subsequent request is rejected with 503. + admission := NewSemaphoreAdmission(1) + acquired := admission.TryAcquire() + require.True(t, acquired, "pre-fill must succeed on a fresh semaphore") + + // The inner handler must never be reached — admission rejects first. + var handlerCalled bool + handler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + handlerCalled = true + w.WriteHeader(http.StatusOK) + }) + + // Build the middleware chain: CORS(Admission(handler)). + chain := CORSMiddleware(cfg)(AdmissionMiddleware(admission)(handler)) + + req := httptest.NewRequest(http.MethodPost, "/", nil) + req.Header.Set("Origin", "http://example.com") + rr := httptest.NewRecorder() + chain.ServeHTTP(rr, req) + + require.False(t, handlerCalled, "handler must not be called when admission rejects") + require.Equal(t, http.StatusServiceUnavailable, rr.Code) + require.Equal(t, "http://example.com", rr.Header().Get("Access-Control-Allow-Origin"), + "CORS headers must be present on admission 503 responses") + require.Contains(t, rr.Header().Values("Vary"), "Origin") + require.Contains(t, rr.Body.String(), "service at capacity") + + retryAfter := rr.Header().Get("Retry-After") + require.NotEmpty(t, retryAfter, "Retry-After header must be present on 503") + retryVal, err := strconv.Atoi(retryAfter) + require.NoError(t, err, "Retry-After must be a valid integer") + require.GreaterOrEqual(t, retryVal, 1) + require.LessOrEqual(t, retryVal, 3) +} diff --git a/pkg/service/http_middleware.go b/pkg/service/http_middleware.go new file mode 100644 index 000000000..93a6ef6e2 --- /dev/null +++ b/pkg/service/http_middleware.go @@ -0,0 +1,198 @@ +// (c) Cartesi and individual authors (see AUTHORS) +// SPDX-License-Identifier: Apache-2.0 (see LICENSE) + +package service + +import ( + "bufio" + "context" + "fmt" + "log/slog" + "net" + "net/http" + "regexp" + "runtime/debug" + + "github.com/google/uuid" +) + +const ( + requestIDHeader = "X-Request-ID" +) + +// requestIDPattern defines the accepted charset and length for +// upstream-supplied request ids. Anything outside this set is treated as +// untrusted and a fresh UUID is generated instead. +// +// The charset is deliberately chosen to cover the ID formats emitted by +// common reverse proxies and tracing systems while remaining safe to log +// and echo in response headers: +// +// - envoy uses `.` and `:` +// - AWS ALB and X-Ray use `-` and `=` +// - GCP Cloud Trace uses `/` +// - `+` appears in some base64-style correlation IDs +// +// Explicitly excluded: `\r`, `\n`, space, `<`, `>`, `"`, `'`, backtick, +// and all other control characters — these would enable log-injection or +// header-splitting if echoed back verbatim. The `{1,128}` quantifier is +// the single source of truth for the length bound. +var requestIDPattern = regexp.MustCompile(`^[A-Za-z0-9._:=/+-]{1,128}$`) + +// RequestIDMiddleware reads the X-Request-ID request header and trusts it +// only when it matches ^[A-Za-z0-9._:=/+-]{1,128}$. Otherwise a fresh +// UUIDv4 is generated. The chosen id is placed on the request context +// under ctxKeyRequestID{} and echoed on the response as X-Request-ID. +func RequestIDMiddleware(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + id := r.Header.Get(requestIDHeader) + if !requestIDPattern.MatchString(id) { + id = uuid.NewString() + } + w.Header().Set(requestIDHeader, id) + ctx := context.WithValue(r.Context(), ctxKeyRequestID{}, id) + next.ServeHTTP(w, r.WithContext(ctx)) + }) +} + +// responseWriterTap wraps an http.ResponseWriter and records whether any +// header or body bytes have been sent. It implements Unwrap so that helpers +// such as http.MaxBytesReader and http.ResponseController can walk the +// wrapper chain to reach the underlying *response and access its internal +// interfaces (for example, to force a connection close after 413). +// +// Without Unwrap, oversized POST bodies still return 413 but the underlying +// connection is not force-closed and clients may see protocol confusion on a +// subsequent pipelined request. Do not remove Unwrap. +type responseWriterTap struct { + http.ResponseWriter + wroteHeader bool +} + +func (t *responseWriterTap) WriteHeader(code int) { + t.wroteHeader = true + t.ResponseWriter.WriteHeader(code) +} + +func (t *responseWriterTap) Write(b []byte) (int, error) { + t.wroteHeader = true + return t.ResponseWriter.Write(b) +} + +func (t *responseWriterTap) Unwrap() http.ResponseWriter { + return t.ResponseWriter +} + +// Flush forwards to the wrapped writer's Flusher implementation if it +// supports streaming. Because responseWriterTap embeds the +// http.ResponseWriter interface (not a concrete type), Go's method +// promotion only surfaces Header/Write/WriteHeader — optional interfaces +// like http.Flusher are not promoted and a direct +// w.(http.Flusher).Flush() on the tap would fail the type assertion even +// when the underlying writer supports it. Forward explicitly so handlers +// can stream through RecoverMiddleware. +// +// http.NewResponseController(w).Flush() also works for the same purpose +// because it walks Unwrap() — use that path when you can't rely on a +// type assertion. +func (t *responseWriterTap) Flush() { + if f, ok := t.ResponseWriter.(http.Flusher); ok { + f.Flush() + } +} + +// Hijack forwards to the wrapped writer's Hijacker implementation if any. +// Required for any future websocket or connection-takeover use case; +// today no inspect or jsonrpc handler hijacks, but having the forwarding +// in place avoids a silent failure the moment one is added. Returns +// http.ErrNotSupported when the underlying writer does not implement +// http.Hijacker (for example, httptest.ResponseRecorder). +func (t *responseWriterTap) Hijack() (net.Conn, *bufio.ReadWriter, error) { + if h, ok := t.ResponseWriter.(http.Hijacker); ok { + return h.Hijack() + } + return nil, nil, http.ErrNotSupported +} + +// RecoverMiddleware catches panics from downstream handlers. If nothing has +// been written yet it emits a generic 500 via WriteInternalError. If bytes +// are already on the wire it re-panics with http.ErrAbortHandler to drop the +// connection silently — stitching a 500 onto a partial 200 would produce a +// corrupt response and a "superfluous WriteHeader" warning. +// +// Panics whose value is http.ErrAbortHandler are re-panicked unchanged so +// that the stdlib server preserves its special (non-logged) semantics. +// +// When RecoverMiddleware wraps RequestIDMiddleware (Recover outermost), the +// deferred recovery cannot read the request id from r.Context() because r is +// the pre-wrap request from this closure — downstream middleware creates a +// new *http.Request via WithContext and passes it to its next.ServeHTTP, but +// the outer r is untouched. The request id is instead read from the tap's +// response header, which RequestIDMiddleware populates on the shared +// ResponseWriter before calling next, so it is always set by the time a +// downstream panic reaches the defer. +func RecoverMiddleware(logger *slog.Logger) func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + tap := &responseWriterTap{ResponseWriter: w} + defer func() { + rec := recover() + if rec == nil { + return + } + if rec == http.ErrAbortHandler { + panic(rec) + } + reqID := tap.Header().Get(requestIDHeader) + logger.Error("http handler panic", + "panic", rec, + "stack", string(debug.Stack()), + "request_id", reqID, + ) + if !tap.wroteHeader { + ctx := context.WithValue(r.Context(), ctxKeyRequestID{}, reqID) + WriteInternalError(ctx, tap, logger, fmt.Errorf("panic in handler: %v", rec)) + return + } + panic(http.ErrAbortHandler) + }() + next.ServeHTTP(tap, r) + }) + } +} + +// HandlerOptions bundles the middleware dependencies for +// [NewServiceHandler]. Logger must be non-nil. Admission may be nil to +// disable admission control, and CORS may be the zero value +// ([CORSConfig]{}) to disable CORS. +type HandlerOptions struct { + Logger *slog.Logger + Admission *SemaphoreAdmission + CORS CORSConfig +} + +// NewServiceHandler wraps h with the node's canonical HTTP hardening +// chain, in order: +// +// RecoverMiddleware -> RequestIDMiddleware -> CORSMiddleware -> AdmissionMiddleware -> h +// +// RecoverMiddleware is outermost so it also catches panics raised inside +// RequestIDMiddleware itself (e.g. an entropy-source failure inside +// uuid.NewString). Without this ordering such a panic would escape to +// net/http's default goroutine recover, dropping the connection with no +// structured log and no 500 response. +// +// CORSMiddleware sits outside AdmissionMiddleware so preflight OPTIONS +// requests and requests from disallowed origins never consume an +// admission permit. +// +// Centralizing the chain here makes the order a property of the helper, +// not of every call site, so every HTTP surface in the node has the same +// posture. +func NewServiceHandler(h http.Handler, opts HandlerOptions) http.Handler { + h = AdmissionMiddleware(opts.Admission)(h) + h = CORSMiddleware(opts.CORS)(h) + h = RequestIDMiddleware(h) + h = RecoverMiddleware(opts.Logger)(h) + return h +} diff --git a/pkg/service/http_middleware_test.go b/pkg/service/http_middleware_test.go new file mode 100644 index 000000000..03cc107b2 --- /dev/null +++ b/pkg/service/http_middleware_test.go @@ -0,0 +1,481 @@ +// (c) Cartesi and individual authors (see AUTHORS) +// SPDX-License-Identifier: Apache-2.0 (see LICENSE) + +package service + +import ( + "bytes" + "errors" + "io" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/google/uuid" + "github.com/stretchr/testify/require" +) + +// runMiddleware wires a middleware in front of a handler and records the +// response via httptest.NewRecorder. +func runMiddleware(t *testing.T, mw func(http.Handler) http.Handler, h http.Handler, req *http.Request) *httptest.ResponseRecorder { + t.Helper() + rr := httptest.NewRecorder() + mw(h).ServeHTTP(rr, req) + return rr +} + +// ----------------------------------------------------------------------------- +// RequestIDMiddleware +// ----------------------------------------------------------------------------- + +func TestRequestIDMiddleware_GeneratesWhenMissing(t *testing.T) { + var capturedID string + h := http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) { + capturedID = RequestIDFromContext(r.Context()) + }) + + req := httptest.NewRequest(http.MethodGet, "/", nil) + rr := runMiddleware(t, RequestIDMiddleware, h, req) + + require.NotEmpty(t, capturedID) + require.Equal(t, capturedID, rr.Header().Get("X-Request-ID")) + + _, err := uuid.Parse(capturedID) + require.NoError(t, err, "generated id should parse as UUID") +} + +func TestRequestIDMiddleware_AcceptsValid(t *testing.T) { + // Pin the full accepted charset. Each entry represents a real-world + // upstream format we must preserve end-to-end for correlation: + // - "abc_123-xyz" — legacy underscore/hyphen + // - "abc.def.123" — envoy-style dotted id + // - "a:b:c" — envoy host:port:id + // - "trace=1-2-3" — AWS X-Ray style with '=' + // - "projects/foo/traces/bar" — GCP Cloud Trace path + // - "trace+span" — base64-ish '+' + cases := []string{ + "abc_123-xyz", + "abc.def.123", + "a:b:c", + "trace=1-2-3", + "projects/foo/traces/bar", + "trace+span", + } + for _, valid := range cases { + t.Run(valid, func(t *testing.T) { + var capturedID string + h := http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) { + capturedID = RequestIDFromContext(r.Context()) + }) + + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.Header.Set("X-Request-ID", valid) + rr := runMiddleware(t, RequestIDMiddleware, h, req) + + require.Equal(t, valid, capturedID) + require.Equal(t, valid, rr.Header().Get("X-Request-ID")) + }) + } +} + +func TestRequestIDMiddleware_RejectsTooLong(t *testing.T) { + long := strings.Repeat("a", 129) + var capturedID string + h := http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) { + capturedID = RequestIDFromContext(r.Context()) + }) + + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.Header.Set("X-Request-ID", long) + runMiddleware(t, RequestIDMiddleware, h, req) + + require.NotEqual(t, long, capturedID) + _, err := uuid.Parse(capturedID) + require.NoError(t, err, "regenerated id should parse as UUID") +} + +func TestRequestIDMiddleware_RejectsBadChars(t *testing.T) { + // Each case must be regenerated as a fresh UUID. Keep the charset + // exclusion list tight: anything that could enable log-injection, + // header-splitting, or HTML/JS smuggling when echoed back in logs or + // on the X-Request-ID response header. + cases := map[string]string{ + "semicolon": "foo;bar", + "space": "id with space", + "newline": "id\nnewline", + "carriage": "id\rcr", + "tab": "foo\tbar", + "angle": "