-
Notifications
You must be signed in to change notification settings - Fork 681
feat(sandbox): add typed resource spec #1340
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -36,7 +36,7 @@ use openshell_core::proto::{ | |
| LintProviderProfilesRequest, ListProviderProfilesRequest, ListProvidersRequest, | ||
| ListSandboxPoliciesRequest, ListSandboxProvidersRequest, ListSandboxesRequest, | ||
| ListServicesRequest, PolicySource, PolicyStatus, Provider, ProviderProfile, | ||
| ProviderProfileDiagnostic, ProviderProfileImportItem, RejectDraftChunkRequest, | ||
| ProviderProfileDiagnostic, ProviderProfileImportItem, RejectDraftChunkRequest, ResourceSpec, | ||
| RevokeSshSessionRequest, Sandbox, SandboxPhase, SandboxPolicy, SandboxSpec, SandboxTemplate, | ||
| ServiceEndpointResponse, SetClusterInferenceRequest, SettingScope, SettingValue, | ||
| TcpForwardFrame, TcpForwardInit, TcpRelayTarget, UpdateConfigRequest, UpdateProviderRequest, | ||
|
|
@@ -1457,6 +1457,54 @@ async fn finalize_sandbox_create_session( | |
| session_result | ||
| } | ||
|
|
||
| #[derive(Clone, Copy, Debug, Default)] | ||
| pub struct SandboxResourceArgs<'a> { | ||
| pub cpu_request: Option<&'a str>, | ||
| pub cpu_limit: Option<&'a str>, | ||
| pub memory_request: Option<&'a str>, | ||
| pub memory_limit: Option<&'a str>, | ||
| pub gpu_count: Option<u32>, | ||
| pub driver_config: &'a [String], | ||
| } | ||
|
|
||
| fn build_resource_spec(args: SandboxResourceArgs<'_>) -> Result<Option<ResourceSpec>> { | ||
| let mut resources = ResourceSpec { | ||
| driver_config: parse_key_value_pairs(args.driver_config, "--resource-config")?, | ||
| ..ResourceSpec::default() | ||
| }; | ||
|
|
||
| if let Some(value) = args.cpu_request { | ||
| resources.cpu_request = value.to_string(); | ||
| } | ||
| if let Some(value) = args.cpu_limit { | ||
| resources.cpu_limit = value.to_string(); | ||
| } | ||
| if let Some(value) = args.memory_request { | ||
| resources.memory_request = value.to_string(); | ||
| } | ||
| if let Some(value) = args.memory_limit { | ||
| resources.memory_limit = value.to_string(); | ||
| } | ||
| if let Some(count) = args.gpu_count { | ||
| if count == 0 { | ||
| return Err(miette::miette!("--gpu-count must be greater than zero")); | ||
| } | ||
| resources.gpu_count = count; | ||
| } | ||
|
|
||
| if resources.cpu_request.is_empty() | ||
| && resources.cpu_limit.is_empty() | ||
| && resources.memory_request.is_empty() | ||
| && resources.memory_limit.is_empty() | ||
| && resources.gpu_count == 0 | ||
| && resources.driver_config.is_empty() | ||
| { | ||
| Ok(None) | ||
| } else { | ||
| Ok(Some(resources)) | ||
| } | ||
| } | ||
|
|
||
| /// Create a sandbox with default settings. | ||
| #[allow(clippy::too_many_arguments, clippy::implicit_hasher)] // user-facing CLI command; default hasher is fine | ||
| pub async fn sandbox_create( | ||
|
|
@@ -1467,6 +1515,7 @@ pub async fn sandbox_create( | |
| upload: Option<&(String, Option<String>, bool)>, | ||
| keep: bool, | ||
| gpu: bool, | ||
| resources: SandboxResourceArgs<'_>, | ||
| gpu_device: Option<&str>, | ||
|
Comment on lines
1517
to
1519
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. My thinking was that we would add a
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Works for me. Thank you. Updates to PR incoming. |
||
| editor: Option<Editor>, | ||
| providers: &[String], | ||
|
|
@@ -1483,6 +1532,12 @@ pub async fn sandbox_create( | |
| "--editor cannot be used with a trailing command; use `openshell sandbox connect <name> --editor ...` after the sandbox is ready" | ||
| )); | ||
| } | ||
| let resources = build_resource_spec(resources)?; | ||
| if resources.as_ref().is_some_and(|spec| spec.gpu_count > 0) && gpu_device.is_some() { | ||
| return Err(miette::miette!( | ||
| "--gpu-count cannot be combined with --gpu-device" | ||
| )); | ||
| } | ||
|
|
||
| // Check port availability *before* creating the sandbox so we don't | ||
| // leave an orphaned sandbox behind when the forward would fail. | ||
|
|
@@ -1518,7 +1573,9 @@ pub async fn sandbox_create( | |
| } | ||
| None => None, | ||
| }; | ||
| let requested_gpu = gpu || image.as_deref().is_some_and(image_requests_gpu); | ||
| let requested_gpu = gpu | ||
| || resources.as_ref().is_some_and(|spec| spec.gpu_count > 0) | ||
| || image.as_deref().is_some_and(image_requests_gpu); | ||
|
|
||
| let inferred_types: Vec<String> = inferred_provider_type(command).into_iter().collect(); | ||
| let configured_providers = ensure_required_providers( | ||
|
|
@@ -1531,17 +1588,22 @@ pub async fn sandbox_create( | |
|
|
||
| let policy = load_sandbox_policy(policy)?; | ||
|
|
||
| let template = image.map(|img| SandboxTemplate { | ||
| image: img, | ||
| ..SandboxTemplate::default() | ||
| }); | ||
| let template = if image.is_some() { | ||
| Some(SandboxTemplate { | ||
| image: image.unwrap_or_default(), | ||
| ..SandboxTemplate::default() | ||
| }) | ||
| } else { | ||
| None | ||
| }; | ||
|
|
||
| let request = CreateSandboxRequest { | ||
| spec: Some(SandboxSpec { | ||
| gpu: requested_gpu, | ||
| gpu_device: gpu_device.unwrap_or_default().to_string(), | ||
| policy, | ||
| providers: configured_providers, | ||
| resources, | ||
| template, | ||
| ..SandboxSpec::default() | ||
| }), | ||
|
|
@@ -6011,11 +6073,11 @@ fn format_timestamp_ms(ms: i64) -> String { | |
| #[cfg(test)] | ||
| mod tests { | ||
| use super::{ | ||
| TlsOptions, dockerfile_sources_supported_for_gateway, format_endpoint, | ||
| format_gateway_select_header, format_gateway_select_items, | ||
| format_provider_attachment_table, gateway_add, gateway_auth_label, | ||
| gateway_env_override_warning, gateway_select_with, gateway_type_label, git_sync_files, | ||
| http_health_check, image_requests_gpu, import_local_package_mtls_bundle, | ||
| SandboxResourceArgs, TlsOptions, build_resource_spec, | ||
| dockerfile_sources_supported_for_gateway, format_endpoint, format_gateway_select_header, | ||
| format_gateway_select_items, format_provider_attachment_table, gateway_add, | ||
| gateway_auth_label, gateway_env_override_warning, gateway_select_with, gateway_type_label, | ||
| git_sync_files, http_health_check, image_requests_gpu, import_local_package_mtls_bundle, | ||
| inferred_provider_type, package_managed_tls_dirs, parse_cli_setting_value, | ||
| parse_credential_pairs, plaintext_gateway_is_remote, provisioning_timeout_message, | ||
| ready_false_condition_message, resolve_from, sandbox_should_persist, | ||
|
|
@@ -6037,6 +6099,56 @@ mod tests { | |
| Provider, SandboxCondition, SandboxStatus, datamodel::v1::ObjectMeta, | ||
| }; | ||
|
|
||
| #[test] | ||
| fn build_resource_spec_sets_typed_fields_and_driver_config() { | ||
| let driver_config = vec!["kubernetes.resource-name=nvidia.com/mig-1g.5gb".to_string()]; | ||
| let resources = build_resource_spec(SandboxResourceArgs { | ||
| cpu_request: Some("2"), | ||
| cpu_limit: Some("4"), | ||
| memory_request: Some("8Gi"), | ||
| memory_limit: Some("16Gi"), | ||
| gpu_count: Some(4), | ||
| driver_config: &driver_config, | ||
| }) | ||
| .expect("resources should parse") | ||
| .expect("resource spec should be present"); | ||
|
|
||
| assert_eq!(resources.cpu_request, "2"); | ||
| assert_eq!(resources.cpu_limit, "4"); | ||
| assert_eq!(resources.memory_request, "8Gi"); | ||
| assert_eq!(resources.memory_limit, "16Gi"); | ||
| assert_eq!(resources.gpu_count, 4); | ||
| assert_eq!( | ||
| resources.driver_config.get("kubernetes.resource-name"), | ||
| Some(&"nvidia.com/mig-1g.5gb".to_string()) | ||
| ); | ||
| } | ||
|
|
||
| #[test] | ||
| fn build_resource_spec_rejects_zero_gpu_count() { | ||
| let err = build_resource_spec(SandboxResourceArgs { | ||
| gpu_count: Some(0), | ||
| ..SandboxResourceArgs::default() | ||
| }) | ||
| .expect_err("zero GPU count should fail"); | ||
|
|
||
| assert!(err.to_string().contains("--gpu-count")); | ||
| assert!(err.to_string().contains("greater than zero")); | ||
| } | ||
|
|
||
| #[test] | ||
| fn build_resource_spec_rejects_invalid_driver_config() { | ||
| let driver_config = vec!["missing-separator".to_string()]; | ||
| let err = build_resource_spec(SandboxResourceArgs { | ||
| driver_config: &driver_config, | ||
| ..SandboxResourceArgs::default() | ||
| }) | ||
| .expect_err("invalid driver config should fail"); | ||
|
|
||
| assert!(err.to_string().contains("--resource-config")); | ||
| assert!(err.to_string().contains("KEY=VALUE")); | ||
| } | ||
|
|
||
| struct EnvVarGuard { | ||
| key: &'static str, | ||
| original: Option<String>, | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Although this PR is said to resolve #1338 (gpu counts for k8s), doing so through a generic
resources-jsonflag doesn't do map well to other drivers. Would it make sense to separate the handling of--gpus-count(also added in #1156) from a more generic option such as this. If I understood correctly, this was added to ALSO set the CPU and MEMORY resource requests and pulling this into its own issue / PR with clearer expectations across the different drivers may make things easier to reason about.At the very least we should define the behaviour when a user specifies a
--resources-jsonbut the driver handling the sandbox creation doesn't support it.