From 20213d2373d37d5606edcafe73a88a694a2ba105 Mon Sep 17 00:00:00 2001 From: Drew Newberry Date: Tue, 12 May 2026 18:56:27 -0700 Subject: [PATCH] test(server): cover service endpoint plaintext security --- crates/openshell-server/src/lib.rs | 288 ++++++++++++++++++++++++++++- 1 file changed, 284 insertions(+), 4 deletions(-) diff --git a/crates/openshell-server/src/lib.rs b/crates/openshell-server/src/lib.rs index 93ccdc9dc..3012e42d2 100644 --- a/crates/openshell-server/src/lib.rs +++ b/crates/openshell-server/src/lib.rs @@ -725,12 +725,169 @@ fn configured_compute_driver(config: &Config) -> Result { #[cfg(test)] mod tests { use super::{ - ConnectionProtocol, allow_plaintext_service_http, classify_initial_bytes, - configured_compute_driver, gateway_listener_addresses, is_benign_tls_handshake_failure, + ConnectionProtocol, MultiplexService, ServerState, TlsAcceptor, + allow_plaintext_service_http, classify_initial_bytes, configured_compute_driver, + gateway_listener_addresses, is_benign_tls_handshake_failure, serve_gateway_listener, }; - use openshell_core::{ComputeDriverKind, Config}; - use std::io::{Error, ErrorKind}; + use openshell_core::{ + ComputeDriverKind, Config, + proto::{HealthRequest, open_shell_client::OpenShellClient}, + }; + use rcgen::{CertificateParams, IsCa, KeyPair}; + use std::io::{Error, ErrorKind, Write}; use std::net::SocketAddr; + use std::sync::Arc; + use std::time::Duration; + use tempfile::{TempDir, tempdir}; + use tokio::io::{AsyncReadExt, AsyncWriteExt}; + use tokio::net::{TcpListener, TcpStream}; + use tokio::sync::watch; + + fn install_rustls_provider() { + let _ = rustls::crypto::ring::default_provider().install_default(); + } + + fn test_tls_acceptor() -> (TempDir, TlsAcceptor) { + install_rustls_provider(); + + let mut ca_params = + CertificateParams::new(Vec::::new()).expect("failed to create CA params"); + ca_params.is_ca = IsCa::Ca(rcgen::BasicConstraints::Unconstrained); + ca_params + .distinguished_name + .push(rcgen::DnType::CommonName, "test-ca"); + let ca_key = KeyPair::generate().expect("failed to generate CA key"); + let ca_cert = ca_params + .self_signed(&ca_key) + .expect("failed to sign CA cert"); + + let server_params = CertificateParams::new(vec!["localhost".to_string()]) + .expect("failed to create server params"); + let server_key = KeyPair::generate().expect("failed to generate server key"); + let server_cert = server_params + .signed_by(&server_key, &ca_cert, &ca_key) + .expect("failed to sign server cert"); + + let dir = tempdir().expect("failed to create tempdir"); + let write_file = |name: &str, data: &[u8]| { + let path = dir.path().join(name); + std::fs::File::create(&path) + .and_then(|mut file| file.write_all(data)) + .expect("failed to write tls test file"); + }; + write_file("ca.pem", ca_cert.pem().as_bytes()); + write_file("server-cert.pem", server_cert.pem().as_bytes()); + write_file("server-key.pem", server_key.serialize_pem().as_bytes()); + + let acceptor = TlsAcceptor::from_files( + &dir.path().join("server-cert.pem"), + &dir.path().join("server-key.pem"), + &dir.path().join("ca.pem"), + false, + ) + .expect("failed to build tls acceptor"); + + (dir, acceptor) + } + + async fn test_state( + bind_addr: SocketAddr, + enable_loopback_service_http: bool, + ) -> Arc { + let store = Arc::new( + crate::persistence::Store::connect("sqlite::memory:?cache=shared") + .await + .expect("failed to create test store"), + ); + let compute = crate::compute::new_test_runtime(store.clone()).await; + Arc::new(ServerState::new( + Config::new(None) + .with_database_url("sqlite::memory:?cache=shared") + .with_bind_address(bind_addr) + .with_server_sans(["*.dev.openshell.localhost"]) + .with_loopback_service_http(enable_loopback_service_http), + store, + compute, + crate::sandbox_index::SandboxIndex::new(), + crate::sandbox_watch::SandboxWatchBus::new(), + crate::tracing_bus::TracingLogBus::new(), + Arc::new(crate::supervisor_session::SupervisorSessionRegistry::new()), + None, + )) + } + + async fn start_tls_gateway_listener( + bind_addr: &str, + enable_loopback_service_http: bool, + ) -> ( + SocketAddr, + watch::Sender, + tokio::task::JoinHandle<()>, + TempDir, + ) { + let listener = TcpListener::bind(bind_addr) + .await + .expect("failed to bind test listener"); + let listen_addr = listener.local_addr().expect("failed to read local addr"); + let state = test_state(listen_addr, enable_loopback_service_http).await; + let service = MultiplexService::new(state); + let (tls_dir, tls_acceptor) = test_tls_acceptor(); + let (shutdown_tx, shutdown_rx) = watch::channel(false); + let handle = tokio::spawn(serve_gateway_listener( + listener, + listen_addr, + service, + Some(tls_acceptor), + enable_loopback_service_http, + shutdown_rx, + )); + (listen_addr, shutdown_tx, handle, tls_dir) + } + + async fn send_plain_http(addr: SocketAddr, request: String) -> String { + let connect_addr: SocketAddr = format!("127.0.0.1:{}", addr.port()) + .parse() + .expect("failed to build loopback connect addr"); + let mut stream = TcpStream::connect(connect_addr) + .await + .expect("failed to connect to test listener"); + stream + .write_all(request.as_bytes()) + .await + .expect("failed to write request"); + + let mut response = Vec::new(); + let read_result = + tokio::time::timeout(Duration::from_secs(2), stream.read_to_end(&mut response)) + .await + .expect("timed out reading response"); + if let Err(err) = read_result + && err.kind() != ErrorKind::ConnectionReset + { + panic!("failed to read response: {err}"); + } + String::from_utf8_lossy(&response).into_owned() + } + + fn service_request(addr: SocketAddr, extra_headers: &[(&str, &str)]) -> String { + let mut request = format!( + "GET / HTTP/1.1\r\nHost: my-sandbox--web.dev.openshell.localhost:{}\r\nConnection: close\r\n", + addr.port() + ); + for (name, value) in extra_headers { + request.push_str(name); + request.push_str(": "); + request.push_str(value); + request.push_str("\r\n"); + } + request.push_str("\r\n"); + request + } + + async fn stop_listener(shutdown: watch::Sender, handle: tokio::task::JoinHandle<()>) { + let _ = shutdown.send(true); + let _ = tokio::time::timeout(Duration::from_secs(2), handle).await; + } #[test] fn classifies_probe_style_tls_disconnects_as_benign() { @@ -782,6 +939,129 @@ mod tests { assert!(!allow_plaintext_service_http(true, loopback, remote_peer)); } + #[tokio::test] + async fn plaintext_service_http_listener_rejects_non_loopback_bind() { + let (addr, shutdown, handle, _tls_dir) = + start_tls_gateway_listener("0.0.0.0:0", true).await; + + let response = send_plain_http(addr, service_request(addr, &[])).await; + + assert!( + response.is_empty(), + "non-loopback gateway listener should drop plaintext service HTTP, got: {response:?}" + ); + stop_listener(shutdown, handle).await; + } + + #[tokio::test] + async fn plaintext_service_http_rejects_cross_origin_browser_contexts() { + let (addr, shutdown, handle, _tls_dir) = + start_tls_gateway_listener("127.0.0.1:0", true).await; + let cases = [ + ( + "cross-site fetch metadata", + vec![("Sec-Fetch-Site", "cross-site")], + ), + ( + "same-site sibling fetch metadata", + vec![("Sec-Fetch-Site", "same-site")], + ), + ( + "mismatched origin", + vec![( + "Origin", + "http://other-sandbox--web.dev.openshell.localhost:8080", + )], + ), + ( + "mismatched referer", + vec![( + "Referer", + "http://other-sandbox--web.dev.openshell.localhost:8080/page", + )], + ), + ]; + + for (name, headers) in cases { + let response = send_plain_http(addr, service_request(addr, &headers)).await; + + assert!( + response.starts_with("HTTP/1.1 403 Forbidden"), + "{name} should be rejected before service lookup, got: {response:?}" + ); + assert!( + response.contains("Cross-origin service request rejected"), + "{name} should explain the service rejection, got: {response:?}" + ); + } + stop_listener(shutdown, handle).await; + } + + #[tokio::test] + async fn plaintext_service_http_allows_same_origin_browser_context_to_reach_service_lookup() { + let (addr, shutdown, handle, _tls_dir) = + start_tls_gateway_listener("127.0.0.1:0", true).await; + let origin = format!( + "http://my-sandbox--web.dev.openshell.localhost:{}", + addr.port() + ); + let response = send_plain_http( + addr, + service_request( + addr, + &[("Sec-Fetch-Site", "same-origin"), ("Origin", &origin)], + ), + ) + .await; + + assert!( + response.starts_with("HTTP/1.1 404 Not Found"), + "same-origin browser context should pass CSRF guard and miss only because no endpoint exists, got: {response:?}" + ); + assert!( + !response.contains("Cross-origin service request rejected"), + "same-origin browser context should not be rejected as cross-origin, got: {response:?}" + ); + stop_listener(shutdown, handle).await; + } + + #[tokio::test] + async fn plaintext_service_http_does_not_expose_grpc_gateway() { + let (addr, shutdown, handle, _tls_dir) = + start_tls_gateway_listener("127.0.0.1:0", true).await; + let grpc_endpoint = format!("http://127.0.0.1:{}", addr.port()); + let grpc_succeeded = tokio::time::timeout(Duration::from_secs(2), async { + match OpenShellClient::connect(grpc_endpoint).await { + Ok(mut client) => client.health(HealthRequest {}).await.is_ok(), + Err(_) => false, + } + }) + .await + .expect("timed out checking plaintext gRPC exposure"); + + assert!( + !grpc_succeeded, + "plaintext service HTTP must not expose successful gateway gRPC" + ); + + let request = format!( + "POST /openshell.v1.OpenShell/Health HTTP/1.1\r\nHost: 127.0.0.1:{}\r\nContent-Type: application/grpc\r\nTE: trailers\r\nContent-Length: 0\r\nConnection: close\r\n\r\n", + addr.port() + ); + + let response = send_plain_http(addr, request).await; + + assert!( + response.starts_with("HTTP/1.1 404 Not Found"), + "plaintext service HTTP router should not serve gateway gRPC, got: {response:?}" + ); + assert!( + !response.contains("grpc-status: 0"), + "plaintext service HTTP must not return a successful gRPC response: {response:?}" + ); + stop_listener(shutdown, handle).await; + } + #[test] fn configured_compute_driver_triggers_auto_detection_when_empty() { let config = Config::new(None).with_compute_drivers([]);