diff --git a/integration-tests/tests/predict_json_path_list_input.txtar b/integration-tests/tests/predict_json_path_list_input.txtar new file mode 100644 index 0000000000..50918b0c7c --- /dev/null +++ b/integration-tests/tests/predict_json_path_list_input.txtar @@ -0,0 +1,28 @@ +# Test `--json` with list[Path] input + +cog build -t $TEST_IMAGE +cog predict $TEST_IMAGE --json '{"paths": ["@1.txt", "@2.txt"]}' +stdout '"status": "succeeded"' +stdout '"output": "test1test2"' + +-- cog.yaml -- +build: + python_version: "3.12" +predict: "predict.py:Predictor" + +-- predict.py -- +from cog import BasePredictor, Path + + +class Predictor(BasePredictor): + def predict(self, paths: list[Path]) -> str: + output_parts = [] + for path in paths: + with open(path) as f: + output_parts.append(f.read()) + return "".join(output_parts) + +-- 1.txt -- +test1 +-- 2.txt -- +test2 diff --git a/pkg/cli/predict.go b/pkg/cli/predict.go index a8c322215c..364453d9e4 100644 --- a/pkg/cli/predict.go +++ b/pkg/cli/predict.go @@ -159,33 +159,59 @@ func transformPathsToBase64URLs(inputs map[string]any) (map[string]any, error) { result := make(map[string]any) for key, value := range inputs { - if strValue, ok := value.(string); ok && strings.HasPrefix(strValue, "@") { - // This is a file path, convert to base64 data URL - filePath := strValue[1:] + transformed, err := transformJSONValuePathsToBase64URLs(value) + if err != nil { + return nil, err + } + result[key] = transformed + } - // Read file - data, err := os.ReadFile(filePath) - if err != nil { - return nil, fmt.Errorf("Failed to read file %q: %w", filePath, err) - } + return result, nil +} - // Get MIME type - mimeType := mime.TypeByExtension(filepath.Ext(filePath)) - if mimeType == "" { - mimeType = "application/octet-stream" - } +func transformJSONValuePathsToBase64URLs(value any) (any, error) { + switch v := value.(type) { + case string: + if !strings.HasPrefix(v, "@") { + return v, nil + } - // Create base64 data URL - base64Data := base64.StdEncoding.EncodeToString(data) - dataURL := fmt.Sprintf("data:%s;base64,%s", mimeType, base64Data) + filePath := v[1:] + data, err := os.ReadFile(filePath) + if err != nil { + return nil, fmt.Errorf("values starting with '@' are treated as file paths; failed to read %q: %w", filePath, err) + } - result[key] = dataURL - } else { - result[key] = value + mimeType := mime.TypeByExtension(filepath.Ext(filePath)) + if mimeType == "" { + mimeType = "application/octet-stream" } - } - return result, nil + base64Data := base64.StdEncoding.EncodeToString(data) + return fmt.Sprintf("data:%s;base64,%s", mimeType, base64Data), nil + case []any: + out := make([]any, len(v)) + for i, item := range v { + transformed, err := transformJSONValuePathsToBase64URLs(item) + if err != nil { + return nil, err + } + out[i] = transformed + } + return out, nil + case map[string]any: + out := make(map[string]any, len(v)) + for key, item := range v { + transformed, err := transformJSONValuePathsToBase64URLs(item) + if err != nil { + return nil, err + } + out[key] = transformed + } + return out, nil + default: + return value, nil + } } func cmdPredict(cmd *cobra.Command, args []string) error { diff --git a/pkg/cli/predict_test.go b/pkg/cli/predict_test.go index f23c8a6f37..b4dd27311e 100644 --- a/pkg/cli/predict_test.go +++ b/pkg/cli/predict_test.go @@ -1,6 +1,8 @@ package cli import ( + "os" + "path/filepath" "testing" "github.com/getkin/kin-openapi/openapi3" @@ -169,3 +171,44 @@ func TestExtractOutputSchemaFromValidSchema(t *testing.T) { require.NotNil(t, outputSchema, "expected non-nil output schema for valid input") require.Contains(t, outputSchema.Type.Slice(), "string", "expected string type") } + +func TestTransformPathsToBase64URLsRecursesIntoNestedJSON(t *testing.T) { + dir := t.TempDir() + fileA := filepath.Join(dir, "a.txt") + fileB := filepath.Join(dir, "b.txt") + fileC := filepath.Join(dir, "c.txt") + require.NoError(t, os.WriteFile(fileA, []byte("alpha"), 0o644)) + require.NoError(t, os.WriteFile(fileB, []byte("beta"), 0o644)) + require.NoError(t, os.WriteFile(fileC, []byte("gamma"), 0o644)) + + inputs := map[string]any{ + "single": "@" + fileA, + "files": []any{ + "@" + fileB, + map[string]any{ + "inner": "@" + fileC, + }, + }, + "count": float64(3), + "plain": "hello", + } + + transformed, err := transformPathsToBase64URLs(inputs) + require.NoError(t, err) + + require.IsType(t, "", transformed["single"]) + require.Contains(t, transformed["single"].(string), "data:text/plain;base64,") + + files, ok := transformed["files"].([]any) + require.True(t, ok) + require.IsType(t, "", files[0]) + require.Contains(t, files[0].(string), "data:text/plain;base64,") + + innerObj, ok := files[1].(map[string]any) + require.True(t, ok) + require.IsType(t, "", innerObj["inner"]) + require.Contains(t, innerObj["inner"].(string), "data:text/plain;base64,") + + require.Equal(t, float64(3), transformed["count"]) + require.Equal(t, "hello", transformed["plain"]) +}