diff --git a/Cargo.lock b/Cargo.lock index cb68ba277..960a60ef6 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -471,6 +471,12 @@ dependencies = [ "libc", ] +[[package]] +name = "anes" +version = "0.1.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4b46cbb362ab8752921c97e041f5e366ee6297bd428a31275b9fcf1e380f7299" + [[package]] name = "anstyle" version = "1.0.13" @@ -498,6 +504,7 @@ dependencies = [ "bytes", "chrono", "config", + "criterion", "database", "deadpool-postgres", "dotenvy", @@ -1627,6 +1634,12 @@ dependencies = [ "serde", ] +[[package]] +name = "cast" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "37b2a672a2cb129a2e41c10b1224bb368f9f37a2b16b612598138befd7b37eb5" + [[package]] name = "cc" version = "1.2.51" @@ -1682,6 +1695,33 @@ dependencies = [ "windows-link", ] +[[package]] +name = "ciborium" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "42e69ffd6f0917f5c029256a24d0161db17cea3997d185db0d35926308770f0e" +dependencies = [ + "ciborium-io", + "ciborium-ll", + "serde", +] + +[[package]] +name = "ciborium-io" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "05afea1e0a06c9be33d539b876f1ce3692f4afea2cb41f740e7743225ed1c757" + +[[package]] +name = "ciborium-ll" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "57663b653d948a338bfb3eeba9bb2fd5fcfaecb9e199e87e1eda4d9e8b240fd9" +dependencies = [ + "ciborium-io", + "half", +] + [[package]] name = "cipher" version = "0.4.4" @@ -1692,6 +1732,31 @@ dependencies = [ "inout", ] +[[package]] +name = "clap" +version = "4.5.60" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2797f34da339ce31042b27d23607e051786132987f595b02ba4f6a6dffb7030a" +dependencies = [ + "clap_builder", +] + +[[package]] +name = "clap_builder" +version = "4.5.60" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "24a241312cea5059b13574bb9b3861cabf758b879c15190b37b6d6fd63ab6876" +dependencies = [ + "anstyle", + "clap_lex", +] + +[[package]] +name = "clap_lex" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3a822ea5bc7590f9d40f1ba12c0dc3c2760f3482c6984db1573ad11031420831" + [[package]] name = "cmake" version = "0.1.57" @@ -1867,6 +1932,44 @@ dependencies = [ "cfg-if", ] +[[package]] +name = "criterion" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f2b12d017a929603d80db1831cd3a24082f8137ce19c69e6447f54f5fc8d692f" +dependencies = [ + "anes", + "cast", + "ciborium", + "clap", + "criterion-plot", + "futures", + "is-terminal", + "itertools 0.10.5", + "num-traits", + "once_cell", + "oorandom", + "plotters", + "rayon", + "regex", + "serde", + "serde_derive", + "serde_json", + "tinytemplate", + "tokio", + "walkdir", +] + +[[package]] +name = "criterion-plot" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6b50826342786a51a89e2da3a28f1c32b06e387201bc2d19791f622c673706b1" +dependencies = [ + "cast", + "itertools 0.10.5", +] + [[package]] name = "critical-section" version = "1.2.0" @@ -1882,6 +1985,16 @@ dependencies = [ "crossbeam-utils", ] +[[package]] +name = "crossbeam-deque" +version = "0.8.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9dd111b7b7f7d55b72c0a6ae361660ee5853c9af73f70c3c2ef6858b950e2e51" +dependencies = [ + "crossbeam-epoch", + "crossbeam-utils", +] + [[package]] name = "crossbeam-epoch" version = "0.9.18" @@ -2935,6 +3048,17 @@ dependencies = [ "tracing", ] +[[package]] +name = "half" +version = "2.7.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6ea2d84b969582b4b1864a92dc5d27cd2b77b622a8d79306834f1be5ba20d84b" +dependencies = [ + "cfg-if", + "crunchy", + "zerocopy", +] + [[package]] name = "hashbrown" version = "0.12.3" @@ -3559,6 +3683,17 @@ dependencies = [ "serde", ] +[[package]] +name = "is-terminal" +version = "0.4.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3640c1c38b8e4e43584d8df18be5fc6b0aa314ce6ebf51b53313d4306cca8e46" +dependencies = [ + "hermit-abi", + "libc", + "windows-sys 0.61.2", +] + [[package]] name = "itertools" version = "0.10.5" @@ -4229,6 +4364,12 @@ dependencies = [ "portable-atomic", ] +[[package]] +name = "oorandom" +version = "11.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d6790f58c7ff633d8771f42965289203411a5e5c68388703c06e14f24770b41e" + [[package]] name = "opaque-debug" version = "0.3.1" @@ -4604,6 +4745,34 @@ version = "0.3.32" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7edddbd0b52d732b21ad9a5fab5c704c14cd949e5e9a1ec5929a24fded1b904c" +[[package]] +name = "plotters" +version = "0.3.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5aeb6f403d7a4911efb1e33402027fc44f29b5bf6def3effcc22d7bb75f2b747" +dependencies = [ + "num-traits", + "plotters-backend", + "plotters-svg", + "wasm-bindgen", + "web-sys", +] + +[[package]] +name = "plotters-backend" +version = "0.3.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "df42e13c12958a16b3f7f4386b9ab1f3e7933914ecea48da7139435263a4172a" + +[[package]] +name = "plotters-svg" +version = "0.3.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "51bae2ac328883f7acdfea3d66a7c35751187f870bc81f94563733a154d7a670" +dependencies = [ + "plotters-backend", +] + [[package]] name = "polyval" version = "0.6.2" @@ -5056,6 +5225,26 @@ dependencies = [ "rustversion", ] +[[package]] +name = "rayon" +version = "1.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "368f01d005bf8fd9b1206fb6fa653e6c4a81ceb1466406b81792d87c5677a58f" +dependencies = [ + "either", + "rayon-core", +] + +[[package]] +name = "rayon-core" +version = "1.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "22e18b0f0062d30d4230b2e85ff77fdfe4326feb054b9783a3460d8435c8ab91" +dependencies = [ + "crossbeam-deque", + "crossbeam-utils", +] + [[package]] name = "redox_syscall" version = "0.5.18" @@ -5964,6 +6153,7 @@ dependencies = [ "bytes", "chrono", "config", + "criterion", "dstack-sdk", "dstack-sdk-types", "ed25519-dalek", @@ -6475,6 +6665,16 @@ dependencies = [ "zerovec", ] +[[package]] +name = "tinytemplate" +version = "1.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "be4d6b5f19ff7664e8c98d03e2139cb510db9b0a60b55f8e8709b689d939b6bc" +dependencies = [ + "serde", + "serde_json", +] + [[package]] name = "tinyvec" version = "1.10.0" diff --git a/crates/api/Cargo.toml b/crates/api/Cargo.toml index 9f8f6a447..5aa0fa2e6 100644 --- a/crates/api/Cargo.toml +++ b/crates/api/Cargo.toml @@ -69,3 +69,8 @@ ed25519-dalek = { version = "2.1", features = ["rand_core"] } rand = "0.10" services = { path = "../services", features = ["test-mocks"] } async-trait = "0.1" +criterion = { version = "0.5", features = ["async_tokio"] } + +[[bench]] +name = "validation_bench" +harness = false diff --git a/crates/api/benches/validation_bench.rs b/crates/api/benches/validation_bench.rs new file mode 100644 index 000000000..37905448f --- /dev/null +++ b/crates/api/benches/validation_bench.rs @@ -0,0 +1,221 @@ +//! Benchmarks for request validation and body-hashing hot paths. +//! +//! Covers ChatCompletionRequest::validate() with varying message counts, +//! multimodal content validation, has_image_content() scanning, +//! serde_json::to_value() for content types, body SHA-256 hashing at +//! various sizes, and Bytes::clone() overhead. + +use std::collections::HashMap; + +use bytes::Bytes; +use criterion::{black_box, criterion_group, criterion_main, BenchmarkId, Criterion, Throughput}; +use sha2::{Digest, Sha256}; + +use api::models::{ + ChatCompletionRequest, Message, MessageContent, MessageContentPart, MessageImageUrl, +}; + +// --------------------------------------------------------------------------- +// Helpers +// --------------------------------------------------------------------------- + +fn make_text_message(role: &str, text: &str) -> Message { + Message { + role: role.to_string(), + content: Some(MessageContent::Text(text.to_string())), + name: None, + } +} + +fn make_multimodal_message() -> Message { + Message { + role: "user".to_string(), + content: Some(MessageContent::Parts(vec![ + MessageContentPart::Text { + text: "Describe this image".to_string(), + }, + MessageContentPart::ImageUrl { + image_url: MessageImageUrl::String("https://example.com/image.png".to_string()), + detail: None, + }, + ])), + name: None, + } +} + +fn make_request(message_count: usize) -> ChatCompletionRequest { + let mut messages = vec![make_text_message("system", "You are a helpful assistant.")]; + for i in 0..message_count { + if i % 2 == 0 { + messages.push(make_text_message("user", "Hello, how are you?")); + } else { + messages.push(make_text_message( + "assistant", + "I'm doing well, thanks for asking!", + )); + } + } + ChatCompletionRequest { + model: "Qwen/Qwen3-30B-A3B-Instruct-2507".to_string(), + messages, + max_tokens: Some(1024), + temperature: Some(1.0), + top_p: Some(1.0), + n: Some(1), + stream: Some(false), + stop: None, + presence_penalty: None, + frequency_penalty: None, + extra: HashMap::new(), + } +} + +fn make_multimodal_request() -> ChatCompletionRequest { + ChatCompletionRequest { + model: "Qwen/Qwen3-30B-A3B-Instruct-2507".to_string(), + messages: vec![ + make_text_message("system", "You are a helpful assistant."), + make_multimodal_message(), + ], + max_tokens: Some(1024), + temperature: Some(1.0), + top_p: Some(1.0), + n: Some(1), + stream: Some(false), + stop: None, + presence_penalty: None, + frequency_penalty: None, + extra: HashMap::new(), + } +} + +// --------------------------------------------------------------------------- +// Benchmark group: validate +// --------------------------------------------------------------------------- + +fn bench_validate(c: &mut Criterion) { + let mut group = c.benchmark_group("validate_request"); + + for n in [10, 50] { + let req = make_request(n); + group.bench_with_input(BenchmarkId::new("text_messages", n), &req, |b, req| { + b.iter(|| black_box(black_box(req).validate())) + }); + } + + let multimodal = make_multimodal_request(); + group.bench_function("multimodal_message", |b| { + b.iter(|| black_box(black_box(&multimodal).validate())) + }); + + group.finish(); +} + +// --------------------------------------------------------------------------- +// Benchmark group: serde_to_value +// --------------------------------------------------------------------------- + +fn bench_serde_to_value(c: &mut Criterion) { + let mut group = c.benchmark_group("serde_to_value"); + + let text_content = MessageContent::Text("Hello, how are you?".to_string()); + let multimodal_content = MessageContent::Parts(vec![ + MessageContentPart::Text { + text: "Describe this image".to_string(), + }, + MessageContentPart::ImageUrl { + image_url: MessageImageUrl::String("https://example.com/image.png".to_string()), + detail: None, + }, + ]); + + group.bench_function("text_content", |b| { + b.iter(|| black_box(serde_json::to_value(black_box(&text_content)).unwrap())) + }); + + group.bench_function("multimodal_content", |b| { + b.iter(|| black_box(serde_json::to_value(black_box(&multimodal_content)).unwrap())) + }); + + group.finish(); +} + +// --------------------------------------------------------------------------- +// Benchmark group: has_image_content +// --------------------------------------------------------------------------- + +fn bench_has_image_content(c: &mut Criterion) { + let mut group = c.benchmark_group("has_image_content"); + + let text_only = make_request(50); + let with_images = make_multimodal_request(); + + group.bench_function("text_only_50_messages", |b| { + b.iter(|| black_box(black_box(&text_only).has_image_content())) + }); + + group.bench_function("with_images", |b| { + b.iter(|| black_box(black_box(&with_images).has_image_content())) + }); + + group.finish(); +} + +// --------------------------------------------------------------------------- +// Benchmark group: body_sha256 +// --------------------------------------------------------------------------- + +fn bench_body_sha256(c: &mut Criterion) { + let mut group = c.benchmark_group("body_sha256"); + + for size_kb in [1, 10, 100] { + let body = vec![b'x'; size_kb * 1024]; + + group.throughput(Throughput::Bytes(body.len() as u64)); + + group.bench_with_input( + BenchmarkId::new("hash", format!("{}kb", size_kb)), + &body, + |b, body| { + b.iter(|| { + let mut hasher = Sha256::new(); + hasher.update(black_box(body)); + let hash = hasher.finalize(); + black_box(hex::encode(hash)) + }) + }, + ); + } + + group.finish(); +} + +// --------------------------------------------------------------------------- +// Benchmark group: bytes_clone +// --------------------------------------------------------------------------- + +fn bench_bytes_clone(c: &mut Criterion) { + let mut group = c.benchmark_group("bytes_clone"); + + let data = Bytes::from(vec![b'x'; 10 * 1024]); + + group.bench_function("clone_10kb", |b| { + b.iter(|| black_box(black_box(&data).clone())) + }); + + group.finish(); +} + +// --------------------------------------------------------------------------- +// Registration +// --------------------------------------------------------------------------- + +criterion_group!( + benches, + bench_validate, + bench_serde_to_value, + bench_has_image_content, + bench_body_sha256, + bench_bytes_clone, +); +criterion_main!(benches); diff --git a/crates/api/src/middleware/auth.rs b/crates/api/src/middleware/auth.rs index 4dbadac6d..18177d385 100644 --- a/crates/api/src/middleware/auth.rs +++ b/crates/api/src/middleware/auth.rs @@ -5,8 +5,11 @@ use axum::{ response::Response, }; use database::User as DbUser; +use moka::future::Cache; use services::auth::{AuthError, AuthServiceTrait, OAuthManager, SessionToken}; +use services::workspace::WorkspaceId; use std::sync::Arc; +use std::time::Duration; use tracing::{debug, error}; /// Authenticated user information passed to route handlers @@ -596,10 +599,20 @@ async fn authenticate_api_key_with_context( // Clone workspace_id to avoid partial move let workspace_id = validated_api_key.workspace_id.clone(); - // Get workspace with organization info + // Try the cache first, fall through to DB on miss + if let Some((workspace, organization)) = state.workspace_context_cache.get(&workspace_id).await + { + return Ok(AuthenticatedApiKey { + api_key: validated_api_key, + workspace, + organization, + }); + } + + // Get workspace with organization info from DB match state .workspace_repository - .get_workspace_with_organization(workspace_id) + .get_workspace_with_organization(workspace_id.clone()) .await { Ok(Some((workspace, organization))) => { @@ -607,6 +620,11 @@ async fn authenticate_api_key_with_context( "Resolved workspace: {} and organization: {} for API key", workspace.name, organization.name ); + // Populate cache + state + .workspace_context_cache + .insert(workspace_id, (workspace.clone(), organization.clone())) + .await; Ok(AuthenticatedApiKey { api_key: validated_api_key, workspace, @@ -645,6 +663,14 @@ pub struct AuthState { pub admin_access_token_repository: Arc, pub admin_domains: Vec, pub encoding_key: String, + /// Cache for workspace + organization lookups (avoids DB hit on every request) + workspace_context_cache: Cache< + WorkspaceId, + ( + services::workspace::Workspace, + services::organization::Organization, + ), + >, } impl AuthState { @@ -663,6 +689,10 @@ impl AuthState { admin_access_token_repository, admin_domains, encoding_key, + workspace_context_cache: Cache::builder() + .max_capacity(10_000) + .time_to_live(Duration::from_secs(30)) + .build(), } } } diff --git a/crates/api/src/models.rs b/crates/api/src/models.rs index 835251712..9a20bdda3 100644 --- a/crates/api/src/models.rs +++ b/crates/api/src/models.rs @@ -776,7 +776,7 @@ impl ChatCompletionRequest { return Err("messages cannot be empty".to_string()); } - for (idx, message) in self.messages.iter().enumerate() { + for message in &self.messages { if message.role.is_empty() { return Err("message role is required".to_string()); } @@ -785,15 +785,6 @@ impl ChatCompletionRequest { { return Err(format!("invalid message role: {}", message.role)); } - // Validate message content can be serialized (catches malformed multimodal content) - if let Some(ref content) = message.content { - if serde_json::to_value(content).is_err() { - return Err(format!( - "message at index {} has invalid content that cannot be processed", - idx - )); - } - } } if let Some(temp) = self.temperature { diff --git a/crates/api/src/routes/completions.rs b/crates/api/src/routes/completions.rs index 2bf019395..0879ab25b 100644 --- a/crates/api/src/routes/completions.rs +++ b/crates/api/src/routes/completions.rs @@ -386,9 +386,10 @@ pub async fn chat_completions( ); } - // Accumulate all SSE bytes for response hash computation - let accumulated_bytes = Arc::new(tokio::sync::Mutex::new(Vec::new())); - let chat_id_state = Arc::new(tokio::sync::Mutex::new(None::)); + // Accumulate all SSE bytes for response hash computation. + // Use std::sync::Mutex (no await points inside critical sections). + let accumulated_bytes = Arc::new(std::sync::Mutex::new(Vec::new())); + let chat_id_state = Arc::new(std::sync::Mutex::new(None::)); let stream_error_count = Arc::new(std::sync::atomic::AtomicU32::new(0)); let accumulated_clone = accumulated_bytes.clone(); @@ -398,65 +399,60 @@ pub async fn chat_completions( // Convert to raw bytes stream with proper SSE formatting let byte_stream = peekable_stream - .then(move |result| { - let accumulated_inner = accumulated_clone.clone(); - let chat_id_inner = chat_id_clone.clone(); - let error_count_inner = error_count_clone.clone(); - let model_for_err = request_model.clone(); - async move { - match result { - Ok(event) => { - // Extract chat_id from the first chunk if available - if let Ok(chunk_str) = - String::from_utf8(event.raw_bytes.to_vec()) - { - if let Some(data) = chunk_str.strip_prefix("data: ") { - if let Ok(serde_json::Value::Object(obj)) = - serde_json::from_str::( - data.trim(), - ) - { - if let Some(serde_json::Value::String(id)) = - obj.get("id") + .map(move |result| { + match result { + Ok(event) => { + // Only parse JSON to extract chat_id from the first chunk; + // skip on all subsequent chunks to avoid per-token overhead. + // Single lock acquisition covers both the check and the write. + { + let mut guard = + chat_id_clone.lock().unwrap_or_else(|e| e.into_inner()); + if guard.is_none() { + if let Ok(chunk_str) = std::str::from_utf8(&event.raw_bytes) + { + if let Some(data) = chunk_str.strip_prefix("data: ") { + if let Ok(serde_json::Value::Object(obj)) = + serde_json::from_str::( + data.trim(), + ) { - // Capture chat_id for use in the chain combinator - // The real hash will be registered there after accumulating all bytes - let mut cid = chat_id_inner.lock().await; - if cid.is_none() { - *cid = Some(id.clone()); + if let Some(serde_json::Value::String(id)) = + obj.get("id") + { + *guard = Some(id.clone()); } } } } } - - // raw_bytes contains "data: {...}\n", extract just the JSON part - let raw_str = String::from_utf8_lossy(&event.raw_bytes); - let json_data = raw_str - .trim() - .strip_prefix("data: ") - .unwrap_or(raw_str.trim()) - .to_string(); - tracing::debug!("Completion stream event: {}", json_data); - // Format as SSE event with proper newlines - let sse_bytes = Bytes::from(format!("data: {json_data}\n\n")); - accumulated_inner.lock().await.extend_from_slice(&sse_bytes); - Ok::(sse_bytes) } - Err(e) => { - let count = error_count_inner - .fetch_add(1, std::sync::atomic::Ordering::Relaxed); - if count == 0 { - tracing::error!( - model = %model_for_err, - error_type = %completion_stream_error_category(&e), - "Completion stream error" - ); - } - Ok::(Bytes::from(format!( - "data: error: {e}\n\n" - ))) + + // raw_bytes is "data: {...}\n"; append one "\n" for SSE double-newline. + // This avoids re-parsing + re-formatting the entire payload per token. + let mut buf = Vec::with_capacity(event.raw_bytes.len() + 1); + buf.extend_from_slice(&event.raw_bytes); + buf.push(b'\n'); + let sse_bytes = Bytes::from(buf); + accumulated_clone + .lock() + .unwrap_or_else(|e| e.into_inner()) + .extend_from_slice(&sse_bytes); + Ok::(sse_bytes) + } + Err(e) => { + let count = error_count_clone + .fetch_add(1, std::sync::atomic::Ordering::Relaxed); + if count == 0 { + tracing::error!( + model = %request_model, + error_type = %completion_stream_error_category(&e), + "Completion stream error" + ); } + Ok::(Bytes::from(format!( + "data: error: {e}\n\n" + ))) } } }) @@ -474,7 +470,7 @@ pub async fn chat_completions( let done_bytes = Bytes::from_static(b"data: [DONE]\n\n"); accumulated_bytes .lock() - .await + .unwrap_or_else(|e| e.into_inner()) .extend_from_slice(&done_bytes); Ok::(done_bytes) diff --git a/crates/api/src/routes/responses.rs b/crates/api/src/routes/responses.rs index 9435435d0..9de5b2605 100644 --- a/crates/api/src/routes/responses.rs +++ b/crates/api/src/routes/responses.rs @@ -268,9 +268,10 @@ pub async fn create_response( "Successfully created streaming response, returning SSE stream with signature accumulation" ); - // Shared state for accumulating bytes and tracking response_id - let accumulated_bytes = Arc::new(tokio::sync::Mutex::new(Vec::new())); - let response_id_state = Arc::new(tokio::sync::Mutex::new(None::)); + // Shared state for accumulating bytes and tracking response_id. + // Use std::sync::Mutex (not tokio) — no await points inside critical sections. + let accumulated_bytes = Arc::new(std::sync::Mutex::new(Vec::::new())); + let response_id_state = Arc::new(std::sync::Mutex::new(None::)); let request_hash = body_hash.hash.clone(); // Clone for closures @@ -288,31 +289,54 @@ pub async fn create_response( // Extract response_id from response.created event if event.event_type == "response.created" { if let Some(ref response) = event.response { - let mut rid = response_id_inner.lock().await; + let mut rid = response_id_inner + .lock() + .unwrap_or_else(|e| e.into_inner()); if rid.is_none() { *rid = Some(response.id.clone()); - tracing::debug!("Extracted response_id: {}", response.id); + tracing::debug!( + "Extracted response_id: {}", + response.id + ); } } } // Format as SSE: "event: {type}\ndata: {json}\n\n" + // Pre-allocate buffer to avoid intermediate String allocations. let json = serde_json::to_string(&event) .expect("event serialization failed"); - let sse_bytes = format!("event: {}\ndata: {}\n\n", event.event_type, json); - let bytes = Bytes::from(sse_bytes); - - // Accumulate bytes synchronously - this ensures all bytes are captured - // before the stream chunk is yielded to the client - accumulated_inner.lock().await.extend_from_slice(&bytes); + let mut buf = Vec::with_capacity( + "event: ".len() + + event.event_type.len() + + "\ndata: ".len() + + json.len() + + "\n\n".len(), + ); + buf.extend_from_slice(b"event: "); + buf.extend_from_slice(event.event_type.as_bytes()); + buf.extend_from_slice(b"\ndata: "); + buf.extend_from_slice(json.as_bytes()); + buf.extend_from_slice(b"\n\n"); + let bytes = Bytes::from(buf); + + // Accumulate bytes — no await points, std::sync::Mutex is sufficient. + accumulated_inner + .lock() + .unwrap_or_else(|e| e.into_inner()) + .extend_from_slice(&bytes); // Check if stream is completing - store signature if event.event_type == "response.completed" { - // At this point, all bytes have been accumulated synchronously - // Now we can safely compute the hash and store the signature - let bytes_accumulated = accumulated_inner.lock().await.clone(); - let response_hash = compute_sha256(&bytes_accumulated); - if let Some(rid) = response_id_inner.lock().await.as_ref() { + // Hash directly from the guard reference — no Vec clone needed. + let guard = accumulated_inner + .lock() + .unwrap_or_else(|e| e.into_inner()); + let response_hash = compute_sha256(&guard); + let rid_guard = response_id_inner + .lock() + .unwrap_or_else(|e| e.into_inner()); + if let Some(rid) = rid_guard.as_ref() { let rid = rid.clone(); let req_hash = request_hash_inner.clone(); let attest = attestation_inner.clone(); @@ -325,14 +349,23 @@ pub async fn create_response( // but we've already computed the hash with complete data tokio::spawn(async move { // Store both ECDSA and ED25519 signatures - if let Err(e) = attest.store_response_signature( - &rid, - req_hash.clone(), - response_hash.clone(), - ).await { - tracing::error!("Failed to store response signature: {}", e); + if let Err(e) = attest + .store_response_signature( + &rid, + req_hash.clone(), + response_hash.clone(), + ) + .await + { + tracing::error!( + "Failed to store response signature: {}", + e + ); } else { - tracing::debug!("Successfully stored signature for response_id: {}", rid); + tracing::debug!( + "Successfully stored signature for response_id: {}", + rid + ); } }); } diff --git a/crates/inference_providers/src/vllm/mod.rs b/crates/inference_providers/src/vllm/mod.rs index d5e7d64d4..8882e3b78 100644 --- a/crates/inference_providers/src/vllm/mod.rs +++ b/crates/inference_providers/src/vllm/mod.rs @@ -41,15 +41,22 @@ mod encryption_headers { pub struct VLlmConfig { pub base_url: String, pub api_key: Option, - pub timeout_seconds: i64, + /// Timeout for non-streaming requests (e.g., model listing, attestation). + /// Streaming requests use only the connect timeout — no overall deadline — + /// so long generations are never killed mid-stream. + pub request_timeout_seconds: i64, } impl VLlmConfig { - pub fn new(base_url: String, api_key: Option, timeout_seconds: Option) -> Self { + pub fn new( + base_url: String, + api_key: Option, + request_timeout_seconds: Option, + ) -> Self { Self { base_url, api_key, - timeout_seconds: timeout_seconds.unwrap_or(30), + request_timeout_seconds: request_timeout_seconds.unwrap_or(30), } } } @@ -148,7 +155,9 @@ impl InferenceProvider for VLlmProvider { .client .get(&url) .headers(headers) - .timeout(Duration::from_secs(self.config.timeout_seconds as u64)) + .timeout(Duration::from_secs( + self.config.request_timeout_seconds as u64, + )) .send() .await .map_err(|e| CompletionError::CompletionError(e.to_string()))?; @@ -196,7 +205,9 @@ impl InferenceProvider for VLlmProvider { .client .get(&url) .headers(headers) - .timeout(Duration::from_secs(self.config.timeout_seconds as u64)) + .timeout(Duration::from_secs( + self.config.request_timeout_seconds as u64, + )) .send() .await .map_err(|e| AttestationError::FetchError(e.to_string()))?; @@ -236,7 +247,9 @@ impl InferenceProvider for VLlmProvider { .client .get(&url) .headers(headers) - .timeout(Duration::from_secs(self.config.timeout_seconds as u64)) + .timeout(Duration::from_secs( + self.config.request_timeout_seconds as u64, + )) .send() .await .map_err(|e| ListModelsError::FetchError(format!("{e:?}")))?; @@ -283,12 +296,13 @@ impl InferenceProvider for VLlmProvider { // Prepare encryption headers self.prepare_encryption_headers(&mut headers, &mut streaming_params.extra); + // No per-request timeout for streaming — the client-level connect_timeout + // protects against connection failures, and we must not kill long generations. let response = self .client .post(&url) .headers(headers) .json(&streaming_params) - .timeout(Duration::from_secs(self.config.timeout_seconds as u64)) .send() .await .map_err(|e| CompletionError::CompletionError(e.to_string()))?; @@ -337,7 +351,9 @@ impl InferenceProvider for VLlmProvider { .post(&url) .headers(headers) .json(&non_streaming_params) - .timeout(Duration::from_secs(self.config.timeout_seconds as u64)) + .timeout(Duration::from_secs( + self.config.request_timeout_seconds as u64, + )) .send() .await .map_err(|e| CompletionError::CompletionError(e.to_string()))?; @@ -393,12 +409,13 @@ impl InferenceProvider for VLlmProvider { let headers = self .build_headers() .map_err(CompletionError::CompletionError)?; + // No per-request timeout for streaming — the client-level connect_timeout + // protects against connection failures, and we must not kill long generations. let response = self .client .post(&url) .headers(headers) .json(&streaming_params) - .timeout(Duration::from_secs(self.config.timeout_seconds as u64)) .send() .await .map_err(|e| CompletionError::CompletionError(e.to_string()))?; @@ -531,7 +548,7 @@ impl InferenceProvider for VLlmProvider { .headers(headers) .multipart(form) .timeout(std::time::Duration::from_secs( - self.config.timeout_seconds as u64, + self.config.request_timeout_seconds as u64, )) .send() .await @@ -674,7 +691,7 @@ impl InferenceProvider for VLlmProvider { .headers(headers) .json(¶ms) .timeout(std::time::Duration::from_secs( - self.config.timeout_seconds as u64, + self.config.request_timeout_seconds as u64, )) .send() .await @@ -707,7 +724,7 @@ impl InferenceProvider for VLlmProvider { .headers(headers) .json(¶ms) .timeout(std::time::Duration::from_secs( - self.config.timeout_seconds as u64, + self.config.request_timeout_seconds as u64, )) .send() .await @@ -738,7 +755,7 @@ mod tests { VLlmProvider::new(VLlmConfig { base_url: "http://localhost".to_string(), api_key: None, - timeout_seconds: 30, + request_timeout_seconds: 30, }) } diff --git a/crates/inference_providers/tests/integration_tests.rs b/crates/inference_providers/tests/integration_tests.rs index 8eb2ced71..8454811b8 100644 --- a/crates/inference_providers/tests/integration_tests.rs +++ b/crates/inference_providers/tests/integration_tests.rs @@ -26,7 +26,7 @@ fn create_test_provider() -> Box { base_url: std::env::var("VLLM_BASE_URL") .unwrap_or_else(|_| "http://localhost:8002".to_string()), api_key: std::env::var("VLLM_API_KEY").ok(), - timeout_seconds: std::env::var("VLLM_TEST_TIMEOUT_SECS") + request_timeout_seconds: std::env::var("VLLM_TEST_TIMEOUT_SECS") .ok() .and_then(|v| v.parse().ok()) .unwrap_or(30) as i64, @@ -691,7 +691,7 @@ async fn test_image_generation_real() { let config = VLlmConfig { base_url, api_key: std::env::var("VLLM_API_KEY").ok(), - timeout_seconds: 120, // Image generation can take longer + request_timeout_seconds: 120, // Image generation can take longer }; let provider = VLlmProvider::new(config); diff --git a/crates/services/Cargo.toml b/crates/services/Cargo.toml index 4c72510ee..794100c06 100644 --- a/crates/services/Cargo.toml +++ b/crates/services/Cargo.toml @@ -68,3 +68,20 @@ tokio-test = "0.4" async-trait = "0.1" futures = "0.3" mockall = "0.14" +criterion = { version = "0.5", features = ["async_tokio"] } + +[[bench]] +name = "completions_bench" +harness = false + +[[bench]] +name = "auth_bench" +harness = false + +[[bench]] +name = "provider_pool_bench" +harness = false + +[[bench]] +name = "responses_bench" +harness = false diff --git a/crates/services/benches/auth_bench.rs b/crates/services/benches/auth_bench.rs new file mode 100644 index 000000000..eafd35b14 --- /dev/null +++ b/crates/services/benches/auth_bench.rs @@ -0,0 +1,156 @@ +//! Benchmarks for the authentication hot path. +//! +//! Covers API key format validation, SHA-256 hashing, Moka cache hit/miss, +//! bloom filter positive/negative checks, and the combined fast-path simulation. + +use criterion::{black_box, criterion_group, criterion_main, Criterion}; +use services::common::{hash_api_key, is_valid_api_key_format}; + +// --------------------------------------------------------------------------- +// Helpers +// --------------------------------------------------------------------------- + +/// Generate a realistic API key string (sk- prefix + 32 hex chars = 35 total). +fn make_api_key(seed: u64) -> String { + format!("sk-{:032x}", seed) +} + +// --------------------------------------------------------------------------- +// Benchmark group: api_key_validation +// --------------------------------------------------------------------------- + +fn bench_api_key_validation(c: &mut Criterion) { + let mut group = c.benchmark_group("api_key_validation"); + + let valid_key = make_api_key(42); + let invalid_key = "bad-key"; + + group.bench_function("format_validation_valid", |b| { + b.iter(|| black_box(is_valid_api_key_format(black_box(&valid_key)))) + }); + + group.bench_function("format_validation_invalid", |b| { + b.iter(|| black_box(is_valid_api_key_format(black_box(invalid_key)))) + }); + + group.bench_function("sha256_hash", |b| { + b.iter(|| black_box(hash_api_key(black_box(&valid_key)))) + }); + + group.finish(); +} + +// --------------------------------------------------------------------------- +// Benchmark group: moka_cache +// --------------------------------------------------------------------------- + +fn bench_moka_cache(c: &mut Criterion) { + let mut group = c.benchmark_group("api_key_cache"); + + let rt = tokio::runtime::Builder::new_current_thread() + .enable_all() + .build() + .unwrap(); + + // Build a cache matching production parameters (10K cap, 30s TTL). + let cache: moka::future::Cache = moka::future::Cache::builder() + .max_capacity(10_000) + .time_to_live(std::time::Duration::from_secs(30)) + .build(); + + let key = make_api_key(1); + let hashed = hash_api_key(&key); + + // Pre-populate for hit benchmark. + rt.block_on(cache.insert(hashed.clone(), "dummy-user-id".to_string())); + + group.bench_function("cache_hit", |b| { + b.iter(|| rt.block_on(async { black_box(cache.get(black_box(&hashed)).await) })) + }); + + let missing_key = hash_api_key(&make_api_key(999_999)); + + group.bench_function("cache_miss", |b| { + b.iter(|| rt.block_on(async { black_box(cache.get(black_box(&missing_key)).await) })) + }); + + group.finish(); +} + +// --------------------------------------------------------------------------- +// Benchmark group: bloom_filter +// --------------------------------------------------------------------------- + +fn bench_bloom_filter(c: &mut Criterion) { + let mut group = c.benchmark_group("bloom_filter"); + + // Populate bloom filter with 1000 hashed keys. + let mut bloom = bloomfilter::Bloom::new_for_fp_rate(1000, 0.01).unwrap(); + let mut known_hash = String::new(); + for i in 0..1000u64 { + let h = hash_api_key(&make_api_key(i)); + if i == 500 { + known_hash = h.clone(); + } + bloom.set(&h); + } + + let absent_hash = hash_api_key(&make_api_key(1_000_000)); + + group.bench_function("check_positive", |b| { + b.iter(|| black_box(bloom.check(black_box(&known_hash)))) + }); + + group.bench_function("check_negative", |b| { + b.iter(|| black_box(bloom.check(black_box(&absent_hash)))) + }); + + group.finish(); +} + +// --------------------------------------------------------------------------- +// Benchmark group: full_auth_hot_path +// --------------------------------------------------------------------------- + +fn bench_full_auth_hot_path(c: &mut Criterion) { + let mut group = c.benchmark_group("full_auth_hot_path"); + + let rt = tokio::runtime::Builder::new_current_thread() + .enable_all() + .build() + .unwrap(); + + let cache: moka::future::Cache = moka::future::Cache::builder() + .max_capacity(10_000) + .time_to_live(std::time::Duration::from_secs(30)) + .build(); + + let key = make_api_key(42); + let hashed = hash_api_key(&key); + rt.block_on(cache.insert(hashed.clone(), "user-id-abc".to_string())); + + // Simulate: format check → SHA-256 → cache hit. + group.bench_function("format_hash_cache_hit", |b| { + b.iter(|| { + let k = black_box(&key); + assert!(is_valid_api_key_format(k)); + let h = hash_api_key(k); + rt.block_on(async { black_box(cache.get(&h).await) }) + }) + }); + + group.finish(); +} + +// --------------------------------------------------------------------------- +// Registration +// --------------------------------------------------------------------------- + +criterion_group!( + benches, + bench_api_key_validation, + bench_moka_cache, + bench_bloom_filter, + bench_full_auth_hot_path, +); +criterion_main!(benches); diff --git a/crates/services/benches/completions_bench.rs b/crates/services/benches/completions_bench.rs new file mode 100644 index 000000000..7d92e252c --- /dev/null +++ b/crates/services/benches/completions_bench.rs @@ -0,0 +1,785 @@ +//! Criterion microbenchmarks for completions hot-path optimizations. +//! +//! Benchmark groups: +//! 1. **sse_token_processing** — old async vs new sync per-token path +//! 2. **sse_operation_breakdown** — isolate each per-token operation +//! 3. **intercept_stream** — `InterceptStream::poll_next` throughput +//! 4. **intercept_stream_breakdown** — isolate poll_next sub-operations +//! 5. **model_resolution_cache** — moka cache hit vs miss + +use bytes::Bytes; +use criterion::{black_box, criterion_group, criterion_main, BenchmarkId, Criterion, Throughput}; +use futures::stream::StreamExt; +use std::sync::Arc; +use std::time::Instant; + +// --------------------------------------------------------------------------- +// Helpers: build synthetic SSE events +// --------------------------------------------------------------------------- + +/// Build a realistic SSE `data: {...}\n` payload for a chat completion chunk. +fn make_sse_payload(index: usize, is_last: bool) -> Bytes { + let usage = if is_last { + r#","usage":{"prompt_tokens":50,"completion_tokens":200,"total_tokens":250}"# + } else { + "" + }; + let finish_reason = if is_last { r#""stop""# } else { "null" }; + let content = if is_last { + String::new() + } else { + format!("token_{index}") + }; + + let json = format!( + r#"data: {{"id":"chatcmpl-bench000","object":"chat.completion.chunk","created":1700000000,"model":"bench-model","choices":[{{"index":0,"delta":{{"content":"{content}"}},"finish_reason":{finish_reason}}}]{usage}}}"#, + ); + Bytes::from(format!("{json}\n")) +} + +/// Build a Vec of raw SSE byte payloads simulating a 200-token stream. +fn make_sse_payloads(n: usize) -> Vec { + (0..n).map(|i| make_sse_payload(i, i == n - 1)).collect() +} + +/// Build SSEEvent objects for InterceptStream benchmarks. +fn make_sse_events( + n: usize, +) -> Vec> { + (0..n) + .map(|i| { + let is_last = i == n - 1; + let raw_bytes = make_sse_payload(i, is_last); + + let content = if is_last { + String::new() + } else { + format!("token_{i}") + }; + let finish_reason = if is_last { + Some(inference_providers::FinishReason::Stop) + } else { + None + }; + let usage = if is_last { + Some(inference_providers::TokenUsage::new(50, 200)) + } else { + None + }; + + let chunk = inference_providers::models::ChatCompletionChunk { + id: "chatcmpl-bench000".to_string(), + object: "chat.completion.chunk".to_string(), + created: 1_700_000_000, + model: "bench-model".to_string(), + system_fingerprint: None, + choices: vec![inference_providers::models::ChatChoice { + index: 0, + delta: Some(inference_providers::ChatDelta { + role: None, + content: Some(content), + name: None, + tool_call_id: None, + tool_calls: None, + reasoning_content: None, + reasoning: None, + }), + logprobs: None, + finish_reason, + token_ids: None, + }], + usage, + prompt_token_ids: None, + modality: None, + }; + + Ok(inference_providers::SSEEvent { + raw_bytes, + chunk: inference_providers::StreamChunk::Chat(chunk), + }) + }) + .collect() +} + +// =========================================================================== +// Group 1: SSE token processing — old (async) vs new (sync) per-token path +// =========================================================================== + +/// **Old path** (pre-optimisation): +/// - `tokio::Mutex` for accumulated_bytes and chat_id_state +/// - `String::from_utf8` (allocating) on every token +/// - Parse JSON on every token to extract chat_id +/// - Uses `.then()` (async) on the stream +async fn process_stream_old(payloads: Vec) { + let accumulated_bytes = Arc::new(tokio::sync::Mutex::new(Vec::::new())); + let chat_id_state = Arc::new(tokio::sync::Mutex::new(None::)); + + let stream = futures::stream::iter(payloads.into_iter().map(Ok::<_, std::convert::Infallible>)); + + let _: Vec<_> = stream + .then(move |result| { + let accumulated = accumulated_bytes.clone(); + let chat_id = chat_id_state.clone(); + async move { + let event_bytes = result.unwrap(); + + // Parse JSON on every token (old behaviour) + if let Ok(chunk_str) = String::from_utf8(event_bytes.to_vec()) { + if let Some(data) = chunk_str.strip_prefix("data: ") { + if let Ok(serde_json::Value::Object(obj)) = + serde_json::from_str::(data.trim()) + { + if let Some(serde_json::Value::String(id)) = obj.get("id") { + let mut guard = chat_id.lock().await; + *guard = Some(id.clone()); + } + } + } + } + + // Accumulate bytes + let raw_str = String::from_utf8_lossy(&event_bytes); + let json_data = raw_str + .trim() + .strip_prefix("data: ") + .unwrap_or(raw_str.trim()); + let sse_bytes = Bytes::from(format!("data: {json_data}\n\n")); + accumulated.lock().await.extend_from_slice(&sse_bytes); + sse_bytes + } + }) + .collect() + .await; +} + +/// **New path** (optimised): +/// - `std::sync::Mutex` for accumulated_bytes and chat_id_state +/// - `std::str::from_utf8` (zero-copy) for chat_id extraction +/// - Parse JSON only on first token (skip when chat_id already set) +/// - Uses `.map()` (sync) on the stream +/// +/// The caller provides the runtime to avoid measuring runtime-construction +/// overhead inside `b.iter()`. +fn process_stream_new(payloads: Vec, rt: &tokio::runtime::Runtime) { + let accumulated_bytes = Arc::new(std::sync::Mutex::new(Vec::::new())); + let chat_id_state = Arc::new(std::sync::Mutex::new(None::)); + + let stream = futures::stream::iter(payloads.into_iter().map(Ok::<_, std::convert::Infallible>)); + + let accumulated_clone = accumulated_bytes.clone(); + let chat_id_clone = chat_id_state.clone(); + + let mapped = stream.map(move |result| { + let event_bytes = result.unwrap(); + + // Only parse JSON for chat_id on first token + { + let mut guard = chat_id_clone.lock().unwrap_or_else(|e| e.into_inner()); + if guard.is_none() { + if let Ok(chunk_str) = std::str::from_utf8(&event_bytes) { + if let Some(data) = chunk_str.strip_prefix("data: ") { + if let Ok(serde_json::Value::Object(obj)) = + serde_json::from_str::(data.trim()) + { + if let Some(serde_json::Value::String(id)) = obj.get("id") { + *guard = Some(id.clone()); + } + } + } + } + } + } + + // Accumulate bytes — append "\n" to raw_bytes ("data: {...}\n" → "data: {...}\n\n") + let mut buf = Vec::with_capacity(event_bytes.len() + 1); + buf.extend_from_slice(&event_bytes); + buf.push(b'\n'); + let sse_bytes = Bytes::from(buf); + accumulated_clone + .lock() + .unwrap_or_else(|e| e.into_inner()) + .extend_from_slice(&sse_bytes); + sse_bytes + }); + + rt.block_on(async { + let _: Vec<_> = mapped.collect().await; + }); +} + +fn bench_sse_token_processing(c: &mut Criterion) { + let mut group = c.benchmark_group("sse_token_processing"); + let token_count: usize = 200; + group.throughput(Throughput::Elements(token_count as u64)); + + let payloads = make_sse_payloads(token_count); + + group.bench_with_input( + BenchmarkId::new("old_async_path", token_count), + &payloads, + |b, payloads| { + let rt = tokio::runtime::Builder::new_multi_thread() + .worker_threads(1) + .enable_all() + .build() + .unwrap(); + b.iter(|| { + rt.block_on(process_stream_old(payloads.clone())); + }); + }, + ); + + group.bench_with_input( + BenchmarkId::new("new_sync_path", token_count), + &payloads, + |b, payloads| { + let rt = tokio::runtime::Builder::new_current_thread() + .build() + .unwrap(); + b.iter(|| { + process_stream_new(payloads.clone(), &rt); + }); + }, + ); + + group.finish(); +} + +// =========================================================================== +// Group 2: SSE operation breakdown — isolate per-token costs +// =========================================================================== + +fn bench_sse_operation_breakdown(c: &mut Criterion) { + let mut group = c.benchmark_group("sse_operation_breakdown"); + let token_count: usize = 200; + group.throughput(Throughput::Elements(token_count as u64)); + + let payloads = make_sse_payloads(token_count); + + // 2a: Just String::from_utf8_lossy + strip + format (the re-framing cost) + group.bench_with_input( + BenchmarkId::new("reframe_bytes_only", token_count), + &payloads, + |b, payloads| { + b.iter(|| { + for payload in payloads { + let raw_str = String::from_utf8_lossy(payload); + let json_data = raw_str + .trim() + .strip_prefix("data: ") + .unwrap_or(raw_str.trim()); + let sse_bytes = Bytes::from(format!("data: {json_data}\n\n")); + black_box(sse_bytes); + } + }); + }, + ); + + // 2b: Just the mutex lock + extend (accumulation cost) + group.bench_with_input( + BenchmarkId::new("mutex_lock_and_accumulate", token_count), + &payloads, + |b, payloads| { + b.iter(|| { + let accumulated = std::sync::Mutex::new(Vec::::new()); + for payload in payloads { + accumulated.lock().unwrap().extend_from_slice(payload); + } + black_box(accumulated); + }); + }, + ); + + // 2c: serde_json parse of a single chunk (the JSON parse cost) + let single_payload = &payloads[0]; + group.bench_function("json_parse_single_chunk", |b| { + b.iter(|| { + let chunk_str = std::str::from_utf8(single_payload).unwrap(); + let data = chunk_str.strip_prefix("data: ").unwrap(); + let val: serde_json::Value = black_box(serde_json::from_str(data.trim()).unwrap()); + black_box(val); + }); + }); + + // 2d: json_parse * 200 (old path) vs json_parse * 1 (new path) + group.bench_with_input( + BenchmarkId::new("json_parse_all_tokens", token_count), + &payloads, + |b, payloads| { + b.iter(|| { + for payload in payloads { + if let Ok(chunk_str) = std::str::from_utf8(payload) { + if let Some(data) = chunk_str.strip_prefix("data: ") { + let _: Result = + black_box(serde_json::from_str(data.trim())); + } + } + } + }); + }, + ); + + // 2e: Bytes::from(format!(...)) allocation per token + group.bench_with_input( + BenchmarkId::new("bytes_format_alloc", token_count), + &payloads, + |b, payloads| { + b.iter(|| { + for payload in payloads { + let raw_str = std::str::from_utf8(payload).unwrap(); + let json_data = raw_str + .trim() + .strip_prefix("data: ") + .unwrap_or(raw_str.trim()); + let sse_bytes = Bytes::from(format!("data: {json_data}\n\n")); + black_box(sse_bytes); + } + }); + }, + ); + + // 2f: Zero-copy passthrough (raw_bytes already have correct format) + group.bench_with_input( + BenchmarkId::new("zero_copy_passthrough", token_count), + &payloads, + |b, payloads| { + b.iter(|| { + for payload in payloads { + // If raw_bytes were already correctly formatted, we could just pass through + black_box(payload.clone()); + } + }); + }, + ); + + group.finish(); +} + +// =========================================================================== +// Group 3: InterceptStream poll_next throughput +// =========================================================================== + +struct NoOpMetrics; + +#[async_trait::async_trait] +impl services::metrics::MetricsServiceTrait for NoOpMetrics { + fn record_latency(&self, _name: &str, _duration: std::time::Duration, _tags: &[&str]) {} + fn record_count(&self, _name: &str, _value: i64, _tags: &[&str]) {} + fn record_histogram(&self, _name: &str, _value: f64, _tags: &[&str]) {} +} + +struct NoOpAttestation; + +#[async_trait::async_trait] +impl services::attestation::ports::AttestationServiceTrait for NoOpAttestation { + async fn get_chat_signature( + &self, + _chat_id: &str, + _signing_algo: Option, + ) -> Result< + services::attestation::models::SignatureLookupResult, + services::attestation::models::AttestationError, + > { + unimplemented!("not used in benchmark") + } + async fn store_chat_signature_from_provider( + &self, + _chat_id: &str, + ) -> Result<(), services::attestation::models::AttestationError> { + Ok(()) + } + async fn store_response_signature( + &self, + _response_id: &str, + _request_hash: String, + _response_hash: String, + ) -> Result<(), services::attestation::models::AttestationError> { + Ok(()) + } + async fn get_attestation_report( + &self, + _model: Option, + _signing_algo: Option, + _nonce: Option, + _signing_address: Option, + ) -> Result< + services::attestation::models::AttestationReport, + services::attestation::models::AttestationError, + > { + unimplemented!("not used in benchmark") + } + async fn verify_vpc_signature( + &self, + _timestamp: i64, + _signature: String, + ) -> Result { + unimplemented!("not used in benchmark") + } +} + +struct NoOpUsage; + +#[async_trait::async_trait] +impl services::usage::UsageServiceTrait for NoOpUsage { + async fn calculate_cost( + &self, + _model_id: &str, + _input_tokens: i32, + _output_tokens: i32, + ) -> Result { + unimplemented!("not used in benchmark") + } + async fn record_usage( + &self, + _request: services::usage::RecordUsageServiceRequest, + ) -> Result { + Ok(dummy_usage_log_entry()) + } + async fn record_usage_from_api( + &self, + _organization_id: uuid::Uuid, + _workspace_id: uuid::Uuid, + _api_key_id: uuid::Uuid, + _request: services::usage::RecordUsageApiRequest, + ) -> Result { + unimplemented!("not used in benchmark") + } + async fn check_can_use( + &self, + _organization_id: uuid::Uuid, + ) -> Result { + unimplemented!("not used in benchmark") + } + async fn get_balance( + &self, + _organization_id: uuid::Uuid, + ) -> Result, services::usage::UsageError> { + unimplemented!("not used in benchmark") + } + async fn get_usage_history( + &self, + _organization_id: uuid::Uuid, + _limit: Option, + _offset: Option, + ) -> Result<(Vec, i64), services::usage::UsageError> { + unimplemented!("not used in benchmark") + } + async fn get_limit( + &self, + _organization_id: uuid::Uuid, + ) -> Result, services::usage::UsageError> { + unimplemented!("not used in benchmark") + } + async fn get_usage_history_by_api_key( + &self, + _api_key_id: uuid::Uuid, + _limit: Option, + _offset: Option, + ) -> Result<(Vec, i64), services::usage::UsageError> { + unimplemented!("not used in benchmark") + } + async fn get_api_key_usage_history_with_permissions( + &self, + _workspace_id: uuid::Uuid, + _api_key_id: uuid::Uuid, + _user_id: uuid::Uuid, + _limit: Option, + _offset: Option, + ) -> Result<(Vec, i64), services::usage::UsageError> { + unimplemented!("not used in benchmark") + } + async fn get_costs_by_inference_ids( + &self, + _organization_id: uuid::Uuid, + _inference_ids: Vec, + ) -> Result, services::usage::UsageError> { + unimplemented!("not used in benchmark") + } +} + +fn dummy_usage_log_entry() -> services::usage::UsageLogEntry { + services::usage::UsageLogEntry { + id: uuid::Uuid::nil(), + organization_id: uuid::Uuid::nil(), + workspace_id: uuid::Uuid::nil(), + api_key_id: uuid::Uuid::nil(), + model_id: uuid::Uuid::nil(), + model: "bench-model".to_string(), + input_tokens: 0, + output_tokens: 0, + total_tokens: 0, + input_cost: 0, + output_cost: 0, + total_cost: 0, + inference_type: services::usage::InferenceType::ChatCompletionStream, + created_at: chrono::Utc::now(), + ttft_ms: None, + avg_itl_ms: None, + inference_id: None, + provider_request_id: None, + stop_reason: None, + response_id: None, + image_count: None, + was_inserted: true, + } +} + +fn build_intercept_stream( + events: Vec>, +) -> services::completions::InterceptStream< + futures::stream::Iter< + std::vec::IntoIter< + Result, + >, + >, +> { + let now = Instant::now(); + services::completions::InterceptStream { + inner: futures::stream::iter(events), + attestation_service: Arc::new(NoOpAttestation), + usage_service: Arc::new(NoOpUsage), + metrics_service: Arc::new(NoOpMetrics), + organization_id: uuid::Uuid::nil(), + workspace_id: uuid::Uuid::nil(), + api_key_id: uuid::Uuid::nil(), + model_id: uuid::Uuid::nil(), + model_name: "bench-model".to_string(), + inference_type: services::usage::InferenceType::ChatCompletionStream, + service_start_time: now, + provider_start_time: now, + first_token_received: false, + first_token_time: None, + ttft_ms: None, + token_count: 0, + last_token_time: None, + total_itl_ms: 0.0, + metric_tags: vec![ + "model:bench-model".to_string(), + "environment:bench".to_string(), + ], + concurrent_counter: None, + last_usage_stats: None, + last_chat_id: None, + stream_completed: false, + response_id: None, + last_finish_reason: None, + last_error: None, + state: services::completions::StreamState::Streaming, + attestation_supported: false, + } +} + +fn bench_intercept_stream(c: &mut Criterion) { + let mut group = c.benchmark_group("intercept_stream"); + let token_count: usize = 200; + group.throughput(Throughput::Elements(token_count as u64)); + + let events = make_sse_events(token_count); + + group.bench_with_input( + BenchmarkId::new("poll_next_200_tokens", token_count), + &events, + |b, events| { + let rt = tokio::runtime::Builder::new_current_thread() + .enable_all() + .build() + .unwrap(); + b.iter(|| { + let stream = build_intercept_stream(events.clone()); + rt.block_on(async { + let _: Vec<_> = stream.collect().await; + }); + }); + }, + ); + + group.finish(); +} + +// =========================================================================== +// Group 4: InterceptStream operation breakdown +// =========================================================================== + +fn bench_intercept_stream_breakdown(c: &mut Criterion) { + let mut group = c.benchmark_group("intercept_stream_breakdown"); + let token_count: usize = 200; + group.throughput(Throughput::Elements(token_count as u64)); + + let events = make_sse_events(token_count); + + // 4a: Cost of cloning SSEEvents (done inside poll_next via event.clone()) + group.bench_with_input( + BenchmarkId::new("sse_event_clone", token_count), + &events, + |b, events| { + b.iter(|| { + for e in events.iter().flatten() { + black_box(e.clone()); + } + }); + }, + ); + + // 4b: Cost of Instant::now() per token (called on every poll_next) + group.bench_with_input( + BenchmarkId::new("instant_now_per_token", token_count), + &token_count, + |b, &n| { + b.iter(|| { + for _ in 0..n { + black_box(Instant::now()); + } + }); + }, + ); + + // 4c: Cost of chat_chunk.id.clone() per token + group.bench_with_input( + BenchmarkId::new("string_clone_chat_id", token_count), + &events, + |b, events| { + b.iter(|| { + for e in events.iter().flatten() { + if let inference_providers::StreamChunk::Chat(ref chunk) = e.chunk { + black_box(chunk.id.clone()); + } + } + }); + }, + ); + + // 4d: Cost of building the Vec<&str> metric tags per first-token call + group.bench_function("metric_tags_vec_build", |b| { + let metric_tags = [ + "model:bench-model".to_string(), + "environment:bench".to_string(), + ]; + b.iter(|| { + let tags_str: Vec<&str> = metric_tags.iter().map(|s| s.as_str()).collect(); + black_box(tags_str); + }); + }); + + // 4e: Cost of stream setup (build_intercept_stream without polling) + group.bench_with_input( + BenchmarkId::new("stream_construction", token_count), + &events, + |b, events| { + b.iter(|| { + let stream = build_intercept_stream(events.clone()); + black_box(stream); + }); + }, + ); + + // 4f: Cost of cloning the events Vec (benchmark overhead baseline) + group.bench_with_input( + BenchmarkId::new("events_vec_clone", token_count), + &events, + |b, events| { + b.iter(|| { + black_box(events.clone()); + }); + }, + ); + + group.finish(); +} + +// =========================================================================== +// Group 5: Model resolution cache — hit vs miss +// =========================================================================== + +fn make_test_model() -> services::models::ModelWithPricing { + services::models::ModelWithPricing { + id: uuid::Uuid::nil(), + model_name: "bench/test-model".to_string(), + model_display_name: "Bench Test Model".to_string(), + model_description: "A model for benchmarking".to_string(), + model_icon: None, + input_cost_per_token: 100, + output_cost_per_token: 300, + cost_per_image: 0, + context_length: 8192, + verifiable: false, + aliases: vec!["bench-alias".to_string()], + owned_by: "benchmark".to_string(), + provider_type: "vllm".to_string(), + provider_config: None, + attestation_supported: false, + input_modalities: Some(vec!["text".to_string()]), + output_modalities: Some(vec!["text".to_string()]), + } +} + +fn bench_model_resolution_cache(c: &mut Criterion) { + let mut group = c.benchmark_group("model_resolution_cache"); + + let rt = tokio::runtime::Builder::new_current_thread() + .enable_all() + .build() + .unwrap(); + + // Build a moka cache matching production configuration (60s TTL, 1000 capacity) + let cache: moka::future::Cache> = + moka::future::Cache::builder() + .max_capacity(1_000) + .time_to_live(std::time::Duration::from_secs(60)) + .build(); + + // Pre-populate for cache-hit benchmark + let model = make_test_model(); + rt.block_on(cache.insert("bench/test-model".to_string(), Some(model.clone()))); + + group.bench_function("cache_hit", |b| { + b.iter(|| { + rt.block_on(async { + let _ = cache.get("bench/test-model").await; + }); + }); + }); + + group.bench_function("cache_miss_and_insert", |b| { + let model_for_insert = make_test_model(); + // Use iter_batched to create a fresh cache per iteration, guaranteeing a true miss. + b.iter_batched( + || { + moka::future::Cache::builder() + .max_capacity(1_000) + .time_to_live(std::time::Duration::from_secs(60)) + .build() + }, + |miss_cache: moka::future::Cache< + String, + Option, + >| { + rt.block_on(async { + let key = "bench/test-model"; + let cached = miss_cache.get(key).await; + if cached.is_none() { + miss_cache + .insert(key.to_string(), Some(model_for_insert.clone())) + .await; + } + }); + }, + criterion::BatchSize::SmallInput, + ); + }); + + group.finish(); +} + +// =========================================================================== +// Criterion harness +// =========================================================================== + +criterion_group!( + benches, + bench_sse_token_processing, + bench_sse_operation_breakdown, + bench_intercept_stream, + bench_intercept_stream_breakdown, + bench_model_resolution_cache, +); +criterion_main!(benches); diff --git a/crates/services/benches/provider_pool_bench.rs b/crates/services/benches/provider_pool_bench.rs new file mode 100644 index 000000000..ce6e706a6 --- /dev/null +++ b/crates/services/benches/provider_pool_bench.rs @@ -0,0 +1,287 @@ +//! Benchmarks for the inference provider pool hot path. +//! +//! Covers round-robin index key formatting, mutex-guarded selection, provider +//! ordering with varying pool sizes, sticky routing cache, RwLock+HashMap model +//! lookup, and pub-key filtering via `Arc::ptr_eq`. + +use std::collections::HashMap; +use std::sync::{Arc, Mutex}; +use std::time::Duration; + +use criterion::{black_box, criterion_group, criterion_main, BenchmarkId, Criterion}; + +use async_trait::async_trait; +use inference_providers::models::{ + AttestationError, AudioTranscriptionError, AudioTranscriptionParams, + AudioTranscriptionResponse, ChatCompletionParams, ChatCompletionResponseWithBytes, + ChatSignature, CompletionError, CompletionParams, ImageEditError, ImageEditParams, + ImageEditResponseWithBytes, ImageGenerationError, ImageGenerationParams, + ImageGenerationResponseWithBytes, ListModelsError, ModelsResponse, RerankError, RerankParams, + RerankResponse, ScoreError, ScoreParams, ScoreResponse, +}; +use inference_providers::{InferenceProvider, StreamingResult}; + +// --------------------------------------------------------------------------- +// Stub provider (all methods unimplemented — we never call them) +// --------------------------------------------------------------------------- + +struct StubProvider; + +#[async_trait] +impl InferenceProvider for StubProvider { + async fn models(&self) -> Result { + unimplemented!() + } + async fn chat_completion_stream( + &self, + _: ChatCompletionParams, + _: String, + ) -> Result { + unimplemented!() + } + async fn chat_completion( + &self, + _: ChatCompletionParams, + _: String, + ) -> Result { + unimplemented!() + } + async fn text_completion_stream( + &self, + _: CompletionParams, + ) -> Result { + unimplemented!() + } + async fn image_generation( + &self, + _: ImageGenerationParams, + _: String, + ) -> Result { + unimplemented!() + } + async fn image_edit( + &self, + _: Arc, + _: String, + ) -> Result { + unimplemented!() + } + async fn score(&self, _: ScoreParams, _: String) -> Result { + unimplemented!() + } + async fn rerank(&self, _: RerankParams) -> Result { + unimplemented!() + } + async fn get_signature( + &self, + _: &str, + _: Option, + ) -> Result { + unimplemented!() + } + async fn get_attestation_report( + &self, + _: String, + _: Option, + _: Option, + _: Option, + ) -> Result, AttestationError> { + unimplemented!() + } + async fn audio_transcription( + &self, + _: AudioTranscriptionParams, + _: String, + ) -> Result { + unimplemented!() + } +} + +// --------------------------------------------------------------------------- +// Helpers +// --------------------------------------------------------------------------- + +type DynProvider = Arc; + +fn make_providers(n: usize) -> Vec { + (0..n) + .map(|_| Arc::new(StubProvider) as DynProvider) + .collect() +} + +// --------------------------------------------------------------------------- +// Benchmark group: round_robin +// --------------------------------------------------------------------------- + +fn bench_round_robin(c: &mut Criterion) { + let mut group = c.benchmark_group("round_robin"); + + let model_id = "Qwen/Qwen3-30B-A3B-Instruct-2507"; + + group.bench_function("index_key_format", |b| { + b.iter(|| black_box(format!("id:{}", black_box(model_id)))) + }); + + // Simulate the mutex-guarded round-robin selection (matches production code). + let index: Arc>> = Arc::new(Mutex::new(HashMap::new())); + let providers = make_providers(3); + let key = format!("id:{}", model_id); + + group.bench_function("mutex_and_select_3", |b| { + b.iter(|| { + let mut guard = index.lock().unwrap(); + let entry = guard.entry(key.clone()).or_insert(0); + let selected = *entry % providers.len(); + *entry = (*entry + 1) % providers.len(); + black_box(selected); + }) + }); + + group.finish(); +} + +// --------------------------------------------------------------------------- +// Benchmark group: provider_ordering +// --------------------------------------------------------------------------- + +fn bench_provider_ordering(c: &mut Criterion) { + let mut group = c.benchmark_group("provider_ordering"); + + let index: Arc>> = Arc::new(Mutex::new(HashMap::new())); + + for n in [3, 10] { + let providers = make_providers(n); + let key = format!("id:model-{}", n); + + group.bench_with_input(BenchmarkId::new("order_providers", n), &n, |b, _| { + b.iter(|| { + let mut guard = index.lock().unwrap(); + let entry = guard.entry(key.clone()).or_insert(0); + let start = *entry % providers.len(); + *entry = (*entry + 1) % providers.len(); + + // Build ordered vec rotating from start index (matches production). + let mut ordered = Vec::with_capacity(providers.len()); + for i in 0..providers.len() { + ordered.push(providers[(start + i) % providers.len()].clone()); + } + black_box(ordered); + }) + }); + } + + group.finish(); +} + +// --------------------------------------------------------------------------- +// Benchmark group: sticky_routing +// --------------------------------------------------------------------------- + +fn bench_sticky_routing(c: &mut Criterion) { + let mut group = c.benchmark_group("sticky_routing"); + + let rt = tokio::runtime::Builder::new_current_thread() + .enable_all() + .build() + .unwrap(); + + // Production parameters: 100K cap, 1h TTL. + let cache: moka::future::Cache = moka::future::Cache::builder() + .max_capacity(100_000) + .time_to_live(Duration::from_secs(3600)) + .build(); + + let chat_id = "chatcmpl-abc123def456".to_string(); + let provider: DynProvider = Arc::new(StubProvider); + + rt.block_on(cache.insert(chat_id.clone(), provider)); + + group.bench_function("cache_hit", |b| { + b.iter(|| rt.block_on(async { black_box(cache.get(black_box(&chat_id)).await) })) + }); + + let missing_chat_id = "chatcmpl-missing-999".to_string(); + + group.bench_function("cache_miss", |b| { + b.iter(|| rt.block_on(async { black_box(cache.get(black_box(&missing_chat_id)).await) })) + }); + + group.finish(); +} + +// --------------------------------------------------------------------------- +// Benchmark group: rwlock_model_lookup +// --------------------------------------------------------------------------- + +fn bench_rwlock_model_lookup(c: &mut Criterion) { + let mut group = c.benchmark_group("rwlock_model_lookup"); + + let rt = tokio::runtime::Builder::new_current_thread() + .enable_all() + .build() + .unwrap(); + + let mut model_map: HashMap> = HashMap::new(); + model_map.insert( + "Qwen/Qwen3-30B-A3B-Instruct-2507".to_string(), + make_providers(3), + ); + model_map.insert("meta-llama/Llama-3-70B".to_string(), make_providers(5)); + + let lock = Arc::new(tokio::sync::RwLock::new(model_map)); + let model_id = "Qwen/Qwen3-30B-A3B-Instruct-2507".to_string(); + + group.bench_function("read_and_lookup", |b| { + b.iter(|| { + rt.block_on(async { + let guard = lock.read().await; + black_box(guard.get(black_box(&model_id)).is_some()) + }) + }) + }); + + group.finish(); +} + +// --------------------------------------------------------------------------- +// Benchmark group: pubkey_filtering +// --------------------------------------------------------------------------- + +fn bench_pubkey_filtering(c: &mut Criterion) { + let mut group = c.benchmark_group("pubkey_filtering"); + + // Simulate N model providers and M pubkey providers, intersect via Arc::ptr_eq. + let all_providers = make_providers(5); + + // Model has all 5 providers. + let model_providers = all_providers.clone(); + // Pubkey matches providers 1 and 3 (simulate 2 out of 5). + let pubkey_providers = [all_providers[1].clone(), all_providers[3].clone()]; + + group.bench_function("ptr_eq_5_providers", |b| { + b.iter(|| { + let intersection: Vec<_> = model_providers + .iter() + .filter(|mp| pubkey_providers.iter().any(|pp| Arc::ptr_eq(mp, pp))) + .cloned() + .collect(); + black_box(intersection); + }) + }); + + group.finish(); +} + +// --------------------------------------------------------------------------- +// Registration +// --------------------------------------------------------------------------- + +criterion_group!( + benches, + bench_round_robin, + bench_provider_ordering, + bench_sticky_routing, + bench_rwlock_model_lookup, + bench_pubkey_filtering, +); +criterion_main!(benches); diff --git a/crates/services/benches/responses_bench.rs b/crates/services/benches/responses_bench.rs new file mode 100644 index 000000000..73f46a7cd --- /dev/null +++ b/crates/services/benches/responses_bench.rs @@ -0,0 +1,284 @@ +//! Benchmarks for the Responses API streaming hot path. +//! +//! Covers ResponseStreamEvent serialization (delta and created variants), +//! SSE formatting, tokio::sync::Mutex vs std::sync::Mutex accumulation, +//! full 200-event stream simulation, and SHA-256 hashing at various sizes. + +use criterion::{black_box, criterion_group, criterion_main, BenchmarkId, Criterion, Throughput}; +use sha2::{Digest, Sha256}; + +use services::responses::models::{ + ResponseContentItem, ResponseItemStatus, ResponseObject, ResponseOutputItem, ResponseStatus, + ResponseStreamEvent, ResponseToolChoiceOutput, Usage, +}; + +// --------------------------------------------------------------------------- +// Helpers +// --------------------------------------------------------------------------- + +/// Build a minimal delta event (the most frequent event type during streaming). +fn make_delta_event(seq: u64) -> ResponseStreamEvent { + ResponseStreamEvent { + event_type: "response.output_text.delta".to_string(), + sequence_number: Some(seq), + response: None, + output_index: Some(0), + content_index: Some(0), + item: None, + item_id: None, + part: None, + delta: Some("token".to_string()), + text: None, + logprobs: None, + obfuscation: None, + annotation_index: None, + annotation: None, + conversation_title: None, + } +} + +/// Build a response.created event containing a full ResponseObject. +fn make_created_event() -> ResponseStreamEvent { + let response = ResponseObject { + id: "resp_bench000000000000000000000001".to_string(), + object: "response".to_string(), + created_at: 1700000000, + status: ResponseStatus::InProgress, + background: false, + conversation: None, + error: None, + incomplete_details: None, + instructions: None, + max_output_tokens: None, + max_tool_calls: None, + model: "Qwen/Qwen3-30B-A3B-Instruct-2507".to_string(), + output: vec![ResponseOutputItem::Message { + id: "msg_bench00000000000000000000001".to_string(), + response_id: "resp_bench000000000000000000000001".to_string(), + previous_response_id: None, + next_response_ids: vec![], + created_at: 1700000000, + status: ResponseItemStatus::InProgress, + role: "assistant".to_string(), + content: vec![ResponseContentItem::OutputText { + text: String::new(), + annotations: vec![], + logprobs: vec![], + }], + model: "Qwen/Qwen3-30B-A3B-Instruct-2507".to_string(), + metadata: None, + }], + parallel_tool_calls: false, + previous_response_id: None, + next_response_ids: vec![], + prompt_cache_key: None, + prompt_cache_retention: None, + reasoning: None, + safety_identifier: None, + service_tier: "default".to_string(), + store: false, + temperature: 1.0, + tool_choice: ResponseToolChoiceOutput::Auto("auto".to_string()), + tools: vec![], + top_logprobs: 0, + top_p: 1.0, + truncation: "disabled".to_string(), + usage: Usage::new(0, 0), + user: None, + metadata: None, + }; + + ResponseStreamEvent { + event_type: "response.created".to_string(), + sequence_number: Some(0), + response: Some(response), + output_index: None, + content_index: None, + item: None, + item_id: None, + part: None, + delta: None, + text: None, + logprobs: None, + obfuscation: None, + annotation_index: None, + annotation: None, + conversation_title: None, + } +} + +// --------------------------------------------------------------------------- +// Benchmark group: event_serialization +// --------------------------------------------------------------------------- + +fn bench_event_serialization(c: &mut Criterion) { + let mut group = c.benchmark_group("response_event_serialize"); + + let delta = make_delta_event(1); + let created = make_created_event(); + + group.bench_function("delta", |b| { + b.iter(|| black_box(serde_json::to_string(black_box(&delta)).unwrap())) + }); + + group.bench_function("created", |b| { + b.iter(|| black_box(serde_json::to_string(black_box(&created)).unwrap())) + }); + + group.finish(); +} + +// --------------------------------------------------------------------------- +// Benchmark group: sse_formatting +// --------------------------------------------------------------------------- + +fn bench_sse_formatting(c: &mut Criterion) { + let mut group = c.benchmark_group("sse_format_response"); + + let delta = make_delta_event(1); + let json = serde_json::to_string(&delta).unwrap(); + let event_type = &delta.event_type; + + group.bench_function("format_sse_line", |b| { + b.iter(|| { + black_box(format!( + "event: {}\ndata: {}\n\n", + black_box(event_type), + black_box(&json) + )) + }) + }); + + group.finish(); +} + +// --------------------------------------------------------------------------- +// Benchmark group: mutex_accumulation (tokio vs std) +// --------------------------------------------------------------------------- + +fn bench_mutex_accumulation(c: &mut Criterion) { + let mut group = c.benchmark_group("mutex_accumulation"); + + let chunk = b"token_chunk_data_here_"; + + // tokio::sync::Mutex + { + let rt = tokio::runtime::Builder::new_current_thread() + .enable_all() + .build() + .unwrap(); + let mutex = tokio::sync::Mutex::new(Vec::::with_capacity(4096)); + + group.bench_function("tokio_mutex_200_events", |b| { + b.iter(|| { + rt.block_on(async { + // Reset + mutex.lock().await.clear(); + for _ in 0..200 { + let mut guard = mutex.lock().await; + guard.extend_from_slice(black_box(chunk)); + } + black_box(mutex.lock().await.len()); + }) + }) + }); + } + + // std::sync::Mutex + { + let mutex = std::sync::Mutex::new(Vec::::with_capacity(4096)); + + group.bench_function("std_mutex_200_events", |b| { + b.iter(|| { + mutex.lock().unwrap().clear(); + for _ in 0..200 { + let mut guard = mutex.lock().unwrap(); + guard.extend_from_slice(black_box(chunk)); + } + black_box(mutex.lock().unwrap().len()); + }) + }); + } + + group.finish(); +} + +// --------------------------------------------------------------------------- +// Benchmark group: full_response_stream +// --------------------------------------------------------------------------- + +fn bench_full_response_stream(c: &mut Criterion) { + let mut group = c.benchmark_group("response_stream"); + + let rt = tokio::runtime::Builder::new_current_thread() + .enable_all() + .build() + .unwrap(); + + // Pre-build 200 delta events. + let events: Vec = (0..200).map(make_delta_event).collect(); + + let accumulated = tokio::sync::Mutex::new(Vec::::with_capacity(32 * 1024)); + + group.throughput(Throughput::Elements(200)); + + group.bench_function("200_delta_events", |b| { + b.iter(|| { + rt.block_on(async { + accumulated.lock().await.clear(); + for event in &events { + let json = serde_json::to_string(event).unwrap(); + let sse = format!("event: {}\ndata: {}\n\n", event.event_type, json); + let mut guard = accumulated.lock().await; + guard.extend_from_slice(sse.as_bytes()); + } + black_box(accumulated.lock().await.len()); + }) + }) + }); + + group.finish(); +} + +// --------------------------------------------------------------------------- +// Benchmark group: sha256_accumulated +// --------------------------------------------------------------------------- + +fn bench_sha256_accumulated(c: &mut Criterion) { + let mut group = c.benchmark_group("sha256_accumulated"); + + for size_kb in [1, 10, 100] { + let data = vec![b'x'; size_kb * 1024]; + + group.throughput(Throughput::Bytes(data.len() as u64)); + + group.bench_with_input( + BenchmarkId::new("hash", format!("{}kb", size_kb)), + &data, + |b, data| { + b.iter(|| { + let mut hasher = Sha256::new(); + hasher.update(black_box(data)); + let hash = hasher.finalize(); + black_box(hex::encode(hash)) + }) + }, + ); + } + + group.finish(); +} + +// --------------------------------------------------------------------------- +// Registration +// --------------------------------------------------------------------------- + +criterion_group!( + benches, + bench_event_serialization, + bench_sse_formatting, + bench_mutex_accumulation, + bench_full_response_stream, + bench_sha256_accumulated, +); +criterion_main!(benches); diff --git a/crates/services/src/completions/mod.rs b/crates/services/src/completions/mod.rs index 37467d4a2..63affaa01 100644 --- a/crates/services/src/completions/mod.rs +++ b/crates/services/src/completions/mod.rs @@ -13,18 +13,16 @@ use uuid::Uuid; // Create a new stream that intercepts messages, but passes the original ones through use crate::metrics::{consts::*, MetricsServiceTrait}; -use futures_util::{Future, Stream}; +use futures_util::Stream; use std::pin::Pin; use std::task::{Context, Poll}; use std::time::{Duration, Instant}; const FINALIZE_TIMEOUT_SECS: u64 = 5; -type FinalizeFuture = Pin + Send>>; - -enum StreamState { +#[doc(hidden)] +pub enum StreamState { Streaming, - Finalizing(FinalizeFuture), Done, } @@ -48,79 +46,88 @@ fn get_input_bucket(token_count: i32) -> &'static str { } } -struct InterceptStream +#[doc(hidden)] +pub struct InterceptStream where S: Stream> + Unpin, { - inner: S, - attestation_service: Arc, - usage_service: Arc, - metrics_service: Arc, + pub inner: S, + pub attestation_service: Arc, + pub usage_service: Arc, + pub metrics_service: Arc, // IDs for usage tracking (database) - organization_id: Uuid, - workspace_id: Uuid, - api_key_id: Uuid, - model_id: Uuid, + pub organization_id: Uuid, + pub workspace_id: Uuid, + pub api_key_id: Uuid, + pub model_id: Uuid, #[allow(dead_code)] // Kept for potential debugging/logging use - model_name: String, - inference_type: crate::usage::ports::InferenceType, - service_start_time: Instant, - provider_start_time: Instant, - first_token_received: bool, - first_token_time: Option, + pub model_name: String, + pub inference_type: crate::usage::ports::InferenceType, + pub service_start_time: Instant, + pub provider_start_time: Instant, + pub first_token_received: bool, + pub first_token_time: Option, /// Time to first token in milliseconds (captured for DB storage) - ttft_ms: Option, + pub ttft_ms: Option, /// Token count for ITL calculation - token_count: i32, + pub token_count: i32, /// Last token time for ITL calculation - last_token_time: Option, + pub last_token_time: Option, /// Accumulated inter-token latency for average calculation - total_itl_ms: f64, + pub total_itl_ms: f64, // Pre-allocated low-cardinality metric tags (for Datadog/OTLP) - metric_tags: Vec, - concurrent_counter: Option>, + pub metric_tags: Vec, + pub concurrent_counter: Option>, /// Last received usage stats from streaming chunks - last_usage_stats: Option, + pub last_usage_stats: Option, /// Last chat ID from streaming chunks (for attestation and inference_id) - last_chat_id: Option, + pub last_chat_id: Option, /// Flag indicating the stream completed normally (received None from inner stream) /// If false when Drop is called, the client disconnected mid-stream - stream_completed: bool, + pub stream_completed: bool, /// Response ID when called from Responses API (for usage tracking FK) - response_id: Option, + pub response_id: Option, /// Last finish_reason from provider (e.g., "stop", "length", "tool_calls") - last_finish_reason: Option, + pub last_finish_reason: Option, /// Last error from provider (for determining stop_reason) - last_error: Option, - state: StreamState, + pub last_error: Option, + pub state: StreamState, /// Whether the model supports TEE attestation (false for external providers) - attestation_supported: bool, + pub attestation_supported: bool, } impl InterceptStream where S: Stream> + Unpin, { - /// Store attestation signature before sending [DONE] to client. - /// This runs in the hot path to ensure signature is available when client receives [DONE]. + /// Spawn attestation signature storage as a background task. + /// This no longer blocks [DONE] delivery to the client — the signature is + /// stored asynchronously after the stream completes. /// Skipped for external providers that don't support TEE attestation. - fn create_signature_future(&self) -> FinalizeFuture { - // Skip attestation for external providers (OpenAI, Anthropic, Gemini, etc.) + fn spawn_signature_storage(&self) { if !self.attestation_supported { - return Box::pin(async {}); + return; } let chat_id = match &self.last_chat_id { Some(id) => id.clone(), None => { tracing::warn!("Cannot store signature: no chat_id received in stream"); - return Box::pin(async {}); + return; + } + }; + + let handle = match tokio::runtime::Handle::try_current() { + Ok(h) => h, + Err(_) => { + tracing::error!("Cannot store chat signature: no Tokio runtime available"); + return; } }; let attestation_service = self.attestation_service.clone(); - Box::pin(async move { + handle.spawn(async move { match tokio::time::timeout( Duration::from_secs(FINALIZE_TIMEOUT_SECS), attestation_service.store_chat_signature_from_provider(&chat_id), @@ -138,7 +145,7 @@ where ); } } - }) + }); } /// Record usage and metrics. Called from Drop to ensure it always runs. @@ -316,82 +323,77 @@ where type Item = Result; fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - loop { - match &mut self.state { - StreamState::Streaming => { - match Pin::new(&mut self.inner).poll_next(cx) { - Poll::Ready(Some(Ok(ref event))) => { - let now = Instant::now(); - - if !self.first_token_received { - self.first_token_received = true; - self.first_token_time = Some(now); - let backend_ttft = now.duration_since(self.provider_start_time); - let e2e_ttft = now.duration_since(self.service_start_time); - self.ttft_ms = Some(e2e_ttft.as_millis() as i32); - self.last_token_time = Some(now); - let tags_str: Vec<&str> = - self.metric_tags.iter().map(|s| s.as_str()).collect(); - self.metrics_service.record_latency( - METRIC_LATENCY_TTFT, - backend_ttft, - &tags_str, - ); - self.metrics_service.record_latency( - METRIC_LATENCY_TTFT_TOTAL, - e2e_ttft, - &tags_str, - ); - } else if let Some(last_time) = self.last_token_time { - // Calculate inter-token latency - let itl = now.duration_since(last_time); - self.total_itl_ms += itl.as_secs_f64() * 1000.0; - self.token_count += 1; - self.last_token_time = Some(now); - } + match &mut self.state { + StreamState::Streaming => { + match Pin::new(&mut self.inner).poll_next(cx) { + Poll::Ready(Some(Ok(event))) => { + let now = Instant::now(); + + if !self.first_token_received { + self.first_token_received = true; + self.first_token_time = Some(now); + let backend_ttft = now.duration_since(self.provider_start_time); + let e2e_ttft = now.duration_since(self.service_start_time); + self.ttft_ms = Some(e2e_ttft.as_millis() as i32); + self.last_token_time = Some(now); + let tags_str: Vec<&str> = + self.metric_tags.iter().map(|s| s.as_str()).collect(); + self.metrics_service.record_latency( + METRIC_LATENCY_TTFT, + backend_ttft, + &tags_str, + ); + self.metrics_service.record_latency( + METRIC_LATENCY_TTFT_TOTAL, + e2e_ttft, + &tags_str, + ); + } else if let Some(last_time) = self.last_token_time { + // Calculate inter-token latency + let itl = now.duration_since(last_time); + self.total_itl_ms += itl.as_secs_f64() * 1000.0; + self.token_count += 1; + self.last_token_time = Some(now); + } - if let StreamChunk::Chat(ref chat_chunk) = event.chunk { - // Track chat_id for attestation (updated on each chunk) + if let StreamChunk::Chat(ref chat_chunk) = event.chunk { + // Only capture chat_id on first token (it never changes within a stream) + if self.last_chat_id.is_none() { self.last_chat_id = Some(chat_chunk.id.clone()); + } - // Track usage stats (updated on each chunk that has usage) - if let Some(usage) = &chat_chunk.usage { - self.last_usage_stats = Some(usage.clone()); - } + // Track usage stats (updated on each chunk that has usage) + if let Some(usage) = &chat_chunk.usage { + self.last_usage_stats = Some(usage.clone()); + } - // Track finish_reason from the final chunk (only set once at end) - if let Some(choice) = chat_chunk.choices.first() { - if let Some(ref reason) = choice.finish_reason { - self.last_finish_reason = Some(reason.clone()); - } + // Track finish_reason from the final chunk (only set once at end) + if let Some(choice) = chat_chunk.choices.first() { + if let Some(ref reason) = choice.finish_reason { + self.last_finish_reason = Some(reason.clone()); } } - return Poll::Ready(Some(Ok(event.clone()))); } - Poll::Ready(None) => { - self.stream_completed = true; - let signature_future = self.create_signature_future(); - self.state = StreamState::Finalizing(signature_future); - } - Poll::Ready(Some(Err(ref err))) => { - // Capture error for stop_reason in usage recording (handled in Drop) - // Note: We intentionally skip Finalizing state (attestation) for errors - // because partial completions cannot be verified by clients - self.last_error = Some(err.clone()); - return Poll::Ready(Some(Err(err.clone()))); - } - Poll::Pending => return Poll::Pending, + Poll::Ready(Some(Ok(event))) } - } - StreamState::Finalizing(ref mut future) => match future.as_mut().poll(cx) { - Poll::Ready(()) => { + Poll::Ready(None) => { + self.stream_completed = true; + // Spawn attestation in background — don't block [DONE] + self.spawn_signature_storage(); self.state = StreamState::Done; - return Poll::Ready(None); + Poll::Ready(None) } - Poll::Pending => return Poll::Pending, - }, - StreamState::Done => return Poll::Ready(None), + Poll::Ready(Some(Err(err))) => { + // Capture error for stop_reason in usage recording (handled in Drop) + // Note: We intentionally skip attestation for errors + // because partial completions cannot be verified by clients + self.last_error = Some(err.clone()); + Poll::Ready(Some(Err(err))) + } + Poll::Pending => Poll::Pending, + } } + StreamState::Done => Poll::Ready(None), } } } @@ -453,11 +455,18 @@ pub struct CompletionServiceImpl { org_concurrent_limits: Cache, /// Repository for fetching organization concurrent limits organization_limit_repository: Arc, + /// Cache for model resolution (avoids DB JOIN on every completion request) + model_resolution_cache: Cache>, } /// TTL for organization concurrent limit cache (5 minutes) const ORG_LIMIT_CACHE_TTL_SECS: u64 = 300; +/// TTL for model resolution cache (60 seconds). +/// Models rarely change, so a short TTL eliminates per-request DB JOINs +/// while still picking up model config changes within a minute. +const MODEL_RESOLUTION_CACHE_TTL_SECS: u64 = 60; + /// TTL for concurrent count cache entries (10 minutes). /// Safety net: if a counter gets stuck (e.g., due to a panic or proxy not propagating /// client disconnection), the entry expires and is replaced with a fresh zero counter. @@ -486,6 +495,11 @@ impl CompletionServiceImpl { .max_capacity(10_000) .build(); + let model_resolution_cache = Cache::builder() + .max_capacity(1_000) + .time_to_live(Duration::from_secs(MODEL_RESOLUTION_CACHE_TTL_SECS)) + .build(); + Self { inference_provider_pool, attestation_service, @@ -496,9 +510,28 @@ impl CompletionServiceImpl { concurrent_limit: DEFAULT_CONCURRENT_LIMIT, org_concurrent_limits, organization_limit_repository, + model_resolution_cache, } } + /// Resolve a model identifier with caching. Returns the model if found, or None. + /// Uses `try_get_with` to deduplicate concurrent misses for the same key, + /// preventing DB stampedes when many requests arrive for a cold model. + async fn resolve_model_cached( + &self, + identifier: &str, + ) -> Result, anyhow::Error> { + let id = identifier.to_string(); + let models_repo = self.models_repository.clone(); + + self.model_resolution_cache + .try_get_with(id.clone(), async move { + models_repo.resolve_and_get_model(&id).await + }) + .await + .map_err(|e| anyhow::anyhow!(e)) + } + /// Extract tools and tool_choice from extra HashMap if present. /// This handles the case where the Responses API places tools in request.extra. /// Returns (tools, tool_choice) and removes them from extra to avoid duplication. @@ -1007,13 +1040,9 @@ impl ports::CompletionServiceTrait for CompletionServiceImpl { extra, }; - // Resolve model name (could be an alias) and get model details in a single DB call + // Resolve model name (could be an alias) and get model details (cached, 60s TTL) // This also validates that the model exists and is active - let model = match self - .models_repository - .resolve_and_get_model(&request.model) - .await - { + let model = match self.resolve_model_cached(&request.model).await { Ok(Some(m)) => m, Ok(None) => { let err = ports::CompletionError::InvalidModel(format!( @@ -1144,13 +1173,9 @@ impl ports::CompletionServiceTrait for CompletionServiceImpl { extra, }; - // Resolve model name (could be an alias) and get model details in a single DB call + // Resolve model name (could be an alias) and get model details (cached, 60s TTL) // This also validates that the model exists and is active - let model = match self - .models_repository - .resolve_and_get_model(&request.model) - .await - { + let model = match self.resolve_model_cached(&request.model).await { Ok(Some(m)) => m, Ok(None) => { let err = ports::CompletionError::InvalidModel(format!( diff --git a/crates/services/src/inference_provider_pool/mod.rs b/crates/services/src/inference_provider_pool/mod.rs index 69b0d97ac..ce768c6ec 100644 --- a/crates/services/src/inference_provider_pool/mod.rs +++ b/crates/services/src/inference_provider_pool/mod.rs @@ -12,6 +12,10 @@ use regex::Regex; use serde::Deserialize; use std::{collections::HashMap, net::IpAddr, sync::Arc, time::Duration}; use tokio::sync::{Mutex, RwLock}; + +/// Alias for a synchronous, non-async mutex used for fast in-memory lookups +/// with no await points inside the critical section. +type SyncMutex = std::sync::Mutex; use tracing::{debug, info, warn}; type InferenceProviderTrait = dyn InferenceProvider + Send + Sync; @@ -66,10 +70,10 @@ pub struct InferenceProviderPool { external_providers: Arc>>>, /// Configuration for external providers (API keys, timeouts, etc.) external_configs: ExternalProvidersConfig, - /// Round-robin index for each model - load_balancer_index: Arc>>, - /// Map of chat_id -> provider for sticky routing - chat_id_mapping: Arc>>>, + /// Round-robin index for each model (std::sync::Mutex — no await points inside) + load_balancer_index: Arc>>, + /// Map of chat_id -> provider for sticky routing (TTL-bounded to prevent unbounded growth) + chat_id_mapping: moka::future::Cache>, /// Background task handle for periodic model discovery refresh refresh_task_handle: Arc>>>, /// Background task handle for periodic external provider refresh @@ -91,8 +95,11 @@ impl InferenceProviderPool { provider_mappings: Arc::new(RwLock::new(ProviderMappings::new())), external_providers: Arc::new(RwLock::new(HashMap::new())), external_configs, - load_balancer_index: Arc::new(RwLock::new(HashMap::new())), - chat_id_mapping: Arc::new(RwLock::new(HashMap::new())), + load_balancer_index: Arc::new(SyncMutex::new(HashMap::new())), + chat_id_mapping: moka::future::Cache::builder() + .max_capacity(100_000) + .time_to_live(Duration::from_secs(3600)) // 1 hour TTL for sticky routing + .build(), refresh_task_handle: Arc::new(Mutex::new(None)), external_refresh_task_handle: Arc::new(Mutex::new(None)), } @@ -640,8 +647,7 @@ impl InferenceProviderPool { chat_id: String, provider: Arc, ) { - let mut mapping = self.chat_id_mapping.write().await; - mapping.insert(chat_id.clone(), provider); + self.chat_id_mapping.insert(chat_id.clone(), provider).await; tracing::debug!("Stored chat_id mapping: {}", chat_id); } @@ -650,8 +656,7 @@ impl InferenceProviderPool { &self, chat_id: &str, ) -> Option> { - let mapping = self.chat_id_mapping.read().await; - mapping.get(chat_id).cloned() + self.chat_id_mapping.get(chat_id).await } /// Get providers with load balancing support @@ -705,31 +710,33 @@ impl InferenceProviderPool { return Some(providers); } - // Apply round-robin load balancing + // Apply round-robin load balancing. + // Build the index key in a single String allocation (no .clone() needed — + // entry().or_insert() only allocates on first insertion). let index_key = if let Some(pub_key) = model_pub_key { format!("pubkey:{}", pub_key) } else { format!("id:{}", model_id) }; - let mut indices = self.load_balancer_index.write().await; - let index = indices.entry(index_key.clone()).or_insert(0); - let selected_index = *index % providers.len(); - - // Increment for next request - *index = (*index + 1) % providers.len(); + let selected_index = { + let mut indices = self + .load_balancer_index + .lock() + .unwrap_or_else(|e| e.into_inner()); + let index = indices.entry(index_key).or_insert(0); + let selected = *index % providers.len(); + *index = (*index + 1) % providers.len(); + selected + }; - // Build ordered list following round-robin pattern: - // selected provider first, then continue round-robin (selected+1, selected+2, ...) - let mut ordered_providers = Vec::with_capacity(providers.len()); - for i in 0..providers.len() { - let provider_index = (selected_index + i) % providers.len(); - ordered_providers.push(providers[provider_index].clone()); - } + // Rotate the already-cloned providers vec in-place instead of building a new one. + let mut ordered_providers = providers; + ordered_providers.rotate_left(selected_index); tracing::debug!( - index_key = %index_key, - providers_count = providers.len(), + model_id = %model_id, + providers_count = ordered_providers.len(), selected_index = selected_index, "Prepared providers for fallback with round-robin priority" ); @@ -1667,19 +1674,22 @@ impl InferenceProviderPool { // Step 3: Clear load balancer indices debug!("Step 3: Clearing load balancer indices"); - let mut lb_index = self.load_balancer_index.write().await; - let index_count = lb_index.len(); - lb_index.clear(); - debug!("Cleared {} load balancer indices", index_count); - drop(lb_index); + let index_count = { + let mut lb_index = self + .load_balancer_index + .lock() + .unwrap_or_else(|e| e.into_inner()); + let count = lb_index.len(); + lb_index.clear(); + debug!("Cleared {} load balancer indices", count); + count + }; // Step 4: Clear chat_id to provider mappings debug!("Step 4: Clearing chat session mappings"); - let mut chat_mapping = self.chat_id_mapping.write().await; - let chat_count = chat_mapping.len(); - chat_mapping.clear(); + let chat_count = self.chat_id_mapping.entry_count(); + self.chat_id_mapping.invalidate_all(); debug!("Cleared {} chat session mappings", chat_count); - drop(chat_mapping); info!( "Inference provider pool shutdown completed. Cleaned up: {} models, {} pubkeys, {} external providers, {} load balancer indices, {} chat mappings",