From a8b711167ab4fb29693b630dde55d476c0b914da Mon Sep 17 00:00:00 2001 From: Evan Lezar Date: Mon, 4 May 2026 21:56:16 +0200 Subject: [PATCH 1/4] refactor(gpu): centralize driver request validation Signed-off-by: Evan Lezar --- .../openshell-driver-kubernetes/src/driver.rs | 4 +++ crates/openshell-driver-vm/src/driver.rs | 25 +++++++++++-------- 2 files changed, 19 insertions(+), 10 deletions(-) diff --git a/crates/openshell-driver-kubernetes/src/driver.rs b/crates/openshell-driver-kubernetes/src/driver.rs index 56b73447a..f6e6cb880 100644 --- a/crates/openshell-driver-kubernetes/src/driver.rs +++ b/crates/openshell-driver-kubernetes/src/driver.rs @@ -195,6 +195,10 @@ impl KubernetesComputeDriver { pub async fn validate_sandbox_create(&self, sandbox: &Sandbox) -> Result<(), tonic::Status> { let gpu_requested = sandbox.spec.as_ref().is_some_and(|spec| spec.gpu); + self.validate_gpu_request(gpu_requested).await + } + + async fn validate_gpu_request(&self, gpu_requested: bool) -> Result<(), tonic::Status> { if gpu_requested && !self.has_gpu_capacity().await.map_err(|err| { tonic::Status::internal(format!("check GPU node capacity failed: {err}")) diff --git a/crates/openshell-driver-vm/src/driver.rs b/crates/openshell-driver-vm/src/driver.rs index b797f4835..aec01b55b 100644 --- a/crates/openshell-driver-vm/src/driver.rs +++ b/crates/openshell-driver-vm/src/driver.rs @@ -1461,15 +1461,7 @@ fn validate_vm_sandbox(sandbox: &Sandbox, gpu_enabled: bool) -> Result<(), Statu .as_ref() .ok_or_else(|| Status::invalid_argument("sandbox spec is required"))?; - if spec.gpu && !gpu_enabled { - return Err(Status::failed_precondition( - "GPU support is not enabled on this driver; start with --gpu", - )); - } - - if !spec.gpu && !spec.gpu_device.is_empty() { - return Err(Status::invalid_argument("gpu_device requires gpu=true")); - } + validate_gpu_request(spec.gpu, &spec.gpu_device, gpu_enabled)?; if let Some(template) = spec.template.as_ref() { if !template.agent_socket_path.is_empty() { @@ -1491,7 +1483,6 @@ fn validate_vm_sandbox(sandbox: &Sandbox, gpu_enabled: bool) -> Result<(), Statu Ok(()) } -#[allow(clippy::result_large_err)] fn validate_sandbox_id(sandbox_id: &str) -> Result<(), Status> { if sandbox_id.is_empty() { return Err(Status::invalid_argument("sandbox id is required")); @@ -1517,6 +1508,20 @@ fn validate_sandbox_id(sandbox_id: &str) -> Result<(), Status> { Ok(()) } +#[allow(clippy::result_large_err)] +fn validate_gpu_request(gpu: bool, gpu_device: &str, gpu_enabled: bool) -> Result<(), Status> { + if gpu && !gpu_enabled { + return Err(Status::failed_precondition( + "GPU support is not enabled on this driver; start with --gpu", + )); + } + + if !gpu && !gpu_device.is_empty() { + return Err(Status::invalid_argument("gpu_device requires gpu=true")); + } + Ok(()) +} + #[allow(clippy::result_large_err)] fn parse_registry_reference(image_ref: &str) -> Result { Reference::try_from(image_ref).map_err(|err| { From ca0003fe4b898b7312cd036de60c55d45b337109 Mon Sep 17 00:00:00 2001 From: Evan Lezar Date: Mon, 4 May 2026 22:12:28 +0200 Subject: [PATCH 2/4] refactor(vm): derive GPU device request once Signed-off-by: Evan Lezar --- crates/openshell-driver-vm/src/driver.rs | 31 +++++++++++++++++++++--- 1 file changed, 27 insertions(+), 4 deletions(-) diff --git a/crates/openshell-driver-vm/src/driver.rs b/crates/openshell-driver-vm/src/driver.rs index aec01b55b..38ddb664b 100644 --- a/crates/openshell-driver-vm/src/driver.rs +++ b/crates/openshell-driver-vm/src/driver.rs @@ -358,9 +358,10 @@ impl VmDriver { return Err(Status::already_exists("sandbox already exists")); } - let spec = sandbox.spec.as_ref(); - let is_gpu = spec.is_some_and(|s| s.gpu); - let gpu_device = spec.map_or("", |s| s.gpu_device.as_str()); + let gpu_device = sandbox + .spec + .as_ref() + .and_then(|spec| requested_gpu_device(spec.gpu, &spec.gpu_device)); let state_dir = sandbox_state_dir(&self.config.state_dir, &sandbox.id)?; let rootfs = state_dir.join("rootfs"); @@ -437,7 +438,7 @@ impl VmDriver { ))); } - let gpu_bdf = if is_gpu { + let gpu_bdf = if let Some(gpu_device) = gpu_device { let inventory = self .gpu_inventory .as_ref() @@ -1508,6 +1509,10 @@ fn validate_sandbox_id(sandbox_id: &str) -> Result<(), Status> { Ok(()) } +fn requested_gpu_device(gpu: bool, gpu_device: &str) -> Option<&str> { + gpu.then_some(gpu_device) +} + #[allow(clippy::result_large_err)] fn validate_gpu_request(gpu: bool, gpu_device: &str, gpu_enabled: bool) -> Result<(), Status> { if gpu && !gpu_enabled { @@ -2579,6 +2584,24 @@ mod tests { assert!(err.message().contains("gpu_device requires gpu=true")); } + #[test] + fn requested_gpu_device_returns_none_without_gpu_request() { + assert_eq!(requested_gpu_device(false, ""), None); + } + + #[test] + fn requested_gpu_device_defaults_empty_request_to_inventory_choice() { + assert_eq!(requested_gpu_device(true, ""), Some("")); + } + + #[test] + fn requested_gpu_device_returns_explicit_device_id() { + assert_eq!( + requested_gpu_device(true, "0000:2d:00.0"), + Some("0000:2d:00.0") + ); + } + #[test] fn validate_vm_sandbox_rejects_platform_config() { let sandbox = Sandbox { From 5ce7a158cea3e8270c047f7a9e5d90de209e2817 Mon Sep 17 00:00:00 2001 From: Evan Lezar Date: Mon, 4 May 2026 22:14:48 +0200 Subject: [PATCH 3/4] feat(gpu): introduce GPU request spec Signed-off-by: Evan Lezar --- crates/openshell-cli/src/run.rs | 59 +++++++++++++-- crates/openshell-core/src/gpu.rs | 54 ++++++++++---- crates/openshell-driver-docker/src/lib.rs | 25 +++++-- crates/openshell-driver-docker/src/tests.rs | 32 +++++--- .../openshell-driver-kubernetes/src/driver.rs | 53 ++++++++++++-- .../openshell-driver-podman/src/container.rs | 71 +++++++++++++++++- crates/openshell-driver-podman/src/driver.rs | 14 ++-- crates/openshell-driver-vm/src/driver.rs | 73 ++++++++++++------- crates/openshell-server/src/compute/mod.rs | 60 +++++++++++++-- .../openshell-server/src/grpc/validation.rs | 6 +- e2e/python/conftest.py | 4 +- proto/compute_driver.proto | 23 ++++-- proto/openshell.proto | 24 ++++-- python/openshell/_proto/__init__.py | 8 +- 14 files changed, 399 insertions(+), 107 deletions(-) diff --git a/crates/openshell-cli/src/run.rs b/crates/openshell-cli/src/run.rs index 3205b8f68..a5aac80a4 100644 --- a/crates/openshell-cli/src/run.rs +++ b/crates/openshell-cli/src/run.rs @@ -32,11 +32,11 @@ use openshell_core::proto::{ GetClusterInferenceRequest, GetDraftHistoryRequest, GetDraftPolicyRequest, GetGatewayConfigRequest, GetProviderProfileRequest, GetProviderRequest, GetSandboxConfigRequest, GetSandboxLogsRequest, GetSandboxPolicyStatusRequest, - GetSandboxRequest, GetServiceRequest, HealthRequest, ImportProviderProfilesRequest, + GetSandboxRequest, GetServiceRequest, GpuSpec, HealthRequest, ImportProviderProfilesRequest, LintProviderProfilesRequest, ListProviderProfilesRequest, ListProvidersRequest, ListSandboxPoliciesRequest, ListSandboxProvidersRequest, ListSandboxesRequest, - ListServicesRequest, PolicySource, PolicyStatus, Provider, ProviderProfile, - ProviderProfileDiagnostic, ProviderProfileImportItem, RejectDraftChunkRequest, + ListServicesRequest, PlacementRequirements, PolicySource, PolicyStatus, Provider, + ProviderProfile, ProviderProfileDiagnostic, ProviderProfileImportItem, RejectDraftChunkRequest, RevokeSshSessionRequest, Sandbox, SandboxPhase, SandboxPolicy, SandboxSpec, SandboxTemplate, ServiceEndpointResponse, SetClusterInferenceRequest, SettingScope, SettingValue, TcpForwardFrame, TcpForwardInit, TcpRelayTarget, UpdateConfigRequest, UpdateProviderRequest, @@ -1538,8 +1538,7 @@ pub async fn sandbox_create( let request = CreateSandboxRequest { spec: Some(SandboxSpec { - gpu: requested_gpu, - gpu_device: gpu_device.unwrap_or_default().to_string(), + placement: placement_requirements_from_cli(requested_gpu, gpu_device), policy, providers: configured_providers, template, @@ -1971,6 +1970,20 @@ pub async fn sandbox_create( } } +fn placement_requirements_from_cli( + requested_gpu: bool, + gpu_device: Option<&str>, +) -> Option { + requested_gpu.then(|| PlacementRequirements { + gpu: Some(GpuSpec { + device_id: gpu_device + .filter(|device_id| !device_id.is_empty()) + .map(|device_id| vec![device_id.to_string()]) + .unwrap_or_default(), + }), + }) +} + /// Resolved source for the `--from` flag on `sandbox create`. #[derive(Debug)] enum ResolvedSource { @@ -6017,9 +6030,10 @@ mod tests { gateway_env_override_warning, gateway_select_with, gateway_type_label, git_sync_files, http_health_check, image_requests_gpu, import_local_package_mtls_bundle, inferred_provider_type, package_managed_tls_dirs, parse_cli_setting_value, - parse_credential_pairs, plaintext_gateway_is_remote, provisioning_timeout_message, - ready_false_condition_message, resolve_from, sandbox_should_persist, - service_expose_status_error, service_url_for_gateway, + parse_credential_pairs, placement_requirements_from_cli, plaintext_gateway_is_remote, + provisioning_timeout_message, ready_false_condition_message, resolve_from, + sandbox_should_persist, service_expose_status_error, service_url_for_gateway, + source_requests_gpu, }; use crate::TEST_ENV_LOCK; use hyper::StatusCode; @@ -6296,6 +6310,35 @@ mod tests { } } + #[test] + fn source_requests_gpu_detects_known_community_gpu_name() { + assert!(source_requests_gpu("nvidia-gpu")); + assert!(!source_requests_gpu("base")); + } + + #[test] + fn placement_requirements_from_cli_uses_presence_with_empty_device_ids_for_default_gpu() { + let request = placement_requirements_from_cli(true, None) + .expect("placement requirements should be present"); + let gpu = request.gpu.expect("gpu request should be present"); + + assert!(gpu.device_id.is_empty()); + } + + #[test] + fn placement_requirements_from_cli_maps_gpu_device_to_one_device_id() { + let request = placement_requirements_from_cli(true, Some("0000:2d:00.0")) + .expect("placement requirements should be present"); + let gpu = request.gpu.expect("gpu request should be present"); + + assert_eq!(gpu.device_id, vec!["0000:2d:00.0"]); + } + + #[test] + fn placement_requirements_from_cli_omits_placement_when_not_requested() { + assert!(placement_requirements_from_cli(false, Some("0")).is_none()); + } + #[test] fn resolve_from_classifies_existing_dockerfile_path() { let temp = tempfile::tempdir().expect("failed to create tempdir"); diff --git a/crates/openshell-core/src/gpu.rs b/crates/openshell-core/src/gpu.rs index 5df8702ed..03a85d435 100644 --- a/crates/openshell-core/src/gpu.rs +++ b/crates/openshell-core/src/gpu.rs @@ -4,21 +4,19 @@ //! Shared GPU request helpers. use crate::config::CDI_GPU_DEVICE_ALL; +use crate::proto::compute::v1::GpuSpec; -/// Resolve the existing GPU request fields into CDI device identifiers. +/// Resolve a driver GPU request into CDI device identifiers. /// -/// `None` means no GPU was requested. A GPU request with no explicit device -/// ID uses the CDI all-GPU request; otherwise the driver-native ID passes -/// through unchanged. +/// `None` means no GPU was requested. Presence with no explicit device IDs +/// uses the CDI all-GPU request; otherwise the driver-native IDs pass through. #[must_use] -pub fn cdi_gpu_device_ids(gpu: bool, gpu_device: &str) -> Option> { - gpu.then(|| { - if gpu_device.is_empty() { - vec![CDI_GPU_DEVICE_ALL.to_string()] - } else { - vec![gpu_device.to_string()] - } - }) +pub fn cdi_gpu_device_ids(gpu: Option<&GpuSpec>) -> Option> { + match gpu { + Some(gpu) if gpu.device_id.is_empty() => Some(vec![CDI_GPU_DEVICE_ALL.to_string()]), + Some(gpu) => Some(gpu.device_id.clone()), + None => None, + } } #[cfg(test)] @@ -27,22 +25,46 @@ mod tests { #[test] fn cdi_gpu_device_ids_returns_none_when_absent() { - assert_eq!(cdi_gpu_device_ids(false, ""), None); + assert_eq!(cdi_gpu_device_ids(None), None); } #[test] fn cdi_gpu_device_ids_defaults_empty_request_to_all_gpus() { + let request = GpuSpec { device_id: vec![] }; + assert_eq!( - cdi_gpu_device_ids(true, ""), + cdi_gpu_device_ids(Some(&request)), Some(vec![CDI_GPU_DEVICE_ALL.to_string()]) ); } #[test] - fn cdi_gpu_device_ids_passes_explicit_device_id_through() { + fn cdi_gpu_device_ids_passes_single_device_id_through() { + let request = GpuSpec { + device_id: vec!["nvidia.com/gpu=0".to_string()], + }; + assert_eq!( - cdi_gpu_device_ids(true, "nvidia.com/gpu=0"), + cdi_gpu_device_ids(Some(&request)), Some(vec!["nvidia.com/gpu=0".to_string()]) ); } + + #[test] + fn cdi_gpu_device_ids_passes_multiple_device_ids_through() { + let request = GpuSpec { + device_id: vec![ + "nvidia.com/gpu=0".to_string(), + "nvidia.com/gpu=1".to_string(), + ], + }; + + assert_eq!( + cdi_gpu_device_ids(Some(&request)), + Some(vec![ + "nvidia.com/gpu=0".to_string(), + "nvidia.com/gpu=1".to_string() + ]) + ); + } } diff --git a/crates/openshell-driver-docker/src/lib.rs b/crates/openshell-driver-docker/src/lib.rs index 6059596ab..b3c8535ef 100644 --- a/crates/openshell-driver-docker/src/lib.rs +++ b/crates/openshell-driver-docker/src/lib.rs @@ -24,7 +24,7 @@ use openshell_core::proto::compute::v1::{ CreateSandboxRequest, CreateSandboxResponse, DeleteSandboxRequest, DeleteSandboxResponse, DriverCondition, DriverSandbox, DriverSandboxStatus, DriverSandboxTemplate, GetCapabilitiesRequest, GetCapabilitiesResponse, GetSandboxRequest, GetSandboxResponse, - ListSandboxesRequest, ListSandboxesResponse, StopSandboxRequest, StopSandboxResponse, + GpuSpec, ListSandboxesRequest, ListSandboxesResponse, StopSandboxRequest, StopSandboxResponse, ValidateSandboxCreateRequest, ValidateSandboxCreateResponse, WatchSandboxesDeletedEvent, WatchSandboxesEvent, WatchSandboxesRequest, WatchSandboxesSandboxEvent, compute_driver_server::ComputeDriver, watch_sandboxes_event, @@ -310,7 +310,12 @@ impl DockerComputeDriver { "docker sandboxes require a template image", )); } - Self::validate_gpu_request(spec.gpu, config.supports_gpu)?; + Self::validate_gpu_request( + spec.placement + .as_ref() + .and_then(|placement| placement.gpu.as_ref()), + config.supports_gpu, + )?; if !template.agent_socket_path.trim().is_empty() { return Err(Status::failed_precondition( "docker compute driver does not support template.agent_socket_path", @@ -330,8 +335,8 @@ impl DockerComputeDriver { Ok(()) } - fn validate_gpu_request(gpu: bool, supports_gpu: bool) -> Result<(), Status> { - if gpu && !supports_gpu { + fn validate_gpu_request(gpu: Option<&GpuSpec>, supports_gpu: bool) -> Result<(), Status> { + if gpu.is_some() && !supports_gpu { return Err(Status::failed_precondition( "docker GPU sandboxes require Docker CDI support. Enable CDI on the Docker daemon, then restart the OpenShell gateway/server so GPU capability is detected.", )); @@ -945,8 +950,8 @@ fn build_environment(sandbox: &DriverSandbox, config: &DockerDriverRuntimeConfig .collect() } -fn docker_gpu_device_requests(gpu: bool, gpu_device: &str) -> Option> { - cdi_gpu_device_ids(gpu, gpu_device).map(|device_ids| { +fn docker_gpu_device_requests(gpu: Option<&GpuSpec>) -> Option> { + cdi_gpu_device_ids(gpu).map(|device_ids| { vec![DeviceRequest { driver: Some("cdi".to_string()), device_ids: Some(device_ids), @@ -996,8 +1001,12 @@ fn build_container_create_body( host_config: Some(HostConfig { nano_cpus: resource_limits.nano_cpus, memory: resource_limits.memory_bytes, - device_requests: docker_gpu_device_requests(spec.gpu, &spec.gpu_device), - binds: Some(build_binds(config)), + device_requests: docker_gpu_device_requests( + spec.placement + .as_ref() + .and_then(|placement| placement.gpu.as_ref()), + ), + mounts: Some(build_mounts(config)), restart_policy: Some(RestartPolicy { name: Some(RestartPolicyNameEnum::UNLESS_STOPPED), maximum_retry_count: None, diff --git a/crates/openshell-driver-docker/src/tests.rs b/crates/openshell-driver-docker/src/tests.rs index df68d39d6..5522c2e04 100644 --- a/crates/openshell-driver-docker/src/tests.rs +++ b/crates/openshell-driver-docker/src/tests.rs @@ -4,7 +4,8 @@ use super::*; use openshell_core::config::{CDI_GPU_DEVICE_ALL, DEFAULT_SERVER_PORT}; use openshell_core::proto::compute::v1::{ - DriverResourceRequirements, DriverSandboxSpec, DriverSandboxTemplate, + DriverResourceRequirements, DriverSandboxSpec, DriverSandboxTemplate, GpuSpec, + PlacementRequirements, }; use std::fs; use std::net::{IpAddr, Ipv4Addr, SocketAddr}; @@ -31,8 +32,7 @@ fn test_sandbox() -> DriverSandbox { resources: None, platform_config: None, }), - gpu: false, - gpu_device: String::new(), + placement: None, }), status: None, } @@ -487,7 +487,9 @@ fn build_container_create_body_clears_inherited_cmd() { fn validate_sandbox_rejects_gpu_when_cdi_unavailable() { let config = runtime_config(); let mut sandbox = test_sandbox(); - sandbox.spec.as_mut().unwrap().gpu = true; + sandbox.spec.as_mut().unwrap().placement = Some(PlacementRequirements { + gpu: Some(GpuSpec { device_id: vec![] }), + }); let err = DockerComputeDriver::validate_sandbox(&sandbox, &config).unwrap_err(); @@ -500,7 +502,9 @@ fn build_container_create_body_maps_gpu_to_all_cdi_device() { let mut config = runtime_config(); config.supports_gpu = true; let mut sandbox = test_sandbox(); - sandbox.spec.as_mut().unwrap().gpu = true; + sandbox.spec.as_mut().unwrap().placement = Some(PlacementRequirements { + gpu: Some(GpuSpec { device_id: vec![] }), + }); let create_body = build_container_create_body(&sandbox, &config).unwrap(); let request = create_body @@ -518,13 +522,18 @@ fn build_container_create_body_maps_gpu_to_all_cdi_device() { } #[test] -fn build_container_create_body_passes_explicit_cdi_device_id_through() { +fn build_container_create_body_passes_explicit_cdi_device_ids_through() { let mut config = runtime_config(); config.supports_gpu = true; let mut sandbox = test_sandbox(); - let spec = sandbox.spec.as_mut().unwrap(); - spec.gpu = true; - spec.gpu_device = "nvidia.com/gpu=0".to_string(); + sandbox.spec.as_mut().unwrap().placement = Some(PlacementRequirements { + gpu: Some(GpuSpec { + device_id: vec![ + "nvidia.com/gpu=0".to_string(), + "nvidia.com/gpu=1".to_string(), + ], + }), + }); let create_body = build_container_create_body(&sandbox, &config).unwrap(); let request = create_body @@ -537,7 +546,10 @@ fn build_container_create_body_passes_explicit_cdi_device_id_through() { assert_eq!(request.driver.as_deref(), Some("cdi")); assert_eq!( request.device_ids.as_ref().unwrap(), - &vec!["nvidia.com/gpu=0".to_string()] + &vec![ + "nvidia.com/gpu=0".to_string(), + "nvidia.com/gpu=1".to_string() + ] ); } diff --git a/crates/openshell-driver-kubernetes/src/driver.rs b/crates/openshell-driver-kubernetes/src/driver.rs index f6e6cb880..b5b963f3a 100644 --- a/crates/openshell-driver-kubernetes/src/driver.rs +++ b/crates/openshell-driver-kubernetes/src/driver.rs @@ -15,7 +15,7 @@ use openshell_core::proto::compute::v1::{ DriverCondition as SandboxCondition, DriverPlatformEvent as PlatformEvent, DriverSandbox as Sandbox, DriverSandboxSpec as SandboxSpec, DriverSandboxStatus as SandboxStatus, DriverSandboxTemplate as SandboxTemplate, - GetCapabilitiesResponse, WatchSandboxesDeletedEvent, WatchSandboxesEvent, + GetCapabilitiesResponse, GpuSpec, WatchSandboxesDeletedEvent, WatchSandboxesEvent, WatchSandboxesPlatformEvent, WatchSandboxesSandboxEvent, watch_sandboxes_event, }; use std::collections::BTreeMap; @@ -61,6 +61,15 @@ const SANDBOX_MANAGED_VALUE: &str = "openshell"; const GPU_RESOURCE_NAME: &str = "nvidia.com/gpu"; const GPU_RESOURCE_QUANTITY: &str = "1"; +fn gpu_from_spec(spec: Option<&SandboxSpec>) -> Option<&GpuSpec> { + spec.and_then(|spec| spec.placement.as_ref()) + .and_then(|placement| placement.gpu.as_ref()) +} + +fn gpu_has_explicit_device_ids(gpu: Option<&GpuSpec>) -> bool { + gpu.is_some_and(|gpu| !gpu.device_id.is_empty()) +} + // --------------------------------------------------------------------------- // Default workspace persistence (temporary — will be replaced by snapshotting) // --------------------------------------------------------------------------- @@ -194,12 +203,17 @@ impl KubernetesComputeDriver { } pub async fn validate_sandbox_create(&self, sandbox: &Sandbox) -> Result<(), tonic::Status> { - let gpu_requested = sandbox.spec.as_ref().is_some_and(|spec| spec.gpu); - self.validate_gpu_request(gpu_requested).await + let gpu = gpu_from_spec(sandbox.spec.as_ref()); + self.validate_gpu_request(gpu).await } - async fn validate_gpu_request(&self, gpu_requested: bool) -> Result<(), tonic::Status> { - if gpu_requested + async fn validate_gpu_request(&self, gpu: Option<&GpuSpec>) -> Result<(), tonic::Status> { + if gpu_has_explicit_device_ids(gpu) { + return Err(tonic::Status::invalid_argument( + "kubernetes compute driver does not support explicit GPU device IDs", + )); + } + if gpu.is_some() && !self.has_gpu_capacity().await.map_err(|err| { tonic::Status::internal(format!("check GPU node capacity failed: {err}")) })? @@ -295,6 +309,12 @@ impl KubernetesComputeDriver { } pub async fn create_sandbox(&self, sandbox: &Sandbox) -> Result<(), KubernetesDriverError> { + if gpu_has_explicit_device_ids(gpu_from_spec(sandbox.spec.as_ref())) { + return Err(KubernetesDriverError::Precondition( + "kubernetes compute driver does not support explicit GPU device IDs".to_string(), + )); + } + let name = sandbox.name.as_str(); info!( sandbox_id = %sandbox.id, @@ -1015,7 +1035,13 @@ fn sandbox_to_k8s_spec( if let Some(template) = spec.template.as_ref() { root.insert( "podTemplate".to_string(), - sandbox_template_to_k8s(template, spec.gpu, &pod_env, inject_workspace, params), + sandbox_template_to_k8s( + template, + gpu_from_spec(Some(spec)).is_some(), + &pod_env, + inject_workspace, + params, + ), ); if !template.agent_socket_path.is_empty() { root.insert( @@ -1047,7 +1073,7 @@ fn sandbox_to_k8s_spec( "podTemplate".to_string(), sandbox_template_to_k8s( &SandboxTemplate::default(), - spec.is_some_and(|s| s.gpu), + gpu_from_spec(spec).is_some(), &pod_env, inject_workspace, params, @@ -1808,6 +1834,19 @@ mod tests { ); } + #[test] + fn gpu_has_explicit_device_ids_only_when_ids_are_present() { + use openshell_core::proto::compute::v1::GpuSpec; + + assert!(!gpu_has_explicit_device_ids(None)); + assert!(!gpu_has_explicit_device_ids(Some(&GpuSpec { + device_id: vec![], + }))); + assert!(gpu_has_explicit_device_ids(Some(&GpuSpec { + device_id: vec!["nvidia.com/gpu=0".to_string()], + }))); + } + #[test] fn gpu_sandbox_uses_template_runtime_class_name_when_set() { let template = SandboxTemplate { diff --git a/crates/openshell-driver-podman/src/container.rs b/crates/openshell-driver-podman/src/container.rs index 5b9b0d735..3c974e55c 100644 --- a/crates/openshell-driver-podman/src/container.rs +++ b/crates/openshell-driver-podman/src/container.rs @@ -345,8 +345,12 @@ fn build_resource_limits(sandbox: &DriverSandbox) -> ResourceLimits { /// Build CDI GPU device list if GPU is requested. fn build_devices(sandbox: &DriverSandbox) -> Option> { - let spec = sandbox.spec.as_ref()?; - cdi_gpu_device_ids(spec.gpu, &spec.gpu_device).map(|device_ids| { + let gpu = sandbox + .spec + .as_ref() + .and_then(|spec| spec.placement.as_ref()) + .and_then(|placement| placement.gpu.as_ref()); + cdi_gpu_device_ids(gpu).map(|device_ids| { device_ids .into_iter() .map(|path| LinuxDevice { path }) @@ -782,6 +786,69 @@ mod tests { ); } + #[test] + fn container_spec_omits_devices_without_gpu_request() { + let sandbox = test_sandbox("test-id", "test-name"); + let config = test_config(); + let spec = build_container_spec(&sandbox, &config); + + assert!(spec.get("devices").is_none()); + } + + #[test] + fn container_spec_maps_empty_gpu_request_to_all_cdi_device() { + use openshell_core::config::CDI_GPU_DEVICE_ALL; + use openshell_core::proto::compute::v1::{ + DriverSandboxSpec, GpuSpec, PlacementRequirements, + }; + + let mut sandbox = test_sandbox("test-id", "test-name"); + sandbox.spec = Some(DriverSandboxSpec { + placement: Some(PlacementRequirements { + gpu: Some(GpuSpec { device_id: vec![] }), + }), + ..Default::default() + }); + let config = test_config(); + let spec = build_container_spec(&sandbox, &config); + + assert_eq!( + spec["devices"][0]["path"].as_str(), + Some(CDI_GPU_DEVICE_ALL) + ); + } + + #[test] + fn container_spec_passes_explicit_cdi_device_ids_through() { + use openshell_core::proto::compute::v1::{ + DriverSandboxSpec, GpuSpec, PlacementRequirements, + }; + + let mut sandbox = test_sandbox("test-id", "test-name"); + sandbox.spec = Some(DriverSandboxSpec { + placement: Some(PlacementRequirements { + gpu: Some(GpuSpec { + device_id: vec![ + "nvidia.com/gpu=0".to_string(), + "nvidia.com/gpu=1".to_string(), + ], + }), + }), + ..Default::default() + }); + let config = test_config(); + let spec = build_container_spec(&sandbox, &config); + + assert_eq!( + spec["devices"][0]["path"].as_str(), + Some("nvidia.com/gpu=0") + ); + assert_eq!( + spec["devices"][1]["path"].as_str(), + Some("nvidia.com/gpu=1") + ); + } + #[test] fn container_spec_uses_secret_env_not_plaintext() { let sandbox = test_sandbox("test-id", "test-name"); diff --git a/crates/openshell-driver-podman/src/driver.rs b/crates/openshell-driver-podman/src/driver.rs index f78c5c730..07cff9cfa 100644 --- a/crates/openshell-driver-podman/src/driver.rs +++ b/crates/openshell-driver-podman/src/driver.rs @@ -10,7 +10,7 @@ use crate::watcher::{ self, WatchStream, driver_sandbox_from_inspect, driver_sandbox_from_list_entry, }; use openshell_core::ComputeDriverError; -use openshell_core::proto::compute::v1::{DriverSandbox, GetCapabilitiesResponse}; +use openshell_core::proto::compute::v1::{DriverSandbox, GetCapabilitiesResponse, GpuSpec}; use tracing::{info, warn}; impl From for ComputeDriverError { @@ -198,12 +198,16 @@ impl PodmanComputeDriver { &self, sandbox: &DriverSandbox, ) -> Result<(), ComputeDriverError> { - let gpu_requested = sandbox.spec.as_ref().is_some_and(|s| s.gpu); - Self::validate_gpu_request(gpu_requested) + let gpu = sandbox + .spec + .as_ref() + .and_then(|spec| spec.placement.as_ref()) + .and_then(|placement| placement.gpu.as_ref()); + Self::validate_gpu_request(gpu) } - fn validate_gpu_request(gpu_requested: bool) -> Result<(), ComputeDriverError> { - if gpu_requested && !Self::has_gpu_capacity() { + fn validate_gpu_request(gpu: Option<&GpuSpec>) -> Result<(), ComputeDriverError> { + if gpu.is_some() && !Self::has_gpu_capacity() { return Err(ComputeDriverError::Precondition( "GPU sandbox requested, but no NVIDIA GPU devices are available.".to_string(), )); diff --git a/crates/openshell-driver-vm/src/driver.rs b/crates/openshell-driver-vm/src/driver.rs index 38ddb664b..a635e2401 100644 --- a/crates/openshell-driver-vm/src/driver.rs +++ b/crates/openshell-driver-vm/src/driver.rs @@ -25,7 +25,7 @@ use openshell_core::proto::compute::v1::{ CreateSandboxRequest, CreateSandboxResponse, DeleteSandboxRequest, DeleteSandboxResponse, DriverCondition as SandboxCondition, DriverPlatformEvent as PlatformEvent, DriverSandbox as Sandbox, DriverSandboxStatus as SandboxStatus, GetCapabilitiesRequest, - GetCapabilitiesResponse, GetSandboxRequest, GetSandboxResponse, ListSandboxesRequest, + GetCapabilitiesResponse, GetSandboxRequest, GetSandboxResponse, GpuSpec, ListSandboxesRequest, ListSandboxesResponse, StopSandboxRequest, StopSandboxResponse, ValidateSandboxCreateRequest, ValidateSandboxCreateResponse, WatchSandboxesDeletedEvent, WatchSandboxesEvent, WatchSandboxesPlatformEvent, WatchSandboxesRequest, WatchSandboxesSandboxEvent, @@ -358,10 +358,7 @@ impl VmDriver { return Err(Status::already_exists("sandbox already exists")); } - let gpu_device = sandbox - .spec - .as_ref() - .and_then(|spec| requested_gpu_device(spec.gpu, &spec.gpu_device)); + let gpu_device = requested_gpu_device(sandbox_gpu(sandbox)); let state_dir = sandbox_state_dir(&self.config.state_dir, &sandbox.id)?; let rootfs = state_dir.join("rootfs"); @@ -1462,7 +1459,7 @@ fn validate_vm_sandbox(sandbox: &Sandbox, gpu_enabled: bool) -> Result<(), Statu .as_ref() .ok_or_else(|| Status::invalid_argument("sandbox spec is required"))?; - validate_gpu_request(spec.gpu, &spec.gpu_device, gpu_enabled)?; + validate_gpu_request(sandbox_gpu(sandbox), gpu_enabled)?; if let Some(template) = spec.template.as_ref() { if !template.agent_socket_path.is_empty() { @@ -1509,20 +1506,31 @@ fn validate_sandbox_id(sandbox_id: &str) -> Result<(), Status> { Ok(()) } -fn requested_gpu_device(gpu: bool, gpu_device: &str) -> Option<&str> { - gpu.then_some(gpu_device) +fn sandbox_gpu(sandbox: &Sandbox) -> Option<&GpuSpec> { + sandbox + .spec + .as_ref() + .and_then(|spec| spec.placement.as_ref()) + .and_then(|placement| placement.gpu.as_ref()) +} + +fn requested_gpu_device(gpu: Option<&GpuSpec>) -> Option<&str> { + let gpu = gpu?; + Some(gpu.device_id.first().map_or("", String::as_str)) } #[allow(clippy::result_large_err)] -fn validate_gpu_request(gpu: bool, gpu_device: &str, gpu_enabled: bool) -> Result<(), Status> { - if gpu && !gpu_enabled { +fn validate_gpu_request(gpu: Option<&GpuSpec>, gpu_enabled: bool) -> Result<(), Status> { + if gpu.is_some() && !gpu_enabled { return Err(Status::failed_precondition( "GPU support is not enabled on this driver; start with --gpu", )); } - if !gpu && !gpu_device.is_empty() { - return Err(Status::invalid_argument("gpu_device requires gpu=true")); + if gpu.is_some_and(|gpu| gpu.device_id.len() > 1) { + return Err(Status::invalid_argument( + "vm compute driver supports at most one GPU device ID", + )); } Ok(()) } @@ -2529,7 +2537,8 @@ mod tests { use super::*; use crate::gpu::{SubnetAllocator, allocate_vsock_cid, mac_from_sandbox_id, tap_device_name}; use openshell_core::proto::compute::v1::{ - DriverSandboxSpec as SandboxSpec, DriverSandboxTemplate as SandboxTemplate, + DriverSandboxSpec as SandboxSpec, DriverSandboxTemplate as SandboxTemplate, GpuSpec, + PlacementRequirements, }; use prost_types::{Struct, Value, value::Kind}; use std::fs; @@ -2543,7 +2552,9 @@ mod tests { let sandbox = Sandbox { id: "sandbox-123".to_string(), spec: Some(SandboxSpec { - gpu: true, + placement: Some(PlacementRequirements { + gpu: Some(GpuSpec { device_id: vec![] }), + }), ..Default::default() }), ..Default::default() @@ -2559,7 +2570,9 @@ mod tests { let sandbox = Sandbox { id: "sandbox-123".to_string(), spec: Some(SandboxSpec { - gpu: true, + placement: Some(PlacementRequirements { + gpu: Some(GpuSpec { device_id: vec![] }), + }), ..Default::default() }), ..Default::default() @@ -2568,38 +2581,44 @@ mod tests { } #[test] - fn validate_vm_sandbox_rejects_gpu_device_without_gpu() { + fn validate_vm_sandbox_rejects_multiple_gpu_device_ids() { let sandbox = Sandbox { id: "sandbox-123".to_string(), spec: Some(SandboxSpec { - gpu: false, - gpu_device: "0000:2d:00.0".to_string(), + placement: Some(PlacementRequirements { + gpu: Some(GpuSpec { + device_id: vec!["0000:2d:00.0".to_string(), "0000:3d:00.0".to_string()], + }), + }), ..Default::default() }), ..Default::default() }; let err = validate_vm_sandbox(&sandbox, true) - .expect_err("gpu_device without gpu should be rejected"); + .expect_err("multiple GPU device IDs should be rejected"); assert_eq!(err.code(), Code::InvalidArgument); - assert!(err.message().contains("gpu_device requires gpu=true")); + assert!(err.message().contains("at most one GPU device ID")); } #[test] fn requested_gpu_device_returns_none_without_gpu_request() { - assert_eq!(requested_gpu_device(false, ""), None); + assert_eq!(requested_gpu_device(None), None); } #[test] fn requested_gpu_device_defaults_empty_request_to_inventory_choice() { - assert_eq!(requested_gpu_device(true, ""), Some("")); + let gpu = GpuSpec { device_id: vec![] }; + + assert_eq!(requested_gpu_device(Some(&gpu)), Some("")); } #[test] - fn requested_gpu_device_returns_explicit_device_id() { - assert_eq!( - requested_gpu_device(true, "0000:2d:00.0"), - Some("0000:2d:00.0") - ); + fn requested_gpu_device_returns_first_explicit_device_id() { + let gpu = GpuSpec { + device_id: vec!["0000:2d:00.0".to_string()], + }; + + assert_eq!(requested_gpu_device(Some(&gpu)), Some("0000:2d:00.0")); } #[test] diff --git a/crates/openshell-server/src/compute/mod.rs b/crates/openshell-server/src/compute/mod.rs index d2fd34011..ca49654c9 100644 --- a/crates/openshell-server/src/compute/mod.rs +++ b/crates/openshell-server/src/compute/mod.rs @@ -18,7 +18,8 @@ use futures::{Stream, StreamExt}; use openshell_core::proto::compute::v1::{ CreateSandboxRequest, DeleteSandboxRequest, DriverCondition, DriverPlatformEvent, DriverResourceRequirements, DriverSandbox, DriverSandboxSpec, DriverSandboxStatus, - DriverSandboxTemplate, GetCapabilitiesRequest, GetSandboxRequest, ListSandboxesRequest, + DriverSandboxTemplate, GetCapabilitiesRequest, GetSandboxRequest, GpuSpec as DriverGpuSpec, + ListSandboxesRequest, PlacementRequirements as DriverPlacementRequirements, ValidateSandboxCreateRequest, WatchSandboxesEvent, WatchSandboxesRequest, compute_driver_client::ComputeDriverClient, compute_driver_server::ComputeDriver, watch_sandboxes_event, @@ -1130,8 +1131,14 @@ fn driver_sandbox_spec_from_public(spec: &SandboxSpec) -> DriverSandboxSpec { .template .as_ref() .map(driver_sandbox_template_from_public), - gpu: spec.gpu, - gpu_device: spec.gpu_device.clone(), + placement: spec + .placement + .as_ref() + .map(|placement| DriverPlacementRequirements { + gpu: placement.gpu.as_ref().map(|gpu| DriverGpuSpec { + device_id: gpu.device_id.clone(), + }), + }), } } @@ -1491,7 +1498,9 @@ fn derive_phase(status: Option<&DriverSandboxStatus>) -> SandboxPhase { } fn rewrite_user_facing_conditions(status: &mut Option, spec: Option<&SandboxSpec>) { - let gpu_requested = spec.is_some_and(|sandbox_spec| sandbox_spec.gpu); + let gpu_requested = spec + .and_then(|sandbox_spec| sandbox_spec.placement.as_ref()) + .is_some_and(|placement| placement.gpu.is_some()); if !gpu_requested { return; } @@ -1653,6 +1662,7 @@ mod tests { CreateSandboxResponse, DeleteSandboxResponse, GetCapabilitiesResponse, GetSandboxRequest, GetSandboxResponse, StopSandboxRequest, StopSandboxResponse, ValidateSandboxCreateResponse, }; + use openshell_core::proto::{GpuSpec, PlacementRequirements}; use std::collections::HashMap; use std::sync::Arc; use tokio::sync::{mpsc, oneshot}; @@ -1669,6 +1679,30 @@ mod tests { } } + #[test] + fn driver_sandbox_spec_from_public_preserves_gpu_request_device_ids() { + let public = SandboxSpec { + placement: Some(PlacementRequirements { + gpu: Some(GpuSpec { + device_id: vec!["nvidia.com/gpu=0".to_string()], + }), + }), + ..Default::default() + }; + + let driver = driver_sandbox_spec_from_public(&public); + + assert_eq!( + driver + .placement + .expect("driver placement requirements should be present") + .gpu + .expect("driver GPU request should be present") + .device_id, + vec!["nvidia.com/gpu=0".to_string()] + ); + } + fn struct_value( fields: impl IntoIterator, prost_types::Value)>, ) -> prost_types::Value { @@ -2117,7 +2151,9 @@ mod tests { rewrite_user_facing_conditions( &mut status, Some(&SandboxSpec { - gpu: true, + placement: Some(PlacementRequirements { + gpu: Some(GpuSpec { device_id: vec![] }), + }), ..Default::default() }), ); @@ -2149,7 +2185,7 @@ mod tests { rewrite_user_facing_conditions( &mut status, Some(&SandboxSpec { - gpu: false, + placement: None, ..Default::default() }), ); @@ -2376,7 +2412,9 @@ mod tests { let sandbox = Sandbox { spec: Some(SandboxSpec { - gpu: true, + placement: Some(PlacementRequirements { + gpu: Some(GpuSpec { device_id: vec![] }), + }), ..Default::default() }), ..sandbox_record("sb-1", "sandbox-a", SandboxPhase::Provisioning) @@ -2399,7 +2437,13 @@ mod tests { SandboxPhase::try_from(stored.phase).unwrap(), SandboxPhase::Ready ); - assert!(stored.spec.as_ref().is_some_and(|spec| spec.gpu)); + assert!( + stored + .spec + .as_ref() + .and_then(|spec| spec.placement.as_ref()) + .is_some_and(|placement| placement.gpu.is_some()) + ); } #[tokio::test] diff --git a/crates/openshell-server/src/grpc/validation.rs b/crates/openshell-server/src/grpc/validation.rs index 160b7e031..7b34a188b 100644 --- a/crates/openshell-server/src/grpc/validation.rs +++ b/crates/openshell-server/src/grpc/validation.rs @@ -642,7 +642,7 @@ pub(super) fn level_matches(log_level: &str, min_level: &str) -> bool { #[cfg(test)] mod tests { use super::*; - use openshell_core::proto::SandboxSpec; + use openshell_core::proto::{GpuSpec, PlacementRequirements, SandboxSpec}; use std::collections::HashMap; use tonic::Code; @@ -668,7 +668,9 @@ mod tests { #[test] fn validate_sandbox_spec_accepts_gpu_flag() { let spec = SandboxSpec { - gpu: true, + placement: Some(PlacementRequirements { + gpu: Some(GpuSpec { device_id: vec![] }), + }), ..Default::default() }; assert!(validate_sandbox_spec("gpu-sandbox", &spec).is_ok()); diff --git a/e2e/python/conftest.py b/e2e/python/conftest.py index 712704929..5b3b1b882 100644 --- a/e2e/python/conftest.py +++ b/e2e/python/conftest.py @@ -101,6 +101,8 @@ def gpu_sandbox_spec() -> datamodel_pb2.SandboxSpec: # override (e.g. a locally-built or registry-mirrored image). image = os.environ.get("OPENSHELL_E2E_GPU_IMAGE", "") return datamodel_pb2.SandboxSpec( - gpu=True, + placement=datamodel_pb2.PlacementRequirements( + gpu=datamodel_pb2.GPUSpec(), + ), template=datamodel_pb2.SandboxTemplate(image=image), ) diff --git a/proto/compute_driver.proto b/proto/compute_driver.proto index 3c4308f3f..5cee0cbb4 100644 --- a/proto/compute_driver.proto +++ b/proto/compute_driver.proto @@ -78,18 +78,29 @@ message DriverSandbox { // Driver-owned provisioning inputs required to create a sandbox. message DriverSandboxSpec { + reserved 9, 10; + // Log level exposed to processes running inside the sandbox. string log_level = 1; // Environment variables injected into the sandbox runtime. map environment = 5; // Runtime template consumed by the driver during provisioning. DriverSandboxTemplate template = 6; - // Request NVIDIA GPU resources for this sandbox. - bool gpu = 9; - // Optional PCI BDF address (e.g. "0000:2d:00.0") or device index - // (e.g. "0", "1"). When empty with gpu=true, the driver assigns the - // first available GPU. - string gpu_device = 10; + // Optional placement requirements for the sandbox workload. + PlacementRequirements placement = 11; +} + +// Driver-owned placement requirements for selecting compatible compute resources. +message PlacementRequirements { + // Request GPU resources for this sandbox. Presence indicates a GPU request. + GPUSpec gpu = 1; +} + +// Driver-native GPU placement details. +message GPUSpec { + // Optional driver-native device identifiers. Empty means the driver chooses + // its default GPU assignment behavior. + repeated string device_id = 1; } // Driver-owned runtime template consumed by the compute platform. diff --git a/proto/openshell.proto b/proto/openshell.proto index bb2ce6cec..093521723 100644 --- a/proto/openshell.proto +++ b/proto/openshell.proto @@ -240,6 +240,8 @@ message Sandbox { // Desired sandbox configuration provided through the public API. message SandboxSpec { + reserved 9, 10; + // Log level exposed to processes running inside the sandbox. string log_level = 1; // Environment variables injected into the sandbox runtime. @@ -250,12 +252,22 @@ message SandboxSpec { openshell.sandbox.v1.SandboxPolicy policy = 7; // Provider names to attach to this sandbox. repeated string providers = 8; - // Request NVIDIA GPU resources for this sandbox. - bool gpu = 9; - // Optional PCI BDF address (e.g. "0000:2d:00.0") or device index - // (e.g. "0", "1"). When empty with gpu=true, the driver assigns the - // first available GPU. - string gpu_device = 10; + // Optional placement requirements for the sandbox workload. + PlacementRequirements placement = 11; +} + +// Public placement requirements for selecting compatible compute resources. +message PlacementRequirements { + // Request GPU resources for this sandbox. Presence indicates a GPU request. + GPUSpec gpu = 1; +} + +// Public GPU placement details. Device identifiers are interpreted by the +// selected compute driver. +message GPUSpec { + // Optional driver-native device identifiers. Empty means the driver chooses + // its default GPU assignment behavior. + repeated string device_id = 1; } // Public sandbox template mapped onto compute-driver template inputs. diff --git a/python/openshell/_proto/__init__.py b/python/openshell/_proto/__init__.py index 3ace22421..0763b7b53 100644 --- a/python/openshell/_proto/__init__.py +++ b/python/openshell/_proto/__init__.py @@ -2,7 +2,13 @@ # Sandbox messages and phase enums moved into openshell.proto. Keep aliases on # datamodel_pb2 so existing Python callers and E2E tests continue to work. -for _name in ("Sandbox", "SandboxSpec", "SandboxTemplate"): +for _name in ( + "Sandbox", + "SandboxSpec", + "SandboxTemplate", + "PlacementRequirements", + "GPUSpec", +): if not hasattr(datamodel_pb2, _name): setattr(datamodel_pb2, _name, getattr(openshell_pb2, _name)) From 930c5819838be78290b506b6c59a99d8dfe67195 Mon Sep 17 00:00:00 2001 From: Evan Lezar Date: Wed, 13 May 2026 17:55:10 +0200 Subject: [PATCH 4/4] feat(gpu): add GPU count placement support Signed-off-by: Evan Lezar --- crates/openshell-cli/src/main.rs | 55 +++++++- crates/openshell-cli/src/run.rs | 46 +++++-- .../sandbox_create_lifecycle_integration.rs | 5 + crates/openshell-core/src/gpu.rs | 7 +- crates/openshell-driver-docker/src/lib.rs | 17 ++- crates/openshell-driver-docker/src/tests.rs | 29 ++++- .../openshell-driver-kubernetes/src/driver.rs | 121 +++++++++++++----- .../openshell-driver-podman/src/container.rs | 53 +------- crates/openshell-driver-podman/src/driver.rs | 31 +++++ crates/openshell-driver-vm/src/driver.rs | 93 +++++++++++++- crates/openshell-server/src/compute/mod.rs | 37 +++++- .../openshell-server/src/grpc/validation.rs | 72 ++++++++++- docs/sandboxes/manage-sandboxes.mdx | 13 ++ proto/compute_driver.proto | 6 +- proto/openshell.proto | 6 +- 15 files changed, 483 insertions(+), 108 deletions(-) diff --git a/crates/openshell-cli/src/main.rs b/crates/openshell-cli/src/main.rs index 9cffb243b..c7ab0702e 100644 --- a/crates/openshell-cli/src/main.rs +++ b/crates/openshell-cli/src/main.rs @@ -1029,6 +1029,7 @@ enum DoctorCommands { Check, } +#[allow(clippy::large_enum_variant)] #[derive(Subcommand, Debug)] enum SandboxCommands { /// Create a sandbox. @@ -1087,9 +1088,13 @@ enum SandboxCommands { /// Target a driver-specific GPU device. Docker and Podman use CDI device IDs /// (for example "nvidia.com/gpu=0"); VM uses a PCI BDF or index. /// Only valid with --gpu. When omitted with --gpu, the driver uses its default GPU selection. - #[arg(long, requires = "gpu")] + #[arg(long, requires = "gpu", conflicts_with = "gpu_count")] gpu_device: Option, + /// Request a specific number of GPUs. Mutually exclusive with --gpu-device. + #[arg(long, value_parser = clap::value_parser!(u32).range(1..), conflicts_with = "gpu_device")] + gpu_count: Option, + /// Provider names to attach to this sandbox. #[arg(long = "provider")] providers: Vec, @@ -2365,6 +2370,7 @@ async fn main() -> Result<()> { editor, gpu, gpu_device, + gpu_count, providers, policy, forward, @@ -2431,6 +2437,7 @@ async fn main() -> Result<()> { keep, gpu, gpu_device.as_deref(), + gpu_count, editor, &providers, policy.as_deref(), @@ -3641,6 +3648,52 @@ mod tests { } } + #[test] + fn sandbox_create_gpu_count_parses_without_gpu_flag() { + let cli = Cli::try_parse_from(["openshell", "sandbox", "create", "--gpu-count", "2"]) + .expect("sandbox create --gpu-count should parse"); + + if let Some(Commands::Sandbox { + command: Some(SandboxCommands::Create { gpu, gpu_count, .. }), + .. + }) = cli.command + { + assert!(!gpu); + assert_eq!(gpu_count, Some(2)); + } else { + panic!("expected SandboxCommands::Create"); + } + } + + #[test] + fn sandbox_create_gpu_count_rejects_zero() { + let result = Cli::try_parse_from(["openshell", "sandbox", "create", "--gpu-count", "0"]); + + assert!( + result.is_err(), + "sandbox create --gpu-count 0 should be rejected" + ); + } + + #[test] + fn sandbox_create_gpu_count_conflicts_with_gpu_device() { + let result = Cli::try_parse_from([ + "openshell", + "sandbox", + "create", + "--gpu", + "--gpu-device", + "0", + "--gpu-count", + "2", + ]); + + assert!( + result.is_err(), + "sandbox create should reject --gpu-count with --gpu-device" + ); + } + #[test] fn service_expose_accepts_positional_target_port_and_service() { let cli = Cli::try_parse_from([ diff --git a/crates/openshell-cli/src/run.rs b/crates/openshell-cli/src/run.rs index a5aac80a4..e89b258dc 100644 --- a/crates/openshell-cli/src/run.rs +++ b/crates/openshell-cli/src/run.rs @@ -1468,6 +1468,7 @@ pub async fn sandbox_create( keep: bool, gpu: bool, gpu_device: Option<&str>, + gpu_count: Option, editor: Option, providers: &[String], policy: Option<&str>, @@ -1518,7 +1519,8 @@ pub async fn sandbox_create( } None => None, }; - let requested_gpu = gpu || image.as_deref().is_some_and(image_requests_gpu); + let requested_gpu = + gpu || gpu_count.is_some() || image.as_deref().is_some_and(image_requests_gpu); let inferred_types: Vec = inferred_provider_type(command).into_iter().collect(); let configured_providers = ensure_required_providers( @@ -1538,7 +1540,7 @@ pub async fn sandbox_create( let request = CreateSandboxRequest { spec: Some(SandboxSpec { - placement: placement_requirements_from_cli(requested_gpu, gpu_device), + placement: placement_requirements_from_cli(requested_gpu, gpu_device, gpu_count), policy, providers: configured_providers, template, @@ -1973,13 +1975,20 @@ pub async fn sandbox_create( fn placement_requirements_from_cli( requested_gpu: bool, gpu_device: Option<&str>, + gpu_count: Option, ) -> Option { + let requested_gpu = requested_gpu || gpu_count.is_some(); requested_gpu.then(|| PlacementRequirements { gpu: Some(GpuSpec { - device_id: gpu_device - .filter(|device_id| !device_id.is_empty()) - .map(|device_id| vec![device_id.to_string()]) - .unwrap_or_default(), + device_id: if gpu_count.is_none() { + gpu_device + .filter(|device_id| !device_id.is_empty()) + .map(|device_id| vec![device_id.to_string()]) + .unwrap_or_default() + } else { + Vec::new() + }, + count: gpu_count, }), }) } @@ -6033,7 +6042,6 @@ mod tests { parse_credential_pairs, placement_requirements_from_cli, plaintext_gateway_is_remote, provisioning_timeout_message, ready_false_condition_message, resolve_from, sandbox_should_persist, service_expose_status_error, service_url_for_gateway, - source_requests_gpu, }; use crate::TEST_ENV_LOCK; use hyper::StatusCode; @@ -6311,32 +6319,44 @@ mod tests { } #[test] - fn source_requests_gpu_detects_known_community_gpu_name() { - assert!(source_requests_gpu("nvidia-gpu")); - assert!(!source_requests_gpu("base")); + fn image_requests_gpu_detects_known_community_gpu_name() { + assert!(image_requests_gpu("nvidia-gpu")); + assert!(!image_requests_gpu("base")); } #[test] fn placement_requirements_from_cli_uses_presence_with_empty_device_ids_for_default_gpu() { - let request = placement_requirements_from_cli(true, None) + let request = placement_requirements_from_cli(true, None, None) .expect("placement requirements should be present"); let gpu = request.gpu.expect("gpu request should be present"); assert!(gpu.device_id.is_empty()); + assert_eq!(gpu.count, None); } #[test] fn placement_requirements_from_cli_maps_gpu_device_to_one_device_id() { - let request = placement_requirements_from_cli(true, Some("0000:2d:00.0")) + let request = placement_requirements_from_cli(true, Some("0000:2d:00.0"), None) .expect("placement requirements should be present"); let gpu = request.gpu.expect("gpu request should be present"); assert_eq!(gpu.device_id, vec!["0000:2d:00.0"]); + assert_eq!(gpu.count, None); + } + + #[test] + fn placement_requirements_from_cli_maps_gpu_count() { + let request = placement_requirements_from_cli(false, None, Some(2)) + .expect("placement requirements should be present"); + let gpu = request.gpu.expect("gpu request should be present"); + + assert!(gpu.device_id.is_empty()); + assert_eq!(gpu.count, Some(2)); } #[test] fn placement_requirements_from_cli_omits_placement_when_not_requested() { - assert!(placement_requirements_from_cli(false, Some("0")).is_none()); + assert!(placement_requirements_from_cli(false, Some("0"), None).is_none()); } #[test] diff --git a/crates/openshell-cli/tests/sandbox_create_lifecycle_integration.rs b/crates/openshell-cli/tests/sandbox_create_lifecycle_integration.rs index 6e7d66d11..ab6aa1d38 100644 --- a/crates/openshell-cli/tests/sandbox_create_lifecycle_integration.rs +++ b/crates/openshell-cli/tests/sandbox_create_lifecycle_integration.rs @@ -671,6 +671,7 @@ async fn sandbox_create_keeps_command_sessions_by_default() { false, None, None, + None, &[], None, None, @@ -710,6 +711,7 @@ async fn sandbox_create_deletes_command_sessions_with_no_keep() { false, None, None, + None, &[], None, None, @@ -752,6 +754,7 @@ async fn sandbox_create_deletes_shell_sessions_with_no_keep() { false, None, None, + None, &[], None, None, @@ -794,6 +797,7 @@ async fn sandbox_create_keeps_sandbox_with_hidden_keep_flag() { false, None, None, + None, &[], None, None, @@ -836,6 +840,7 @@ async fn sandbox_create_keeps_sandbox_with_forwarding() { false, None, None, + None, &[], None, Some(openshell_core::forward::ForwardSpec::new(forward_port)), diff --git a/crates/openshell-core/src/gpu.rs b/crates/openshell-core/src/gpu.rs index 03a85d435..85b6c9644 100644 --- a/crates/openshell-core/src/gpu.rs +++ b/crates/openshell-core/src/gpu.rs @@ -30,7 +30,10 @@ mod tests { #[test] fn cdi_gpu_device_ids_defaults_empty_request_to_all_gpus() { - let request = GpuSpec { device_id: vec![] }; + let request = GpuSpec { + device_id: vec![], + count: None, + }; assert_eq!( cdi_gpu_device_ids(Some(&request)), @@ -42,6 +45,7 @@ mod tests { fn cdi_gpu_device_ids_passes_single_device_id_through() { let request = GpuSpec { device_id: vec!["nvidia.com/gpu=0".to_string()], + count: None, }; assert_eq!( @@ -57,6 +61,7 @@ mod tests { "nvidia.com/gpu=0".to_string(), "nvidia.com/gpu=1".to_string(), ], + count: None, }; assert_eq!( diff --git a/crates/openshell-driver-docker/src/lib.rs b/crates/openshell-driver-docker/src/lib.rs index b3c8535ef..67182a8b7 100644 --- a/crates/openshell-driver-docker/src/lib.rs +++ b/crates/openshell-driver-docker/src/lib.rs @@ -336,6 +336,21 @@ impl DockerComputeDriver { } fn validate_gpu_request(gpu: Option<&GpuSpec>, supports_gpu: bool) -> Result<(), Status> { + if let Some(gpu) = gpu { + if gpu.count == Some(0) { + return Err(Status::invalid_argument("gpu.count must be greater than 0")); + } + if gpu.count.is_some() && !gpu.device_id.is_empty() { + return Err(Status::invalid_argument( + "gpu.count is mutually exclusive with gpu.device_id", + )); + } + if gpu.count.is_some() { + return Err(Status::invalid_argument( + "docker compute driver does not support GPU count requests", + )); + } + } if gpu.is_some() && !supports_gpu { return Err(Status::failed_precondition( "docker GPU sandboxes require Docker CDI support. Enable CDI on the Docker daemon, then restart the OpenShell gateway/server so GPU capability is detected.", @@ -1006,7 +1021,7 @@ fn build_container_create_body( .as_ref() .and_then(|placement| placement.gpu.as_ref()), ), - mounts: Some(build_mounts(config)), + binds: Some(build_binds(config)), restart_policy: Some(RestartPolicy { name: Some(RestartPolicyNameEnum::UNLESS_STOPPED), maximum_retry_count: None, diff --git a/crates/openshell-driver-docker/src/tests.rs b/crates/openshell-driver-docker/src/tests.rs index 5522c2e04..ab4173f5f 100644 --- a/crates/openshell-driver-docker/src/tests.rs +++ b/crates/openshell-driver-docker/src/tests.rs @@ -488,7 +488,10 @@ fn validate_sandbox_rejects_gpu_when_cdi_unavailable() { let config = runtime_config(); let mut sandbox = test_sandbox(); sandbox.spec.as_mut().unwrap().placement = Some(PlacementRequirements { - gpu: Some(GpuSpec { device_id: vec![] }), + gpu: Some(GpuSpec { + device_id: vec![], + count: None, + }), }); let err = DockerComputeDriver::validate_sandbox(&sandbox, &config).unwrap_err(); @@ -497,13 +500,34 @@ fn validate_sandbox_rejects_gpu_when_cdi_unavailable() { assert!(err.message().contains("Docker CDI")); } +#[test] +fn validate_sandbox_rejects_gpu_count() { + let mut config = runtime_config(); + config.supports_gpu = true; + let mut sandbox = test_sandbox(); + sandbox.spec.as_mut().unwrap().placement = Some(PlacementRequirements { + gpu: Some(GpuSpec { + device_id: vec![], + count: Some(2), + }), + }); + + let err = DockerComputeDriver::validate_sandbox(&sandbox, &config).unwrap_err(); + + assert_eq!(err.code(), tonic::Code::InvalidArgument); + assert!(err.message().contains("does not support GPU count")); +} + #[test] fn build_container_create_body_maps_gpu_to_all_cdi_device() { let mut config = runtime_config(); config.supports_gpu = true; let mut sandbox = test_sandbox(); sandbox.spec.as_mut().unwrap().placement = Some(PlacementRequirements { - gpu: Some(GpuSpec { device_id: vec![] }), + gpu: Some(GpuSpec { + device_id: vec![], + count: None, + }), }); let create_body = build_container_create_body(&sandbox, &config).unwrap(); @@ -532,6 +556,7 @@ fn build_container_create_body_passes_explicit_cdi_device_ids_through() { "nvidia.com/gpu=0".to_string(), "nvidia.com/gpu=1".to_string(), ], + count: None, }), }); diff --git a/crates/openshell-driver-kubernetes/src/driver.rs b/crates/openshell-driver-kubernetes/src/driver.rs index b5b963f3a..9ea55815d 100644 --- a/crates/openshell-driver-kubernetes/src/driver.rs +++ b/crates/openshell-driver-kubernetes/src/driver.rs @@ -59,7 +59,7 @@ const SANDBOX_ID_LABEL: &str = "openshell.ai/sandbox-id"; const SANDBOX_MANAGED_LABEL: &str = "openshell.ai/managed-by"; const SANDBOX_MANAGED_VALUE: &str = "openshell"; const GPU_RESOURCE_NAME: &str = "nvidia.com/gpu"; -const GPU_RESOURCE_QUANTITY: &str = "1"; +const DEFAULT_GPU_COUNT: u32 = 1; fn gpu_from_spec(spec: Option<&SandboxSpec>) -> Option<&GpuSpec> { spec.and_then(|spec| spec.placement.as_ref()) @@ -208,6 +208,18 @@ impl KubernetesComputeDriver { } async fn validate_gpu_request(&self, gpu: Option<&GpuSpec>) -> Result<(), tonic::Status> { + if let Some(gpu) = gpu { + if gpu.count == Some(0) { + return Err(tonic::Status::invalid_argument( + "gpu.count must be greater than 0", + )); + } + if gpu.count.is_some() && !gpu.device_id.is_empty() { + return Err(tonic::Status::invalid_argument( + "gpu.count is mutually exclusive with gpu.device_id", + )); + } + } if gpu_has_explicit_device_ids(gpu) { return Err(tonic::Status::invalid_argument( "kubernetes compute driver does not support explicit GPU device IDs", @@ -309,10 +321,23 @@ impl KubernetesComputeDriver { } pub async fn create_sandbox(&self, sandbox: &Sandbox) -> Result<(), KubernetesDriverError> { - if gpu_has_explicit_device_ids(gpu_from_spec(sandbox.spec.as_ref())) { - return Err(KubernetesDriverError::Precondition( - "kubernetes compute driver does not support explicit GPU device IDs".to_string(), - )); + if let Some(gpu) = gpu_from_spec(sandbox.spec.as_ref()) { + if gpu.count == Some(0) { + return Err(KubernetesDriverError::Precondition( + "gpu.count must be greater than 0".to_string(), + )); + } + if gpu.count.is_some() && !gpu.device_id.is_empty() { + return Err(KubernetesDriverError::Precondition( + "gpu.count is mutually exclusive with gpu.device_id".to_string(), + )); + } + if gpu_has_explicit_device_ids(Some(gpu)) { + return Err(KubernetesDriverError::Precondition( + "kubernetes compute driver does not support explicit GPU device IDs" + .to_string(), + )); + } } let name = sandbox.name.as_str(); @@ -1037,7 +1062,7 @@ fn sandbox_to_k8s_spec( "podTemplate".to_string(), sandbox_template_to_k8s( template, - gpu_from_spec(Some(spec)).is_some(), + gpu_from_spec(Some(spec)), &pod_env, inject_workspace, params, @@ -1073,7 +1098,7 @@ fn sandbox_to_k8s_spec( "podTemplate".to_string(), sandbox_template_to_k8s( &SandboxTemplate::default(), - gpu_from_spec(spec).is_some(), + gpu_from_spec(spec), &pod_env, inject_workspace, params, @@ -1088,7 +1113,7 @@ fn sandbox_to_k8s_spec( fn sandbox_template_to_k8s( template: &SandboxTemplate, - gpu: bool, + gpu: Option<&GpuSpec>, spec_environment: &std::collections::HashMap, inject_workspace: bool, params: &SandboxPodParams<'_>, @@ -1121,7 +1146,7 @@ fn sandbox_template_to_k8s( if use_user_namespaces { spec.insert("hostUsers".to_string(), serde_json::json!(false)); - if gpu { + if gpu.is_some() { warn!( "GPU sandbox with user namespaces enabled — \ NVIDIA device plugin compatibility is unverified" @@ -1255,7 +1280,10 @@ fn sandbox_template_to_k8s( result } -fn container_resources(template: &SandboxTemplate, gpu: bool) -> Option { +fn container_resources( + template: &SandboxTemplate, + gpu: Option<&GpuSpec>, +) -> Option { // Start from the raw resources passthrough in platform_config (preserves // custom resource types like GPU limits that users set via the public API // Struct), then overlay the typed DriverResourceRequirements on top. @@ -1277,8 +1305,9 @@ fn container_resources(template: &SandboxTemplate, gpu: bool) -> Option Option) -> GpuSpec { + GpuSpec { + device_id: vec![], + count, + } + } + #[test] fn apply_required_env_always_injects_ssh_handshake_secret() { let mut env = Vec::new(); @@ -1817,7 +1853,7 @@ mod tests { let params = SandboxPodParams::default(); sandbox_template_to_k8s( &SandboxTemplate::default(), - true, + Some(&gpu_spec(None)), &std::collections::HashMap::new(), true, ¶ms, @@ -1830,7 +1866,26 @@ mod tests { ); assert_eq!( pod_template["spec"]["containers"][0]["resources"]["limits"][GPU_RESOURCE_NAME], - serde_json::json!(GPU_RESOURCE_QUANTITY) + serde_json::json!(DEFAULT_GPU_COUNT.to_string()) + ); + } + + #[test] + fn gpu_sandbox_uses_requested_gpu_count() { + let pod_template = { + let params = SandboxPodParams::default(); + sandbox_template_to_k8s( + &SandboxTemplate::default(), + Some(&gpu_spec(Some(2))), + &std::collections::HashMap::new(), + true, + ¶ms, + ) + }; + + assert_eq!( + pod_template["spec"]["containers"][0]["resources"]["limits"][GPU_RESOURCE_NAME], + serde_json::json!("2") ); } @@ -1841,9 +1896,11 @@ mod tests { assert!(!gpu_has_explicit_device_ids(None)); assert!(!gpu_has_explicit_device_ids(Some(&GpuSpec { device_id: vec![], + count: None, }))); assert!(gpu_has_explicit_device_ids(Some(&GpuSpec { device_id: vec!["nvidia.com/gpu=0".to_string()], + count: None, }))); } @@ -1866,7 +1923,7 @@ mod tests { let params = SandboxPodParams::default(); sandbox_template_to_k8s( &template, - true, + Some(&gpu_spec(None)), &std::collections::HashMap::new(), true, ¶ms, @@ -1898,7 +1955,7 @@ mod tests { let params = SandboxPodParams::default(); sandbox_template_to_k8s( &template, - false, + None, &std::collections::HashMap::new(), true, ¶ms, @@ -1926,7 +1983,7 @@ mod tests { let params = SandboxPodParams::default(); sandbox_template_to_k8s( &template, - true, + Some(&gpu_spec(None)), &std::collections::HashMap::new(), true, ¶ms, @@ -1937,7 +1994,7 @@ mod tests { assert_eq!(limits["cpu"], serde_json::json!("2")); assert_eq!( limits[GPU_RESOURCE_NAME], - serde_json::json!(GPU_RESOURCE_QUANTITY) + serde_json::json!(DEFAULT_GPU_COUNT.to_string()) ); } @@ -1950,7 +2007,7 @@ mod tests { }; sandbox_template_to_k8s( &SandboxTemplate::default(), - false, + None, &std::collections::HashMap::new(), true, ¶ms, @@ -1975,7 +2032,7 @@ mod tests { let params = SandboxPodParams::default(); sandbox_template_to_k8s( &SandboxTemplate::default(), - false, + None, &std::collections::HashMap::new(), true, ¶ms, @@ -1998,7 +2055,7 @@ mod tests { }; sandbox_template_to_k8s( &template, - false, + None, &std::collections::HashMap::new(), true, ¶ms, @@ -2137,7 +2194,7 @@ mod tests { }; let pod_template = sandbox_template_to_k8s( &SandboxTemplate::default(), - false, + None, &std::collections::HashMap::new(), false, // user provided custom VCTs ¶ms, @@ -2175,7 +2232,7 @@ mod tests { }; sandbox_template_to_k8s( &SandboxTemplate::default(), - false, + None, &std::collections::HashMap::new(), true, ¶ms, @@ -2240,7 +2297,7 @@ mod tests { let params = SandboxPodParams::default(); // cluster default is off let pod_template = sandbox_template_to_k8s( &template, - false, + None, &std::collections::HashMap::new(), true, ¶ms, @@ -2278,7 +2335,7 @@ mod tests { }; let pod_template = sandbox_template_to_k8s( &template, - false, + None, &std::collections::HashMap::new(), true, ¶ms, @@ -2304,7 +2361,7 @@ mod tests { let params = SandboxPodParams::default(); sandbox_template_to_k8s( &SandboxTemplate::default(), - false, + None, &std::collections::HashMap::new(), true, ¶ms, @@ -2399,7 +2456,7 @@ mod tests { let params = SandboxPodParams::default(); sandbox_template_to_k8s( &template, - false, + None, &std::collections::HashMap::new(), false, ¶ms, @@ -2460,7 +2517,7 @@ mod tests { let params = SandboxPodParams::default(); sandbox_template_to_k8s( &template, - false, + None, &std::collections::HashMap::new(), false, ¶ms, diff --git a/crates/openshell-driver-podman/src/container.rs b/crates/openshell-driver-podman/src/container.rs index 3c974e55c..dcad04a64 100644 --- a/crates/openshell-driver-podman/src/container.rs +++ b/crates/openshell-driver-podman/src/container.rs @@ -691,53 +691,6 @@ mod tests { assert_eq!(short_id("short"), "short"); } - #[test] - fn container_spec_omits_devices_without_gpu_request() { - let sandbox = test_sandbox("test-id", "test-name"); - let config = test_config(); - let spec = build_container_spec(&sandbox, &config); - - assert!(spec.get("devices").is_none()); - } - - #[test] - fn container_spec_maps_empty_gpu_request_to_all_cdi_device() { - use openshell_core::config::CDI_GPU_DEVICE_ALL; - use openshell_core::proto::compute::v1::DriverSandboxSpec; - - let mut sandbox = test_sandbox("test-id", "test-name"); - sandbox.spec = Some(DriverSandboxSpec { - gpu: true, - ..Default::default() - }); - let config = test_config(); - let spec = build_container_spec(&sandbox, &config); - - assert_eq!( - spec["devices"][0]["path"].as_str(), - Some(CDI_GPU_DEVICE_ALL) - ); - } - - #[test] - fn container_spec_passes_explicit_cdi_device_id_through() { - use openshell_core::proto::compute::v1::DriverSandboxSpec; - - let mut sandbox = test_sandbox("test-id", "test-name"); - sandbox.spec = Some(DriverSandboxSpec { - gpu: true, - gpu_device: "nvidia.com/gpu=0".to_string(), - ..Default::default() - }); - let config = test_config(); - let spec = build_container_spec(&sandbox, &config); - - assert_eq!( - spec["devices"][0]["path"].as_str(), - Some("nvidia.com/gpu=0") - ); - } - #[test] fn container_spec_includes_required_capabilities() { let sandbox = test_sandbox("test-id", "test-name"); @@ -805,7 +758,10 @@ mod tests { let mut sandbox = test_sandbox("test-id", "test-name"); sandbox.spec = Some(DriverSandboxSpec { placement: Some(PlacementRequirements { - gpu: Some(GpuSpec { device_id: vec![] }), + gpu: Some(GpuSpec { + device_id: vec![], + count: None, + }), }), ..Default::default() }); @@ -832,6 +788,7 @@ mod tests { "nvidia.com/gpu=0".to_string(), "nvidia.com/gpu=1".to_string(), ], + count: None, }), }), ..Default::default() diff --git a/crates/openshell-driver-podman/src/driver.rs b/crates/openshell-driver-podman/src/driver.rs index 07cff9cfa..7cd814529 100644 --- a/crates/openshell-driver-podman/src/driver.rs +++ b/crates/openshell-driver-podman/src/driver.rs @@ -207,6 +207,23 @@ impl PodmanComputeDriver { } fn validate_gpu_request(gpu: Option<&GpuSpec>) -> Result<(), ComputeDriverError> { + if let Some(gpu) = gpu { + if gpu.count == Some(0) { + return Err(ComputeDriverError::Precondition( + "gpu.count must be greater than 0".to_string(), + )); + } + if gpu.count.is_some() && !gpu.device_id.is_empty() { + return Err(ComputeDriverError::Precondition( + "gpu.count is mutually exclusive with gpu.device_id".to_string(), + )); + } + if gpu.count.is_some() { + return Err(ComputeDriverError::Precondition( + "podman compute driver does not support GPU count requests".to_string(), + )); + } + } if gpu.is_some() && !Self::has_gpu_capacity() { return Err(ComputeDriverError::Precondition( "GPU sandbox requested, but no NVIDIA GPU devices are available.".to_string(), @@ -227,6 +244,7 @@ impl PodmanComputeDriver { "sandbox id is required".into(), )); } + self.validate_sandbox_create(sandbox)?; // Validate the composed container name early, before creating any // resources (secret, volume), so we don't leave orphans when the @@ -579,6 +597,19 @@ mod tests { assert!(matches!(err, ComputeDriverError::Message(_))); } + #[test] + fn validate_gpu_request_rejects_count() { + let err = PodmanComputeDriver::validate_gpu_request(Some(&GpuSpec { + device_id: vec![], + count: Some(2), + })) + .expect_err("GPU count should be rejected"); + + assert!( + matches!(err, ComputeDriverError::Precondition(message) if message.contains("does not support GPU count")) + ); + } + // ── grpc_endpoint auto-detection ─────────────────────────────────── // // PodmanComputeDriver::new() fills grpc_endpoint when it is empty. diff --git a/crates/openshell-driver-vm/src/driver.rs b/crates/openshell-driver-vm/src/driver.rs index a635e2401..b5b579f81 100644 --- a/crates/openshell-driver-vm/src/driver.rs +++ b/crates/openshell-driver-vm/src/driver.rs @@ -1527,6 +1527,22 @@ fn validate_gpu_request(gpu: Option<&GpuSpec>, gpu_enabled: bool) -> Result<(), )); } + if let Some(gpu) = gpu { + if gpu.count == Some(0) { + return Err(Status::invalid_argument("gpu.count must be greater than 0")); + } + if gpu.count.is_some() && !gpu.device_id.is_empty() { + return Err(Status::invalid_argument( + "gpu.count is mutually exclusive with gpu.device_id", + )); + } + if gpu.count.is_some_and(|count| count > 1) { + return Err(Status::invalid_argument( + "vm compute driver supports at most one GPU", + )); + } + } + if gpu.is_some_and(|gpu| gpu.device_id.len() > 1) { return Err(Status::invalid_argument( "vm compute driver supports at most one GPU device ID", @@ -2553,7 +2569,10 @@ mod tests { id: "sandbox-123".to_string(), spec: Some(SandboxSpec { placement: Some(PlacementRequirements { - gpu: Some(GpuSpec { device_id: vec![] }), + gpu: Some(GpuSpec { + device_id: vec![], + count: None, + }), }), ..Default::default() }), @@ -2571,7 +2590,10 @@ mod tests { id: "sandbox-123".to_string(), spec: Some(SandboxSpec { placement: Some(PlacementRequirements { - gpu: Some(GpuSpec { device_id: vec![] }), + gpu: Some(GpuSpec { + device_id: vec![], + count: None, + }), }), ..Default::default() }), @@ -2580,6 +2602,66 @@ mod tests { validate_vm_sandbox(&sandbox, true).expect("gpu should be accepted when enabled"); } + #[test] + fn validate_vm_sandbox_accepts_gpu_count_one_when_enabled() { + let sandbox = Sandbox { + id: "sandbox-123".to_string(), + spec: Some(SandboxSpec { + placement: Some(PlacementRequirements { + gpu: Some(GpuSpec { + device_id: vec![], + count: Some(1), + }), + }), + ..Default::default() + }), + ..Default::default() + }; + validate_vm_sandbox(&sandbox, true).expect("gpu count one should be accepted"); + } + + #[test] + fn validate_vm_sandbox_rejects_gpu_count_greater_than_one() { + let sandbox = Sandbox { + id: "sandbox-123".to_string(), + spec: Some(SandboxSpec { + placement: Some(PlacementRequirements { + gpu: Some(GpuSpec { + device_id: vec![], + count: Some(2), + }), + }), + ..Default::default() + }), + ..Default::default() + }; + let err = + validate_vm_sandbox(&sandbox, true).expect_err("gpu count > 1 should be rejected"); + assert_eq!(err.code(), Code::InvalidArgument); + assert!(err.message().contains("at most one GPU")); + } + + #[test] + fn validate_vm_sandbox_rejects_gpu_count_with_device_id() { + let sandbox = Sandbox { + id: "sandbox-123".to_string(), + spec: Some(SandboxSpec { + placement: Some(PlacementRequirements { + gpu: Some(GpuSpec { + device_id: vec!["0000:2d:00.0".to_string()], + count: Some(1), + }), + }), + ..Default::default() + }), + ..Default::default() + }; + let err = validate_vm_sandbox(&sandbox, true) + .expect_err("gpu count with device ID should be rejected"); + assert_eq!(err.code(), Code::InvalidArgument); + assert!(err.message().contains("mutually exclusive")); + } + #[test] fn validate_vm_sandbox_rejects_multiple_gpu_device_ids() { let sandbox = Sandbox { @@ -2588,6 +2670,7 @@ mod tests { placement: Some(PlacementRequirements { gpu: Some(GpuSpec { device_id: vec!["0000:2d:00.0".to_string(), "0000:3d:00.0".to_string()], + count: None, }), }), ..Default::default() @@ -2607,7 +2690,10 @@ mod tests { #[test] fn requested_gpu_device_defaults_empty_request_to_inventory_choice() { - let gpu = GpuSpec { device_id: vec![] }; + let gpu = GpuSpec { + device_id: vec![], + count: None, + }; assert_eq!(requested_gpu_device(Some(&gpu)), Some("")); } @@ -2616,6 +2702,7 @@ mod tests { fn requested_gpu_device_returns_first_explicit_device_id() { let gpu = GpuSpec { device_id: vec!["0000:2d:00.0".to_string()], + count: None, }; assert_eq!(requested_gpu_device(Some(&gpu)), Some("0000:2d:00.0")); diff --git a/crates/openshell-server/src/compute/mod.rs b/crates/openshell-server/src/compute/mod.rs index ca49654c9..b9cc31aaa 100644 --- a/crates/openshell-server/src/compute/mod.rs +++ b/crates/openshell-server/src/compute/mod.rs @@ -1137,6 +1137,7 @@ fn driver_sandbox_spec_from_public(spec: &SandboxSpec) -> DriverSandboxSpec { .map(|placement| DriverPlacementRequirements { gpu: placement.gpu.as_ref().map(|gpu| DriverGpuSpec { device_id: gpu.device_id.clone(), + count: gpu.count, }), }), } @@ -1685,6 +1686,7 @@ mod tests { placement: Some(PlacementRequirements { gpu: Some(GpuSpec { device_id: vec!["nvidia.com/gpu=0".to_string()], + count: None, }), }), ..Default::default() @@ -1703,6 +1705,31 @@ mod tests { ); } + #[test] + fn driver_sandbox_spec_from_public_preserves_gpu_count() { + let public = SandboxSpec { + placement: Some(PlacementRequirements { + gpu: Some(GpuSpec { + device_id: vec![], + count: Some(2), + }), + }), + ..Default::default() + }; + + let driver = driver_sandbox_spec_from_public(&public); + + assert_eq!( + driver + .placement + .expect("driver placement requirements should be present") + .gpu + .expect("driver GPU request should be present") + .count, + Some(2) + ); + } + fn struct_value( fields: impl IntoIterator, prost_types::Value)>, ) -> prost_types::Value { @@ -2152,7 +2179,10 @@ mod tests { &mut status, Some(&SandboxSpec { placement: Some(PlacementRequirements { - gpu: Some(GpuSpec { device_id: vec![] }), + gpu: Some(GpuSpec { + device_id: vec![], + count: None, + }), }), ..Default::default() }), @@ -2413,7 +2443,10 @@ mod tests { let sandbox = Sandbox { spec: Some(SandboxSpec { placement: Some(PlacementRequirements { - gpu: Some(GpuSpec { device_id: vec![] }), + gpu: Some(GpuSpec { + device_id: vec![], + count: None, + }), }), ..Default::default() }), diff --git a/crates/openshell-server/src/grpc/validation.rs b/crates/openshell-server/src/grpc/validation.rs index 7b34a188b..0f891a7d0 100644 --- a/crates/openshell-server/src/grpc/validation.rs +++ b/crates/openshell-server/src/grpc/validation.rs @@ -131,6 +131,13 @@ pub(super) fn validate_sandbox_spec( validate_sandbox_template(tmpl)?; } + // --- spec.placement --- + if let Some(placement) = spec.placement.as_ref() + && let Some(gpu) = placement.gpu.as_ref() + { + validate_gpu_spec(gpu)?; + } + // --- spec.policy serialized size --- if let Some(ref policy) = spec.policy { let size = policy.encoded_len(); @@ -144,6 +151,20 @@ pub(super) fn validate_sandbox_spec( Ok(()) } +fn validate_gpu_spec(gpu: &openshell_core::proto::GpuSpec) -> Result<(), Status> { + if gpu.count == Some(0) { + return Err(Status::invalid_argument( + "placement.gpu.count must be greater than 0", + )); + } + if gpu.count.is_some() && !gpu.device_id.is_empty() { + return Err(Status::invalid_argument( + "placement.gpu.count is mutually exclusive with placement.gpu.device_id", + )); + } + Ok(()) +} + /// Validate template-level field sizes. fn validate_sandbox_template(tmpl: &SandboxTemplate) -> Result<(), Status> { // String fields. @@ -669,13 +690,62 @@ mod tests { fn validate_sandbox_spec_accepts_gpu_flag() { let spec = SandboxSpec { placement: Some(PlacementRequirements { - gpu: Some(GpuSpec { device_id: vec![] }), + gpu: Some(GpuSpec { + device_id: vec![], + count: None, + }), }), ..Default::default() }; assert!(validate_sandbox_spec("gpu-sandbox", &spec).is_ok()); } + #[test] + fn validate_sandbox_spec_accepts_gpu_count() { + let spec = SandboxSpec { + placement: Some(PlacementRequirements { + gpu: Some(GpuSpec { + device_id: vec![], + count: Some(2), + }), + }), + ..Default::default() + }; + assert!(validate_sandbox_spec("gpu-sandbox", &spec).is_ok()); + } + + #[test] + fn validate_sandbox_spec_rejects_zero_gpu_count() { + let spec = SandboxSpec { + placement: Some(PlacementRequirements { + gpu: Some(GpuSpec { + device_id: vec![], + count: Some(0), + }), + }), + ..Default::default() + }; + let err = validate_sandbox_spec("gpu-sandbox", &spec).unwrap_err(); + assert_eq!(err.code(), Code::InvalidArgument); + assert!(err.message().contains("count must be greater than 0")); + } + + #[test] + fn validate_sandbox_spec_rejects_gpu_count_with_device_ids() { + let spec = SandboxSpec { + placement: Some(PlacementRequirements { + gpu: Some(GpuSpec { + device_id: vec!["nvidia.com/gpu=0".to_string()], + count: Some(1), + }), + }), + ..Default::default() + }; + let err = validate_sandbox_spec("gpu-sandbox", &spec).unwrap_err(); + assert_eq!(err.code(), Code::InvalidArgument); + assert!(err.message().contains("mutually exclusive")); + } + #[test] fn validate_sandbox_spec_accepts_empty_defaults() { assert!(validate_sandbox_spec("", &default_spec()).is_ok()); diff --git a/docs/sandboxes/manage-sandboxes.mdx b/docs/sandboxes/manage-sandboxes.mdx index 3eddecec5..52df28aa6 100644 --- a/docs/sandboxes/manage-sandboxes.mdx +++ b/docs/sandboxes/manage-sandboxes.mdx @@ -35,6 +35,19 @@ To request GPU resources, add `--gpu`: openshell sandbox create --gpu -- claude ``` +To request a specific number of GPUs, use `--gpu-count`. GPU count requests are +mutually exclusive with explicit GPU device IDs, and the count must be greater +than zero. + +```shell +openshell sandbox create --gpu-count 2 -- claude +``` + +Kubernetes-backed sandboxes honor `--gpu-count` by setting the `nvidia.com/gpu` +resource limit to the requested count. VM-backed sandboxes accept only +`--gpu-count 1`. Docker-backed and Podman-backed sandboxes currently reject GPU +count requests. + For Docker-backed sandboxes, GPU injection uses Docker CDI. If you enable Docker CDI after the gateway starts, restart the gateway so OpenShell can detect the updated Docker daemon capability. diff --git a/proto/compute_driver.proto b/proto/compute_driver.proto index 5cee0cbb4..4524591d2 100644 --- a/proto/compute_driver.proto +++ b/proto/compute_driver.proto @@ -98,9 +98,11 @@ message PlacementRequirements { // Driver-native GPU placement details. message GPUSpec { - // Optional driver-native device identifiers. Empty means the driver chooses - // its default GPU assignment behavior. + // Optional driver-native device identifiers. Mutually exclusive with count. + // Empty means the driver chooses its default GPU assignment behavior. repeated string device_id = 1; + // Optional number of GPUs requested. Mutually exclusive with device_id. + optional uint32 count = 2; } // Driver-owned runtime template consumed by the compute platform. diff --git a/proto/openshell.proto b/proto/openshell.proto index 093521723..9454febbb 100644 --- a/proto/openshell.proto +++ b/proto/openshell.proto @@ -265,9 +265,11 @@ message PlacementRequirements { // Public GPU placement details. Device identifiers are interpreted by the // selected compute driver. message GPUSpec { - // Optional driver-native device identifiers. Empty means the driver chooses - // its default GPU assignment behavior. + // Optional driver-native device identifiers. Mutually exclusive with count. + // Empty means the driver chooses its default GPU assignment behavior. repeated string device_id = 1; + // Optional number of GPUs requested. Mutually exclusive with device_id. + optional uint32 count = 2; } // Public sandbox template mapped onto compute-driver template inputs.