diff --git a/architecture/02-schema.md b/architecture/02-schema.md index f038219360..6ed6d16742 100644 --- a/architecture/02-schema.md +++ b/architecture/02-schema.md @@ -102,15 +102,15 @@ class Runner(BaseRunner): The resolver handles local imports relative to the predictor file and project root: -| Import Style | File Resolved | -| ------------------------------------ | ----------------------------------------------------------- | -| `from output_types import X` | `/output_types.py` | -| `from .output_types import X` | `/output_types.py` | -| `from models.output import X` | `/models/output.py` | -| `from .models.output import X` | `/models/output.py` | -| `from output_types import X as Y` | `/output_types.py` (alias tracked) | -| `from .output_types import X as Y` | `/output_types.py` (alias tracked) | -| `from . import output_types` | `/output_types.py` (module alias tracked) | +| Import Style | File Resolved | +| ---------------------------------- | -------------------------------------------------------- | +| `from output_types import X` | `/output_types.py` | +| `from .output_types import X` | `/output_types.py` | +| `from models.output import X` | `/models/output.py` | +| `from .models.output import X` | `/models/output.py` | +| `from output_types import X as Y` | `/output_types.py` (alias tracked) | +| `from .output_types import X as Y` | `/output_types.py` (alias tracked) | +| `from . import output_types` | `/output_types.py` (module alias tracked) | **How it distinguishes local from external**: the resolver converts the module path to a filesystem path and checks if the file exists. If `output_types.py` exists in the project directory, it's local. If not (e.g., `from transformers import ...`), it's external. Known external packages (stdlib, torch, numpy, etc.) are skipped without a filesystem check. @@ -175,18 +175,28 @@ Each `SchemaType` produces its JSON Schema fragment via `JSONSchema()`: ### Input Types -| Python | JSON Schema | Notes | -| ------------------------------------- | ---------------------------------------------------------------- | -------------------------- | -| `str` | `{"type": "string"}` | | -| `int` | `{"type": "integer"}` | | -| `float` | `{"type": "number"}` | | -| `bool` | `{"type": "boolean"}` | | -| `cog.Path` | `{"type": "string", "format": "uri"}` | URLs downloaded at runtime | -| `cog.File` | `{"type": "string", "format": "uri"}` | File uploads | -| `cog.Secret` | `{"type": "string", "format": "password", "x-cog-secret": true}` | Masked in logs | -| `list[T]` | `{"type": "array", "items": {...}}` | | -| `Optional[T]` | Type T + not in `required` | Input fields only | -| `Literal["a", "b"]` / `choices=[...]` | `{"enum": ["a", "b"]}` | | +| Python | JSON Schema | Notes | +| ------------------------------------- | ---------------------------------------------------------------- | --------------------------------------------------------------------- | +| `str` | `{"type": "string"}` | | +| `int` | `{"type": "integer"}` | | +| `float` | `{"type": "number"}` | | +| `bool` | `{"type": "boolean"}` | | +| `cog.Path` | `{"type": "string", "format": "uri"}` | URLs downloaded at runtime | +| `cog.File` | `{"type": "string", "format": "uri"}` | File uploads | +| `cog.Secret` | `{"type": "string", "format": "password", "x-cog-secret": true}` | Masked in logs | +| `list[T]` | `{"type": "array", "items": {...}}` | | +| `Optional[T]` / `T \| None` | Type T + `nullable: true`, not in `required` | Input fields only; never required | +| `A \| B` / `Union[A, B]` | `{"anyOf": [A, B]}` | Input-only, JSON-native unions only | +| `A \| B \| None` | `{"anyOf": [A, B]}` + `nullable: true` | Multi-variant union; stays in `required` unless a default is supplied | +| `Literal["a", "b"]` / `choices=[...]` | `{"enum": ["a", "b"]}` | | + +Input unions are intentionally narrower than output types. Cog supports JSON-native input unions (`str`, `int`, `float`, `bool`, `dict`/`Any`, `list[T]`, and `None`) so request validation can happen at the HTTP boundary and Python normalisation can choose a deterministic value type. Cog rejects unions involving `Path`, `File`, `Secret`, custom coders, and `BaseModel` because those cases are ambiguous for clients or runtime coercion. Output unions remain unsupported (see below). + +A plain single-type optional (`Optional[T]` or `T | None`) is **never** placed in `required`, regardless of whether a default is supplied. A multi-variant nullable union (`A | B | None`) is different: because the field carries a concrete `anyOf` value type, it stays in `required` unless a default makes it omittable. This is why the two rows above differ in their `required` behaviour. + +Nullable behaviour matches every other optional field: `nullable: true` (plus omission from `required`) means an **omitted** value falls back to the default. An **explicit** JSON `null` is still validated against the field type and is rejected at the HTTP edge, because the runtime validator does not treat OpenAPI's `nullable` keyword as an additional accepted value. "May be null" therefore means "may be omitted", not "accepts an explicit null payload". + +> **Runtime caveat:** Cog marks optionals as not-`required` in the schema, but the predictor still needs a Python-level default so the omitted value resolves to `None`. Use `value: Optional[T] = Input(...)` (the `Input(...)` supplies an implicit `None`) or `Input(default=None)`. A bare `value: Optional[T]` annotation with no `= Input(...)` generates a correct "optional" schema but raises `TypeError: missing 1 required positional argument` when the field is omitted at runtime. ### Output Types diff --git a/docs/llms.txt b/docs/llms.txt index e8e9e1f7c7..92c90a36ab 100644 --- a/docs/llms.txt +++ b/docs/llms.txt @@ -2652,6 +2652,30 @@ def run(self, prompt: Optional[str] = Input(description="prompt")) -> str: > [!NOTE] > `Optional[T]` is supported in `BaseModel` output fields but **not** as a top-level return type. Use a `BaseModel` with optional fields instead. +#### `Union` + +Use `A | B` or `Union[A, B]` to accept more than one type for a single input. Cog supports JSON-native union members: `str`, `int`, `float`, `bool`, `dict`/`Any`, `list[T]`, and `None`. + +```python +from cog import BaseRunner, Input + +class Runner(BaseRunner): + def run(self, + value: str | float = Input(description="A string or a number"), + ) -> str: + return f"{type(value).__name__}:{value}" +``` + +At runtime, Cog validates the request against the union and passes the value through as the matching type. For overlapping numeric types, Cog prefers the most specific match (e.g. `bool` before `int`, `int` before `float`), and a JSON integer is accepted for a `float` member. + +Combine a union with `None` to make it nullable: + +```python +def run(self, value: str | float | None = Input(default=None)) -> str: ... +``` + +Union inputs are validated at the HTTP boundary, so unions involving `Path`, `File`, `Secret`, custom coders, and `BaseModel` are **not** supported, and the build fails if you use them. Union return types are also unsupported — use a `BaseModel` output instead. + #### `list` Use `list[T]` or `List[T]` to accept or return a list of values. `T` can be a supported Cog type, but nested container types are not supported. diff --git a/docs/python.md b/docs/python.md index 5f5619a9af..2603101ff9 100644 --- a/docs/python.md +++ b/docs/python.md @@ -620,6 +620,30 @@ def run(self, prompt: Optional[str] = Input(description="prompt")) -> str: > [!NOTE] > `Optional[T]` is supported in `BaseModel` output fields but **not** as a top-level return type. Use a `BaseModel` with optional fields instead. +#### `Union` + +Use `A | B` or `Union[A, B]` to accept more than one type for a single input. Cog supports JSON-native union members: `str`, `int`, `float`, `bool`, `dict`/`Any`, `list[T]`, and `None`. + +```python +from cog import BaseRunner, Input + +class Runner(BaseRunner): + def run(self, + value: str | float = Input(description="A string or a number"), + ) -> str: + return f"{type(value).__name__}:{value}" +``` + +At runtime, Cog validates the request against the union and passes the value through as the matching type. For overlapping numeric types, Cog prefers the most specific match (e.g. `bool` before `int`, `int` before `float`), and a JSON integer is accepted for a `float` member. + +Combine a union with `None` to make it nullable: + +```python +def run(self, value: str | float | None = Input(default=None)) -> str: ... +``` + +Union inputs are validated at the HTTP boundary, so unions involving `Path`, `File`, `Secret`, custom coders, and `BaseModel` are **not** supported, and the build fails if you use them. Union return types are also unsupported — use a `BaseModel` output instead. + #### `list` Use `list[T]` or `List[T]` to accept or return a list of values. `T` can be a supported Cog type, but nested container types are not supported. diff --git a/integration-tests/tests/union_input_cli.txtar b/integration-tests/tests/union_input_cli.txtar new file mode 100644 index 0000000000..aeebd7fa4d --- /dev/null +++ b/integration-tests/tests/union_input_cli.txtar @@ -0,0 +1,45 @@ +# Test schema-directed CLI parsing for JSON-native union inputs. +# +# value: str | float (string member first) +# flipped: float | str (number member first) +# +# - "hello" parses as a string because no numeric parse succeeds +# - "1.5" parses as a float because the union accepts a number +# - "1" parses as an integer (still a valid JSON number for the union) +# +# The `flipped` field exercises the case where resolveSchemaType resolves a +# union to its numeric member first: a non-numeric value must still fall back +# to the string member instead of erroring. +# +# Note: the worker does not coerce primitives at runtime (validation happens +# at the HTTP edge against the OpenAPI schema), so the CLI must choose the +# wire type. A bare integer stays an integer; only fractional values become +# floats. This matches how a plain `float` input also receives a Python int +# for `-i num=10`. + +cog build -t $TEST_IMAGE + +cog predict $TEST_IMAGE -i value=hello -i flipped=world +stdout 'value=str:hello flipped=str:world' + +cog predict $TEST_IMAGE -i value=1.5 -i flipped=2.5 +stdout 'value=float:1.5 flipped=float:2.5' + +cog predict $TEST_IMAGE -i value=1 -i flipped=2 +stdout 'value=int:1 flipped=int:2' + +-- cog.yaml -- +build: + python_version: "3.12" +predict: "predict.py:Predictor" + +-- predict.py -- +from cog import BasePredictor + + +class Predictor(BasePredictor): + def predict(self, value: str | float, flipped: float | str) -> str: + return ( + f"value={type(value).__name__}:{value} " + f"flipped={type(flipped).__name__}:{flipped}" + ) diff --git a/integration-tests/tests/union_input_http.txtar b/integration-tests/tests/union_input_http.txtar new file mode 100644 index 0000000000..7ec032c426 --- /dev/null +++ b/integration-tests/tests/union_input_http.txtar @@ -0,0 +1,54 @@ +# Test JSON-native union inputs over cog serve. +# value: str | float | None = Input(default=None) +# - string accepted, returns "str:..." +# - float accepted, returns "float:..." +# - integer accepted (valid JSON number), returns "int:..." +# - omitted optional defaults to None, returns "NoneType:none" +# - explicit null is rejected, matching how every optional field behaves +# today (validation happens at the HTTP edge and the runtime validator +# does not accept explicit JSON null for a typed field) +# - bool rejected (not a member of str | float), returns a validation error + +cog build -t $TEST_IMAGE + +cog serve + +# String member +curl POST /predictions '{"input":{"value":"hello"}}' +stdout '"output":"str:hello"' + +# Float member +curl POST /predictions '{"input":{"value":1.5}}' +stdout '"output":"float:1.5"' + +# Integer is a valid JSON number for the union; passed through as int +curl POST /predictions '{"input":{"value":1}}' +stdout '"output":"int:1"' + +# Omitted optional value defaults to None +curl POST /predictions '{"input":{}}' +stdout '"output":"NoneType:none"' + +# Explicit null is rejected, consistent with all optional fields +! curl POST /predictions '{"input":{"value":null}}' + +# bool is not a member of str | float -> rejected +! curl POST /predictions '{"input":{"value":true}}' + +# nested object is not a member of str | float -> rejected +! curl POST /predictions '{"input":{"value":{"x":1}}}' + +-- cog.yaml -- +build: + python_version: "3.12" +predict: "predict.py:Predictor" + +-- predict.py -- +from cog import BasePredictor, Input + + +class Predictor(BasePredictor): + def predict(self, value: str | float | None = Input(default=None)) -> str: + if value is None: + return "NoneType:none" + return f"{type(value).__name__}:{value}" diff --git a/integration-tests/tests/union_input_list_http.txtar b/integration-tests/tests/union_input_list_http.txtar new file mode 100644 index 0000000000..e36fa90be5 --- /dev/null +++ b/integration-tests/tests/union_input_list_http.txtar @@ -0,0 +1,48 @@ +# Test list JSON-native union inputs over cog serve. +# +# nums: list[int] | list[float] (required list union) +# +# - list[int] | list[float] accepts [1] and [1.5] and the empty list [] +# - list element types are validated: ["3"] and [true] are rejected +# - integer elements stay int, fractional elements are float (no runtime +# coercion: the wire type is preserved) + +cog build -t $TEST_IMAGE + +cog serve + +# Integer list element kept as int +curl POST /predictions '{"input":{"nums":[1]}}' +stdout '"output":"int:1"' + +# Float list element +curl POST /predictions '{"input":{"nums":[1.5]}}' +stdout '"output":"float:1.5"' + +# Empty list is accepted +curl POST /predictions '{"input":{"nums":[]}}' +stdout '"output":"empty"' + +# String element is not valid for list[int] | list[float] -> rejected +! curl POST /predictions '{"input":{"nums":["3"]}}' + +# bool element is not valid for list[int] | list[float] -> rejected +! curl POST /predictions '{"input":{"nums":[true]}}' + +# A bare scalar is not a list -> rejected +! curl POST /predictions '{"input":{"nums":1}}' + +-- cog.yaml -- +build: + python_version: "3.12" +predict: "predict.py:Predictor" + +-- predict.py -- +from cog import BasePredictor + + +class Predictor(BasePredictor): + def predict(self, nums: list[int] | list[float]) -> str: + if not nums: + return "empty" + return f"{type(nums[0]).__name__}:{nums[0]}" diff --git a/pkg/predict/input.go b/pkg/predict/input.go index 0e17993f31..e0af7d3253 100644 --- a/pkg/predict/input.go +++ b/pkg/predict/input.go @@ -66,11 +66,11 @@ func NewInputsForMode(keyVals map[string][]string, schema *openapi3.T, isTrain b propertiesSchemas := properties.(openapi3.Schemas) property, err := propertiesSchemas.JSONLookup(key) if err == nil { - propertySchema := property.(*openapi3.Schema) + originalSchema := property.(*openapi3.Schema) // Resolve allOf/$ref to find the actual type. // cog-schema-gen emits allOf:[{$ref: ...}] for choices/enums, // where the referenced schema has the concrete type. - propertySchema = resolveSchemaType(propertySchema) + propertySchema := resolveSchemaType(originalSchema) switch { case propertySchema.Type.Is("object"): encodedVal := json.RawMessage(val) @@ -99,6 +99,13 @@ func NewInputsForMode(keyVals map[string][]string, schema *openapi3.T, isTrain b } else { value, err := strconv.ParseFloat(val, 32) if err != nil { + // For a union like `float | str` the schema + // resolves to the numeric member first; a + // non-numeric value should fall back to the + // string member instead of erroring. + if schemaAcceptsString(originalSchema) { + break + } return input, err } float := float32(value) @@ -108,11 +115,48 @@ func NewInputsForMode(keyVals map[string][]string, schema *openapi3.T, isTrain b case propertySchema.Type.Is("integer"): value, err := strconv.ParseInt(val, 10, 32) if err != nil { + // For a union like `int | float` the schema + // resolves to the integer member first; a + // fractional value should fall back to the float + // member instead of erroring. + if schemaAcceptsFloat(originalSchema) { + if value, err := strconv.ParseFloat(val, 32); err == nil { + float := float32(value) + input[key] = Input{Float: &float} + continue + } + } + // See the number case above: fall back to a string + // member for unions such as `int | str`. + if schemaAcceptsString(originalSchema) { + break + } return input, err } valueInt := int32(value) input[key] = Input{Int: &valueInt} continue + case schemaAcceptsNumber(originalSchema): + // Union input (anyOf) that includes a numeric member, e.g. + // `str | float`. Parse numeric-looking values as numbers so + // the runtime receives the intended type; otherwise fall + // through to the string member below. + if value, err := strconv.ParseInt(val, 10, 32); err == nil { + valueInt := int32(value) + input[key] = Input{Int: &valueInt} + continue + } + // Only parse fractional values as float when the + // union actually accepts a float member; otherwise a + // value like `1.5` for `str | int` must fall back to + // the string member below. + if schemaAcceptsFloat(originalSchema) { + if value, err := strconv.ParseFloat(val, 32); err == nil { + float := float32(value) + input[key] = Input{Float: &float} + continue + } + } } } } @@ -188,6 +232,79 @@ func fileToDataURL(filePath string) (string, error) { return dataURL, nil } +// schemaAcceptsString reports whether the schema accepts a string value, +// including union (anyOf) members. This lets CLI `-i` parsing fall back to a +// string member when a numeric parse fails for unions such as `float | str`, +// where resolveSchemaType resolves to the numeric member. +func schemaAcceptsString(s *openapi3.Schema) bool { + if s == nil { + return false + } + if s.Type != nil && s.Type.Is("string") { + return true + } + for _, ref := range s.AnyOf { + if ref.Value != nil && schemaAcceptsString(ref.Value) { + return true + } + } + for _, ref := range s.AllOf { + if ref.Value != nil && schemaAcceptsString(ref.Value) { + return true + } + } + return false +} + +// schemaAcceptsFloat reports whether the schema accepts a floating-point +// value, including union (anyOf) members. Unlike schemaAcceptsNumber, it does +// not match integer-only members, so CLI `-i` parsing can decide whether a +// fractional value like `1.5` is valid for unions such as `int | float` +// (accepts float) versus `str | int` (does not). +func schemaAcceptsFloat(s *openapi3.Schema) bool { + if s == nil { + return false + } + if s.Type != nil && s.Type.Is("number") { + return true + } + for _, ref := range s.AnyOf { + if ref.Value != nil && schemaAcceptsFloat(ref.Value) { + return true + } + } + for _, ref := range s.AllOf { + if ref.Value != nil && schemaAcceptsFloat(ref.Value) { + return true + } + } + return false +} + +// schemaAcceptsNumber reports whether the schema accepts a numeric value, +// including union (anyOf) members. This lets CLI `-i` parsing coerce +// numeric-looking strings for union inputs such as `str | float`, where +// resolveSchemaType resolves to a non-numeric member. +func schemaAcceptsNumber(s *openapi3.Schema) bool { + if s == nil { + return false + } + if s.Type != nil && (s.Type.Is("number") || s.Type.Is("integer")) { + return true + } + for _, ref := range s.AnyOf { + if ref.Value != nil && schemaAcceptsNumber(ref.Value) { + return true + } + } + for _, ref := range s.AllOf { + if ref.Value != nil && schemaAcceptsNumber(ref.Value) { + return true + } + } + return false +} + // resolveSchemaType walks through allOf/anyOf/$ref wrappers to find a schema // that has a concrete Type set. This is needed because the static schema gen // emits allOf:[{$ref: "#/components/schemas/Foo"}] for enum/choices fields, diff --git a/pkg/predict/input_test.go b/pkg/predict/input_test.go new file mode 100644 index 0000000000..b383d3ff93 --- /dev/null +++ b/pkg/predict/input_test.go @@ -0,0 +1,226 @@ +package predict + +import ( + "testing" + + "github.com/getkin/kin-openapi/openapi3" + "github.com/stretchr/testify/require" +) + +// unionInputSchema builds an OpenAPI doc whose single input field `value` +// is a union of string and number. The variant order is configurable so we +// can exercise both `str | float` (string first) and `float | str` (number +// first), which resolve differently via resolveSchemaType. +func unionInputSchema(numberFirst bool) *openapi3.T { + stringRef := openapi3.SchemaRef{Value: &openapi3.Schema{Type: &openapi3.Types{"string"}}} + numberRef := openapi3.SchemaRef{Value: &openapi3.Schema{Type: &openapi3.Types{"number"}}} + anyOf := openapi3.SchemaRefs{&stringRef, &numberRef} + if numberFirst { + anyOf = openapi3.SchemaRefs{&numberRef, &stringRef} + } + valueSchema := &openapi3.Schema{AnyOf: anyOf} + inputSchema := &openapi3.Schema{ + Type: &openapi3.Types{"object"}, + Properties: openapi3.Schemas{ + "value": {Value: valueSchema}, + }, + } + return &openapi3.T{ + Components: &openapi3.Components{ + Schemas: openapi3.Schemas{ + "Input": {Value: inputSchema}, + }, + }, + } +} + +func TestNewInputsForMode_UnionParsesNumber(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + numberFirst bool + val string + wantInt *int32 + wantFlt *float32 + wantStr *string + }{ + // str | float (string member first) + {name: "str|float integer", val: "1", wantInt: ptrI32(1)}, + {name: "str|float float", val: "1.5", wantFlt: ptrF32(1.5)}, + {name: "str|float string", val: "hello", wantStr: ptrStr("hello")}, + // float | str (number member first) -- must still fall back to string + {name: "float|str integer", numberFirst: true, val: "1", wantInt: ptrI32(1)}, + {name: "float|str float", numberFirst: true, val: "1.5", wantFlt: ptrF32(1.5)}, + {name: "float|str string", numberFirst: true, val: "hello", wantStr: ptrStr("hello")}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + schema := unionInputSchema(tt.numberFirst) + inputs, err := NewInputsForMode(map[string][]string{"value": {tt.val}}, schema, false) + require.NoError(t, err) + + got := inputs["value"] + switch { + case tt.wantInt != nil: + require.NotNil(t, got.Int) + require.Equal(t, *tt.wantInt, *got.Int) + case tt.wantFlt != nil: + require.NotNil(t, got.Float) + require.Equal(t, *tt.wantFlt, *got.Float) + case tt.wantStr != nil: + require.NotNil(t, got.String) + require.Equal(t, *tt.wantStr, *got.String) + } + }) + } +} + +// unionInputSchemaOf builds an OpenAPI doc whose single input field `value` +// is a union (anyOf) of the given JSON Schema types, in the given order. +func unionInputSchemaOf(types ...string) *openapi3.T { + anyOf := make(openapi3.SchemaRefs, len(types)) + for i, t := range types { + anyOf[i] = &openapi3.SchemaRef{Value: &openapi3.Schema{Type: &openapi3.Types{t}}} + } + inputSchema := &openapi3.Schema{ + Type: &openapi3.Types{"object"}, + Properties: openapi3.Schemas{ + "value": {Value: &openapi3.Schema{AnyOf: anyOf}}, + }, + } + return &openapi3.T{ + Components: &openapi3.Components{ + Schemas: openapi3.Schemas{ + "Input": {Value: inputSchema}, + }, + }, + } +} + +func TestNewInputsForMode_UnionIntFloatAndStrInt(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + types []string + val string + wantInt *int32 + wantFlt *float32 + wantStr *string + }{ + // int | float: integer member resolves first; a fractional value must + // fall back to the float member instead of erroring. + {name: "int|float integer", types: []string{"integer", "number"}, val: "1", wantInt: ptrI32(1)}, + {name: "int|float fractional", types: []string{"integer", "number"}, val: "1.5", wantFlt: ptrF32(1.5)}, + // str | int: string resolves first; a fractional value is not valid for + // the integer member and must fall back to the string member. + {name: "str|int integer", types: []string{"string", "integer"}, val: "1", wantInt: ptrI32(1)}, + {name: "str|int fractional", types: []string{"string", "integer"}, val: "1.5", wantStr: ptrStr("1.5")}, + {name: "str|int string", types: []string{"string", "integer"}, val: "hello", wantStr: ptrStr("hello")}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + schema := unionInputSchemaOf(tt.types...) + inputs, err := NewInputsForMode(map[string][]string{"value": {tt.val}}, schema, false) + require.NoError(t, err) + + got := inputs["value"] + switch { + case tt.wantInt != nil: + require.NotNil(t, got.Int, "expected int") + require.Equal(t, *tt.wantInt, *got.Int) + case tt.wantFlt != nil: + require.NotNil(t, got.Float, "expected float") + require.Equal(t, *tt.wantFlt, *got.Float) + case tt.wantStr != nil: + require.NotNil(t, got.String, "expected string") + require.Equal(t, *tt.wantStr, *got.String) + } + }) + } +} + +func TestSchemaAcceptsNumber(t *testing.T) { + t.Parallel() + + require.True(t, schemaAcceptsNumber(&openapi3.Schema{Type: &openapi3.Types{"number"}})) + require.True(t, schemaAcceptsNumber(&openapi3.Schema{Type: &openapi3.Types{"integer"}})) + require.False(t, schemaAcceptsNumber(&openapi3.Schema{Type: &openapi3.Types{"string"}})) + require.False(t, schemaAcceptsNumber(nil)) + + union := &openapi3.Schema{ + AnyOf: openapi3.SchemaRefs{ + {Value: &openapi3.Schema{Type: &openapi3.Types{"string"}}}, + {Value: &openapi3.Schema{Type: &openapi3.Types{"number"}}}, + }, + } + require.True(t, schemaAcceptsNumber(union)) + + stringOnlyUnion := &openapi3.Schema{ + AnyOf: openapi3.SchemaRefs{ + {Value: &openapi3.Schema{Type: &openapi3.Types{"string"}}}, + {Value: &openapi3.Schema{Type: &openapi3.Types{"boolean"}}}, + }, + } + require.False(t, schemaAcceptsNumber(stringOnlyUnion)) +} + +func TestSchemaAcceptsString(t *testing.T) { + t.Parallel() + + require.True(t, schemaAcceptsString(&openapi3.Schema{Type: &openapi3.Types{"string"}})) + require.False(t, schemaAcceptsString(&openapi3.Schema{Type: &openapi3.Types{"number"}})) + require.False(t, schemaAcceptsString(nil)) + + union := &openapi3.Schema{ + AnyOf: openapi3.SchemaRefs{ + {Value: &openapi3.Schema{Type: &openapi3.Types{"number"}}}, + {Value: &openapi3.Schema{Type: &openapi3.Types{"string"}}}, + }, + } + require.True(t, schemaAcceptsString(union)) + + numericOnlyUnion := &openapi3.Schema{ + AnyOf: openapi3.SchemaRefs{ + {Value: &openapi3.Schema{Type: &openapi3.Types{"number"}}}, + {Value: &openapi3.Schema{Type: &openapi3.Types{"integer"}}}, + }, + } + require.False(t, schemaAcceptsString(numericOnlyUnion)) +} + +func TestSchemaAcceptsFloat(t *testing.T) { + t.Parallel() + + require.True(t, schemaAcceptsFloat(&openapi3.Schema{Type: &openapi3.Types{"number"}})) + require.False(t, schemaAcceptsFloat(&openapi3.Schema{Type: &openapi3.Types{"integer"}})) + require.False(t, schemaAcceptsFloat(&openapi3.Schema{Type: &openapi3.Types{"string"}})) + require.False(t, schemaAcceptsFloat(nil)) + + intFloatUnion := &openapi3.Schema{ + AnyOf: openapi3.SchemaRefs{ + {Value: &openapi3.Schema{Type: &openapi3.Types{"integer"}}}, + {Value: &openapi3.Schema{Type: &openapi3.Types{"number"}}}, + }, + } + require.True(t, schemaAcceptsFloat(intFloatUnion)) + + strIntUnion := &openapi3.Schema{ + AnyOf: openapi3.SchemaRefs{ + {Value: &openapi3.Schema{Type: &openapi3.Types{"string"}}}, + {Value: &openapi3.Schema{Type: &openapi3.Types{"integer"}}}, + }, + } + require.False(t, schemaAcceptsFloat(strIntUnion)) +} + +func ptrI32(v int32) *int32 { return &v } +func ptrF32(v float32) *float32 { return &v } +func ptrStr(v string) *string { return &v } diff --git a/pkg/schema/openapi.go b/pkg/schema/openapi.go index f15f4471ca..41f604e462 100644 --- a/pkg/schema/openapi.go +++ b/pkg/schema/openapi.go @@ -268,6 +268,52 @@ type enumSchema struct { schema map[string]any } +func inputTypeJSONSchema(it InputType) map[string]any { + var schema map[string]any + switch it.Kind { + case InputKindPrimitive: + schema = it.Primitive.JSONType() + case InputKindAny: + schema = TypeAny.JSONType() + case InputKindArray: + items := TypeAny.JSONType() + if it.Elem != nil { + items = inputTypeJSONSchema(*it.Elem) + } + schema = map[string]any{ + "type": "array", + "items": items, + } + case InputKindUnion: + variants := make([]any, len(it.Variants)) + for i, variant := range it.Variants { + variantSchema := inputTypeJSONSchema(variant) + if it.Nullable { + variantSchema["nullable"] = true + } + variants[i] = variantSchema + } + // A nullable union is represented with OpenAPI's `nullable` keyword + // (set below), matching how plain optional fields behave: an omitted + // value yields the default, while explicit JSON `null` is validated + // against the field type just like any other optional input. + schema = map[string]any{"anyOf": variants} + default: + schema = TypeAny.JSONType() + } + if it.Nullable { + schema["nullable"] = true + } + return schema +} + +func inputSchemaForField(field InputField) map[string]any { + if field.InputType != nil { + return inputTypeJSONSchema(*field.InputType) + } + return field.FieldType.JSONType() +} + // buildInputSchema builds the Input schema object and any enum schemas for choices. func buildInputSchema(info *PredictorInfo) (map[string]any, []enumSchema) { properties := newOrderedMapAny() @@ -314,7 +360,7 @@ func buildInputSchema(info *PredictorInfo) (map[string]any, []enumSchema) { } else { // Regular field — inline type prop["title"] = TitleCase(name) - maps.Copy(prop, field.FieldType.JSONType()) + maps.Copy(prop, inputSchemaForField(field)) } // Determine effective default. A default of None on a non-nullable @@ -332,6 +378,9 @@ func buildInputSchema(info *PredictorInfo) (map[string]any, []enumSchema) { // `Optional[Secret] = Input(default=None)`. See // TestNoneDefaultOnBareSecretIsOptional for the regression case. isNullable := field.FieldType.Repetition == Optional || field.FieldType.Repetition == OptionalRepeated + if field.InputType != nil && field.InputType.Nullable { + isNullable = true + } if field.FieldType.Primitive == TypeSecret && field.FieldType.Repetition == Required && field.Default != nil && field.Default.Kind == DefaultNone { @@ -343,7 +392,8 @@ func buildInputSchema(info *PredictorInfo) (map[string]any, []enumSchema) { } // Required? - if !hasEffectiveDefault && (field.FieldType.Repetition == Required || field.FieldType.Repetition == Repeated) { + isUnionInput := field.InputType != nil && field.InputType.Kind == InputKindUnion + if !hasEffectiveDefault && (isUnionInput || field.FieldType.Repetition == Required || field.FieldType.Repetition == Repeated) { required = append(required, name) } diff --git a/pkg/schema/openapi_test.go b/pkg/schema/openapi_test.go index d9b528bb9f..57669b3e6c 100644 --- a/pkg/schema/openapi_test.go +++ b/pkg/schema/openapi_test.go @@ -40,6 +40,25 @@ func parseSpec(t *testing.T, info *PredictorInfo) map[string]any { return spec } +func extractInputProperty(t *testing.T, raw []byte, name string) string { + t.Helper() + var doc map[string]any + require.NoError(t, json.Unmarshal(raw, &doc)) + components, ok := doc["components"].(map[string]any) + require.True(t, ok) + schemas, ok := components["schemas"].(map[string]any) + require.True(t, ok) + input, ok := schemas["Input"].(map[string]any) + require.True(t, ok) + properties, ok := input["properties"].(map[string]any) + require.True(t, ok) + prop, ok := properties[name] + require.True(t, ok) + out, err := json.Marshal(prop) + require.NoError(t, err) + return string(out) +} + func getPath(m map[string]any, keys ...string) any { var cur any = m for _, k := range keys { @@ -534,6 +553,62 @@ func TestInputOptionalRepeatedType(t *testing.T) { assert.Nil(t, inputSchema["required"]) } +func TestOpenAPIUnionInputStringFloat(t *testing.T) { + inputs := NewOrderedMap[string, InputField]() + inputs.Set("value", InputField{ + Name: "value", + Order: 0, + FieldType: FieldType{Primitive: TypeAny, Repetition: Required}, + InputType: ptr(InputUnionOf( + InputPrimitive(TypeString), + InputPrimitive(TypeFloat), + )), + }) + + out, err := GenerateOpenAPISchema(&PredictorInfo{ + Inputs: inputs, + Output: SchemaPrim(TypeString), + Mode: ModePredict, + }) + require.NoError(t, err) + require.JSONEq(t, `{"anyOf":[{"type":"string"},{"type":"number"}],"title":"Value","x-order":0}`, extractInputProperty(t, out, "value")) +} + +func TestOpenAPIRequiredNullableUnionInput(t *testing.T) { + inputs := NewOrderedMap[string, InputField]() + it := InputUnionOf(InputPrimitive(TypeString), InputPrimitive(TypeFloat)) + it.Nullable = true + inputs.Set("value", InputField{ + Name: "value", + Order: 0, + FieldType: FieldType{Primitive: TypeAny, Repetition: Required}, + InputType: &it, + }) + + out, err := GenerateOpenAPISchema(&PredictorInfo{ + Inputs: inputs, + Output: SchemaPrim(TypeString), + Mode: ModePredict, + }) + require.NoError(t, err) + + var doc map[string]any + require.NoError(t, json.Unmarshal(out, &doc)) + input := doc["components"].(map[string]any)["schemas"].(map[string]any)["Input"].(map[string]any) + require.Equal(t, []any{"value"}, input["required"]) + prop := input["properties"].(map[string]any)["value"].(map[string]any) + require.JSONEq(t, `{ + "anyOf":[ + {"nullable":true,"type":"string"}, + {"nullable":true,"type":"number"} + ], + "nullable":true, + "title":"Value", + "x-order":0 + }`, extractInputProperty(t, out, "value")) + require.Equal(t, true, prop["nullable"]) +} + // --------------------------------------------------------------------------- // Tests: Choices / Enums // --------------------------------------------------------------------------- diff --git a/pkg/schema/python/inputs.go b/pkg/schema/python/inputs.go index 7275eed636..06669f089d 100644 --- a/pkg/schema/python/inputs.go +++ b/pkg/schema/python/inputs.go @@ -262,16 +262,17 @@ func typedParameterParts(node *sitter.Node, source []byte) (string, *sitter.Node return name, typeNode } -func inputField(name string, order int, fieldType schema.FieldType) schema.InputField { +func inputField(name string, order int, inputType schema.InputType, fieldType schema.FieldType) schema.InputField { return schema.InputField{ Name: name, Order: order, FieldType: fieldType, + InputType: &inputType, } } -func inputFieldWithInfo(name string, order int, fieldType schema.FieldType, info inputCallInfo) schema.InputField { - field := inputField(name, order, fieldType) +func inputFieldWithInfo(name string, order int, inputType schema.InputType, fieldType schema.FieldType, info inputCallInfo) schema.InputField { + field := inputField(name, order, inputType, fieldType) field.Default = info.Default field.Description = info.Description field.GE = info.GE @@ -293,12 +294,12 @@ func firstParamIsSelf(params *sitter.Node, source []byte) bool { return false } -func resolveParameterFieldType(typeNode *sitter.Node, source []byte, ctx *inputParseContext) (schema.FieldType, error) { +func resolveParameterInputTypes(typeNode *sitter.Node, source []byte, ctx *inputParseContext) (schema.InputType, schema.FieldType, error) { typeAnn, err := parseTypeAnnotation(typeNode, source) if err != nil { - return schema.FieldType{}, err + return schema.InputType{}, schema.FieldType{}, err } - return schema.ResolveFieldType(typeAnn, ctx.imports, ctx.typedDicts) + return schema.ResolveInputType(typeAnn, ctx.imports, ctx.typedDicts) } func extractInputs( @@ -362,12 +363,13 @@ func parseTypedParameter(node *sitter.Node, source []byte, order int, ctx *input return schema.InputField{}, schema.WrapError(schema.ErrMissingTypeAnnotation, fmt.Sprintf("parameter '%s' on %s has no type annotation", name, ctx.methodName), nil) } - fieldType, err := resolveParameterFieldType(typeNode, source, ctx) + inputType, fieldType, err := resolveParameterInputTypes(typeNode, source, ctx) if err != nil { return schema.InputField{}, err } - return inputField(name, order, fieldType), nil + field := inputField(name, order, inputType, fieldType) + return field, schema.ValidateInputField(field) } func parseTypedDefaultParameter(node *sitter.Node, source []byte, order int, ctx *inputParseContext) (schema.InputField, error) { @@ -382,7 +384,7 @@ func parseTypedDefaultParameter(node *sitter.Node, source []byte, order int, ctx return schema.InputField{}, schema.WrapError(schema.ErrMissingTypeAnnotation, fmt.Sprintf("parameter '%s' on %s has no type annotation", name, ctx.methodName), nil) } - fieldType, err := resolveParameterFieldType(typeNode, source, ctx) + inputType, fieldType, err := resolveParameterInputTypes(typeNode, source, ctx) if err != nil { return schema.InputField{}, err } @@ -396,19 +398,21 @@ func parseTypedDefaultParameter(node *sitter.Node, source []byte, order int, ctx if err != nil { return schema.InputField{}, err } - return inputFieldWithInfo(name, order, fieldType, info), nil + field := inputFieldWithInfo(name, order, inputType, fieldType, info) + return field, schema.ValidateInputField(field) } // 2. Reference to Input() via class attribute or static method if info, ok := resolveInputReference(valNode, source, ctx.registry); ok { - return inputFieldWithInfo(name, order, fieldType, info), nil + field := inputFieldWithInfo(name, order, inputType, fieldType, info) + return field, schema.ValidateInputField(field) } // 3. Plain default — must be statically resolvable if def, ok := resolveDefaultExpr(valNode, source, ctx.scope); ok { - field := inputField(name, order, fieldType) + field := inputField(name, order, inputType, fieldType) field.Default = &def - return field, nil + return field, schema.ValidateInputField(field) } // Can't resolve — hard error @@ -418,7 +422,8 @@ func parseTypedDefaultParameter(node *sitter.Node, source []byte, order int, ctx } // No default — required parameter - return inputField(name, order, fieldType), nil + field := inputField(name, order, inputType, fieldType) + return field, schema.ValidateInputField(field) } func isInputCall(node *sitter.Node, source []byte, imports *schema.ImportContext) bool { diff --git a/pkg/schema/python/parser_test.go b/pkg/schema/python/parser_test.go index 1d26dace12..de0a71f267 100644 --- a/pkg/schema/python/parser_test.go +++ b/pkg/schema/python/parser_test.go @@ -404,6 +404,129 @@ class Predictor(BasePredictor): require.Equal(t, schema.TypeString, name.FieldType.Primitive) } +func TestOptionalInputOpenAPINotRequired(t *testing.T) { + source := []byte(` +from typing import Optional + +class Predictor: + def predict(self, value: Optional[str]) -> str: + return "ok" +`) + + info, err := ParsePredictor(source, "Predictor", schema.ModePredict, "") + require.NoError(t, err) + + out, err := schema.GenerateOpenAPISchema(info) + require.NoError(t, err) + + var doc map[string]any + require.NoError(t, json.Unmarshal(out, &doc)) + components, ok := doc["components"].(map[string]any) + require.True(t, ok) + schemas, ok := components["schemas"].(map[string]any) + require.True(t, ok) + input, ok := schemas["Input"].(map[string]any) + require.True(t, ok) + _, hasRequired := input["required"] + require.False(t, hasRequired) + + properties, ok := input["properties"].(map[string]any) + require.True(t, ok) + prop, ok := properties["value"].(map[string]any) + require.True(t, ok) + require.Equal(t, true, prop["nullable"]) +} + +func TestUnionInputStringFloat(t *testing.T) { + source := []byte(` +class Predictor: + def predict(self, value: str | float) -> str: + return str(value) +`) + + info, err := ParsePredictor(source, "Predictor", schema.ModePredict, "") + require.NoError(t, err) + + value, ok := info.Inputs.Get("value") + require.True(t, ok) + require.NotNil(t, value.InputType) + require.Equal(t, schema.InputKindUnion, value.InputType.Kind) + require.Len(t, value.InputType.Variants, 2) + require.Equal(t, schema.TypeString, value.InputType.Variants[0].Primitive) + require.Equal(t, schema.TypeFloat, value.InputType.Variants[1].Primitive) + require.False(t, value.InputType.Nullable) +} + +func TestUnionInputStringFloatNone(t *testing.T) { + source := []byte(` +from cog import Input + +class Predictor: + def predict(self, value: str | float | None = Input(default=None)) -> str: + return "ok" +`) + + info, err := ParsePredictor(source, "Predictor", schema.ModePredict, "") + require.NoError(t, err) + + value, ok := info.Inputs.Get("value") + require.True(t, ok) + require.NotNil(t, value.InputType) + require.Equal(t, schema.InputKindUnion, value.InputType.Kind) + require.True(t, value.InputType.Nullable) + require.NotNil(t, value.Default) + require.Equal(t, schema.DefaultNone, value.Default.Kind) +} + +func TestUnionInputNullableWithoutDefaultOpenAPI(t *testing.T) { + source := []byte(` +class Predictor: + def predict(self, value: str | float | None) -> str: + return "ok" +`) + + info, err := ParsePredictor(source, "Predictor", schema.ModePredict, "") + require.NoError(t, err) + + out, err := schema.GenerateOpenAPISchema(info) + require.NoError(t, err) + + var doc map[string]any + require.NoError(t, json.Unmarshal(out, &doc)) + components, ok := doc["components"].(map[string]any) + require.True(t, ok) + schemas, ok := components["schemas"].(map[string]any) + require.True(t, ok) + input, ok := schemas["Input"].(map[string]any) + require.True(t, ok) + require.Equal(t, []any{"value"}, input["required"]) + + properties, ok := input["properties"].(map[string]any) + require.True(t, ok) + prop, ok := properties["value"].(map[string]any) + require.True(t, ok) + require.Equal(t, true, prop["nullable"]) + require.Equal(t, []any{ + map[string]any{"nullable": true, "type": "string"}, + map[string]any{"nullable": true, "type": "number"}, + }, prop["anyOf"]) +} + +func TestUnionInputRejectsPathString(t *testing.T) { + source := []byte(` +from cog import Path + +class Predictor: + def predict(self, value: Path | str) -> str: + return "ok" +`) + + _, err := ParsePredictor(source, "Predictor", schema.ModePredict, "") + require.Error(t, err) + require.Contains(t, err.Error(), "Path") + require.Contains(t, err.Error(), "union") +} + // --------------------------------------------------------------------------- // List inputs // --------------------------------------------------------------------------- diff --git a/pkg/schema/types.go b/pkg/schema/types.go index 4013424a69..0d86535801 100644 --- a/pkg/schema/types.go +++ b/pkg/schema/types.go @@ -98,6 +98,48 @@ type FieldType struct { Repetition Repetition } +// InputTypeKind tags the recursive input type representation. +type InputTypeKind int + +const ( + InputKindPrimitive InputTypeKind = iota + InputKindAny + InputKindArray + InputKindUnion +) + +// InputType represents JSON-native input types, including unions. +type InputType struct { + Kind InputTypeKind + Primitive PrimitiveType + Elem *InputType + Variants []InputType + Nullable bool +} + +// InputPrimitive creates a primitive input type. +func InputPrimitive(primitive PrimitiveType) InputType { + if primitive == TypeAny { + return InputAnyType() + } + return InputType{Kind: InputKindPrimitive, Primitive: primitive} +} + +// InputAnyType creates an opaque JSON input type. +func InputAnyType() InputType { + return InputType{Kind: InputKindAny, Primitive: TypeAny} +} + +// InputArrayOf creates an array input type. +func InputArrayOf(elem InputType) InputType { + return InputType{Kind: InputKindArray, Elem: &elem} +} + +// InputUnionOf creates a union input type. +func InputUnionOf(variants ...InputType) InputType { + return InputType{Kind: InputKindUnion, Variants: variants} +} + // JSONType returns the JSON Schema fragment for this field type. func (ft FieldType) JSONType() map[string]any { if ft.Repetition == Repeated || ft.Repetition == OptionalRepeated { @@ -179,6 +221,7 @@ type InputField struct { Name string Order int FieldType FieldType + InputType *InputType Default *DefaultValue Description *string GE *float64 @@ -195,6 +238,16 @@ func (f *InputField) IsRequired() bool { return f.Default == nil && (f.FieldType.Repetition == Required || f.FieldType.Repetition == Repeated) } +// ValidateInputField checks combinations unsupported by the static input model. +func ValidateInputField(field InputField) error { + if field.InputType != nil && field.InputType.Kind == InputKindUnion { + if len(field.Choices) > 0 || field.GE != nil || field.LE != nil || field.MinLength != nil || field.MaxLength != nil || field.Regex != nil { + return errUnsupportedType("constraints and choices are not supported on union inputs") + } + } + return nil +} + // PredictorInfo is the top-level extraction result. type PredictorInfo struct { Inputs *OrderedMap[string, InputField] @@ -434,6 +487,208 @@ func ResolveFieldType(ann TypeAnnotation, ctx *ImportContext, typedDicts map[str return FieldType{}, errUnsupportedType("unknown type annotation") } +// ResolveInputType resolves a TypeAnnotation into the recursive input type model +// and the legacy FieldType compatibility layer. +func ResolveInputType(ann TypeAnnotation, ctx *ImportContext, typedDicts map[string]bool) (InputType, FieldType, error) { + inputType, err := resolveInputType(ann, ctx, typedDicts) + if err != nil { + return InputType{}, FieldType{}, err + } + return inputType, fieldTypeFromInputType(inputType), nil +} + +func resolveInputType(ann TypeAnnotation, ctx *ImportContext, typedDicts map[string]bool) (InputType, error) { + if inner, ok := unwrapOpaqueAnnotated(ann, ctx); ok { + return inputTypeFromFieldType(opaqueFieldType(inner, ctx)), nil + } + + switch ann.Kind { + case TypeAnnotSimple: + name := ann.Name + if typedDicts[name] { + return InputAnyType(), nil + } + qualifiedEntry := ImportEntry{} + if resolved, entry, ok := ctx.ResolveQualifiedName(name); ok { + name = resolved + qualifiedEntry = entry + if typedDicts[entry.Original+"."+name] { + return InputAnyType(), nil + } + } + if typedDicts[name] { + return InputAnyType(), nil + } + if name == "dict" || name == "Dict" { + return InputAnyType(), nil + } + prim, ok := PrimitiveFromName(name) + if !ok { + if qualifiedEntry.Module != "" { + return InputType{}, errUnresolvableImportedType(name, qualifiedEntry.Module) + } + if entry, imported := ctx.Names.Get(name); imported { + return InputType{}, errUnresolvableImportedType(name, entry.Module) + } + return InputType{}, errUnsupportedType(name) + } + return InputPrimitive(prim), nil + + case TypeAnnotGeneric: + outer := ann.Name + if resolved, _, ok := ctx.ResolveQualifiedName(outer); ok { + outer = resolved + } + if outer == "dict" || outer == "Dict" { + return InputAnyType(), nil + } + if outer == "Optional" { + if len(ann.Args) != 1 { + return InputType{}, errUnsupportedType(fmt.Sprintf("Optional expects exactly 1 type argument, got %d", len(ann.Args))) + } + inner, err := resolveInputType(ann.Args[0], ctx, typedDicts) + if err != nil { + return InputType{}, err + } + inner.Nullable = true + return inner, nil + } + if outer == "Union" { + return resolveInputUnion(ann.Args, ctx, typedDicts) + } + if ctx.isAnnotated(ann.Name) { + if len(ann.Args) == 0 { + return InputType{}, errUnsupportedType("Annotated expects at least 1 type argument") + } + return resolveInputType(ann.Args[0], ctx, typedDicts) + } + if outer == "List" || outer == "list" { + if len(ann.Args) != 1 { + return InputType{}, errUnsupportedType(fmt.Sprintf("List expects exactly 1 type argument, got %d", len(ann.Args))) + } + if opaqueInner, ok := unwrapOpaqueAnnotated(ann.Args[0], ctx); ok { + inner := inputTypeFromFieldType(opaqueFieldType(opaqueInner, ctx)) + if inner.Nullable || inner.Kind == InputKindArray || inner.Kind == InputKindUnion { + return InputType{}, errUnsupportedType("nested generics like List[Optional[X]] are not supported") + } + return InputArrayOf(inner), nil + } + inner, err := resolveInputType(ann.Args[0], ctx, typedDicts) + if err != nil { + return InputType{}, err + } + if inner.Nullable || inner.Kind == InputKindArray || inner.Kind == InputKindUnion { + return InputType{}, errUnsupportedType("nested generics like List[Optional[X]] are not supported") + } + return InputArrayOf(inner), nil + } + return InputType{}, errUnsupportedType(fmt.Sprintf("%s[...] is not a supported input type", outer)) + + case TypeAnnotUnion: + return resolveInputUnion(ann.Args, ctx, typedDicts) + } + return InputType{}, errUnsupportedType("unknown type annotation") +} + +func resolveInputUnion(args []TypeAnnotation, ctx *ImportContext, typedDicts map[string]bool) (InputType, error) { + variants := make([]InputType, 0, len(args)) + nullable := false + for _, arg := range args { + if arg.Kind == TypeAnnotSimple && arg.Name == "None" { + nullable = true + continue + } + if arg.Kind == TypeAnnotUnion || (arg.Kind == TypeAnnotGeneric && arg.Name == "Union") { + return InputType{}, errUnsupportedType("nested union inputs are not supported") + } + variant, err := resolveInputType(arg, ctx, typedDicts) + if err != nil { + return InputType{}, err + } + if variant.Kind == InputKindUnion { + return InputType{}, errUnsupportedType("nested union inputs are not supported") + } + variants = append(variants, variant) + } + + if len(variants) == 0 { + return InputType{}, errUnsupportedType("union inputs must include at least one non-None type") + } + if len(variants) == 1 { + variant := variants[0] + variant.Nullable = variant.Nullable || nullable + return variant, nil + } + for _, variant := range variants { + if err := validateUnionVariant(variant); err != nil { + return InputType{}, err + } + } + + union := InputUnionOf(variants...) + union.Nullable = nullable + return union, nil +} + +func validateUnionVariant(inputType InputType) error { + if inputType.Nullable { + return errUnsupportedType("nested nullable variants are not supported in union inputs") + } + switch inputType.Kind { + case InputKindPrimitive: + if inputType.Primitive == TypePath || inputType.Primitive == TypeFile || inputType.Primitive == TypeSecret { + return errUnsupportedType(fmt.Sprintf("%s is not supported in union inputs", inputType.Primitive)) + } + case InputKindArray: + if inputType.Elem != nil { + return validateUnionVariant(*inputType.Elem) + } + case InputKindUnion: + return errUnsupportedType("nested union inputs are not supported") + } + return nil +} + +func inputTypeFromFieldType(fieldType FieldType) InputType { + var inputType InputType + if fieldType.Primitive == TypeAny { + inputType = InputAnyType() + } else { + inputType = InputPrimitive(fieldType.Primitive) + } + if fieldType.Repetition == Repeated || fieldType.Repetition == OptionalRepeated { + inputType = InputArrayOf(inputType) + } + if fieldType.Repetition == Optional || fieldType.Repetition == OptionalRepeated { + inputType.Nullable = true + } + return inputType +} + +func fieldTypeFromInputType(inputType InputType) FieldType { + repetition := Required + if inputType.Nullable { + repetition = Optional + } + switch inputType.Kind { + case InputKindPrimitive: + return FieldType{Primitive: inputType.Primitive, Repetition: repetition} + case InputKindArray: + arrayRepetition := Repeated + if inputType.Nullable { + arrayRepetition = OptionalRepeated + } + if inputType.Elem != nil && inputType.Elem.Kind == InputKindPrimitive { + return FieldType{Primitive: inputType.Elem.Primitive, Repetition: arrayRepetition} + } + return FieldType{Primitive: TypeAny, Repetition: arrayRepetition} + case InputKindAny, InputKindUnion: + return FieldType{Primitive: TypeAny, Repetition: repetition} + default: + return FieldType{Primitive: TypeAny, Repetition: repetition} + } +} + func unwrapOpaqueAnnotated(ann TypeAnnotation, ctx *ImportContext) (TypeAnnotation, bool) { if ann.Kind != TypeAnnotGeneric || !ctx.isAnnotated(ann.Name) || len(ann.Args) < 2 { return ann, false diff --git a/python/cog/_adt.py b/python/cog/_adt.py index f40f870e47..3fea80cc27 100644 --- a/python/cog/_adt.py +++ b/python/cog/_adt.py @@ -37,6 +37,10 @@ def _is_union(tpe: type) -> bool: return False +def _is_none_type(tpe: Any) -> bool: + return tpe is None or tpe is type(None) + + def _is_dict_like(tpe: Any) -> bool: """Check if a type should be treated like a dict, including TypedDict.""" if tpe is dict: @@ -52,7 +56,7 @@ def _is_dict_like(tpe: Any) -> bool: pass is_typeddict = getattr(typing, "is_typeddict", None) - return callable(is_typeddict) and is_typeddict(tpe) + return bool(callable(is_typeddict) and is_typeddict(tpe)) def _unwrap_opaque(tpe: Any) -> tuple[Any, bool]: @@ -249,6 +253,90 @@ class Repetition(Enum): OPTIONAL_REPEATED = 4 # list[X] | None +def _is_supported_union_variant(ft: "FieldType") -> bool: + if ft.union_variants is not None: + return False + if ft.primitive in { + PrimitiveType.PATH, + PrimitiveType.FILE, + PrimitiveType.SECRET, + PrimitiveType.CUSTOM, + }: + return False + return ft.repetition in {Repetition.REQUIRED, Repetition.REPEATED} + + +def _is_exact_union_match(value: Any, ft: "FieldType") -> bool: + if ft.repetition is Repetition.REPEATED: + return isinstance(value, list) + if ft.repetition is not Repetition.REQUIRED: + return False + if ft.primitive is PrimitiveType.BOOL: + return isinstance(value, bool) + if ft.primitive is PrimitiveType.INTEGER: + return isinstance(value, int) and not isinstance(value, bool) + if ft.primitive is PrimitiveType.FLOAT: + return isinstance(value, float) + if ft.primitive is PrimitiveType.STRING: + return isinstance(value, str) + if ft.primitive is PrimitiveType.ANY: + return isinstance(value, dict) + return False + + +def _union_primitive_accepts_value(value: Any, primitive: PrimitiveType) -> bool: + if primitive is PrimitiveType.BOOL: + return type(value) is bool + if primitive is PrimitiveType.INTEGER: + return type(value) is int + if primitive is PrimitiveType.FLOAT: + return type(value) is float or type(value) is int + if primitive is PrimitiveType.STRING: + return type(value) is str + if primitive is PrimitiveType.ANY: + return isinstance(value, dict) + return False + + +def _union_variant_accepts_value(value: Any, variant: "FieldType") -> bool: + if variant.repetition is Repetition.REPEATED: + return isinstance(value, list) and all( + _union_primitive_accepts_value(element, variant.primitive) + for element in value + ) + if variant.repetition is not Repetition.REQUIRED: + return False + return _union_primitive_accepts_value(value, variant.primitive) + + +def _union_variant_priority(ft: "FieldType") -> int: + primitive_priority = { + PrimitiveType.BOOL: 0, + PrimitiveType.INTEGER: 1, + PrimitiveType.FLOAT: 2, + PrimitiveType.STRING: 3, + PrimitiveType.ANY: 4, + } + repetition_offset = 10 if ft.repetition is Repetition.REPEATED else 0 + return repetition_offset + primitive_priority.get(ft.primitive, 100) + + +def _ordered_union_variants( + value: Any, variants: List["FieldType"] +) -> List["FieldType"]: + return [ + variant + for _, variant in sorted( + enumerate(variants), + key=lambda item: ( + not _is_exact_union_match(value, item[1]), + _union_variant_priority(item[1]), + item[0], + ), + ) + ] + + @dataclass(frozen=True) class FieldType: """Type information for an input/output field.""" @@ -256,6 +344,7 @@ class FieldType: primitive: PrimitiveType repetition: Repetition coder: Optional[Coder] + union_variants: Optional[List["FieldType"]] = None @staticmethod def from_type(tpe: type) -> "FieldType": @@ -317,9 +406,43 @@ def from_type(tpe: type) -> "FieldType": elif _is_union(tpe): t_args = typing.get_args(tpe) - if not (len(t_args) == 2 and type(None) in t_args): - raise ValueError(f"unsupported union type {tpe}") - elem_t = t_args[0] if t_args[1] is type(None) else t_args[1] + has_none = any(_is_none_type(arg) for arg in t_args) + non_none_args = [arg for arg in t_args if not _is_none_type(arg)] + + if len(non_none_args) != 1: + repetition = Repetition.OPTIONAL if has_none else Repetition.REQUIRED + variants = [] + for arg in non_none_args: + try: + variant = FieldType.from_type(arg) + except ValueError as exc: + raise ValueError( + f"unsupported union member {_type_name(arg)} in union {tpe}" + ) from exc + if not _is_supported_union_variant(variant): + raise ValueError( + f"unsupported union member {_type_name(arg)} in union {tpe}" + ) + variants.append(variant) + return FieldType( + primitive=PrimitiveType.ANY, + repetition=repetition, + coder=None, + union_variants=variants, + ) + + if not has_none: + elem_t = non_none_args[0] + repetition = Repetition.REQUIRED + cog_t = PrimitiveType.from_type(elem_t) + coder = None + if cog_t is PrimitiveType.CUSTOM: + coder = Coder.lookup(elem_t) + if coder is None: + raise ValueError(f"unsupported Cog type {_type_name(elem_t)}") + return FieldType(primitive=cog_t, repetition=repetition, coder=coder) + + elem_t = non_none_args[0] inner, elem_is_opaque = _unwrap_opaque(elem_t) if elem_is_opaque: return _opaque_field_type(inner | None) @@ -384,6 +507,20 @@ def from_type(tpe: type) -> "FieldType": def normalize(self, value: Any) -> Any: """Normalize a value according to this field type.""" + if self.union_variants is not None: + if value is None: + if self.repetition is Repetition.OPTIONAL: + return None + raise ValueError("missing value for required union field") + for variant in _ordered_union_variants(value, self.union_variants): + if not _union_variant_accepts_value(value, variant): + continue + try: + return variant.normalize(value) + except (TypeError, ValueError): + pass + raise ValueError(f"failed to normalize value as {self.python_type_name()}") + if self.repetition is Repetition.REQUIRED: return self.primitive.normalize(value) elif self.repetition is Repetition.OPTIONAL: @@ -398,6 +535,14 @@ def normalize(self, value: Any) -> Any: def json_type(self) -> Dict[str, Any]: """Get the JSON Schema type for this field.""" + if self.union_variants is not None: + jt: Dict[str, Any] = { + "anyOf": [variant.json_type() for variant in self.union_variants] + } + if self.repetition is Repetition.OPTIONAL: + jt["nullable"] = True + return jt + if self.repetition in (Repetition.REPEATED, Repetition.OPTIONAL_REPEATED): return {"type": "array", "items": self.primitive.json_type()} return self.primitive.json_type() @@ -428,6 +573,14 @@ def json_decode(self, value: Any) -> Any: def python_type_name(self) -> str: """Get the Python type name for this field.""" + if self.union_variants is not None: + name = " | ".join( + variant.python_type_name() for variant in self.union_variants + ) + if self.repetition is Repetition.OPTIONAL: + return f"Optional[{name}]" + return name + if self.repetition is Repetition.REQUIRED: return self.primitive.python_type_name() elif self.repetition is Repetition.OPTIONAL: diff --git a/python/tests/test_adt.py b/python/tests/test_adt.py index ebffeb7d72..22800bf36b 100644 --- a/python/tests/test_adt.py +++ b/python/tests/test_adt.py @@ -2,6 +2,7 @@ from typing import Annotated, Any, Dict, List, Optional, TypedDict +import pytest from typing_extensions import TypedDict as ExtensionsTypedDict from cog import Opaque @@ -322,3 +323,106 @@ def test_optional_str(self) -> None: ft = FieldType.from_type(Optional[str]) assert ft.primitive is PrimitiveType.STRING assert ft.repetition is Repetition.OPTIONAL + + +class TestUnionInputTypes: + def test_union_str_float_field_type(self) -> None: + ft = FieldType.from_type(str | float) + assert ft.primitive is PrimitiveType.ANY + assert ft.repetition is Repetition.REQUIRED + assert ft.union_variants is not None + assert [v.primitive for v in ft.union_variants] == [ + PrimitiveType.STRING, + PrimitiveType.FLOAT, + ] + + def test_union_str_float_none_field_type(self) -> None: + ft = FieldType.from_type(str | float | None) + assert ft.repetition is Repetition.OPTIONAL + assert ft.union_variants is not None + + def test_union_int_float_prefers_int(self) -> None: + ft = FieldType.from_type(int | float) + assert ft.normalize(1) == 1 + assert isinstance(ft.normalize(1), int) + + def test_union_bool_int_prefers_bool(self) -> None: + ft = FieldType.from_type(bool | int) + value = ft.normalize(True) + assert value is True + + def test_union_int_float_rejects_bool(self) -> None: + ft = FieldType.from_type(int | float) + with pytest.raises(ValueError): + ft.normalize(True) + + def test_union_str_bool_rejects_int(self) -> None: + ft = FieldType.from_type(str | bool) + with pytest.raises(ValueError): + ft.normalize(1) + + def test_union_str_dict_rejects_scalar(self) -> None: + ft = FieldType.from_type(str | dict) + with pytest.raises(ValueError): + ft.normalize(123) + + def test_union_list_int_float_rejects_bool_element(self) -> None: + ft = FieldType.from_type(list[int] | list[float]) + with pytest.raises(ValueError): + ft.normalize([True]) + + def test_union_list_int_float_rejects_string_element(self) -> None: + ft = FieldType.from_type(list[int] | list[float]) + with pytest.raises(ValueError): + ft.normalize(["3"]) + + def test_union_list_int_float_accepts_numeric_elements(self) -> None: + ft = FieldType.from_type(list[int] | list[float]) + assert ft.normalize([1]) == [1] + assert isinstance(ft.normalize([1])[0], int) + assert ft.normalize([1.5]) == [1.5] + + def test_union_optional_normalize_none(self) -> None: + ft = FieldType.from_type(str | float | None) + assert ft.repetition is Repetition.OPTIONAL + assert ft.normalize(None) is None + + def test_union_required_normalize_none_raises(self) -> None: + ft = FieldType.from_type(str | float) + assert ft.repetition is Repetition.REQUIRED + with pytest.raises(ValueError): + ft.normalize(None) + + def test_union_required_json_type_omits_nullable(self) -> None: + ft = FieldType.from_type(int | str) + assert ft.json_type() == { + "anyOf": [{"type": "integer"}, {"type": "string"}], + } + + def test_union_list_int_float_accepts_empty_list(self) -> None: + ft = FieldType.from_type(list[int] | list[float]) + assert ft.normalize([]) == [] + + def test_union_mixed_scalar_and_list(self) -> None: + ft = FieldType.from_type(list[int] | int) + assert ft.normalize(5) == 5 + assert isinstance(ft.normalize(5), int) + assert ft.normalize([5]) == [5] + + def test_union_str_float_none_json_type(self) -> None: + ft = FieldType.from_type(str | float | None) + assert ft.json_type() == { + "anyOf": [{"type": "string"}, {"type": "number"}], + "nullable": True, + } + + def test_union_rejects_path_string(self) -> None: + from cog import Path + + try: + FieldType.from_type(Path | str) + except ValueError as exc: + assert "Path" in str(exc) + assert "union" in str(exc) + else: + raise AssertionError("Expected ValueError for Path | str") diff --git a/python/tests/test_inspector.py b/python/tests/test_inspector.py index 7597a783c2..bf1307aae9 100644 --- a/python/tests/test_inspector.py +++ b/python/tests/test_inspector.py @@ -168,6 +168,22 @@ def predict(self, value: Annotated[ExternalObject, Opaque]) -> str: assert field.type.repetition is adt.Repetition.REQUIRED +def test_inspector_supports_union_input() -> None: + class Predictor: + def predict(self, value: str | float) -> str: + return str(value) + + info = _create_predictor_info( + "predict", "Predictor", Predictor.predict, "predict", True + ) + field = info.inputs["value"] + assert field.type.union_variants is not None + assert [v.primitive for v in field.type.union_variants] == [ + adt.PrimitiveType.STRING, + adt.PrimitiveType.FLOAT, + ] + + def test_inspector_preserves_opaque_list_input_metadata() -> None: class Predictor: def predict(self, value: Annotated[List[ExternalObject], Opaque]) -> str: