Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 54 additions & 1 deletion crates/openshell-cli/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1029,6 +1029,7 @@ enum DoctorCommands {
Check,
}

#[allow(clippy::large_enum_variant)]
#[derive(Subcommand, Debug)]
enum SandboxCommands {
/// Create a sandbox.
Expand Down Expand Up @@ -1087,9 +1088,13 @@ enum SandboxCommands {
/// Target a driver-specific GPU device. Docker and Podman use CDI device IDs
/// (for example "nvidia.com/gpu=0"); VM uses a PCI BDF or index.
/// Only valid with --gpu. When omitted with --gpu, the driver uses its default GPU selection.
#[arg(long, requires = "gpu")]
#[arg(long, requires = "gpu", conflicts_with = "gpu_count")]
gpu_device: Option<String>,

/// Request a specific number of GPUs. Mutually exclusive with --gpu-device.
#[arg(long, value_parser = clap::value_parser!(u32).range(1..), conflicts_with = "gpu_device")]
gpu_count: Option<u32>,

/// Provider names to attach to this sandbox.
#[arg(long = "provider")]
providers: Vec<String>,
Expand Down Expand Up @@ -2365,6 +2370,7 @@ async fn main() -> Result<()> {
editor,
gpu,
gpu_device,
gpu_count,
providers,
policy,
forward,
Expand Down Expand Up @@ -2431,6 +2437,7 @@ async fn main() -> Result<()> {
keep,
gpu,
gpu_device.as_deref(),
gpu_count,
editor,
&providers,
policy.as_deref(),
Expand Down Expand Up @@ -3641,6 +3648,52 @@ mod tests {
}
}

#[test]
fn sandbox_create_gpu_count_parses_without_gpu_flag() {
let cli = Cli::try_parse_from(["openshell", "sandbox", "create", "--gpu-count", "2"])
.expect("sandbox create --gpu-count should parse");

if let Some(Commands::Sandbox {
command: Some(SandboxCommands::Create { gpu, gpu_count, .. }),
..
}) = cli.command
{
assert!(!gpu);
assert_eq!(gpu_count, Some(2));
} else {
panic!("expected SandboxCommands::Create");
}
}

#[test]
fn sandbox_create_gpu_count_rejects_zero() {
let result = Cli::try_parse_from(["openshell", "sandbox", "create", "--gpu-count", "0"]);

assert!(
result.is_err(),
"sandbox create --gpu-count 0 should be rejected"
);
}

#[test]
fn sandbox_create_gpu_count_conflicts_with_gpu_device() {
let result = Cli::try_parse_from([
"openshell",
"sandbox",
"create",
"--gpu",
"--gpu-device",
"0",
"--gpu-count",
"2",
]);

assert!(
result.is_err(),
"sandbox create should reject --gpu-count with --gpu-device"
);
}

#[test]
fn service_expose_accepts_positional_target_port_and_service() {
let cli = Cli::try_parse_from([
Expand Down
81 changes: 72 additions & 9 deletions crates/openshell-cli/src/run.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,11 @@ use openshell_core::proto::{
GetClusterInferenceRequest, GetDraftHistoryRequest, GetDraftPolicyRequest,
GetGatewayConfigRequest, GetProviderProfileRequest, GetProviderRequest,
GetSandboxConfigRequest, GetSandboxLogsRequest, GetSandboxPolicyStatusRequest,
GetSandboxRequest, GetServiceRequest, HealthRequest, ImportProviderProfilesRequest,
GetSandboxRequest, GetServiceRequest, GpuSpec, HealthRequest, ImportProviderProfilesRequest,
LintProviderProfilesRequest, ListProviderProfilesRequest, ListProvidersRequest,
ListSandboxPoliciesRequest, ListSandboxProvidersRequest, ListSandboxesRequest,
ListServicesRequest, PolicySource, PolicyStatus, Provider, ProviderProfile,
ProviderProfileDiagnostic, ProviderProfileImportItem, RejectDraftChunkRequest,
ListServicesRequest, PlacementRequirements, PolicySource, PolicyStatus, Provider,
ProviderProfile, ProviderProfileDiagnostic, ProviderProfileImportItem, RejectDraftChunkRequest,
RevokeSshSessionRequest, Sandbox, SandboxPhase, SandboxPolicy, SandboxSpec, SandboxTemplate,
ServiceEndpointResponse, SetClusterInferenceRequest, SettingScope, SettingValue,
TcpForwardFrame, TcpForwardInit, TcpRelayTarget, UpdateConfigRequest, UpdateProviderRequest,
Expand Down Expand Up @@ -1468,6 +1468,7 @@ pub async fn sandbox_create(
keep: bool,
gpu: bool,
gpu_device: Option<&str>,
gpu_count: Option<u32>,
editor: Option<Editor>,
providers: &[String],
policy: Option<&str>,
Expand Down Expand Up @@ -1518,7 +1519,8 @@ pub async fn sandbox_create(
}
None => None,
};
let requested_gpu = gpu || image.as_deref().is_some_and(image_requests_gpu);
let requested_gpu =
gpu || gpu_count.is_some() || image.as_deref().is_some_and(image_requests_gpu);

let inferred_types: Vec<String> = inferred_provider_type(command).into_iter().collect();
let configured_providers = ensure_required_providers(
Expand All @@ -1538,8 +1540,7 @@ pub async fn sandbox_create(

let request = CreateSandboxRequest {
spec: Some(SandboxSpec {
gpu: requested_gpu,
gpu_device: gpu_device.unwrap_or_default().to_string(),
placement: placement_requirements_from_cli(requested_gpu, gpu_device, gpu_count),
policy,
providers: configured_providers,
template,
Expand Down Expand Up @@ -1971,6 +1972,27 @@ pub async fn sandbox_create(
}
}

fn placement_requirements_from_cli(
requested_gpu: bool,
gpu_device: Option<&str>,
gpu_count: Option<u32>,
) -> Option<PlacementRequirements> {
let requested_gpu = requested_gpu || gpu_count.is_some();
requested_gpu.then(|| PlacementRequirements {
gpu: Some(GpuSpec {
device_id: if gpu_count.is_none() {
gpu_device
.filter(|device_id| !device_id.is_empty())
.map(|device_id| vec![device_id.to_string()])
.unwrap_or_default()
} else {
Vec::new()
},
count: gpu_count,
}),
})
}

/// Resolved source for the `--from` flag on `sandbox create`.
#[derive(Debug)]
enum ResolvedSource {
Expand Down Expand Up @@ -6017,9 +6039,9 @@ mod tests {
gateway_env_override_warning, gateway_select_with, gateway_type_label, git_sync_files,
http_health_check, image_requests_gpu, import_local_package_mtls_bundle,
inferred_provider_type, package_managed_tls_dirs, parse_cli_setting_value,
parse_credential_pairs, plaintext_gateway_is_remote, provisioning_timeout_message,
ready_false_condition_message, resolve_from, sandbox_should_persist,
service_expose_status_error, service_url_for_gateway,
parse_credential_pairs, placement_requirements_from_cli, plaintext_gateway_is_remote,
provisioning_timeout_message, ready_false_condition_message, resolve_from,
sandbox_should_persist, service_expose_status_error, service_url_for_gateway,
};
use crate::TEST_ENV_LOCK;
use hyper::StatusCode;
Expand Down Expand Up @@ -6296,6 +6318,47 @@ mod tests {
}
}

#[test]
fn image_requests_gpu_detects_known_community_gpu_name() {
assert!(image_requests_gpu("nvidia-gpu"));
assert!(!image_requests_gpu("base"));
}

#[test]
fn placement_requirements_from_cli_uses_presence_with_empty_device_ids_for_default_gpu() {
let request = placement_requirements_from_cli(true, None, None)
.expect("placement requirements should be present");
let gpu = request.gpu.expect("gpu request should be present");

assert!(gpu.device_id.is_empty());
assert_eq!(gpu.count, None);
}

#[test]
fn placement_requirements_from_cli_maps_gpu_device_to_one_device_id() {
let request = placement_requirements_from_cli(true, Some("0000:2d:00.0"), None)
.expect("placement requirements should be present");
let gpu = request.gpu.expect("gpu request should be present");

assert_eq!(gpu.device_id, vec!["0000:2d:00.0"]);
assert_eq!(gpu.count, None);
}

#[test]
fn placement_requirements_from_cli_maps_gpu_count() {
let request = placement_requirements_from_cli(false, None, Some(2))
.expect("placement requirements should be present");
let gpu = request.gpu.expect("gpu request should be present");

assert!(gpu.device_id.is_empty());
assert_eq!(gpu.count, Some(2));
}

#[test]
fn placement_requirements_from_cli_omits_placement_when_not_requested() {
assert!(placement_requirements_from_cli(false, Some("0"), None).is_none());
}

#[test]
fn resolve_from_classifies_existing_dockerfile_path() {
let temp = tempfile::tempdir().expect("failed to create tempdir");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -671,6 +671,7 @@ async fn sandbox_create_keeps_command_sessions_by_default() {
false,
None,
None,
None,
&[],
None,
None,
Expand Down Expand Up @@ -710,6 +711,7 @@ async fn sandbox_create_deletes_command_sessions_with_no_keep() {
false,
None,
None,
None,
&[],
None,
None,
Expand Down Expand Up @@ -752,6 +754,7 @@ async fn sandbox_create_deletes_shell_sessions_with_no_keep() {
false,
None,
None,
None,
&[],
None,
None,
Expand Down Expand Up @@ -794,6 +797,7 @@ async fn sandbox_create_keeps_sandbox_with_hidden_keep_flag() {
false,
None,
None,
None,
&[],
None,
None,
Expand Down Expand Up @@ -836,6 +840,7 @@ async fn sandbox_create_keeps_sandbox_with_forwarding() {
false,
None,
None,
None,
&[],
None,
Some(openshell_core::forward::ForwardSpec::new(forward_port)),
Expand Down
59 changes: 43 additions & 16 deletions crates/openshell-core/src/gpu.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,21 +4,19 @@
//! Shared GPU request helpers.

use crate::config::CDI_GPU_DEVICE_ALL;
use crate::proto::compute::v1::GpuSpec;

/// Resolve the existing GPU request fields into CDI device identifiers.
/// Resolve a driver GPU request into CDI device identifiers.
///
/// `None` means no GPU was requested. A GPU request with no explicit device
/// ID uses the CDI all-GPU request; otherwise the driver-native ID passes
/// through unchanged.
/// `None` means no GPU was requested. Presence with no explicit device IDs
/// uses the CDI all-GPU request; otherwise the driver-native IDs pass through.
#[must_use]
pub fn cdi_gpu_device_ids(gpu: bool, gpu_device: &str) -> Option<Vec<String>> {
gpu.then(|| {
if gpu_device.is_empty() {
vec![CDI_GPU_DEVICE_ALL.to_string()]
} else {
vec![gpu_device.to_string()]
}
})
pub fn cdi_gpu_device_ids(gpu: Option<&GpuSpec>) -> Option<Vec<String>> {
match gpu {
Some(gpu) if gpu.device_id.is_empty() => Some(vec![CDI_GPU_DEVICE_ALL.to_string()]),
Some(gpu) => Some(gpu.device_id.clone()),
None => None,
}
}

#[cfg(test)]
Expand All @@ -27,22 +25,51 @@ mod tests {

#[test]
fn cdi_gpu_device_ids_returns_none_when_absent() {
assert_eq!(cdi_gpu_device_ids(false, ""), None);
assert_eq!(cdi_gpu_device_ids(None), None);
}

#[test]
fn cdi_gpu_device_ids_defaults_empty_request_to_all_gpus() {
let request = GpuSpec {
device_id: vec![],
count: None,
};

assert_eq!(
cdi_gpu_device_ids(true, ""),
cdi_gpu_device_ids(Some(&request)),
Some(vec![CDI_GPU_DEVICE_ALL.to_string()])
);
}

#[test]
fn cdi_gpu_device_ids_passes_explicit_device_id_through() {
fn cdi_gpu_device_ids_passes_single_device_id_through() {
let request = GpuSpec {
device_id: vec!["nvidia.com/gpu=0".to_string()],
count: None,
};

assert_eq!(
cdi_gpu_device_ids(true, "nvidia.com/gpu=0"),
cdi_gpu_device_ids(Some(&request)),
Some(vec!["nvidia.com/gpu=0".to_string()])
);
}

#[test]
fn cdi_gpu_device_ids_passes_multiple_device_ids_through() {
let request = GpuSpec {
device_id: vec![
"nvidia.com/gpu=0".to_string(),
"nvidia.com/gpu=1".to_string(),
],
count: None,
};

assert_eq!(
cdi_gpu_device_ids(Some(&request)),
Some(vec![
"nvidia.com/gpu=0".to_string(),
"nvidia.com/gpu=1".to_string()
])
);
}
}
Loading
Loading