Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions integration-tests/tests/predict_json_path_list_input.txtar
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
# Test `--json` with list[Path] input

cog build -t $TEST_IMAGE
cog predict $TEST_IMAGE --json '{"paths": ["@1.txt", "@2.txt"]}'
stdout '"status": "succeeded"'
stdout '"output": "test1test2"'

-- cog.yaml --
build:
python_version: "3.12"
predict: "predict.py:Predictor"

-- predict.py --
from cog import BasePredictor, Path


class Predictor(BasePredictor):
def predict(self, paths: list[Path]) -> str:
output_parts = []
for path in paths:
with open(path) as f:
output_parts.append(f.read())
return "".join(output_parts)

-- 1.txt --
test1
-- 2.txt --
test2
68 changes: 47 additions & 21 deletions pkg/cli/predict.go
Original file line number Diff line number Diff line change
Expand Up @@ -159,33 +159,59 @@ func transformPathsToBase64URLs(inputs map[string]any) (map[string]any, error) {
result := make(map[string]any)

for key, value := range inputs {
if strValue, ok := value.(string); ok && strings.HasPrefix(strValue, "@") {
// This is a file path, convert to base64 data URL
filePath := strValue[1:]
transformed, err := transformJSONValuePathsToBase64URLs(value)
if err != nil {
return nil, err
}
result[key] = transformed
}

// Read file
data, err := os.ReadFile(filePath)
if err != nil {
return nil, fmt.Errorf("Failed to read file %q: %w", filePath, err)
}
return result, nil
}

// Get MIME type
mimeType := mime.TypeByExtension(filepath.Ext(filePath))
if mimeType == "" {
mimeType = "application/octet-stream"
}
func transformJSONValuePathsToBase64URLs(value any) (any, error) {
Comment thread
immanuwell marked this conversation as resolved.
switch v := value.(type) {
case string:
if !strings.HasPrefix(v, "@") {
return v, nil
}

// Create base64 data URL
base64Data := base64.StdEncoding.EncodeToString(data)
dataURL := fmt.Sprintf("data:%s;base64,%s", mimeType, base64Data)
filePath := v[1:]
data, err := os.ReadFile(filePath)
if err != nil {
return nil, fmt.Errorf("values starting with '@' are treated as file paths; failed to read %q: %w", filePath, err)
}

result[key] = dataURL
} else {
result[key] = value
mimeType := mime.TypeByExtension(filepath.Ext(filePath))
if mimeType == "" {
mimeType = "application/octet-stream"
}
}

return result, nil
base64Data := base64.StdEncoding.EncodeToString(data)
return fmt.Sprintf("data:%s;base64,%s", mimeType, base64Data), nil
case []any:
out := make([]any, len(v))
for i, item := range v {
transformed, err := transformJSONValuePathsToBase64URLs(item)
if err != nil {
return nil, err
}
out[i] = transformed
}
return out, nil
case map[string]any:
out := make(map[string]any, len(v))
for key, item := range v {
transformed, err := transformJSONValuePathsToBase64URLs(item)
if err != nil {
return nil, err
}
out[key] = transformed
}
return out, nil
default:
return value, nil
}
}

func cmdPredict(cmd *cobra.Command, args []string) error {
Expand Down
43 changes: 43 additions & 0 deletions pkg/cli/predict_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package cli

import (
"os"
"path/filepath"
"testing"

"github.com/getkin/kin-openapi/openapi3"
Expand Down Expand Up @@ -169,3 +171,44 @@ func TestExtractOutputSchemaFromValidSchema(t *testing.T) {
require.NotNil(t, outputSchema, "expected non-nil output schema for valid input")
require.Contains(t, outputSchema.Type.Slice(), "string", "expected string type")
}

func TestTransformPathsToBase64URLsRecursesIntoNestedJSON(t *testing.T) {
dir := t.TempDir()
fileA := filepath.Join(dir, "a.txt")
fileB := filepath.Join(dir, "b.txt")
fileC := filepath.Join(dir, "c.txt")
require.NoError(t, os.WriteFile(fileA, []byte("alpha"), 0o644))
require.NoError(t, os.WriteFile(fileB, []byte("beta"), 0o644))
require.NoError(t, os.WriteFile(fileC, []byte("gamma"), 0o644))

inputs := map[string]any{
"single": "@" + fileA,
"files": []any{
"@" + fileB,
map[string]any{
"inner": "@" + fileC,
},
},
"count": float64(3),
"plain": "hello",
}

transformed, err := transformPathsToBase64URLs(inputs)
require.NoError(t, err)

require.IsType(t, "", transformed["single"])
require.Contains(t, transformed["single"].(string), "data:text/plain;base64,")

files, ok := transformed["files"].([]any)
require.True(t, ok)
require.IsType(t, "", files[0])
require.Contains(t, files[0].(string), "data:text/plain;base64,")

innerObj, ok := files[1].(map[string]any)
require.True(t, ok)
require.IsType(t, "", innerObj["inner"])
require.Contains(t, innerObj["inner"].(string), "data:text/plain;base64,")

require.Equal(t, float64(3), transformed["count"])
require.Equal(t, "hello", transformed["plain"])
}
Loading