Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
288 changes: 284 additions & 4 deletions crates/openshell-server/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -725,12 +725,169 @@ fn configured_compute_driver(config: &Config) -> Result<ComputeDriverKind> {
#[cfg(test)]
mod tests {
use super::{
ConnectionProtocol, allow_plaintext_service_http, classify_initial_bytes,
configured_compute_driver, gateway_listener_addresses, is_benign_tls_handshake_failure,
ConnectionProtocol, MultiplexService, ServerState, TlsAcceptor,
allow_plaintext_service_http, classify_initial_bytes, configured_compute_driver,
gateway_listener_addresses, is_benign_tls_handshake_failure, serve_gateway_listener,
};
use openshell_core::{ComputeDriverKind, Config};
use std::io::{Error, ErrorKind};
use openshell_core::{
ComputeDriverKind, Config,
proto::{HealthRequest, open_shell_client::OpenShellClient},
};
use rcgen::{CertificateParams, IsCa, KeyPair};
use std::io::{Error, ErrorKind, Write};
use std::net::SocketAddr;
use std::sync::Arc;
use std::time::Duration;
use tempfile::{TempDir, tempdir};
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::{TcpListener, TcpStream};
use tokio::sync::watch;

fn install_rustls_provider() {
let _ = rustls::crypto::ring::default_provider().install_default();
}

fn test_tls_acceptor() -> (TempDir, TlsAcceptor) {
install_rustls_provider();

let mut ca_params =
CertificateParams::new(Vec::<String>::new()).expect("failed to create CA params");
ca_params.is_ca = IsCa::Ca(rcgen::BasicConstraints::Unconstrained);
ca_params
.distinguished_name
.push(rcgen::DnType::CommonName, "test-ca");
let ca_key = KeyPair::generate().expect("failed to generate CA key");
let ca_cert = ca_params
.self_signed(&ca_key)
.expect("failed to sign CA cert");

let server_params = CertificateParams::new(vec!["localhost".to_string()])
.expect("failed to create server params");
let server_key = KeyPair::generate().expect("failed to generate server key");
let server_cert = server_params
.signed_by(&server_key, &ca_cert, &ca_key)
.expect("failed to sign server cert");

let dir = tempdir().expect("failed to create tempdir");
let write_file = |name: &str, data: &[u8]| {
let path = dir.path().join(name);
std::fs::File::create(&path)
.and_then(|mut file| file.write_all(data))
.expect("failed to write tls test file");
};
write_file("ca.pem", ca_cert.pem().as_bytes());
write_file("server-cert.pem", server_cert.pem().as_bytes());
write_file("server-key.pem", server_key.serialize_pem().as_bytes());

let acceptor = TlsAcceptor::from_files(
&dir.path().join("server-cert.pem"),
&dir.path().join("server-key.pem"),
&dir.path().join("ca.pem"),
false,
)
.expect("failed to build tls acceptor");

(dir, acceptor)
}

async fn test_state(
bind_addr: SocketAddr,
enable_loopback_service_http: bool,
) -> Arc<ServerState> {
let store = Arc::new(
crate::persistence::Store::connect("sqlite::memory:?cache=shared")
.await
.expect("failed to create test store"),
);
let compute = crate::compute::new_test_runtime(store.clone()).await;
Arc::new(ServerState::new(
Config::new(None)
.with_database_url("sqlite::memory:?cache=shared")
.with_bind_address(bind_addr)
.with_server_sans(["*.dev.openshell.localhost"])
.with_loopback_service_http(enable_loopback_service_http),
store,
compute,
crate::sandbox_index::SandboxIndex::new(),
crate::sandbox_watch::SandboxWatchBus::new(),
crate::tracing_bus::TracingLogBus::new(),
Arc::new(crate::supervisor_session::SupervisorSessionRegistry::new()),
None,
))
}

async fn start_tls_gateway_listener(
bind_addr: &str,
enable_loopback_service_http: bool,
) -> (
SocketAddr,
watch::Sender<bool>,
tokio::task::JoinHandle<()>,
TempDir,
) {
let listener = TcpListener::bind(bind_addr)
.await
.expect("failed to bind test listener");
let listen_addr = listener.local_addr().expect("failed to read local addr");
let state = test_state(listen_addr, enable_loopback_service_http).await;
let service = MultiplexService::new(state);
let (tls_dir, tls_acceptor) = test_tls_acceptor();
let (shutdown_tx, shutdown_rx) = watch::channel(false);
let handle = tokio::spawn(serve_gateway_listener(
listener,
listen_addr,
service,
Some(tls_acceptor),
enable_loopback_service_http,
shutdown_rx,
));
(listen_addr, shutdown_tx, handle, tls_dir)
}

async fn send_plain_http(addr: SocketAddr, request: String) -> String {
let connect_addr: SocketAddr = format!("127.0.0.1:{}", addr.port())
.parse()
.expect("failed to build loopback connect addr");
let mut stream = TcpStream::connect(connect_addr)
.await
.expect("failed to connect to test listener");
stream
.write_all(request.as_bytes())
.await
.expect("failed to write request");

let mut response = Vec::new();
let read_result =
tokio::time::timeout(Duration::from_secs(2), stream.read_to_end(&mut response))
.await
.expect("timed out reading response");
if let Err(err) = read_result
&& err.kind() != ErrorKind::ConnectionReset
{
panic!("failed to read response: {err}");
}
String::from_utf8_lossy(&response).into_owned()
}

fn service_request(addr: SocketAddr, extra_headers: &[(&str, &str)]) -> String {
let mut request = format!(
"GET / HTTP/1.1\r\nHost: my-sandbox--web.dev.openshell.localhost:{}\r\nConnection: close\r\n",
addr.port()
);
for (name, value) in extra_headers {
request.push_str(name);
request.push_str(": ");
request.push_str(value);
request.push_str("\r\n");
}
request.push_str("\r\n");
request
}

async fn stop_listener(shutdown: watch::Sender<bool>, handle: tokio::task::JoinHandle<()>) {
let _ = shutdown.send(true);
let _ = tokio::time::timeout(Duration::from_secs(2), handle).await;
}

#[test]
fn classifies_probe_style_tls_disconnects_as_benign() {
Expand Down Expand Up @@ -782,6 +939,129 @@ mod tests {
assert!(!allow_plaintext_service_http(true, loopback, remote_peer));
}

#[tokio::test]
async fn plaintext_service_http_listener_rejects_non_loopback_bind() {
let (addr, shutdown, handle, _tls_dir) =
start_tls_gateway_listener("0.0.0.0:0", true).await;

let response = send_plain_http(addr, service_request(addr, &[])).await;

assert!(
response.is_empty(),
"non-loopback gateway listener should drop plaintext service HTTP, got: {response:?}"
);
stop_listener(shutdown, handle).await;
}

#[tokio::test]
async fn plaintext_service_http_rejects_cross_origin_browser_contexts() {
let (addr, shutdown, handle, _tls_dir) =
start_tls_gateway_listener("127.0.0.1:0", true).await;
let cases = [
(
"cross-site fetch metadata",
vec![("Sec-Fetch-Site", "cross-site")],
),
(
"same-site sibling fetch metadata",
vec![("Sec-Fetch-Site", "same-site")],
),
(
"mismatched origin",
vec![(
"Origin",
"http://other-sandbox--web.dev.openshell.localhost:8080",
)],
),
(
"mismatched referer",
vec![(
"Referer",
"http://other-sandbox--web.dev.openshell.localhost:8080/page",
)],
),
];

for (name, headers) in cases {
let response = send_plain_http(addr, service_request(addr, &headers)).await;

assert!(
response.starts_with("HTTP/1.1 403 Forbidden"),
"{name} should be rejected before service lookup, got: {response:?}"
);
assert!(
response.contains("Cross-origin service request rejected"),
"{name} should explain the service rejection, got: {response:?}"
);
}
stop_listener(shutdown, handle).await;
}

#[tokio::test]
async fn plaintext_service_http_allows_same_origin_browser_context_to_reach_service_lookup() {
let (addr, shutdown, handle, _tls_dir) =
start_tls_gateway_listener("127.0.0.1:0", true).await;
let origin = format!(
"http://my-sandbox--web.dev.openshell.localhost:{}",
addr.port()
);
let response = send_plain_http(
addr,
service_request(
addr,
&[("Sec-Fetch-Site", "same-origin"), ("Origin", &origin)],
),
)
.await;

assert!(
response.starts_with("HTTP/1.1 404 Not Found"),
"same-origin browser context should pass CSRF guard and miss only because no endpoint exists, got: {response:?}"
);
assert!(
!response.contains("Cross-origin service request rejected"),
"same-origin browser context should not be rejected as cross-origin, got: {response:?}"
);
stop_listener(shutdown, handle).await;
}

#[tokio::test]
async fn plaintext_service_http_does_not_expose_grpc_gateway() {
let (addr, shutdown, handle, _tls_dir) =
start_tls_gateway_listener("127.0.0.1:0", true).await;
let grpc_endpoint = format!("http://127.0.0.1:{}", addr.port());
let grpc_succeeded = tokio::time::timeout(Duration::from_secs(2), async {
match OpenShellClient::connect(grpc_endpoint).await {
Ok(mut client) => client.health(HealthRequest {}).await.is_ok(),
Err(_) => false,
}
})
.await
.expect("timed out checking plaintext gRPC exposure");

assert!(
!grpc_succeeded,
"plaintext service HTTP must not expose successful gateway gRPC"
);

let request = format!(
"POST /openshell.v1.OpenShell/Health HTTP/1.1\r\nHost: 127.0.0.1:{}\r\nContent-Type: application/grpc\r\nTE: trailers\r\nContent-Length: 0\r\nConnection: close\r\n\r\n",
addr.port()
);

let response = send_plain_http(addr, request).await;

assert!(
response.starts_with("HTTP/1.1 404 Not Found"),
"plaintext service HTTP router should not serve gateway gRPC, got: {response:?}"
);
assert!(
!response.contains("grpc-status: 0"),
"plaintext service HTTP must not return a successful gRPC response: {response:?}"
);
stop_listener(shutdown, handle).await;
}

#[test]
fn configured_compute_driver_triggers_auto_detection_when_empty() {
let config = Config::new(None).with_compute_drivers([]);
Expand Down
Loading