From 2526fc7675b1f21ab036fb96453fb5df97e33de0 Mon Sep 17 00:00:00 2001 From: Dawei Date: Wed, 4 Feb 2026 10:39:16 -0800 Subject: [PATCH 1/2] Fix mixed anyOf unions --- src/replit_river/codegen/client.py | 76 +++++++---- .../test_anyof_mixed_types/__init__.py | 13 ++ .../test_service/__init__.py | 44 +++++++ .../test_service/anyof_mixed_method.py | 118 ++++++++++++++++++ tests/v1/codegen/snapshot/test_anyof_mixed.py | 20 +++ .../v1/codegen/types/anyof_mixed_schema.json | 85 +++++++++++++ 6 files changed, 333 insertions(+), 23 deletions(-) create mode 100644 tests/v1/codegen/snapshot/snapshots/test_anyof_mixed_types/__init__.py create mode 100644 tests/v1/codegen/snapshot/snapshots/test_anyof_mixed_types/test_service/__init__.py create mode 100644 tests/v1/codegen/snapshot/snapshots/test_anyof_mixed_types/test_service/anyof_mixed_method.py create mode 100644 tests/v1/codegen/snapshot/test_anyof_mixed.py create mode 100644 tests/v1/codegen/types/anyof_mixed_schema.json diff --git a/src/replit_river/codegen/client.py b/src/replit_river/codegen/client.py index 1d85526d..8c53d340 100644 --- a/src/replit_river/codegen/client.py +++ b/src/replit_river/codegen/client.py @@ -389,7 +389,8 @@ def {_field_name}( type = original_type any_of: list[TypeExpression] = [] - typeddict_encoder = [] + # Collect (type_check, encoder_expr) pairs for building ternary chain + encoder_parts: list[tuple[str | None, str]] = [] for i, t in enumerate(type.anyOf): type_name, _, contents, _ = encode_type( t, @@ -403,34 +404,63 @@ def {_field_name}( chunks.extend(contents) if isinstance(t, RiverConcreteType): if t.type == "string": - typeddict_encoder.extend(["x", " if isinstance(x, str) else "]) - else: - # TODO(dstewart): This structure changed since we were incorrectly - # leaking ListTypeExprs into codegen. This generated - # code is probably wrong. + encoder_parts.append(("isinstance(x, str)", "x")) + elif t.type == "array": match type_name: case ListTypeExpr(inner_type_name): - typeddict_encoder.append( - f"encode_{render_literal_type(inner_type_name)}(x)" + # Primitives don't need encoding + inner_type_str = render_literal_type(inner_type_name) + if inner_type_str in ("str", "int", "float", "bool", "Any"): + encoder_parts.append(("isinstance(x, list)", "list(x)")) + else: + encoder_parts.append( + ( + "isinstance(x, list)", + f"[encode_{inner_type_str}(y) for y in x]", + ) + ) + case _: + encoder_parts.append(("isinstance(x, list)", "list(x)")) + elif t.type == "object": + match type_name: + case TypeName(value): + encoder_parts.append( + ("isinstance(x, dict)", f"encode_{value}(x)") ) + case _: + encoder_parts.append(("isinstance(x, dict)", "dict(x)")) + elif t.type in ("number", "integer"): + match type_name: case LiteralTypeExpr(const): - typeddict_encoder.append(repr(const)) + encoder_parts.append((f"x == {repr(const)}", repr(const))) + case _: + encoder_parts.append(("isinstance(x, (int, float))", "x")) + elif t.type == "boolean": + encoder_parts.append(("isinstance(x, bool)", "x")) + elif t.type == "null" or t.type == "undefined": + encoder_parts.append(("x is None", "None")) + else: + # Fallback for other types + match type_name: case TypeName(value): - typeddict_encoder.append(f"encode_{value}(x)") + encoder_parts.append((None, f"encode_{value}(x)")) + case LiteralTypeExpr(const): + encoder_parts.append((None, repr(const))) case NoneTypeExpr(): - typeddict_encoder.append("None") - case other: - _o2: ( - DictTypeExpr - | OpenUnionTypeExpr - | UnionTypeExpr - | LiteralType - ) = other - raise ValueError( - f"What does it mean to have { - render_type_expr(_o2) - } here?" - ) + encoder_parts.append((None, "None")) + case _: + encoder_parts.append((None, "x")) + + # Build the ternary chain from encoder_parts + typeddict_encoder: list[str] = [] + for i, (type_check, encoder_expr) in enumerate(encoder_parts): + is_last = i == len(encoder_parts) - 1 + if is_last or type_check is None: + # Last item or no type check - just the expression + typeddict_encoder.append(encoder_expr) + else: + # Add expression with type check + typeddict_encoder.append(f"{encoder_expr} if {type_check} else") if permit_unknown_members: union = _make_open_union_type_expr(any_of) else: diff --git a/tests/v1/codegen/snapshot/snapshots/test_anyof_mixed_types/__init__.py b/tests/v1/codegen/snapshot/snapshots/test_anyof_mixed_types/__init__.py new file mode 100644 index 00000000..bb6db710 --- /dev/null +++ b/tests/v1/codegen/snapshot/snapshots/test_anyof_mixed_types/__init__.py @@ -0,0 +1,13 @@ +# Code generated by river.codegen. DO NOT EDIT. +from pydantic import BaseModel +from typing import Literal + +import replit_river as river + + +from .test_service import Test_ServiceService + + +class AnyOfMixedClient: + def __init__(self, client: river.Client[Literal[None]]): + self.test_service = Test_ServiceService(client) diff --git a/tests/v1/codegen/snapshot/snapshots/test_anyof_mixed_types/test_service/__init__.py b/tests/v1/codegen/snapshot/snapshots/test_anyof_mixed_types/test_service/__init__.py new file mode 100644 index 00000000..061cf29b --- /dev/null +++ b/tests/v1/codegen/snapshot/snapshots/test_anyof_mixed_types/test_service/__init__.py @@ -0,0 +1,44 @@ +# Code generated by river.codegen. DO NOT EDIT. +from collections.abc import AsyncIterable, AsyncIterator +from typing import Any +import datetime + +from pydantic import TypeAdapter + +from replit_river.error_schema import RiverError, RiverErrorTypeAdapter +import replit_river as river + + +from .anyof_mixed_method import ( + Anyof_Mixed_MethodInput, + Anyof_Mixed_MethodOutput, + Anyof_Mixed_MethodOutputTypeAdapter, + encode_Anyof_Mixed_MethodInput, + encode_Anyof_Mixed_MethodInputNumber_Or_String, + encode_Anyof_Mixed_MethodInputRun_Command, + encode_Anyof_Mixed_MethodInputValue_Or_Null, +) + + +class Test_ServiceService: + def __init__(self, client: river.Client[Any]): + self.client = client + + async def anyof_mixed_method( + self, + input: Anyof_Mixed_MethodInput, + timeout: datetime.timedelta, + ) -> Anyof_Mixed_MethodOutput: + return await self.client.send_rpc( + "test_service", + "anyof_mixed_method", + input, + encode_Anyof_Mixed_MethodInput, + lambda x: Anyof_Mixed_MethodOutputTypeAdapter.validate_python( + x # type: ignore[arg-type] + ), + lambda x: RiverErrorTypeAdapter.validate_python( + x # type: ignore[arg-type] + ), + timeout, + ) diff --git a/tests/v1/codegen/snapshot/snapshots/test_anyof_mixed_types/test_service/anyof_mixed_method.py b/tests/v1/codegen/snapshot/snapshots/test_anyof_mixed_types/test_service/anyof_mixed_method.py new file mode 100644 index 00000000..4846c795 --- /dev/null +++ b/tests/v1/codegen/snapshot/snapshots/test_anyof_mixed_types/test_service/anyof_mixed_method.py @@ -0,0 +1,118 @@ +# Code generated by river.codegen. DO NOT EDIT. +from collections.abc import AsyncIterable, AsyncIterator +import datetime +from typing import ( + Any, + Literal, + Mapping, + NotRequired, + TypedDict, +) +from typing_extensions import Annotated + +from pydantic import BaseModel, Field, TypeAdapter, WrapValidator +from replit_river.error_schema import RiverError +from replit_river.client import ( + RiverUnknownError, + translate_unknown_error, + RiverUnknownValue, + translate_unknown_value, +) + +import replit_river as river + + +Anyof_Mixed_MethodInputNumber_Or_String = float | str + + +def encode_Anyof_Mixed_MethodInputNumber_Or_String( + x: "Anyof_Mixed_MethodInputNumber_Or_String", +) -> Any: + return x + + +def encode_Anyof_Mixed_MethodInputRun_CommandAnyOf_0( + x: "Anyof_Mixed_MethodInputRun_CommandAnyOf_0", +) -> Any: + return { + k: v + for (k, v) in ( + { + "args": x.get("args"), + "env": x.get("env"), + } + ).items() + if v is not None + } + + +class Anyof_Mixed_MethodInputRun_CommandAnyOf_0(TypedDict): + args: list[str] + env: NotRequired[dict[str, str] | None] + + +Anyof_Mixed_MethodInputRun_Command = ( + Anyof_Mixed_MethodInputRun_CommandAnyOf_0 | str | list[str] +) + + +def encode_Anyof_Mixed_MethodInputRun_Command( + x: "Anyof_Mixed_MethodInputRun_Command", +) -> Any: + return ( + encode_Anyof_Mixed_MethodInputRun_CommandAnyOf_0(x) + if isinstance(x, dict) + else x + if isinstance(x, str) + else list(x) + ) + + +Anyof_Mixed_MethodInputValue_Or_Null = str | None + + +def encode_Anyof_Mixed_MethodInputValue_Or_Null( + x: "Anyof_Mixed_MethodInputValue_Or_Null", +) -> Any: + return x + + +def encode_Anyof_Mixed_MethodInput( + x: "Anyof_Mixed_MethodInput", +) -> Any: + return { + k: v + for (k, v) in ( + { + "number_or_string": encode_Anyof_Mixed_MethodInputNumber_Or_String( + x["number_or_string"] + ) + if "number_or_string" in x and x["number_or_string"] + else None, + "run_command": encode_Anyof_Mixed_MethodInputRun_Command( + x["run_command"] + ), + "value_or_null": encode_Anyof_Mixed_MethodInputValue_Or_Null( + x["value_or_null"] + ) + if "value_or_null" in x and x["value_or_null"] + else None, + } + ).items() + if v is not None + } + + +class Anyof_Mixed_MethodInput(TypedDict): + number_or_string: NotRequired[Anyof_Mixed_MethodInputNumber_Or_String | None] + run_command: Anyof_Mixed_MethodInputRun_Command + value_or_null: NotRequired[Anyof_Mixed_MethodInputValue_Or_Null | None] + + +class Anyof_Mixed_MethodOutput(BaseModel): + success: bool + + +Anyof_Mixed_MethodOutputTypeAdapter: TypeAdapter[Anyof_Mixed_MethodOutput] = ( + TypeAdapter(Anyof_Mixed_MethodOutput) +) diff --git a/tests/v1/codegen/snapshot/test_anyof_mixed.py b/tests/v1/codegen/snapshot/test_anyof_mixed.py new file mode 100644 index 00000000..cd29787f --- /dev/null +++ b/tests/v1/codegen/snapshot/test_anyof_mixed.py @@ -0,0 +1,20 @@ +from pytest_snapshot.plugin import Snapshot + +from tests.fixtures.codegen_snapshot_fixtures import validate_codegen + + +async def test_anyof_mixed_types(snapshot: Snapshot) -> None: + """Test codegen for anyOf unions with mixed types (object, string, array). + + This tests the fix for the bug where non-discriminated anyOf unions + with mixed types like [object, string, array] would generate malformed + Python code with broken ternary expressions. + """ + validate_codegen( + snapshot=snapshot, + snapshot_dir="tests/v1/codegen/snapshot/snapshots", + read_schema=lambda: open("tests/v1/codegen/types/anyof_mixed_schema.json"), + target_path="test_anyof_mixed_types", + client_name="AnyOfMixedClient", + protocol_version="v1.1", + ) diff --git a/tests/v1/codegen/types/anyof_mixed_schema.json b/tests/v1/codegen/types/anyof_mixed_schema.json new file mode 100644 index 00000000..b6a17729 --- /dev/null +++ b/tests/v1/codegen/types/anyof_mixed_schema.json @@ -0,0 +1,85 @@ +{ + "services": { + "test_service": { + "procedures": { + "anyof_mixed_method": { + "input": { + "type": "object", + "properties": { + "run_command": { + "description": "Command can be object with args, string, or array of strings", + "anyOf": [ + { + "type": "object", + "properties": { + "args": { + "type": "array", + "items": { + "type": "string" + } + }, + "env": { + "type": "object", + "patternProperties": { + "^(.*)$": { + "type": "string" + } + } + } + }, + "required": ["args"] + }, + { + "type": "string" + }, + { + "type": "array", + "items": { + "type": "string" + } + } + ] + }, + "value_or_null": { + "description": "Value can be string or null", + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ] + }, + "number_or_string": { + "description": "Can be number or string", + "anyOf": [ + { + "type": "number" + }, + { + "type": "string" + } + ] + } + }, + "required": ["run_command"] + }, + "output": { + "type": "object", + "properties": { + "success": { + "type": "boolean" + } + }, + "required": ["success"] + }, + "errors": { + "not": {} + }, + "type": "rpc" + } + } + } + } +} From bb92425954013b5e8b5c5ada626f71d539ad7968 Mon Sep 17 00:00:00 2001 From: Dawei Date: Wed, 4 Feb 2026 10:48:45 -0800 Subject: [PATCH 2/2] CI --- src/replit_river/codegen/client.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/replit_river/codegen/client.py b/src/replit_river/codegen/client.py index 8c53d340..e1a8d6cd 100644 --- a/src/replit_river/codegen/client.py +++ b/src/replit_river/codegen/client.py @@ -452,7 +452,7 @@ def {_field_name}( encoder_parts.append((None, "x")) # Build the ternary chain from encoder_parts - typeddict_encoder: list[str] = [] + typeddict_encoder = list[str]() for i, (type_check, encoder_expr) in enumerate(encoder_parts): is_last = i == len(encoder_parts) - 1 if is_last or type_check is None: