From 88827b60015c3266541c3d7ccefbec9baf062201 Mon Sep 17 00:00:00 2001 From: Andy Staples Date: Fri, 15 May 2026 10:49:46 -0600 Subject: [PATCH] Add type-aware json serialization/object encoding --- CHANGELOG.md | 4 + .../decorators/durable_app.py | 20 +- .../models/DurableEntityContext.py | 39 +- .../models/DurableOrchestrationClient.py | 4 +- .../models/DurableOrchestrationContext.py | 79 ++- .../models/OrchestratorState.py | 4 +- azure/durable_functions/models/Task.py | 7 + .../models/TaskOrchestrationExecutor.py | 18 +- .../models/actions/CallActivityAction.py | 5 +- .../actions/CallActivityWithRetryAction.py | 5 +- .../models/actions/CallEntityAction.py | 5 +- .../actions/CallSubOrchestratorAction.py | 5 +- .../CallSubOrchestratorWithRetryAction.py | 5 +- .../models/actions/ContinueAsNewAction.py | 5 +- .../models/actions/SignalEntityAction.py | 5 +- .../models/entities/EntityState.py | 4 +- .../models/entities/OperationResult.py | 5 +- .../models/utils/df_serialization.py | 226 ++++++ .../models/utils/type_discovery.py | 83 +++ azure/durable_functions/orchestrator.py | 7 +- tests/models/test_Decorators.py | 16 + .../test_DurableOrchestrationContext.py | 71 ++ tests/orchestrator/test_expected_type.py | 164 +++++ tests/orchestrator/test_external_event.py | 36 +- tests/utils/test_df_serialization.py | 656 ++++++++++++++++++ tests/utils/test_type_discovery.py | 81 +++ 26 files changed, 1497 insertions(+), 62 deletions(-) create mode 100644 azure/durable_functions/models/utils/df_serialization.py create mode 100644 azure/durable_functions/models/utils/type_discovery.py create mode 100644 tests/orchestrator/test_expected_type.py create mode 100644 tests/utils/test_df_serialization.py create mode 100644 tests/utils/test_type_discovery.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 3216b2c8..d9af3788 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,10 @@ All notable changes to this project will be documented in this file. ### Added - Client operation correlation logging: `FunctionInvocationId` is now propagated via HTTP headers to the host for client operations, enabling correlation with host logs. +- Centralized JSON serialization module (`azure.durable_functions.models.utils.df_serialization`): all serialization/deserialization of user payloads (orchestrator inputs/outputs, activity arguments and results, sub-orchestrator payloads, entity inputs/outputs, and client inputs) now flows through `df_dumps` / `df_loads`, replacing scattered `json.dumps(…, default=_serialize_custom_object)` / `json.loads(…, object_hook=_deserialize_custom_object)` calls. The wire format is **unchanged** — builtins serialize to plain JSON and custom objects continue to use the `{"__class__", "__module__", "__data__"}` convention. +- Type-hint-driven validation via `df_loads(s, expected_type=...)`: when the V2 programming model provides a return-type annotation for an activity or sub-orchestrator, `df_loads` validates the deserialized payload against that type **before** the legacy `object_hook` fires, catching class/module mismatches early. +- **Strict typing mode** (opt-in via `AZURE_FUNCTIONS_DURABLE_STRICT_TYPING=1`): when enabled, `import_module` is never called on either encode or decode. On encode, `df_dumps` wraps only the top-level custom object — `to_json()` must return plain-JSON-serializable data (nested custom objects must be serialized explicitly). On decode, `df_loads` calls `expected_type.from_json(raw["__data__"])` directly; `df_loads` without `expected_type` raises `TypeError` for custom-object payloads. A `TypeError` is also raised on type mismatch. +- Return-type discovery for V2 decorated activities/sub-orchestrators (`azure.durable_functions.models.utils.type_discovery`): resolves the concrete return annotation from the user's registered function, used to supply `expected_type` to `df_loads`. ## 1.0.0b6 diff --git a/azure/durable_functions/decorators/durable_app.py b/azure/durable_functions/decorators/durable_app.py index 2b74a045..b885e4a5 100644 --- a/azure/durable_functions/decorators/durable_app.py +++ b/azure/durable_functions/decorators/durable_app.py @@ -76,7 +76,7 @@ def decorator(entity_func): return decorator - def _configure_orchestrator_callable(self, wrap) -> Callable: + def _configure_orchestrator_callable(self, wrap, input_type=None) -> Callable: """Obtain decorator to construct an Orchestrator class from a user-defined Function. In the old programming model, this decorator's logic was unavoidable boilerplate @@ -86,6 +86,9 @@ def _configure_orchestrator_callable(self, wrap) -> Callable: ---------- wrap: Callable The next decorator to be applied. + input_type: Optional[type] + The expected type for orchestration input, forwarded from + the orchestration_trigger decorator. Returns ------- @@ -99,12 +102,16 @@ def decorator(orchestrator_func): # invoke next decorator, with the Orchestrator as input handle.__name__ = orchestrator_func.__name__ + # Stash the decorator-declared input type so the runtime + # can feed it to df_loads via context.get_input(). + handle._df_input_type = input_type return wrap(handle) return decorator def orchestration_trigger(self, context_name: str, - orchestration: Optional[str] = None): + orchestration: Optional[str] = None, + input_type: Optional[type] = None): """Register an Orchestrator Function. Parameters @@ -114,8 +121,13 @@ def orchestration_trigger(self, context_name: str, orchestration: Optional[str] Name of Orchestrator Function. The value is None by default, in which case the name of the method is used. + input_type: Optional[type] + The expected type for the orchestration input. When set, + ``context.get_input()`` will use this type to decode the + input payload without consulting ``sys.modules``. A + call-site ``expected_type`` argument on ``get_input`` + takes precedence over this value. """ - @self._configure_orchestrator_callable @self._configure_function_builder def wrap(fb): @@ -127,7 +139,7 @@ def decorator(): return decorator() - return wrap + return self._configure_orchestrator_callable(wrap, input_type=input_type) def activity_trigger(self, input_name: str, activity: Optional[str] = None): diff --git a/azure/durable_functions/models/DurableEntityContext.py b/azure/durable_functions/models/DurableEntityContext.py index 37cc980c..43c8f322 100644 --- a/azure/durable_functions/models/DurableEntityContext.py +++ b/azure/durable_functions/models/DurableEntityContext.py @@ -1,5 +1,5 @@ from typing import Optional, Any, Dict, Tuple, List, Callable -from azure.functions._durable_functions import _deserialize_custom_object +from .utils.df_serialization import df_loads import json @@ -36,6 +36,7 @@ def __init__(self, self._is_newly_constructed: bool = False self._state: Any = state + self._state_is_raw: bool = False self._input: Any = None self._operation: Optional[str] = None self._result: Any = None @@ -109,10 +110,17 @@ def from_json(cls, json_str: str) -> Tuple['DurableEntityContext', List[Dict[str serialized_state = json_dict["state"] if serialized_state is not None: - json_dict["state"] = from_json_util(serialized_state) + # Keep the raw serialized form so get_state() can deserialize + # lazily with an expected_type supplied by the user. + json_dict["state"] = serialized_state + else: + json_dict["state"] = None batch = json_dict.pop("batch") - return cls(**json_dict), batch + ctx = cls(**json_dict) + if serialized_state is not None: + ctx._state_is_raw = True + return ctx, batch def set_state(self, state: Any) -> None: """Set the state of the entity. @@ -127,19 +135,26 @@ def set_state(self, state: Any) -> None: # should only serialize the state at the end of the batch self._state = state - def get_state(self, initializer: Optional[Callable[[], Any]] = None) -> Any: + def get_state(self, initializer: Optional[Callable[[], Any]] = None, + expected_type: Optional[type] = None) -> Any: """Get the current state of this entity. Parameters ---------- initializer: Optional[Callable[[], Any]] A 0-argument function to provide an initial state. Defaults to None. + expected_type: Optional[type] + The type to decode the state as. When set, the codec uses + this type directly without consulting ``sys.modules``. Returns ------- Any The current state of the entity """ + if self._state is not None and self._state_is_raw: + self._state = from_json_util(self._state, expected_type=expected_type) + self._state_is_raw = False state = self._state if state is not None: return state @@ -149,9 +164,15 @@ def get_state(self, initializer: Optional[Callable[[], Any]] = None) -> Any: state = initializer() return state - def get_input(self) -> Any: + def get_input(self, expected_type: Optional[type] = None) -> Any: """Get the input for this operation. + Parameters + ---------- + expected_type: Optional[type] + The type to decode the input as. When set, the codec uses + this type directly without consulting ``sys.modules``. + Returns ------- Any @@ -160,7 +181,7 @@ def get_input(self) -> Any: input_ = None req_input = self._input req_input = json.loads(req_input) - input_ = None if req_input is None else from_json_util(req_input) + input_ = None if req_input is None else df_loads(req_input, expected_type=expected_type) return input_ def set_result(self, result: Any) -> None: @@ -180,7 +201,7 @@ def destruct_on_exit(self) -> None: self._state = None -def from_json_util(json_str: str) -> Any: +def from_json_util(json_str: str, expected_type: Optional[type] = None) -> Any: """Load an arbitrary datatype from its JSON representation. The Out-of-proc SDK has a special JSON encoding strategy @@ -192,10 +213,12 @@ def from_json_util(json_str: str) -> Any: ---------- json_str: str A JSON-formatted string, from durable-extension + expected_type: Optional[type] + The type to decode the value as. Returns ------- Any: The original datatype that was serialized """ - return json.loads(json_str, object_hook=_deserialize_custom_object) + return df_loads(json_str, expected_type=expected_type) diff --git a/azure/durable_functions/models/DurableOrchestrationClient.py b/azure/durable_functions/models/DurableOrchestrationClient.py index 009001e5..b6acc389 100644 --- a/azure/durable_functions/models/DurableOrchestrationClient.py +++ b/azure/durable_functions/models/DurableOrchestrationClient.py @@ -16,7 +16,7 @@ from ..models.DurableOrchestrationBindings import DurableOrchestrationBindings from .utils.http_utils import get_async_request, post_async_request, delete_async_request from .utils.entity_utils import EntityId -from azure.functions._durable_functions import _serialize_custom_object +from .utils.df_serialization import df_dumps class DurableOrchestrationClient: @@ -633,7 +633,7 @@ def _get_json_input(client_input: object) -> Optional[str]: If the JSON serialization failed, see `serialize_custom_object` """ if client_input is not None: - return json.dumps(client_input, default=_serialize_custom_object) + return df_dumps(client_input) return None @staticmethod diff --git a/azure/durable_functions/models/DurableOrchestrationContext.py b/azure/durable_functions/models/DurableOrchestrationContext.py index 531307c3..d2803360 100644 --- a/azure/durable_functions/models/DurableOrchestrationContext.py +++ b/azure/durable_functions/models/DurableOrchestrationContext.py @@ -34,7 +34,11 @@ from .actions import Action from ..models.TokenSource import TokenSource from .utils.entity_utils import EntityId -from azure.functions._durable_functions import _deserialize_custom_object +from .utils.df_serialization import df_loads +from .utils.type_discovery import ( + activity_output_type, + sub_orchestrator_output_type, +) from azure.durable_functions.constants import DATETIME_STRING_FORMAT from azure.durable_functions.decorators.metadata import OrchestrationTrigger, ActivityTrigger from azure.functions.decorators.function_app import FunctionBuilder @@ -167,7 +171,8 @@ def _set_is_replaying(self, is_replaying: bool): """ self._is_replaying = is_replaying - def call_activity(self, name: Union[str, Callable], input_: Optional[Any] = None) -> TaskBase: + def call_activity(self, name: Union[str, Callable], input_: Optional[Any] = None, + expected_type: Optional[type] = None) -> TaskBase: """Schedule an activity for execution. Parameters @@ -177,6 +182,10 @@ def call_activity(self, name: Union[str, Callable], input_: Optional[Any] = None in the Python V2 programming model, the activity function itself. input_: Optional[Any] The JSON-serializable input to pass to the activity function. + expected_type: Optional[type] + The type to decode the activity result as. Takes precedence + over the type discovered from the activity's return + annotation. Returns ------- @@ -191,16 +200,21 @@ def call_activity(self, name: Union[str, Callable], input_: Optional[Any] = None "decorator. Otherwise, provide in the name of the activity as a string." raise ValueError(error_message) + # Discover the activity's return type from its annotation, if any, + # so the result can be decoded without consulting sys.modules. + resolved_type = expected_type or activity_output_type(name) if isinstance(name, FunctionBuilder): name = self._get_function_name(name, ActivityTrigger) action = CallActivityAction(name, input_) task = self._generate_task(action) + task._expected_output_type = resolved_type return task def call_activity_with_retry(self, name: Union[str, Callable], retry_options: RetryOptions, - input_: Optional[Any] = None) -> TaskBase: + input_: Optional[Any] = None, + expected_type: Optional[type] = None) -> TaskBase: """Schedule an activity for execution with retry options. Parameters @@ -212,6 +226,10 @@ def call_activity_with_retry(self, The retry options for the activity function. input_: Optional[Any] The JSON-serializable input to pass to the activity function. + expected_type: Optional[type] + The type to decode the activity result as. Takes precedence + over the type discovered from the activity's return + annotation. Returns ------- @@ -227,11 +245,13 @@ def call_activity_with_retry(self, "decorator. Otherwise, provide in the name of the activity as a string." raise ValueError(error_message) + resolved_type = expected_type or activity_output_type(name) if isinstance(name, FunctionBuilder): name = self._get_function_name(name, ActivityTrigger) action = CallActivityWithRetryAction(name, retry_options, input_) task = self._generate_task(action, retry_options) + task._expected_output_type = resolved_type return task def call_http(self, method: str, uri: str, content: Optional[str] = None, @@ -288,7 +308,8 @@ def call_http(self, method: str, uri: str, content: Optional[str] = None, def call_sub_orchestrator(self, name: Union[str, Callable], input_: Optional[Any] = None, instance_id: Optional[str] = None, - version: Optional[str] = None) -> TaskBase: + version: Optional[str] = None, + expected_type: Optional[type] = None) -> TaskBase: """Schedule sub-orchestration function named `name` for execution. Parameters @@ -302,6 +323,10 @@ def call_sub_orchestrator(self, version: Optional[str] The version to assign to the sub-orchestration instance. If not specified, the defaultVersion from host.json will be used. + expected_type: Optional[type] + The type to decode the sub-orchestrator result as. Takes + precedence over the type discovered from the + sub-orchestrator's return annotation. Returns ------- @@ -316,18 +341,21 @@ def call_sub_orchestrator(self, "decorator. Otherwise, provide in the name of the activity as a string." raise ValueError(error_message) + resolved_type = expected_type or sub_orchestrator_output_type(name) if isinstance(name, FunctionBuilder): name = self._get_function_name(name, OrchestrationTrigger) action = CallSubOrchestratorAction(name, input_, instance_id, version) task = self._generate_task(action) + task._expected_output_type = resolved_type return task def call_sub_orchestrator_with_retry(self, name: Union[str, Callable], retry_options: RetryOptions, input_: Optional[Any] = None, instance_id: Optional[str] = None, - version: Optional[str] = None) -> TaskBase: + version: Optional[str] = None, + expected_type: Optional[type] = None) -> TaskBase: """Schedule sub-orchestration function named `name` for execution, with retry-options. Parameters @@ -343,6 +371,10 @@ def call_sub_orchestrator_with_retry(self, version: Optional[str] The version to assign to the sub-orchestration instance. If not specified, the defaultVersion from host.json will be used. + expected_type: Optional[type] + The type to decode the sub-orchestrator result as. Takes + precedence over the type discovered from the + sub-orchestrator's return annotation. Returns ------- @@ -357,18 +389,31 @@ def call_sub_orchestrator_with_retry(self, "decorator. Otherwise, provide in the name of the activity as a string." raise ValueError(error_message) + resolved_type = expected_type or sub_orchestrator_output_type(name) if isinstance(name, FunctionBuilder): name = self._get_function_name(name, OrchestrationTrigger) action = CallSubOrchestratorWithRetryAction( name, retry_options, input_, instance_id, version) task = self._generate_task(action, retry_options) + task._expected_output_type = resolved_type return task - def get_input(self) -> Optional[Any]: - """Get the orchestration input.""" - return None if self._input is None else json.loads(self._input, - object_hook=_deserialize_custom_object) + def get_input(self, expected_type: Optional[type] = None) -> Optional[Any]: + """Get the orchestration input. + + Parameters + ---------- + expected_type : Optional[type] + The type to decode the input as. Takes precedence over + the ``input_type`` declared on the orchestration trigger + decorator. When neither is set, decoding falls back to + module-only class resolution. + """ + if self._input is None: + return None + resolved = expected_type or getattr(self, "_input_expected_type", None) + return df_loads(self._input, expected_type=resolved) def new_uuid(self) -> str: """Create a new UUID that is safe for replay within an orchestration or operation. @@ -535,7 +580,8 @@ def function_context(self) -> FunctionContext: return self._function_context def call_entity(self, entityId: EntityId, - operationName: str, operationInput: Optional[Any] = None): + operationName: str, operationInput: Optional[Any] = None, + expected_type: Optional[type] = None): """Get the result of Durable Entity operation given some input. Parameters @@ -546,6 +592,10 @@ def call_entity(self, entityId: EntityId, The operation to execute operationInput: Optional[Any] The input for tne operation, defaults to None. + expected_type: Optional[type] + The type to decode the entity response as. When set, the + codec uses this type directly without consulting + ``sys.modules``. Returns ------- @@ -554,6 +604,7 @@ def call_entity(self, entityId: EntityId, """ action = CallEntityAction(entityId, operationName, operationInput) task = self._generate_task(action) + task._expected_output_type = expected_type return task def _record_fire_and_forget_action(self, action: Action): @@ -627,13 +678,18 @@ def create_timer(self, fire_at: datetime.datetime) -> TaskBase: task = self._generate_task(action, task_constructor=TimerTask) return task - def wait_for_external_event(self, name: str) -> TaskBase: + def wait_for_external_event(self, name: str, + expected_type: Optional[type] = None) -> TaskBase: """Wait asynchronously for an event to be raised with the name `name`. Parameters ---------- name : str The event name of the event that the task is waiting for. + expected_type : Optional[type] + The type to decode the event payload as. When set, the + codec uses this type directly without consulting + ``sys.modules``. Returns ------- @@ -642,6 +698,7 @@ def wait_for_external_event(self, name: str) -> TaskBase: """ action = WaitForExternalEventAction(name) task = self._generate_task(action, id_=name) + task._expected_output_type = expected_type return task def continue_as_new(self, input_: Any): diff --git a/azure/durable_functions/models/OrchestratorState.py b/azure/durable_functions/models/OrchestratorState.py index 7b426292..36fa2b28 100644 --- a/azure/durable_functions/models/OrchestratorState.py +++ b/azure/durable_functions/models/OrchestratorState.py @@ -4,8 +4,8 @@ from azure.durable_functions.models.ReplaySchema import ReplaySchema from .utils.json_utils import add_attrib +from .utils.df_serialization import _get_serialize_default from azure.durable_functions.models.actions.Action import Action -from azure.functions._durable_functions import _serialize_custom_object class OrchestratorState: @@ -114,4 +114,4 @@ def to_json_string(self) -> str: The instance of the object in json string format """ json_dict = self.to_json() - return json.dumps(json_dict, default=_serialize_custom_object) + return json.dumps(json_dict, default=_get_serialize_default()) diff --git a/azure/durable_functions/models/Task.py b/azure/durable_functions/models/Task.py index 7aa5b256..e5667001 100644 --- a/azure/durable_functions/models/Task.py +++ b/azure/durable_functions/models/Task.py @@ -58,6 +58,13 @@ def __init__(self, id_: Union[int, str], actions: Union[List[Action], Action]): self.action_repr: Union[List[Action], Action] = actions self.is_played = False self._is_scheduled_flag = False + # The expected return type discovered from the user function's + # annotation, when the task was scheduled with a V2 FunctionBuilder. + # Forwarded to ``df_loads`` so custom objects can be decoded without + # touching ``sys.modules``/``importlib``. ``None`` means "no type + # info available" -- the codec then falls back to module lookup + # and, ultimately, the legacy decoder with a warning. + self._expected_output_type: Optional[type] = None @property def _is_scheduled(self) -> bool: diff --git a/azure/durable_functions/models/TaskOrchestrationExecutor.py b/azure/durable_functions/models/TaskOrchestrationExecutor.py index efe7adbb..73fd63f4 100644 --- a/azure/durable_functions/models/TaskOrchestrationExecutor.py +++ b/azure/durable_functions/models/TaskOrchestrationExecutor.py @@ -9,7 +9,7 @@ from collections import namedtuple import json from ..models.entities.ResponseMessage import ResponseMessage -from azure.functions._durable_functions import _deserialize_custom_object +from .utils.df_serialization import df_loads class TaskOrchestrationExecutor: @@ -181,18 +181,21 @@ def parse_history_event(directive_result): raise ValueError("EventType is not found in task object") # We provide the ability to deserialize custom objects, because the output of this - # will be passed directly to the orchestrator as the output of some activity + # will be passed directly to the orchestrator as the output of some activity. + # The expected type (when discoverable from the activity / sub-orchestrator's + # return annotation) lets ``df_loads`` decode custom classes without consulting + # ``sys.modules`` / ``importlib``. + expected_type = getattr(task, "_expected_output_type", None) if (event_type == HistoryEventType.SUB_ORCHESTRATION_INSTANCE_COMPLETED and directive_result.Result is not None): - return json.loads(directive_result.Result, object_hook=_deserialize_custom_object) + return df_loads(directive_result.Result, expected_type=expected_type) if (event_type == HistoryEventType.TASK_COMPLETED and directive_result.Result is not None): - return json.loads(directive_result.Result, object_hook=_deserialize_custom_object) + return df_loads(directive_result.Result, expected_type=expected_type) if (event_type == HistoryEventType.EVENT_RAISED and directive_result.Input is not None): # TODO: Investigate why the payload is in "Input" instead of "Result" - response = json.loads(directive_result.Input, - object_hook=_deserialize_custom_object) + response = df_loads(directive_result.Input, expected_type=expected_type) return response return None @@ -217,7 +220,8 @@ def parse_history_event(directive_result): new_value = parse_history_event(event) if task._api_name == "CallEntityAction": event_payload = ResponseMessage.from_dict(new_value) - new_value = json.loads(event_payload.result) + entity_expected = getattr(task, "_expected_output_type", None) + new_value = df_loads(event_payload.result, expected_type=entity_expected) if event_payload.is_exception: new_value = Exception(new_value) diff --git a/azure/durable_functions/models/actions/CallActivityAction.py b/azure/durable_functions/models/actions/CallActivityAction.py index 2e5c4ade..ea3fe7c4 100644 --- a/azure/durable_functions/models/actions/CallActivityAction.py +++ b/azure/durable_functions/models/actions/CallActivityAction.py @@ -3,8 +3,7 @@ from .Action import Action from .ActionType import ActionType from ..utils.json_utils import add_attrib -from json import dumps -from azure.functions._durable_functions import _serialize_custom_object +from ..utils.df_serialization import df_dumps class CallActivityAction(Action): @@ -16,7 +15,7 @@ class CallActivityAction(Action): def __init__(self, function_name: str, input_=None): self.function_name: str = function_name # It appears that `.input_` needs to be JSON-serializable at this point - self.input_ = dumps(input_, default=_serialize_custom_object) + self.input_ = df_dumps(input_) if not self.function_name: raise ValueError("function_name cannot be empty") diff --git a/azure/durable_functions/models/actions/CallActivityWithRetryAction.py b/azure/durable_functions/models/actions/CallActivityWithRetryAction.py index a6b33288..e21cda55 100644 --- a/azure/durable_functions/models/actions/CallActivityWithRetryAction.py +++ b/azure/durable_functions/models/actions/CallActivityWithRetryAction.py @@ -1,11 +1,10 @@ -from json import dumps from typing import Dict, Union from .Action import Action from .ActionType import ActionType from ..RetryOptions import RetryOptions from ..utils.json_utils import add_attrib, add_json_attrib -from azure.functions._durable_functions import _serialize_custom_object +from ..utils.df_serialization import df_dumps class CallActivityWithRetryAction(Action): @@ -18,7 +17,7 @@ def __init__(self, function_name: str, retry_options: RetryOptions, input_=None): self.function_name: str = function_name self.retry_options: RetryOptions = retry_options - self.input_ = dumps(input_, default=_serialize_custom_object) + self.input_ = df_dumps(input_) if not self.function_name: raise ValueError("function_name cannot be empty") diff --git a/azure/durable_functions/models/actions/CallEntityAction.py b/azure/durable_functions/models/actions/CallEntityAction.py index 55baa4ef..894914a5 100644 --- a/azure/durable_functions/models/actions/CallEntityAction.py +++ b/azure/durable_functions/models/actions/CallEntityAction.py @@ -3,8 +3,7 @@ from .Action import Action from .ActionType import ActionType from ..utils.json_utils import add_attrib -from json import dumps -from azure.functions._durable_functions import _serialize_custom_object +from ..utils.df_serialization import df_dumps from ..utils.entity_utils import EntityId @@ -23,7 +22,7 @@ def __init__(self, entity_id: EntityId, operation: str, input_=None): self.instance_id: str = EntityId.get_scheduler_id(entity_id) self.operation: str = operation - self.input_: str = dumps(input_, default=_serialize_custom_object) + self.input_: str = df_dumps(input_) @property def action_type(self) -> int: diff --git a/azure/durable_functions/models/actions/CallSubOrchestratorAction.py b/azure/durable_functions/models/actions/CallSubOrchestratorAction.py index 03a22413..2925e459 100644 --- a/azure/durable_functions/models/actions/CallSubOrchestratorAction.py +++ b/azure/durable_functions/models/actions/CallSubOrchestratorAction.py @@ -3,8 +3,7 @@ from .Action import Action from .ActionType import ActionType from ..utils.json_utils import add_attrib -from json import dumps -from azure.functions._durable_functions import _serialize_custom_object +from ..utils.df_serialization import df_dumps class CallSubOrchestratorAction(Action): @@ -13,7 +12,7 @@ class CallSubOrchestratorAction(Action): def __init__(self, function_name: str, _input: Optional[Any] = None, instance_id: Optional[str] = None, version: Optional[str] = None): self.function_name: str = function_name - self._input: str = dumps(_input, default=_serialize_custom_object) + self._input: str = df_dumps(_input) self.instance_id: Optional[str] = instance_id self.version: Optional[str] = version diff --git a/azure/durable_functions/models/actions/CallSubOrchestratorWithRetryAction.py b/azure/durable_functions/models/actions/CallSubOrchestratorWithRetryAction.py index c72d7181..61c5bb73 100644 --- a/azure/durable_functions/models/actions/CallSubOrchestratorWithRetryAction.py +++ b/azure/durable_functions/models/actions/CallSubOrchestratorWithRetryAction.py @@ -3,9 +3,8 @@ from .Action import Action from .ActionType import ActionType from ..utils.json_utils import add_attrib, add_json_attrib -from json import dumps from ..RetryOptions import RetryOptions -from azure.functions._durable_functions import _serialize_custom_object +from ..utils.df_serialization import df_dumps class CallSubOrchestratorWithRetryAction(Action): @@ -15,7 +14,7 @@ def __init__(self, function_name: str, retry_options: RetryOptions, _input: Optional[Any] = None, instance_id: Optional[str] = None, version: Optional[str] = None): self.function_name: str = function_name - self._input: str = dumps(_input, default=_serialize_custom_object) + self._input: str = df_dumps(_input) self.retry_options: RetryOptions = retry_options self.instance_id: Optional[str] = instance_id self.version: Optional[str] = version diff --git a/azure/durable_functions/models/actions/ContinueAsNewAction.py b/azure/durable_functions/models/actions/ContinueAsNewAction.py index 7af0508b..4573566c 100644 --- a/azure/durable_functions/models/actions/ContinueAsNewAction.py +++ b/azure/durable_functions/models/actions/ContinueAsNewAction.py @@ -3,8 +3,7 @@ from .Action import Action from .ActionType import ActionType from ..utils.json_utils import add_attrib -from json import dumps -from azure.functions._durable_functions import _serialize_custom_object +from ..utils.df_serialization import df_dumps class ContinueAsNewAction(Action): @@ -15,7 +14,7 @@ class ContinueAsNewAction(Action): """ def __init__(self, input_=None): - self.input_ = dumps(input_, default=_serialize_custom_object) + self.input_ = df_dumps(input_) @property def action_type(self) -> int: diff --git a/azure/durable_functions/models/actions/SignalEntityAction.py b/azure/durable_functions/models/actions/SignalEntityAction.py index d6e9be54..d7ace9a5 100644 --- a/azure/durable_functions/models/actions/SignalEntityAction.py +++ b/azure/durable_functions/models/actions/SignalEntityAction.py @@ -3,8 +3,7 @@ from .Action import Action from .ActionType import ActionType from ..utils.json_utils import add_attrib -from json import dumps -from azure.functions._durable_functions import _serialize_custom_object +from ..utils.df_serialization import df_dumps from ..utils.entity_utils import EntityId @@ -23,7 +22,7 @@ def __init__(self, entity_id: EntityId, operation: str, input_=None): self.instance_id: str = EntityId.get_scheduler_id(entity_id) self.operation: str = operation - self.input_: str = dumps(input_, default=_serialize_custom_object) + self.input_: str = df_dumps(input_) @property def action_type(self) -> int: diff --git a/azure/durable_functions/models/entities/EntityState.py b/azure/durable_functions/models/entities/EntityState.py index 13d22e7e..1fabf6d7 100644 --- a/azure/durable_functions/models/entities/EntityState.py +++ b/azure/durable_functions/models/entities/EntityState.py @@ -1,6 +1,6 @@ from typing import List, Optional, Dict, Any from .Signal import Signal -from azure.functions._durable_functions import _serialize_custom_object +from ..utils.df_serialization import df_dumps from .OperationResult import OperationResult import json @@ -56,7 +56,7 @@ def to_json(self) -> Dict[str, Any]: serialized_results = list(map(lambda x: x.to_json(), self.results)) json_dict["entityExists"] = self.entity_exists - json_dict["entityState"] = json.dumps(self.state, default=_serialize_custom_object) + json_dict["entityState"] = df_dumps(self.state) json_dict["results"] = serialized_results json_dict["signals"] = self.signals return json_dict diff --git a/azure/durable_functions/models/entities/OperationResult.py b/azure/durable_functions/models/entities/OperationResult.py index 05147f09..744dd285 100644 --- a/azure/durable_functions/models/entities/OperationResult.py +++ b/azure/durable_functions/models/entities/OperationResult.py @@ -1,6 +1,5 @@ from typing import Optional, Dict, Any -from azure.functions._durable_functions import _serialize_custom_object -import json +from ..utils.df_serialization import df_dumps class OperationResult: @@ -90,5 +89,5 @@ def to_json(self) -> Dict[str, Any]: to_json["isError"] = self.is_error to_json["duration"] = self.duration to_json["startTime"] = self.execution_start_time_ms - to_json["result"] = json.dumps(self.result, default=_serialize_custom_object) + to_json["result"] = df_dumps(self.result) return to_json diff --git a/azure/durable_functions/models/utils/df_serialization.py b/azure/durable_functions/models/utils/df_serialization.py new file mode 100644 index 00000000..31bae9b6 --- /dev/null +++ b/azure/durable_functions/models/utils/df_serialization.py @@ -0,0 +1,226 @@ +"""Centralized JSON serialization for Durable Functions payloads. + +This module wraps the legacy `json.dumps(value, default=_serialize_custom_object)` +/ `json.loads(s, object_hook=_deserialize_custom_object)` pipeline from +`azure.functions._durable_functions` behind `df_dumps` and `df_loads`. + +The wire format is **unchanged** -- builtins serialize to plain JSON and custom +objects use the `{"__class__": ..., "__module__": ..., "__data__": ...}` +convention that the Durable extension and downstream consumers already expect. + +`df_loads` adds an optional `expected_type` parameter that controls +type validation. Behavior depends on the typing mode: + +* **Loose mode** (default) -- the payload is inspected before + deserialization and a warning is logged on type mismatch, then the + legacy ``object_hook`` pipeline runs as usual. +* **Strict mode** -- ``import_module`` is never called on either side. + On encode, ``to_json`` is called on the top-level object only and + the result must be plain-JSON-serializable (nested custom objects + are **not** auto-encoded -- ``to_json`` must handle them). On + decode, ``expected_type.from_json`` is invoked directly with plain + JSON data. A ``TypeError`` is raised on type mismatch or if + ``expected_type`` is not provided for a custom-object payload. + Opt in by setting ``AZURE_FUNCTIONS_DURABLE_STRICT_TYPING`` to a + truthy value (``1``, ``true``, ``yes``).""" + +from __future__ import annotations + +import json +import logging +import os +from typing import Any, Optional + +from azure.functions._durable_functions import ( + _deserialize_custom_object, + _serialize_custom_object, +) + +logger = logging.getLogger(__name__) + +_STRICT_ENV_VAR = "AZURE_FUNCTIONS_DURABLE_STRICT_TYPING" +_TRUTHY = frozenset({"1", "true", "yes"}) + + +def _is_strict_mode() -> bool: + return os.environ.get(_STRICT_ENV_VAR, "").strip().lower() in _TRUTHY + + +# --------------------------------------------------------------------------- +# Public API +# --------------------------------------------------------------------------- + +def df_dumps(value: Any) -> str: + """Serialize *value* to a JSON string. + + In **loose mode** (default), custom objects are encoded recursively + via the legacy ``default=_serialize_custom_object`` handler — any + nested custom object is automatically wrapped in the + ``{"__class__", "__module__", "__data__"}`` envelope. + + In **strict mode**, the top-level custom object (if it has + ``to_json``) is wrapped in the legacy envelope, but the + ``__data__`` payload is serialized as **plain JSON** — no + ``default=`` hook fires. This means ``to_json()`` must return a + value that is natively JSON-serializable (dicts, lists, strings, + numbers, bools, None). A ``TypeError`` is raised at encode time + if any nested value is not serializable. + """ + if _is_strict_mode(): + if hasattr(value, "to_json"): + envelope = _serialize_custom_object(value) + return json.dumps(envelope) + # Primitive / plain-JSON value — serialize without default= + # so stray custom objects are caught immediately. + return json.dumps(value) + return json.dumps(value, default=_serialize_custom_object) + + +def df_loads(s: str, expected_type: Optional[type] = None) -> Any: + """Deserialize a JSON string, optionally validating the result type. + + Parameters + ---------- + s : str + The JSON-encoded payload. + expected_type : type, optional + When provided the raw JSON is parsed first (without triggering + ``import_module`` via the legacy ``object_hook``). If the + payload is a legacy custom-object dict its embedded class info + is validated against *expected_type* **before** any module is + imported. A matching *expected_type* is used to call + ``from_json`` directly, avoiding ``import_module`` entirely. + In loose mode a warning is emitted on mismatch; in strict mode + a ``TypeError`` is raised. + """ + if expected_type is not None: + return _loads_with_expected_type(s, expected_type) + + if _is_strict_mode(): + return _loads_strict_no_type(s) + + return json.loads(s, object_hook=_deserialize_custom_object) + + +def _loads_strict_no_type(s: str) -> Any: + """Strict-mode fallback when no *expected_type* is available. + + Parses without ``object_hook`` so ``import_module`` is never called. + If the top-level value is a legacy custom-object dict, raises + ``TypeError`` — the caller must supply an ``expected_type`` to + deserialize custom objects in strict mode. + """ + raw = json.loads(s) + if _is_legacy_custom_dict(raw): + raise TypeError( + "df_loads: strict mode requires expected_type to " + "deserialize custom-object payloads, but none was provided. " + f"Payload declares {raw['__module__']}.{raw['__class__']}." + ) + return raw + + +def _get_serialize_default(): + """Return the `default` callback for `json.dumps`. + + Use this in places that build their own `json.dumps` call (e.g. + `OrchestratorState.to_json_string`) rather than going through + `df_dumps`. + + In strict mode returns ``None`` — `OrchestratorState` fields are + already serialized via `df_dumps` so there should be no remaining + custom objects to encode. A stray custom object will raise + ``TypeError`` from ``json.dumps``, surfacing the problem early. + """ + if _is_strict_mode(): + return None + return _serialize_custom_object + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +_LEGACY_KEYS = frozenset({"__class__", "__module__", "__data__"}) + + +def _is_legacy_custom_dict(d: Any) -> bool: + """Return True if *d* is a dict with legacy custom-object markers.""" + return isinstance(d, dict) and _LEGACY_KEYS.issubset(d) + + +def _loads_with_expected_type(s: str, expected_type: type) -> Any: + """Parse *s* and validate against *expected_type*. + + The raw JSON is parsed **without** the legacy ``object_hook`` so we + can inspect the payload before ``import_module`` fires. + + * **Strict mode** -- for custom-object payloads, calls + ``expected_type.from_json`` directly (no ``import_module``). For + primitives, validates then returns the plain value. Raises + ``TypeError`` on mismatch. + * **Loose mode** -- logs a warning on mismatch, then falls through + to the normal ``json.loads(s, object_hook=...)`` legacy path. + """ + raw = json.loads(s) + strict = _is_strict_mode() + + if _is_legacy_custom_dict(raw): + class_name = raw["__class__"] + module_name = raw["__module__"] + type_matches = (class_name == expected_type.__name__ + and module_name == expected_type.__module__) + + if not type_matches: + msg = ( + f"df_loads: payload declares class " + f"{module_name}.{class_name} but expected " + f"{expected_type.__module__}.{expected_type.__name__}" + ) + if strict: + raise TypeError(msg) + logger.warning(msg) + + if strict: + # Bypass import_module entirely — call from_json directly. + if not _has_json_protocol(expected_type): + raise TypeError( + f"df_loads: expected_type " + f"{expected_type.__module__}.{expected_type.__name__} " + f"does not expose from_json" + ) + return expected_type.from_json(raw["__data__"]) + + # Loose mode — legacy deserialization. + return json.loads(s, object_hook=_deserialize_custom_object) + + # Primitive / plain-JSON payload — validate the Python type. + if not _is_compatible(raw, expected_type): + msg = ( + f"df_loads: deserialized value ({type(raw).__name__}) is not " + f"compatible with expected type {expected_type}" + ) + if strict: + raise TypeError(msg) + logger.warning(msg) + + if strict: + return raw + # Loose mode — use legacy deserializer so nested custom objects + # (inside dicts/lists) are still reconstructed via object_hook. + return json.loads(s, object_hook=_deserialize_custom_object) + +def _has_json_protocol(cls: type) -> bool: + """Return True iff *cls* exposes callable `to_json` and `from_json`.""" + return callable(getattr(cls, "to_json", None)) and callable( + getattr(cls, "from_json", None) + ) + + +def _is_compatible(value: Any, expected_type: type) -> bool: + """Best-effort `isinstance` check that tolerates generic type hints.""" + try: + return isinstance(value, expected_type) + except TypeError: + # typing constructs like `List[int]` aren't valid for isinstance. + return True diff --git a/azure/durable_functions/models/utils/type_discovery.py b/azure/durable_functions/models/utils/type_discovery.py new file mode 100644 index 00000000..64da16cc --- /dev/null +++ b/azure/durable_functions/models/utils/type_discovery.py @@ -0,0 +1,83 @@ +"""Best-effort type-hint discovery for Durable Functions call sites. + +These helpers feed the ``expected_type`` argument of +``df_serialization.df_loads`` so that custom-class instances can be +re-instantiated without consulting ``sys.modules`` / ``importlib``. + +All public helpers swallow exceptions and return ``None`` on failure -- +the caller treats ``None`` as "no type information available" and falls +back to module-only resolution (and, ultimately, the legacy decoder +with a warning). +""" + +from __future__ import annotations + +import inspect +import logging +from typing import Any, Callable, Optional + +logger = logging.getLogger(__name__) + + +def _unwrap_function_builder(name_or_callable: Any) -> Optional[Callable]: + """Return the underlying user function from a V2 ``FunctionBuilder``. + + Returns ``None`` for plain strings, plain callables, or anything we + don't recognize. + """ + # Avoid a hard dependency on the FunctionBuilder symbol (it lives in + # the azure-functions package and may move). + func = getattr(getattr(name_or_callable, "_function", None), "_func", None) + if callable(func): + return func + return None + + +def _return_annotation(fn: Callable) -> Optional[type]: + try: + sig = inspect.signature(fn) + except (TypeError, ValueError): + return None + ann = sig.return_annotation + if ann is inspect.Signature.empty: + return None + return ann if isinstance(ann, type) else None + + +def activity_output_type(name_or_callable: Any) -> Optional[type]: + """Discover the return-annotation type of a V2 activity function. + + Returns ``None`` if ``name_or_callable`` is a plain string (V1 model + or hand-written name) or if the annotation isn't a concrete type. + """ + fn = _unwrap_function_builder(name_or_callable) + if fn is None: + return None + return _return_annotation(fn) + + +def sub_orchestrator_output_type(name_or_callable: Any) -> Optional[type]: + """Discover the return-annotation type of a V2 sub-orchestrator function.""" + fn = _unwrap_function_builder(name_or_callable) + if fn is None: + return None + return _return_annotation(fn) + + +def entity_operation_input_type(entity_user_fn: Optional[Callable], + operation_name: str) -> Optional[type]: + """Best-effort discovery of an entity operation's input type. + + Entities in the V2 model are typically a single function that + dispatches on ``context.operation_name``. There is no general way to + statically associate an operation name with a parameter type; this + helper currently returns ``None`` for all such functions and exists + as the extension point for richer entity-dispatch patterns we may + add in the future (e.g. class-based entities with one method per + operation). + """ + if entity_user_fn is None or not operation_name: + return None + # Future work: inspect class-based entity dispatch tables. For now, + # signal "unknown" so the codec falls back to module-only resolution. + return None diff --git a/azure/durable_functions/orchestrator.py b/azure/durable_functions/orchestrator.py index 9e3a29b2..3717cf38 100644 --- a/azure/durable_functions/orchestrator.py +++ b/azure/durable_functions/orchestrator.py @@ -66,7 +66,12 @@ def handle(context: func.OrchestrationContext) -> str: context_body = getattr(context, "body", None) if context_body is None: context_body = context - return Orchestrator(fn).handle(DurableOrchestrationContext.from_json(context_body)) + ctx = DurableOrchestrationContext.from_json(context_body) + # Propagate the decorator-declared input type (set by + # @app.orchestration_trigger(input_type=...)) so that + # context.get_input() can decode the payload type-safely. + ctx._input_expected_type = getattr(handle, "_df_input_type", None) + return Orchestrator(fn).handle(ctx) handle.orchestrator_function = fn diff --git a/tests/models/test_Decorators.py b/tests/models/test_Decorators.py index cf6d1148..0c753f77 100644 --- a/tests/models/test_Decorators.py +++ b/tests/models/test_Decorators.py @@ -34,6 +34,22 @@ def dummy_function(my_context): ] }) +def test_orchestration_trigger_input_type_stashed(app): + """Verify that input_type= on the decorator is stashed on the handle.""" + + class MyInput: + pass + + @app.orchestration_trigger(context_name="my_context", input_type=MyInput) + def dummy_function(my_context): + pass + + user_code = get_user_code(app) + assert user_code.get_function_name() == "dummy_function" + # The input type is stashed on the inner callable (the Orchestrator + # handle) which lives at Function._func. + assert getattr(user_code._func, "_df_input_type", None) is MyInput + def test_activity_trigger(app): @app.activity_trigger(input_name="my_input") diff --git a/tests/models/test_DurableOrchestrationContext.py b/tests/models/test_DurableOrchestrationContext.py index 3aecae5a..690837d5 100644 --- a/tests/models/test_DurableOrchestrationContext.py +++ b/tests/models/test_DurableOrchestrationContext.py @@ -101,6 +101,77 @@ def test_get_input_json_str(): assert 'Seattle' == result['city'] + +class _Order: + """Test fixture for expected_type round-trips.""" + def __init__(self, item: str, qty: int): + self.item = item + self.qty = qty + + @staticmethod + def to_json(obj): + return {"item": obj.item, "qty": obj.qty} + + @staticmethod + def from_json(data): + return _Order(data["item"], data["qty"]) + + +def test_get_input_with_expected_type_kwarg(): + from azure.durable_functions.models.utils.df_serialization import df_dumps + builder = ContextBuilder('test_function_context') + builder.input_ = df_dumps(_Order("widget", 5)) + context = DurableOrchestrationContext.from_json(builder.to_json_string()) + + result = context.get_input(expected_type=_Order) + assert isinstance(result, _Order) + assert result.item == "widget" + assert result.qty == 5 + + +def test_get_input_with_decorator_input_type(): + from azure.durable_functions.models.utils.df_serialization import df_dumps + builder = ContextBuilder('test_function_context') + builder.input_ = df_dumps(_Order("widget", 5)) + context = DurableOrchestrationContext.from_json(builder.to_json_string()) + # Simulate what Orchestrator.create does when input_type is set + context._input_expected_type = _Order + + result = context.get_input() + assert isinstance(result, _Order) + assert result.item == "widget" + + +def test_get_input_kwarg_overrides_decorator_type(): + """Call-site expected_type takes precedence over decorator input_type.""" + from azure.durable_functions.models.utils.df_serialization import df_dumps + + class _Alt: + def __init__(self, item, qty): + self.item = item + self.qty = qty + + @staticmethod + def to_json(obj): + return {"item": obj.item, "qty": obj.qty} + + @staticmethod + def from_json(data): + return _Alt(data["item"], data["qty"]) + + builder = ContextBuilder('test_function_context') + builder.input_ = df_dumps(_Order("widget", 5)) + context = DurableOrchestrationContext.from_json(builder.to_json_string()) + context._input_expected_type = _Order # decorator says _Order + + # expected_type is used for pre-validation only; the legacy decoder + # still uses the payload's declared class. A warning is emitted + # because _Alt != _Order. + result = context.get_input(expected_type=_Alt) + assert isinstance(result, _Order) # legacy decoder uses payload class + assert result.item == "widget" + + def test_version_equals_version_from_execution_started_event(): builder = ContextBuilder('test_function_context') builder.history_events = [] diff --git a/tests/orchestrator/test_expected_type.py b/tests/orchestrator/test_expected_type.py new file mode 100644 index 00000000..58bcadb4 --- /dev/null +++ b/tests/orchestrator/test_expected_type.py @@ -0,0 +1,164 @@ +"""Tests for the expected_type kwarg on orchestration context APIs. + +Covers call_activity, call_sub_orchestrator, and their _with_retry variants +when an explicit expected_type is provided at the call site (V1 string-name +callers with no auto-discovery). +""" +import json +from datetime import datetime + +from tests.orchestrator.orchestrator_test_utils import ( + assert_orchestration_state_equals, + get_orchestration_state_result, +) +from tests.test_utils.ContextBuilder import ContextBuilder +from azure.durable_functions.models.OrchestratorState import OrchestratorState +from azure.durable_functions.models.actions.CallActivityAction import CallActivityAction +from azure.durable_functions.models.actions.CallSubOrchestratorAction import CallSubOrchestratorAction +from azure.durable_functions.models.RetryOptions import RetryOptions +from azure.durable_functions.models.utils.df_serialization import df_dumps + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +class _Order: + def __init__(self, item: str, qty: int = 1): + self.item = item + self.qty = qty + + @staticmethod + def to_json(obj): + return {"item": obj.item, "qty": obj.qty} + + @staticmethod + def from_json(data): + return _Order(data["item"], data["qty"]) + + +def _base_state(output=None) -> OrchestratorState: + return OrchestratorState(is_done=False, actions=[], output=output) + + +def _add_activity_completed(ctx_builder, id_, result_str, name="DoWork"): + ctx_builder.add_task_scheduled_event(name=name, id_=id_) + ctx_builder.add_orchestrator_completed_event() + ctx_builder.add_orchestrator_started_event() + ctx_builder.add_task_completed_event(id_=id_, result=result_str) + + +def _add_sub_orch_completed(ctx_builder, id_, result_str, name="SubOrch"): + ctx_builder.add_sub_orchestrator_started_event(name=name, id_=id_, input_="") + ctx_builder.add_orchestrator_completed_event() + ctx_builder.add_orchestrator_started_event() + ctx_builder.add_sub_orchestrator_completed_event(result=result_str, id_=id_) + + +# --------------------------------------------------------------------------- +# call_activity with expected_type +# --------------------------------------------------------------------------- + +def orchestrator_activity_expected_type(context): + result = yield context.call_activity("DoWork", "x", expected_type=_Order) + return result.item + + +def test_call_activity_with_expected_type(): + payload = df_dumps(_Order("widget", 5)) + ctx = ContextBuilder("test") + _add_activity_completed(ctx, 0, payload) + + result = get_orchestration_state_result(ctx, orchestrator_activity_expected_type) + + assert result["isDone"] is True + # The orchestrator returns result.item which is "widget" + assert result["output"] == "widget" + + +# --------------------------------------------------------------------------- +# call_activity_with_retry with expected_type +# --------------------------------------------------------------------------- + +def orchestrator_activity_retry_expected_type(context): + opts = RetryOptions(5000, 3) + result = yield context.call_activity_with_retry( + "DoWork", opts, "x", expected_type=_Order) + return result.item + + +def test_call_activity_with_retry_expected_type(): + payload = df_dumps(_Order("gadget", 2)) + ctx = ContextBuilder("test") + _add_activity_completed(ctx, 0, payload) + + result = get_orchestration_state_result(ctx, orchestrator_activity_retry_expected_type) + + assert result["isDone"] is True + assert result["output"] == "gadget" + + +# --------------------------------------------------------------------------- +# call_sub_orchestrator with expected_type +# --------------------------------------------------------------------------- + +def orchestrator_sub_orch_expected_type(context): + result = yield context.call_sub_orchestrator( + "SubOrch", "input", expected_type=_Order) + return result.item + + +def test_call_sub_orchestrator_with_expected_type(): + payload = df_dumps(_Order("part", 10)) + ctx = ContextBuilder("test") + _add_sub_orch_completed(ctx, 0, payload) + + result = get_orchestration_state_result(ctx, orchestrator_sub_orch_expected_type) + + assert result["isDone"] is True + assert result["output"] == "part" + + +# --------------------------------------------------------------------------- +# call_sub_orchestrator_with_retry with expected_type +# --------------------------------------------------------------------------- + +def orchestrator_sub_orch_retry_expected_type(context): + opts = RetryOptions(5000, 3) + result = yield context.call_sub_orchestrator_with_retry( + "SubOrch", opts, "input", expected_type=_Order) + return result.item + + +def test_call_sub_orchestrator_with_retry_expected_type(): + payload = df_dumps(_Order("gizmo", 3)) + ctx = ContextBuilder("test") + _add_sub_orch_completed(ctx, 0, payload) + + result = get_orchestration_state_result(ctx, orchestrator_sub_orch_retry_expected_type) + + assert result["isDone"] is True + assert result["output"] == "gizmo" + + +# --------------------------------------------------------------------------- +# expected_type kwarg overrides auto-discovered type (None in V1) +# --------------------------------------------------------------------------- + +def orchestrator_override(context): + """Call with string name (V1) + expected_type; auto-discovery returns None.""" + result = yield context.call_activity("DoWork", "x", expected_type=_Order) + return [result.item, result.qty] + + +def test_expected_type_kwarg_used_when_auto_discovery_returns_none(): + payload = df_dumps(_Order("bolt", 99)) + ctx = ContextBuilder("test") + _add_activity_completed(ctx, 0, payload) + + result = get_orchestration_state_result(ctx, orchestrator_override) + + assert result["isDone"] is True + output = result["output"] + assert output[0] == "bolt" + assert output[1] == 99 diff --git a/tests/orchestrator/test_external_event.py b/tests/orchestrator/test_external_event.py index 263ef774..86df6107 100644 --- a/tests/orchestrator/test_external_event.py +++ b/tests/orchestrator/test_external_event.py @@ -3,6 +3,7 @@ from tests.orchestrator.orchestrator_test_utils import assert_orchestration_state_equals, get_orchestration_state_result from tests.test_utils.ContextBuilder import ContextBuilder from azure.durable_functions.models.actions.WaitForExternalEventAction import WaitForExternalEventAction +from azure.durable_functions.models.utils.df_serialization import df_dumps def generator_function(context): result = yield context.wait_for_external_event("A") @@ -51,4 +52,37 @@ def test_succeeds_on_out_of_order_payload(): expected_state.actions.append([WaitForExternalEventAction("B")]) expected_state._is_done = True expected = expected_state.to_json() - assert_orchestration_state_equals(expected, result) \ No newline at end of file + assert_orchestration_state_equals(expected, result) + + +class _Payload: + """Simple custom class for testing expected_type on external events.""" + def __init__(self, value: str): + self.value = value + + @staticmethod + def to_json(obj): + return {"value": obj.value} + + @staticmethod + def from_json(data): + return _Payload(data["value"]) + + +def generator_function_with_expected_type(context): + result = yield context.wait_for_external_event("A", expected_type=_Payload) + return result.value + + +def test_external_event_with_expected_type(): + """wait_for_external_event(expected_type=...) decodes custom objects.""" + timestamp = datetime.now() + json_input = df_dumps(_Payload("hello")) + context_builder = ContextBuilder() + context_builder.add_event_raised_event( + "A", input_=json_input, timestamp=timestamp, id_=-1) + result = get_orchestration_state_result( + context_builder, generator_function_with_expected_type) + + assert result["isDone"] is True + assert result["output"] == "hello" \ No newline at end of file diff --git a/tests/utils/test_df_serialization.py b/tests/utils/test_df_serialization.py new file mode 100644 index 00000000..6c1d692d --- /dev/null +++ b/tests/utils/test_df_serialization.py @@ -0,0 +1,656 @@ +"""Comprehensive round-trip and validation tests for df_serialization. + +Every data shape is tested in three configurations: + 1. No expected_type (legacy object_hook path) + 2. Loose mode + expected_type (warn on mismatch, legacy deserialize) + 3. Strict mode + expected_type (raise on mismatch, from_json directly) +""" + +import json +import logging +import os + +import pytest + +from azure.durable_functions.models.utils import df_serialization +from azure.durable_functions.models.utils.df_serialization import ( + df_dumps, + df_loads, + _get_serialize_default, + _STRICT_ENV_VAR, +) + + +# --------------------------------------------------------------------------- +# Helper classes +# --------------------------------------------------------------------------- + +class PlainPerson: + """Simple class: to_json returns a dict, from_json accepts a dict.""" + + def __init__(self, name: str, age: int): + self.name = name + self.age = age + + @staticmethod + def to_json(obj): + return {"name": obj.name, "age": obj.age} + + @staticmethod + def from_json(data): + return PlainPerson(data["name"], data["age"]) + + def __eq__(self, other): + return (isinstance(other, PlainPerson) + and self.name == other.name and self.age == other.age) + + +class ScalarPerson: + """to_json returns a scalar (str), not a dict.""" + + def __init__(self, name: str): + self.name = name + + @staticmethod + def to_json(obj): + return obj.name + + @staticmethod + def from_json(data): + return ScalarPerson(data) + + def __eq__(self, other): + return isinstance(other, ScalarPerson) and self.name == other.name + + +class Hat: + """Leaf object for nesting tests.""" + + def __init__(self, color: str): + self.color = color + + @staticmethod + def to_json(obj): + return {"color": obj.color} + + @staticmethod + def from_json(data): + return Hat(data["color"]) + + def __eq__(self, other): + return isinstance(other, Hat) and self.color == other.color + + +class NaiveOrder: + """Nested object whose from_json expects pre-constructed Hat instances. + + This relies on the bottom-up object_hook behavior — from_json receives + a Hat instance at data["hat"], not a raw dict. Works in loose mode but + fails in strict mode because strict skips object_hook. + """ + + def __init__(self, item: str, hat: Hat): + self.item = item + self.hat = hat + + @staticmethod + def to_json(obj): + return {"item": obj.item, "hat": obj.hat} + + @staticmethod + def from_json(data): + # Assumes data["hat"] is already a Hat instance (object_hook fired) + return NaiveOrder(data["item"], data["hat"]) + + def __eq__(self, other): + return (isinstance(other, NaiveOrder) + and self.item == other.item and self.hat == other.hat) + + +class SmartOrder: + """Nested object with strict-mode-compatible to_json / from_json. + + to_json produces plain JSON (calls Hat.to_json explicitly), so the + result is natively JSON-serializable without ``default=``. from_json + handles both the strict-mode shape (plain dict from to_json) and + the loose-mode shape (pre-constructed Hat or raw legacy dict). + """ + + def __init__(self, item: str, hat: Hat): + self.item = item + self.hat = hat + + @staticmethod + def to_json(obj): + return {"item": obj.item, "hat": Hat.to_json(obj.hat)} + + @staticmethod + def from_json(data): + hat_data = data["hat"] + if isinstance(hat_data, Hat): + # Loose mode: object_hook already constructed the Hat + hat = hat_data + else: + # Strict mode or plain dict: reconstruct from to_json output + hat = Hat.from_json(hat_data) + return SmartOrder(data["item"], hat) + + def __eq__(self, other): + return (isinstance(other, SmartOrder) + and self.item == other.item and self.hat == other.hat) + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + +@pytest.fixture +def strict(monkeypatch): + """Enable strict typing mode for the duration of a test.""" + monkeypatch.setenv(_STRICT_ENV_VAR, "1") + + +@pytest.fixture +def loose(monkeypatch): + """Explicitly disable strict typing mode.""" + monkeypatch.delenv(_STRICT_ENV_VAR, raising=False) + + +# =================================================================== +# 1. PRIMITIVES (str, int, float, bool, None, list, dict) +# =================================================================== + +@pytest.mark.parametrize("value", [ + None, + True, + False, + 0, + -1, + 42, + 3.14, + "", + "hello", + [], + [1, 2, 3], + [True, None, "mixed"], + {}, + {"a": 1, "b": [1, 2]}, + {"nested": {"deep": {"value": 7}}}, +]) +class TestPrimitiveRoundTrips: + """Primitives must round-trip identically in all three paths.""" + + def test_no_expected_type(self, value): + assert df_loads(df_dumps(value)) == value + + def test_loose_with_matching_type(self, value, loose, caplog): + # Use the actual type of the value as expected_type + et = type(value) if value is not None else type(None) + with caplog.at_level(logging.WARNING, logger=df_serialization.__name__): + result = df_loads(df_dumps(value), expected_type=et) + assert result == value + + def test_strict_with_matching_type(self, value, strict): + et = type(value) if value is not None else type(None) + result = df_loads(df_dumps(value), expected_type=et) + assert result == value + + +# =================================================================== +# 2. SIMPLE CUSTOM OBJECTS (dict-returning to_json) +# =================================================================== + +class TestSimpleObject: + + def test_no_expected_type(self): + obj = PlainPerson("andy", 99) + decoded = df_loads(df_dumps(obj)) + assert decoded == obj + + def test_loose_matching_type(self, loose): + obj = PlainPerson("andy", 99) + decoded = df_loads(df_dumps(obj), expected_type=PlainPerson) + assert decoded == obj + + def test_strict_matching_type(self, strict): + obj = PlainPerson("andy", 99) + decoded = df_loads(df_dumps(obj), expected_type=PlainPerson) + assert decoded == obj + + def test_loose_mismatched_type_warns(self, loose, caplog): + encoded = df_dumps(PlainPerson("a", 1)) + with caplog.at_level(logging.WARNING, logger=df_serialization.__name__): + decoded = df_loads(encoded, expected_type=ScalarPerson) + # Loose mode: legacy decoder uses the payload's class + assert isinstance(decoded, PlainPerson) + assert any("payload declares class" in r.message for r in caplog.records) + + def test_strict_mismatched_type_raises(self, strict): + encoded = df_dumps(PlainPerson("a", 1)) + with pytest.raises(TypeError, match="payload declares class"): + df_loads(encoded, expected_type=ScalarPerson) + + +# =================================================================== +# 3. SCALAR-RETURNING to_json +# =================================================================== + +class TestScalarToJson: + + def test_no_expected_type(self): + obj = ScalarPerson("andy") + decoded = df_loads(df_dumps(obj)) + assert decoded == obj + + def test_loose_matching_type(self, loose): + obj = ScalarPerson("andy") + decoded = df_loads(df_dumps(obj), expected_type=ScalarPerson) + assert decoded == obj + + def test_strict_matching_type(self, strict): + obj = ScalarPerson("andy") + decoded = df_loads(df_dumps(obj), expected_type=ScalarPerson) + assert decoded == obj + + +# =================================================================== +# 4. DICT WITH OBJECT PROPERTIES e.g. {"person": PlainPerson(...)} +# =================================================================== + +class TestDictWithObjectProperty: + """A plain dict containing a custom object as a value.""" + + def _make_payload(self): + return {"person": PlainPerson("a", 1), "count": 7} + + def test_no_expected_type(self): + """Loose path: object_hook reconstructs nested objects.""" + decoded = df_loads(df_dumps(self._make_payload())) + assert decoded["count"] == 7 + assert isinstance(decoded["person"], PlainPerson) + assert decoded["person"].name == "a" + + def test_loose_expected_dict(self, loose, caplog): + """Loose path + expected_type=dict: works, inner objects reconstructed.""" + with caplog.at_level(logging.WARNING, logger=df_serialization.__name__): + decoded = df_loads(df_dumps(self._make_payload()), expected_type=dict) + assert isinstance(decoded["person"], PlainPerson) + # No warning — top-level is a dict matching expected_type + assert not any("not compatible" in r.message for r in caplog.records) + + def test_strict_encode_fails_for_nested_custom_objects(self, strict): + """Strict mode: a plain dict containing a custom object cannot be + encoded — json.dumps runs without default= so Hat raises TypeError.""" + with pytest.raises(TypeError): + df_dumps(self._make_payload()) + + +# =================================================================== +# 5. NESTED OBJECTS — "naive" from_json (expects pre-constructed) +# =================================================================== + +class TestNaiveNestedObject: + """NaiveOrder.from_json expects Hat to already be a Hat instance.""" + + def _make(self): + return NaiveOrder("widget", Hat("red")) + + def test_no_expected_type(self): + """Legacy path: object_hook fires bottom-up, Hat constructed first.""" + decoded = df_loads(df_dumps(self._make())) + assert isinstance(decoded, NaiveOrder) + assert isinstance(decoded.hat, Hat) + assert decoded.hat.color == "red" + + def test_loose_matching_type(self, loose): + """Loose + expected_type: legacy path still fires, nested works.""" + decoded = df_loads(df_dumps(self._make()), expected_type=NaiveOrder) + assert decoded == self._make() + + def test_strict_encode_fails_for_naive_to_json(self, strict): + """Strict mode: NaiveOrder.to_json returns a Hat instance, which + is not natively JSON-serializable. df_dumps should fail at encode.""" + with pytest.raises(TypeError): + df_dumps(self._make()) + + +# =================================================================== +# 6. NESTED OBJECTS — "smart" from_json (handles raw dicts) +# =================================================================== + +class TestSmartNestedObject: + """SmartOrder.from_json manually calls Hat.from_json when needed.""" + + def _make(self): + return SmartOrder("gadget", Hat("blue")) + + def test_no_expected_type(self): + decoded = df_loads(df_dumps(self._make())) + assert isinstance(decoded, SmartOrder) + assert decoded.hat == Hat("blue") + + def test_loose_matching_type(self, loose): + decoded = df_loads(df_dumps(self._make()), expected_type=SmartOrder) + assert decoded == self._make() + + def test_strict_matching_type(self, strict): + """Strict mode works: SmartOrder.from_json handles the raw dict.""" + decoded = df_loads(df_dumps(self._make()), expected_type=SmartOrder) + assert decoded == self._make() + assert isinstance(decoded.hat, Hat) + assert decoded.hat.color == "blue" + + +# =================================================================== +# 7. LIST OF OBJECTS +# =================================================================== + +class TestListOfObjects: + + def _make(self): + return [PlainPerson("a", 1), PlainPerson("b", 2)] + + def test_no_expected_type(self): + decoded = df_loads(df_dumps(self._make())) + assert len(decoded) == 2 + assert all(isinstance(p, PlainPerson) for p in decoded) + + def test_loose_expected_list(self, loose): + decoded = df_loads(df_dumps(self._make()), expected_type=list) + assert len(decoded) == 2 + assert all(isinstance(p, PlainPerson) for p in decoded) + + def test_strict_encode_fails_for_nested_custom_objects(self, strict): + """Strict mode: a list of custom objects cannot be encoded — the + list itself doesn't have to_json, and json.dumps runs without + default= so PlainPerson raises TypeError.""" + with pytest.raises(TypeError): + df_dumps(self._make()) + + +# =================================================================== +# 8. PRIMITIVE TYPE MISMATCHES +# =================================================================== + +class TestPrimitiveTypeMismatch: + + def test_loose_warns(self, loose, caplog): + with caplog.at_level(logging.WARNING, logger=df_serialization.__name__): + result = df_loads(df_dumps("hello"), expected_type=int) + assert result == "hello" + assert any("not compatible" in r.message for r in caplog.records) + + def test_strict_raises(self, strict): + with pytest.raises(TypeError, match="not compatible with expected type"): + df_loads(df_dumps("hello"), expected_type=int) + + def test_loose_str_expected_dict_warns(self, loose, caplog): + with caplog.at_level(logging.WARNING, logger=df_serialization.__name__): + result = df_loads(df_dumps("hello"), expected_type=dict) + assert result == "hello" + assert any("not compatible" in r.message for r in caplog.records) + + def test_strict_str_expected_dict_raises(self, strict): + with pytest.raises(TypeError): + df_loads(df_dumps("hello"), expected_type=dict) + + +# =================================================================== +# 9. typing CONSTRUCTS (List[int], Optional[str], etc.) +# =================================================================== + +class TestTypingConstructs: + """Generic type hints can't be validated with isinstance — we pass + through without error in both modes.""" + + def test_loose_list_of_int(self, loose): + from typing import List + decoded = df_loads(df_dumps([1, 2, 3]), expected_type=List[int]) + assert decoded == [1, 2, 3] + + def test_strict_list_of_int(self, strict): + from typing import List + decoded = df_loads(df_dumps([1, 2, 3]), expected_type=List[int]) + assert decoded == [1, 2, 3] + + def test_loose_optional_str(self, loose): + from typing import Optional + decoded = df_loads(df_dumps("hi"), expected_type=Optional[str]) + assert decoded == "hi" + + +# =================================================================== +# 10. STRICT MODE ENV VAR VALUES +# =================================================================== + +class TestStrictModeEnvVar: + + @pytest.mark.parametrize("val", ["1", "true", "yes", "TRUE", "Yes", " 1 "]) + def test_truthy_values_enable_strict(self, monkeypatch, val): + monkeypatch.setenv(_STRICT_ENV_VAR, val) + with pytest.raises(TypeError): + df_loads(df_dumps("hello"), expected_type=int) + + @pytest.mark.parametrize("val", ["0", "false", "no", "", "nope"]) + def test_non_truthy_values_stay_loose(self, monkeypatch, val, caplog): + monkeypatch.setenv(_STRICT_ENV_VAR, val) + with caplog.at_level(logging.WARNING, logger=df_serialization.__name__): + result = df_loads(df_dumps("hello"), expected_type=int) + assert result == "hello" + + def test_unset_is_loose(self, monkeypatch): + monkeypatch.delenv(_STRICT_ENV_VAR, raising=False) + result = df_loads(df_dumps("hello"), expected_type=int) + assert result == "hello" + + +# =================================================================== +# 10b. STRICT MODE WITHOUT expected_type +# =================================================================== + +class TestStrictNoExpectedType: + """In strict mode, df_loads without expected_type must never call import_module.""" + + def test_primitive_returns_raw(self, strict): + assert df_loads(df_dumps(42)) == 42 + + def test_string_returns_raw(self, strict): + assert df_loads(df_dumps("hello")) == "hello" + + def test_none_returns_raw(self, strict): + assert df_loads(df_dumps(None)) is None + + def test_plain_dict_returns_raw(self, strict): + d = {"key": "value", "n": 1} + assert df_loads(df_dumps(d)) == d + + def test_plain_list_returns_raw(self, strict): + lst = [1, "two", None] + assert df_loads(df_dumps(lst)) == lst + + def test_custom_object_raises(self, strict): + s = df_dumps(PlainPerson("alice", 30)) + with pytest.raises(TypeError, match="strict mode requires expected_type"): + df_loads(s) + + def test_custom_object_error_includes_class(self, strict): + s = df_dumps(PlainPerson("alice", 30)) + with pytest.raises(TypeError, match="PlainPerson"): + df_loads(s) + + def test_loose_mode_custom_object_still_works(self, loose): + """Without strict, the legacy path runs even without expected_type.""" + p = PlainPerson("bob", 25) + result = df_loads(df_dumps(p)) + assert isinstance(result, PlainPerson) + assert result.name == "bob" + + +# =================================================================== +# 11. WIRE FORMAT VERIFICATION +# =================================================================== + +class TestWireFormat: + + def test_df_dumps_matches_legacy_json_dumps(self): + from azure.functions._durable_functions import _serialize_custom_object + value = {"key": "value", "list": [1, 2, 3]} + assert df_dumps(value) == json.dumps(value, default=_serialize_custom_object) + + def test_custom_object_produces_legacy_keys(self): + raw = json.loads(df_dumps(PlainPerson("andy", 99))) + assert raw == { + "__class__": "PlainPerson", + "__module__": __name__, + "__data__": {"name": "andy", "age": 99}, + } + + def test_scalar_to_json_produces_legacy_keys(self): + raw = json.loads(df_dumps(ScalarPerson("andy"))) + assert raw == { + "__class__": "ScalarPerson", + "__module__": __name__, + "__data__": "andy", + } + + def test_nested_object_produces_plain_json_data(self): + """SmartOrder.to_json serializes Hat explicitly, so __data__ + contains plain JSON — no nested legacy envelope.""" + raw = json.loads(df_dumps(SmartOrder("gadget", Hat("blue")))) + assert raw["__class__"] == "SmartOrder" + assert raw["__data__"] == {"item": "gadget", "hat": {"color": "blue"}} + + +# =================================================================== +# 12. _get_serialize_default +# =================================================================== + +class TestGetSerializeDefault: + + def test_returns_callable(self): + cb = _get_serialize_default() + assert callable(cb) + + def test_produces_legacy_dict(self): + cb = _get_serialize_default() + result = cb(PlainPerson("a", 1)) + assert result == { + "__class__": "PlainPerson", + "__module__": __name__, + "__data__": {"name": "a", "age": 1}, + } + + def test_strict_returns_none(self, strict): + cb = _get_serialize_default() + assert cb is None + + +# =================================================================== +# 13. ENCODE ERRORS +# =================================================================== + +class TestEncodeErrors: + + def test_class_without_to_json(self): + class NoProtocol: + pass + with pytest.raises(TypeError): + df_dumps(NoProtocol()) + + def test_set(self): + with pytest.raises(TypeError): + df_dumps({1, 2, 3}) + + def test_bytes(self): + with pytest.raises(TypeError): + df_dumps(b"hello") + + +# =================================================================== +# 13b. STRICT-MODE ENCODE +# =================================================================== + +class TestStrictEncode: + """In strict mode, df_dumps rejects non-serializable nested values.""" + + def test_primitive(self, strict): + assert df_dumps(42) == "42" + + def test_string(self, strict): + assert df_dumps("hello") == '"hello"' + + def test_plain_dict(self, strict): + assert json.loads(df_dumps({"a": 1})) == {"a": 1} + + def test_custom_object_top_level_ok(self, strict): + """Top-level custom object is wrapped in envelope.""" + raw = json.loads(df_dumps(PlainPerson("andy", 99))) + assert raw["__class__"] == "PlainPerson" + assert raw["__data__"] == {"name": "andy", "age": 99} + + def test_strict_smart_order_data_is_plain_json(self, strict): + """SmartOrder.to_json returns plain JSON, so encoding succeeds + and __data__ contains no nested envelopes.""" + raw = json.loads(df_dumps(SmartOrder("gadget", Hat("blue")))) + assert raw["__class__"] == "SmartOrder" + assert raw["__data__"] == {"item": "gadget", "hat": {"color": "blue"}} + + def test_strict_naive_order_fails(self, strict): + """NaiveOrder.to_json returns a Hat instance — not serializable.""" + with pytest.raises(TypeError): + df_dumps(NaiveOrder("widget", Hat("red"))) + + def test_strict_dict_with_custom_value_fails(self, strict): + """Plain dict containing a custom object — not serializable.""" + with pytest.raises(TypeError): + df_dumps({"person": PlainPerson("a", 1)}) + + def test_strict_list_with_custom_value_fails(self, strict): + """List containing custom objects — not serializable.""" + with pytest.raises(TypeError): + df_dumps([PlainPerson("a", 1)]) + + def test_loose_dict_with_custom_value_ok(self, loose): + """In loose mode, nested custom objects are still auto-wrapped.""" + raw = json.loads(df_dumps({"person": PlainPerson("a", 1)})) + assert raw["person"]["__class__"] == "PlainPerson" + + +# =================================================================== +# 14. EDGE CASES +# =================================================================== + +class TestEdgeCases: + + def test_bool_does_not_become_int(self): + """bool is a subclass of int — verify it stays bool.""" + out = df_loads(df_dumps(True)) + assert out is True + assert isinstance(out, bool) + + def test_none_with_expected_type_nonetype(self, loose): + assert df_loads(df_dumps(None), expected_type=type(None)) is None + + def test_none_with_expected_type_nonetype_strict(self, strict): + assert df_loads(df_dumps(None), expected_type=type(None)) is None + + def test_empty_dict_expected_dict(self, loose): + assert df_loads(df_dumps({}), expected_type=dict) == {} + + def test_empty_list_expected_list(self, strict): + assert df_loads(df_dumps([]), expected_type=list) == [] + + def test_tuple_becomes_list(self): + """Tuples serialize as JSON arrays — come back as lists.""" + assert df_loads(df_dumps((1, 2, 3))) == [1, 2, 3] + + def test_int_dict_keys_become_strings(self): + decoded = df_loads(df_dumps({1: "one", 2: "two"})) + assert decoded == {"1": "one", "2": "two"} + + def test_no_expected_type_no_warning(self, caplog): + """When expected_type is None, no warnings should fire.""" + with caplog.at_level(logging.WARNING, logger=df_serialization.__name__): + df_loads(df_dumps(PlainPerson("a", 1))) + assert not any("not compatible" in r.message for r in caplog.records) + assert not any("payload declares" in r.message for r in caplog.records) diff --git a/tests/utils/test_type_discovery.py b/tests/utils/test_type_discovery.py new file mode 100644 index 00000000..a5b83904 --- /dev/null +++ b/tests/utils/test_type_discovery.py @@ -0,0 +1,81 @@ +"""Tests for type_discovery helpers.""" + +from typing import Optional +from unittest.mock import MagicMock + +from azure.durable_functions.models.utils.type_discovery import ( + activity_output_type, + sub_orchestrator_output_type, + entity_operation_input_type, +) + + +class _Result: + pass + + +def _make_function_builder(fn): + """Build a minimal stand-in for FunctionBuilder._function._func.""" + fb = MagicMock() + fb._function._func = fn + return fb + + +# --------------------------------------------------------------------------- +# activity_output_type +# --------------------------------------------------------------------------- + +def test_activity_output_type_returns_annotation(): + def my_activity(x) -> _Result: + return _Result() + fb = _make_function_builder(my_activity) + assert activity_output_type(fb) is _Result + + +def test_activity_output_type_returns_none_for_string(): + assert activity_output_type("activity_name") is None + + +def test_activity_output_type_returns_none_when_unannotated(): + def my_activity(x): + return None + fb = _make_function_builder(my_activity) + assert activity_output_type(fb) is None + + +def test_activity_output_type_returns_none_for_typing_construct(): + def my_activity(x) -> Optional[_Result]: + return None + fb = _make_function_builder(my_activity) + # Optional[_Result] is not a concrete class, so we return None. + assert activity_output_type(fb) is None + + +# --------------------------------------------------------------------------- +# sub_orchestrator_output_type (same shape as activity) +# --------------------------------------------------------------------------- + +def test_sub_orchestrator_output_type_returns_annotation(): + def my_sub_orch(ctx) -> _Result: + return _Result() + fb = _make_function_builder(my_sub_orch) + assert sub_orchestrator_output_type(fb) is _Result + + +def test_sub_orchestrator_output_type_returns_none_for_string(): + assert sub_orchestrator_output_type("orch_name") is None + + +# --------------------------------------------------------------------------- +# entity_operation_input_type (always None today) +# --------------------------------------------------------------------------- + +def test_entity_operation_input_type_returns_none(): + def my_entity(ctx): + pass + assert entity_operation_input_type(my_entity, "add") is None + + +def test_entity_operation_input_type_returns_none_for_missing_inputs(): + assert entity_operation_input_type(None, "add") is None + assert entity_operation_input_type(lambda ctx: None, "") is None