diff --git a/cmd/driver/main.go b/cmd/driver/main.go index ef9e8e9..8fa37d2 100644 --- a/cmd/driver/main.go +++ b/cmd/driver/main.go @@ -40,6 +40,8 @@ func main() { "Secret name containing ca.crt for sandbox TLS verification (OPENSHELL_TLS_CA)") flag.StringVar(&cfg.TLSClientSecret, "tls-client-secret", cfg.TLSClientSecret, "Secret name containing tls.crt and tls.key for sandbox mTLS client auth") + flag.StringVar(&cfg.ImagePullPolicy, "sandbox-image-pull-policy", cfg.ImagePullPolicy, + "Image pull policy for sandbox pod containers (Always, IfNotPresent, Never); empty uses K8s default") flag.Parse() if cfg.Tenant == "" { diff --git a/internal/driver/config.go b/internal/driver/config.go index 669de4b..e4714e0 100644 --- a/internal/driver/config.go +++ b/internal/driver/config.go @@ -10,6 +10,7 @@ type Config struct { GatewayEndpoint string TLSCASecret string // Secret name containing ca.crt for gateway TLS verification TLSClientSecret string // Secret name containing tls.crt and tls.key for mTLS client auth + ImagePullPolicy string // Policy for sandbox pod containers (Always, IfNotPresent, Never); empty means K8s default } func DefaultConfig() Config { diff --git a/internal/driver/provisioner.go b/internal/driver/provisioner.go index 3e47116..30e8739 100644 --- a/internal/driver/provisioner.go +++ b/internal/driver/provisioner.go @@ -248,6 +248,9 @@ func (p *K8sProvisioner) buildSandboxSpec(sb *pb.DriverSandbox) map[string]inter }, }, } + if p.cfg.ImagePullPolicy != "" { + initContainer["imagePullPolicy"] = p.cfg.ImagePullPolicy + } // Agent container runs the supervisor and mounts it read-only. agentVolumeMounts := []interface{}{ @@ -286,6 +289,9 @@ func (p *K8sProvisioner) buildSandboxSpec(sb *pb.DriverSandbox) map[string]inter }, "volumeMounts": agentVolumeMounts, } + if p.cfg.ImagePullPolicy != "" { + container["imagePullPolicy"] = p.cfg.ImagePullPolicy + } if res := tmpl.GetResources(); res != nil { container["resources"] = buildResources(res, spec.GetGpu()) diff --git a/internal/driver/provisioner_test.go b/internal/driver/provisioner_test.go index fa1f8ba..d3691b2 100644 --- a/internal/driver/provisioner_test.go +++ b/internal/driver/provisioner_test.go @@ -379,3 +379,83 @@ func TestK8sProvisioner_Watch_ChannelCloses(t *testing.T) { for range ch { } } + +func TestBuildSandboxSpec_ImagePullPolicy(t *testing.T) { + cfg := testConfig() + cfg.ImagePullPolicy = "IfNotPresent" + + logger := testLogger() + scheme := runtime.NewScheme() + dynClient := dynamicfake.NewSimpleDynamicClientWithCustomListKinds( + scheme, + map[schema.GroupVersionResource]string{sandboxGVR: "SandboxList"}, + ) + clientset := kubefake.NewSimpleClientset() + p := NewK8sProvisioner(dynClient, clientset, cfg, logger) + + sb := &pb.DriverSandbox{ + Id: "sb-pull", + Spec: &pb.DriverSandboxSpec{ + Template: &pb.DriverSandboxTemplate{ + Image: "agent:latest", + }, + }, + } + + spec := p.buildSandboxSpec(sb) + podTemplate := spec["podTemplate"].(map[string]interface{}) + podSpec := podTemplate["spec"].(map[string]interface{}) + + // Verify init container has imagePullPolicy set. + initContainers := podSpec["initContainers"].([]interface{}) + initC := initContainers[0].(map[string]interface{}) + if initC["imagePullPolicy"] != "IfNotPresent" { + t.Errorf("expected init container imagePullPolicy=IfNotPresent, got %v", initC["imagePullPolicy"]) + } + + // Verify agent container has imagePullPolicy set. + containers := podSpec["containers"].([]interface{}) + agentC := containers[0].(map[string]interface{}) + if agentC["imagePullPolicy"] != "IfNotPresent" { + t.Errorf("expected agent container imagePullPolicy=IfNotPresent, got %v", agentC["imagePullPolicy"]) + } +} + +func TestBuildSandboxSpec_ImagePullPolicy_Empty(t *testing.T) { + cfg := testConfig() + // ImagePullPolicy left empty — should not appear in spec. + + logger := testLogger() + scheme := runtime.NewScheme() + dynClient := dynamicfake.NewSimpleDynamicClientWithCustomListKinds( + scheme, + map[schema.GroupVersionResource]string{sandboxGVR: "SandboxList"}, + ) + clientset := kubefake.NewSimpleClientset() + p := NewK8sProvisioner(dynClient, clientset, cfg, logger) + + sb := &pb.DriverSandbox{ + Id: "sb-nopull", + Spec: &pb.DriverSandboxSpec{ + Template: &pb.DriverSandboxTemplate{ + Image: "agent:latest", + }, + }, + } + + spec := p.buildSandboxSpec(sb) + podTemplate := spec["podTemplate"].(map[string]interface{}) + podSpec := podTemplate["spec"].(map[string]interface{}) + + initContainers := podSpec["initContainers"].([]interface{}) + initC := initContainers[0].(map[string]interface{}) + if _, ok := initC["imagePullPolicy"]; ok { + t.Error("expected no imagePullPolicy on init container when config is empty") + } + + containers := podSpec["containers"].([]interface{}) + agentC := containers[0].(map[string]interface{}) + if _, ok := agentC["imagePullPolicy"]; ok { + t.Error("expected no imagePullPolicy on agent container when config is empty") + } +}