Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 53 additions & 23 deletions src/replit_river/codegen/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,7 +389,8 @@ def {_field_name}(
type = original_type
any_of: list[TypeExpression] = []

typeddict_encoder = []
# Collect (type_check, encoder_expr) pairs for building ternary chain
encoder_parts: list[tuple[str | None, str]] = []
for i, t in enumerate(type.anyOf):
type_name, _, contents, _ = encode_type(
t,
Expand All @@ -403,34 +404,63 @@ def {_field_name}(
chunks.extend(contents)
if isinstance(t, RiverConcreteType):
if t.type == "string":
typeddict_encoder.extend(["x", " if isinstance(x, str) else "])
else:
# TODO(dstewart): This structure changed since we were incorrectly
# leaking ListTypeExprs into codegen. This generated
# code is probably wrong.
encoder_parts.append(("isinstance(x, str)", "x"))
elif t.type == "array":
match type_name:
case ListTypeExpr(inner_type_name):
typeddict_encoder.append(
f"encode_{render_literal_type(inner_type_name)}(x)"
# Primitives don't need encoding
inner_type_str = render_literal_type(inner_type_name)
if inner_type_str in ("str", "int", "float", "bool", "Any"):
encoder_parts.append(("isinstance(x, list)", "list(x)"))
else:
encoder_parts.append(
(
"isinstance(x, list)",
f"[encode_{inner_type_str}(y) for y in x]",
)
)
case _:
encoder_parts.append(("isinstance(x, list)", "list(x)"))
elif t.type == "object":
match type_name:
case TypeName(value):
encoder_parts.append(
("isinstance(x, dict)", f"encode_{value}(x)")
)
case _:
encoder_parts.append(("isinstance(x, dict)", "dict(x)"))
elif t.type in ("number", "integer"):
match type_name:
case LiteralTypeExpr(const):
typeddict_encoder.append(repr(const))
encoder_parts.append((f"x == {repr(const)}", repr(const)))
case _:
encoder_parts.append(("isinstance(x, (int, float))", "x"))
elif t.type == "boolean":
encoder_parts.append(("isinstance(x, bool)", "x"))
elif t.type == "null" or t.type == "undefined":
encoder_parts.append(("x is None", "None"))
else:
# Fallback for other types
match type_name:
case TypeName(value):
typeddict_encoder.append(f"encode_{value}(x)")
encoder_parts.append((None, f"encode_{value}(x)"))
case LiteralTypeExpr(const):
encoder_parts.append((None, repr(const)))
case NoneTypeExpr():
typeddict_encoder.append("None")
case other:
_o2: (
DictTypeExpr
| OpenUnionTypeExpr
| UnionTypeExpr
| LiteralType
) = other
raise ValueError(
f"What does it mean to have {
render_type_expr(_o2)
} here?"
)
encoder_parts.append((None, "None"))
case _:
encoder_parts.append((None, "x"))

# Build the ternary chain from encoder_parts
typeddict_encoder = list[str]()
for i, (type_check, encoder_expr) in enumerate(encoder_parts):
is_last = i == len(encoder_parts) - 1
if is_last or type_check is None:
# Last item or no type check - just the expression
typeddict_encoder.append(encoder_expr)
else:
# Add expression with type check
typeddict_encoder.append(f"{encoder_expr} if {type_check} else")
if permit_unknown_members:
union = _make_open_union_type_expr(any_of)
else:
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Code generated by river.codegen. DO NOT EDIT.
from pydantic import BaseModel
from typing import Literal

import replit_river as river


from .test_service import Test_ServiceService


class AnyOfMixedClient:
def __init__(self, client: river.Client[Literal[None]]):
self.test_service = Test_ServiceService(client)
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
# Code generated by river.codegen. DO NOT EDIT.
from collections.abc import AsyncIterable, AsyncIterator
from typing import Any
import datetime

from pydantic import TypeAdapter

from replit_river.error_schema import RiverError, RiverErrorTypeAdapter
import replit_river as river


from .anyof_mixed_method import (
Anyof_Mixed_MethodInput,
Anyof_Mixed_MethodOutput,
Anyof_Mixed_MethodOutputTypeAdapter,
encode_Anyof_Mixed_MethodInput,
encode_Anyof_Mixed_MethodInputNumber_Or_String,
encode_Anyof_Mixed_MethodInputRun_Command,
encode_Anyof_Mixed_MethodInputValue_Or_Null,
)


class Test_ServiceService:
def __init__(self, client: river.Client[Any]):
self.client = client

async def anyof_mixed_method(
self,
input: Anyof_Mixed_MethodInput,
timeout: datetime.timedelta,
) -> Anyof_Mixed_MethodOutput:
return await self.client.send_rpc(
"test_service",
"anyof_mixed_method",
input,
encode_Anyof_Mixed_MethodInput,
lambda x: Anyof_Mixed_MethodOutputTypeAdapter.validate_python(
x # type: ignore[arg-type]
),
lambda x: RiverErrorTypeAdapter.validate_python(
x # type: ignore[arg-type]
),
timeout,
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
# Code generated by river.codegen. DO NOT EDIT.
from collections.abc import AsyncIterable, AsyncIterator
import datetime
from typing import (
Any,
Literal,
Mapping,
NotRequired,
TypedDict,
)
from typing_extensions import Annotated

from pydantic import BaseModel, Field, TypeAdapter, WrapValidator
from replit_river.error_schema import RiverError
from replit_river.client import (
RiverUnknownError,
translate_unknown_error,
RiverUnknownValue,
translate_unknown_value,
)

import replit_river as river


Anyof_Mixed_MethodInputNumber_Or_String = float | str


def encode_Anyof_Mixed_MethodInputNumber_Or_String(
x: "Anyof_Mixed_MethodInputNumber_Or_String",
) -> Any:
return x


def encode_Anyof_Mixed_MethodInputRun_CommandAnyOf_0(
x: "Anyof_Mixed_MethodInputRun_CommandAnyOf_0",
) -> Any:
return {
k: v
for (k, v) in (
{
"args": x.get("args"),
"env": x.get("env"),
}
).items()
if v is not None
}


class Anyof_Mixed_MethodInputRun_CommandAnyOf_0(TypedDict):
args: list[str]
env: NotRequired[dict[str, str] | None]


Anyof_Mixed_MethodInputRun_Command = (
Anyof_Mixed_MethodInputRun_CommandAnyOf_0 | str | list[str]
)


def encode_Anyof_Mixed_MethodInputRun_Command(
x: "Anyof_Mixed_MethodInputRun_Command",
) -> Any:
return (
encode_Anyof_Mixed_MethodInputRun_CommandAnyOf_0(x)
if isinstance(x, dict)
else x
if isinstance(x, str)
else list(x)
)


Anyof_Mixed_MethodInputValue_Or_Null = str | None


def encode_Anyof_Mixed_MethodInputValue_Or_Null(
x: "Anyof_Mixed_MethodInputValue_Or_Null",
) -> Any:
return x


def encode_Anyof_Mixed_MethodInput(
x: "Anyof_Mixed_MethodInput",
) -> Any:
return {
k: v
for (k, v) in (
{
"number_or_string": encode_Anyof_Mixed_MethodInputNumber_Or_String(
x["number_or_string"]
)
if "number_or_string" in x and x["number_or_string"]
else None,
"run_command": encode_Anyof_Mixed_MethodInputRun_Command(
x["run_command"]
),
"value_or_null": encode_Anyof_Mixed_MethodInputValue_Or_Null(
x["value_or_null"]
)
if "value_or_null" in x and x["value_or_null"]
else None,
}
).items()
if v is not None
}


class Anyof_Mixed_MethodInput(TypedDict):
number_or_string: NotRequired[Anyof_Mixed_MethodInputNumber_Or_String | None]
run_command: Anyof_Mixed_MethodInputRun_Command
value_or_null: NotRequired[Anyof_Mixed_MethodInputValue_Or_Null | None]


class Anyof_Mixed_MethodOutput(BaseModel):
success: bool


Anyof_Mixed_MethodOutputTypeAdapter: TypeAdapter[Anyof_Mixed_MethodOutput] = (
TypeAdapter(Anyof_Mixed_MethodOutput)
)
20 changes: 20 additions & 0 deletions tests/v1/codegen/snapshot/test_anyof_mixed.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
from pytest_snapshot.plugin import Snapshot

from tests.fixtures.codegen_snapshot_fixtures import validate_codegen


async def test_anyof_mixed_types(snapshot: Snapshot) -> None:
"""Test codegen for anyOf unions with mixed types (object, string, array).

This tests the fix for the bug where non-discriminated anyOf unions
with mixed types like [object, string, array] would generate malformed
Python code with broken ternary expressions.
"""
validate_codegen(
snapshot=snapshot,
snapshot_dir="tests/v1/codegen/snapshot/snapshots",
read_schema=lambda: open("tests/v1/codegen/types/anyof_mixed_schema.json"),
target_path="test_anyof_mixed_types",
client_name="AnyOfMixedClient",
protocol_version="v1.1",
)
85 changes: 85 additions & 0 deletions tests/v1/codegen/types/anyof_mixed_schema.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
{
"services": {
"test_service": {
"procedures": {
"anyof_mixed_method": {
"input": {
"type": "object",
"properties": {
"run_command": {
"description": "Command can be object with args, string, or array of strings",
"anyOf": [
{
"type": "object",
"properties": {
"args": {
"type": "array",
"items": {
"type": "string"
}
},
"env": {
"type": "object",
"patternProperties": {
"^(.*)$": {
"type": "string"
}
}
}
},
"required": ["args"]
},
{
"type": "string"
},
{
"type": "array",
"items": {
"type": "string"
}
}
]
},
"value_or_null": {
"description": "Value can be string or null",
"anyOf": [
{
"type": "string"
},
{
"type": "null"
}
]
},
"number_or_string": {
"description": "Can be number or string",
"anyOf": [
{
"type": "number"
},
{
"type": "string"
}
]
}
},
"required": ["run_command"]
},
"output": {
"type": "object",
"properties": {
"success": {
"type": "boolean"
}
},
"required": ["success"]
},
"errors": {
"not": {}
},
"type": "rpc"
}
}
}
}
}
Loading