diff --git a/README.md b/README.md index f26c7b837..1ed2a228f 100644 --- a/README.md +++ b/README.md @@ -54,6 +54,7 @@ informal introduction to the features and their implementation. - [Data Conversion](#data-conversion) - [Pydantic Support](#pydantic-support) - [Custom Type Data Conversion](#custom-type-data-conversion) + - [External Payload Storage](#external-payload-storage) - [Workers](#workers) - [Workflows](#workflows) - [Definition](#definition) @@ -309,8 +310,9 @@ other_ns_client = Client(**config) Data converters are used to convert raw Temporal payloads to/from actual Python types. A custom data converter of type `temporalio.converter.DataConverter` can be set via the `data_converter` parameter of the `Client` constructor. Data -converters are a combination of payload converters, payload codecs, and failure converters. Payload converters convert -Python values to/from serialized bytes. Payload codecs convert bytes to bytes (e.g. for compression or encryption). +converters are a combination of payload converters, external payload storage, payload codecs, and failure converters. Payload +converters convert Python values to/from serialized bytes. External payload storage optionally stores and retrieves payloads +to/from external storage services using drivers. Payload codecs convert bytes to bytes (e.g. for compression or encryption). Failure converters convert exceptions to/from serialized failures. The default data converter supports converting multiple types including: @@ -455,6 +457,131 @@ my_data_converter = dataclasses.replace( Now `IPv4Address` can be used in type hints including collections, optionals, etc. +##### External Payload Storage + +⚠️ **External payload storage support is currently at an experimental release stage.** ⚠️ + +External payload storage allows large payloads to be offloaded to an external storage service (such as Amazon S3) rather than stored inline in workflow history. This is useful when workflows or activities work with data that would otherwise exceed Temporal's payload size limits. + +External payload storage is configured via the `external_storage` parameter on `DataConverter`, which accepts a `temporalio.converter.ExternalStorage` instance. Any driver used to store payloads must also be configured on the component that retrieves them — for example, if the client stores workflow inputs using a driver, the worker must include that driver in its `ExternalStorage.drivers` list to retrieve them. + +The simplest setup uses a single storage driver: + +```python +import dataclasses +from temporalio.client import Client +from temporalio.converter import DataConverter +from temporalio.converter import ExternalStorage + +driver = MyDriver() + +client = await Client.connect( + "localhost:7233", + data_converter=dataclasses.replace( + DataConverter.default, + external_storage=ExternalStorage(drivers=[driver]), + ), +) +``` + +Some things to note about external payload storage: + +* Only payloads that meet or exceed `ExternalStorage.payload_size_threshold` (default 256 KiB) are offloaded. Smaller payloads are stored inline as normal. +* External payload storage applies transparently to workflow inputs/outputs, activity inputs/outputs, signals, updates, queries, and failure details. +* The `DataConverter`'s `payload_codec` (if configured) is applied to the *reference* payload stored in workflow history, not to the externally stored bytes. To encrypt or compress the bytes handed to a driver, use `ExternalStorage.payload_codec`. +* Setting `ExternalStorage.payload_size_threshold` to `None` causes every payload to be considered for external payload storage regardless of size. + +###### Multiple Drivers and Driver Selection + +When multiple storage backends are needed, list all drivers in `ExternalStorage.drivers` and provide a `driver_selector` to control which driver stores new payloads. Any driver in the list not chosen for storing is still available for retrieval, which is useful when migrating between storage backends. + +```python +from temporalio.converter import ExternalStorage + +options = ExternalStorage( + drivers=[hot_driver, cold_driver], + driver_selector=lambda context, payload: ( + "hot-storage" if payload.ByteSize() < 5 * 1024 * 1024 else "cold-storage" + ), +) +``` + +For more complex selection logic, use a plain callable that reads from the `StorageDriverStoreContext`: + +```python +import temporalio.converter +from temporalio.api.common.v1 import Payload + +def feature_flag_is_on(workflow_id: str | None) -> bool: + """Check whether external storage is enabled for this workflow via a feature flag service.""" + return workflow_id is not None and len(workflow_id) % 2 == 0 + +def feature_flag_selector( + context: temporalio.converter.StorageDriverStoreContext, _payload: Payload +) -> str | None: + workflow_id = None + if isinstance(context.serialization_context, temporalio.converter.WorkflowSerializationContext): + workflow_id = context.serialization_context.workflow_id + elif isinstance(context.serialization_context, temporalio.converter.ActivitySerializationContext): + workflow_id = context.serialization_context.workflow_id + return "my-driver" if feature_flag_is_on(workflow_id) else None + +options = ExternalStorage( + drivers=[my_driver], + driver_selector=feature_flag_selector, +) +``` + +Some things to note about driver selection: + +* When no `driver_selector` is set, the first driver in `ExternalStorage.drivers` is always used for storing. +* Returning `None` from a selector leaves the payload stored inline in workflow history rather than offloading it. +* The driver name returned by the selector must match a driver registered in `ExternalStorage.drivers`. If it does not, an error is raised. + +###### Custom Drivers + +Implement `temporalio.converter.StorageDriver` to integrate with an external storage system: + +```python +from collections.abc import Sequence +from temporalio.converter import StorageDriver, StorageDriverClaim, StorageDriverRetrieveContext, StorageDriverStoreContext +from temporalio.api.common.v1 import Payload + +class MyDriver(StorageDriver): + def __init__(self, driver_name: str | None = None): + self._driver_name = driver_name or "my-org:driver:my-driver" + + def name(self) -> str: + return self._driver_name + + async def store( + self, context: StorageDriverStoreContext, payloads: Sequence[Payload] + ) -> list[StorageDriverClaim]: + claims = [] + for payload in payloads: + key = await my_storage.put(payload.SerializeToString()) + claims.append(StorageDriverClaim(data={"key": key})) + return claims + + async def retrieve( + self, context: StorageDriverRetrieveContext, claims: Sequence[StorageDriverClaim] + ) -> list[Payload]: + payloads = [] + for claim in claims: + data = await my_storage.get(claim.data["key"]) + p = Payload() + p.ParseFromString(data) + payloads.append(p) + return payloads +``` + +Some things to note about implementing a custom driver: + +* `store` and `retrieve` must return lists of the same length as their respective input sequences. +* `StorageDriver.name()` must return a string that is unique among all drivers in `ExternalStorage.drivers`. This name is embedded in the reference payload stored in workflow history and used to look up the correct driver during retrieval — changing it after payloads have been stored will break retrieval. +* `StorageDriver.type()` is automatically implemented to return the name of the class. This can be overridden in subclasses but must remain consistent across all instances of the subclass. +* Implement `temporalio.converter.WithSerializationContext` on your driver to receive workflow or activity context (namespace, workflow ID, activity ID, etc.) at serialization time. + ### Workers Workers host workflows and/or activities. Here's how to run a worker: diff --git a/temporalio/bridge/worker.py b/temporalio/bridge/worker.py index c98afefca..c2e426d28 100644 --- a/temporalio/bridge/worker.py +++ b/temporalio/bridge/worker.py @@ -303,10 +303,9 @@ async def decode_activation( decode_headers: bool, ) -> None: """Decode all payloads in the activation.""" - if data_converter._decode_payload_has_effect: - await CommandAwarePayloadVisitor( - skip_search_attributes=True, skip_headers=not decode_headers - ).visit(_Visitor(data_converter._decode_payload_sequence), activation) + await CommandAwarePayloadVisitor( + skip_search_attributes=True, skip_headers=not decode_headers + ).visit(_Visitor(data_converter._decode_payload_sequence), activation) async def encode_completion( diff --git a/temporalio/converter.py b/temporalio/converter/__init__.py similarity index 96% rename from temporalio/converter.py rename to temporalio/converter/__init__.py index dc37f5039..6ab3a5d6d 100644 --- a/temporalio/converter.py +++ b/temporalio/converter/__init__.py @@ -1359,6 +1359,13 @@ class DataConverter(WithSerializationContext): payload_limits: PayloadLimitsConfig = PayloadLimitsConfig() """Settings for payload size limits.""" + external_storage: extstore.ExternalStorage | None = None + """Options for external storage. If None, external storage is disabled. + + .. warning:: + This API is experimental. + """ + default: ClassVar[DataConverter] """Singleton default data converter.""" @@ -1445,18 +1452,22 @@ def with_context(self, context: SerializationContext) -> Self: payload_converter = self.payload_converter payload_codec = self.payload_codec failure_converter = self.failure_converter + external_storage = self.external_storage if isinstance(payload_converter, WithSerializationContext): payload_converter = payload_converter.with_context(context) if isinstance(payload_codec, WithSerializationContext): payload_codec = payload_codec.with_context(context) if isinstance(failure_converter, WithSerializationContext): failure_converter = failure_converter.with_context(context) + if isinstance(external_storage, WithSerializationContext): + external_storage = external_storage.with_context(context) if all( new is orig for new, orig in [ (payload_converter, self.payload_converter), (payload_codec, self.payload_codec), (failure_converter, self.failure_converter), + (external_storage, self.external_storage), ] ): return self @@ -1464,6 +1475,7 @@ def with_context(self, context: SerializationContext) -> Self: object.__setattr__(cloned, "payload_converter", payload_converter) object.__setattr__(cloned, "payload_codec", payload_codec) object.__setattr__(cloned, "failure_converter", failure_converter) + object.__setattr__(cloned, "external_storage", external_storage) return cloned def _with_payload_error_limits( @@ -1523,12 +1535,16 @@ async def _encode_memo_existing( async def _encode_payload( self, payload: temporalio.api.common.v1.Payload ) -> temporalio.api.common.v1.Payload: + if self.external_storage is not None: + payload = await self.external_storage._store_payload(payload) if self.payload_codec: payload = (await self.payload_codec.encode([payload]))[0] self._validate_payload_limits([payload]) return payload async def _encode_payloads(self, payloads: temporalio.api.common.v1.Payloads): + if self.external_storage is not None: + await self.external_storage._store_payloads(payloads) if self.payload_codec: await self.payload_codec.encode_wrapper(payloads) self._validate_payload_limits(payloads.payloads) @@ -1536,35 +1552,55 @@ async def _encode_payloads(self, payloads: temporalio.api.common.v1.Payloads): async def _encode_payload_sequence( self, payloads: Sequence[temporalio.api.common.v1.Payload] ) -> list[temporalio.api.common.v1.Payload]: - encoded_payloads = list(payloads) + result = ( + await self.external_storage._store_payload_sequence(payloads) + if self.external_storage is not None + else list(payloads) + ) if self.payload_codec: - encoded_payloads = await self.payload_codec.encode(encoded_payloads) - self._validate_payload_limits(encoded_payloads) - return encoded_payloads + result = await self.payload_codec.encode(result) + self._validate_payload_limits(result) + return result async def _decode_payload( self, payload: temporalio.api.common.v1.Payload ) -> temporalio.api.common.v1.Payload: if self.payload_codec: payload = (await self.payload_codec.decode([payload]))[0] + if self.external_storage is not None: + payload = await self.external_storage._retrieve_payload(payload) + elif len(payload.external_payloads) > 0: + warnings.warn( + "[TMPRL1105] Detected externally stored payload(s) but external storage is not configured.", + category=extstore.StorageWarning, + ) return payload async def _decode_payloads(self, payloads: temporalio.api.common.v1.Payloads): if self.payload_codec: await self.payload_codec.decode_wrapper(payloads) + if self.external_storage is not None: + await self.external_storage._retrieve_payloads(payloads) + elif any(len(p.external_payloads) > 0 for p in payloads.payloads): + warnings.warn( + "[TMPRL1105] Detected externally stored payload(s) but external storage is not configured.", + category=extstore.StorageWarning, + ) async def _decode_payload_sequence( self, payloads: Sequence[temporalio.api.common.v1.Payload] ) -> list[temporalio.api.common.v1.Payload]: - if not self.payload_codec: - return list(payloads) - return await self.payload_codec.decode(payloads) - - # Temporary shortcircuit detection while the _decode_* methods may no-op if - # a payload codec is not configured. Remove once those paths have more to them. - @property - def _decode_payload_has_effect(self) -> bool: - return self.payload_codec is not None + result = list(payloads) + if self.payload_codec: + result = await self.payload_codec.decode(result) + if self.external_storage is not None: + result = await self.external_storage._retrieve_payload_sequence(result) + elif any(len(p.external_payloads) > 0 for p in result): + warnings.warn( + "[TMPRL1105] Detected externally stored payload(s) but external storage is not configured.", + category=extstore.StorageWarning, + ) + return result @staticmethod async def _apply_to_failure_payloads( @@ -1642,6 +1678,27 @@ def _validate_limits( JSONPlainPayloadConverter(), # JSON Plain needs to remain in last because it throws on unknown types ) +# Imported here to break the circular dependency +from temporalio.converter import _extstore as extstore # noqa: E402 +from temporalio.converter._extstore import ( # noqa: E402 + ExternalStorage as ExternalStorage, +) +from temporalio.converter._extstore import ( # noqa: E402 + StorageDriver as StorageDriver, +) +from temporalio.converter._extstore import ( # noqa: E402 + StorageDriverClaim as StorageDriverClaim, +) +from temporalio.converter._extstore import ( # noqa: E402 + StorageDriverRetrieveContext as StorageDriverRetrieveContext, +) +from temporalio.converter._extstore import ( # noqa: E402 + StorageDriverStoreContext as StorageDriverStoreContext, +) +from temporalio.converter._extstore import ( # noqa: E402 + StorageWarning as StorageWarning, +) + DataConverter.default = DataConverter() PayloadConverter.default = DataConverter.default.payload_converter diff --git a/temporalio/converter/_extstore.py b/temporalio/converter/_extstore.py new file mode 100644 index 000000000..1c92211de --- /dev/null +++ b/temporalio/converter/_extstore.py @@ -0,0 +1,451 @@ +"""External payload storage support for offloading payloads to external storage systems.""" + +from __future__ import annotations + +import asyncio +import dataclasses +from abc import ABC, abstractmethod +from collections.abc import Callable, Coroutine, Mapping, Sequence +from dataclasses import dataclass +from typing import Any, ClassVar, TypeVar + +from typing_extensions import Self + +from temporalio.api.common.v1 import Payload, Payloads +from temporalio.converter import ( + JSONPlainPayloadConverter, + PayloadCodec, + SerializationContext, + WithSerializationContext, +) + +_T = TypeVar("_T") + + +async def _gather_cancel_on_error( + coros: Sequence[Coroutine[Any, Any, _T]], +) -> list[_T]: + """Run coroutines concurrently; cancel all remaining tasks if any one fails.""" + tasks = [asyncio.create_task(c) for c in coros] + try: + return await asyncio.gather(*tasks) + except BaseException: + for task in tasks: + task.cancel() + await asyncio.gather(*tasks, return_exceptions=True) + raise + + +@dataclass(frozen=True) +class StorageDriverClaim: + """Claim for an externally stored payload. + + .. warning:: + This API is experimental. + """ + + data: Mapping[str, str] + """Driver-defined data for identifying and retrieving an externally stored payload.""" + + +@dataclass(frozen=True) +class StorageDriverStoreContext: + """Context passed to :meth:`StorageDriver.store` and ``driver_selector`` calls. + + .. warning:: + This API is experimental. + """ + + serialization_context: SerializationContext | None = None + """The serialization context active when this store operation was initiated, + or ``None`` if no context has been set. + """ + + +@dataclass(frozen=True) +class StorageDriverRetrieveContext: + """Context passed to :meth:`StorageDriver.retrieve` calls. + + .. warning:: + This API is experimental. + """ + + +class StorageDriver(ABC): + """Base driver for storing and retrieve payloads from external storage systems. + + .. warning:: + This API is experimental. + """ + + @abstractmethod + def name(self) -> str: + """Returns the name of this driver instance. A driver may allow its name + to be parameterized at construction time so that multiple instances of + the same driver class can coexist in :attr:`ExternalStorage.drivers` with + distinct names. + """ + raise NotImplementedError + + def type(self) -> str: + """Returns the type of the storage driver. This string should be the same + across all instantiations of the same driver class. This allows the equivalent + driver implementation in different languages to be named the same. + + Defaults to the class name. Subclasses may override this to return a + stable, language-agnostic identifier. + """ + return type(self).__name__ + + @abstractmethod + async def store( + self, + context: StorageDriverStoreContext, + payloads: Sequence[Payload], + ) -> list[StorageDriverClaim]: + """Stores payloads in external storage and returns a :class:`StorageDriverClaim` + for each one. The returned list must be the same length as ``payloads``. + """ + raise NotImplementedError + + @abstractmethod + async def retrieve( + self, + context: StorageDriverRetrieveContext, + claims: Sequence[StorageDriverClaim], + ) -> list[Payload]: + """Retrieves payloads from external storage for the given :class:`StorageDriverClaim` + list. The returned list must be the same length as ``claims``. + + Raise :class:`PayloadNotFoundError` when a retrieval attempt confirms + that a payload is absent from storage. This signals an unrecoverable + condition that will fail the workflow rather than retrying the workflow + task. + """ + raise NotImplementedError + + +class StorageWarning(RuntimeWarning): + """Warning for external storage issues. + + .. warning:: + This API is experimental. + """ + + +@dataclass(frozen=True) +class _StorageReference: + driver_name: str + driver_claim: StorageDriverClaim + + +@dataclass(frozen=True) +class ExternalStorage(WithSerializationContext): + """Configuration for external storage behavior. + + .. warning:: + This API is experimental. + """ + + drivers: Sequence[StorageDriver] + """Drivers available for storing and retrieving payloads. At least one + driver must be provided. + + When no :attr:`driver_selector` is set, the first driver in this list is + used for all store operations. Additional drivers may be included solely to + support retrieval — for example, to download payloads that remote callers + uploaded to an external storage system that is not your primary store + driver. Drivers in this list are looked up by :meth:`Driver.name` during + retrieval, so each driver must have a unique name. + """ + + driver_selector: ( + Callable[[StorageDriverStoreContext, Payload], str | None] | None + ) = None + """Controls which driver stores a given payload. A callable of the form + ``(StorageDriverStoreContext, Payload) -> str | None`` that returns the name of + the driver to use, or ``None`` to leave the payload stored inline. + + When ``None``, the first driver in :attr:`drivers` is used for all store + operations. + """ + + payload_size_threshold: int | None = 256 * 1024 + """Minimum payload size in bytes before external storage is considered. + Defaults to 256 KiB. Set to ``None`` to consider every payload for + external storage regardless of size. + """ + + payload_codec: PayloadCodec | None = None + """Optional codec applied to payloads before they are handed to a + :class:`StorageDriver` for storage, and after they are retrieved. When ``None``, + payloads are stored as-is by the driver. + """ + + _driver_map: dict[str, StorageDriver] = dataclasses.field( + init=False, repr=False, compare=False + ) + """Name-keyed index of :attr:`drivers`, built at construction time.""" + + _context: SerializationContext | None = dataclasses.field( + init=False, default=None, repr=False, compare=False + ) + + _claim_converter: ClassVar[JSONPlainPayloadConverter] = JSONPlainPayloadConverter( + encoding="json/external-storage-reference" + ) + + def __post_init__(self) -> None: + """Validate drivers and build the internal name-keyed driver map. + + Raises :exc:`ValueError` if no drivers are provided or if any two drivers + share the same name. + """ + if not self.drivers: + raise ValueError( + "ExternalStorage.drivers must contain at least one driver." + ) + driver_map: dict[str, StorageDriver] = {} + for driver in self.drivers: + name = driver.name() + if name in driver_map: + raise ValueError( + f"ExternalStorage.drivers contains multiple drivers with name '{name}'. " + "Each driver must have a unique name." + ) + driver_map[name] = driver + object.__setattr__(self, "_driver_map", driver_map) + + def with_context(self, context: SerializationContext) -> Self: + """Return a copy of these options with the serialization context applied.""" + payload_codec = self.payload_codec + if isinstance(payload_codec, WithSerializationContext): + payload_codec = payload_codec.with_context(context) + result = dataclasses.replace(self, payload_codec=payload_codec) + object.__setattr__(result, "_context", context) + return result + + def _select_driver( + self, context: StorageDriverStoreContext, payload: Payload + ) -> StorageDriver | None: + """Returns the driver to use for this payload, or None to pass through.""" + if ( + self.payload_size_threshold is not None + and payload.ByteSize() < self.payload_size_threshold + ): + return None + selector = self.driver_selector + if selector is None: + return self.drivers[0] if self.drivers else None + driver_name = selector(context, payload) + if driver_name is None: + return None + driver = self._driver_map.get(driver_name) + if driver is None: + raise ValueError(f"No driver found with name '{driver_name}'") + return driver + + def _get_driver_by_name(self, name: str) -> StorageDriver: + """Looks up a driver by name, raising :class:`ValueError` if not found.""" + driver = self._driver_map.get(name) + if driver is None: + raise ValueError(f"No driver found with name '{name}'") + return driver + + async def _store_payload(self, payload: Payload) -> Payload: + context = StorageDriverStoreContext(serialization_context=self._context) + + driver = self._select_driver(context, payload) + if driver is None: + return payload + + encoded_payload = payload + if self.payload_codec: + encoded_payload = (await self.payload_codec.encode([payload]))[0] + + claims = await driver.store(context, [encoded_payload]) + + self._validate_claim_length(claims, expected=1, driver=driver) + + reference = _StorageReference( + driver_name=driver.name(), + driver_claim=claims[0], + ) + reference_payload = self._claim_converter.to_payload(reference) + if reference_payload is None: + raise ValueError( + f"Failed to serialize storage reference for driver '{driver.name()}'" + ) + reference_payload.external_payloads.add().size_bytes = ( + encoded_payload.ByteSize() + ) + return reference_payload + + async def _store_payloads(self, payloads: Payloads): + stored_payloads = await self._store_payload_sequence(payloads.payloads) + for i, payload in enumerate(stored_payloads): + payloads.payloads[i].CopyFrom(payload) + + async def _store_payload_sequence( + self, + payloads: Sequence[Payload], + ) -> list[Payload]: + if len(payloads) == 1: + return [await self._store_payload(payloads[0])] + + results = list(payloads) + context = StorageDriverStoreContext(serialization_context=self._context) + + to_store: list[tuple[int, Payload, StorageDriver]] = [] + for index, payload in enumerate(payloads): + driver = self._select_driver(context, payload) + if driver is None: + continue + to_store.append((index, payload, driver)) + + if not to_store: + return results + + payloads_to_encode = [payload for _, payload, _ in to_store] + encoded_payloads = payloads_to_encode + if self.payload_codec: + encoded_payloads = await self.payload_codec.encode(payloads_to_encode) + + driver_groups: dict[StorageDriver, list[tuple[int, Payload]]] = {} + for i, (orig_index, _, driver) in enumerate(to_store): + driver_groups.setdefault(driver, []).append( + (orig_index, encoded_payloads[i]) + ) + + driver_group_list = list(driver_groups.items()) + + all_claims = await _gather_cancel_on_error( + [ + driver.store(context, [p for _, p in indexed_payloads]) + for driver, indexed_payloads in driver_group_list + ] + ) + + for (driver, indexed_payloads), claims in zip(driver_group_list, all_claims): + indices = [idx for idx, _ in indexed_payloads] + sizes = [p.ByteSize() for _, p in indexed_payloads] + + self._validate_claim_length(claims, expected=len(indices), driver=driver) + + for i, claim in enumerate(claims): + reference = _StorageReference( + driver_name=driver.name(), + driver_claim=claim, + ) + reference_payload = self._claim_converter.to_payload(reference) + if reference_payload is None: + raise ValueError( + f"Failed to serialize storage reference for driver '{driver.name()}'" + ) + reference_payload.external_payloads.add().size_bytes = sizes[i] + results[indices[i]] = reference_payload + + return results + + async def _retrieve_payload(self, payload: Payload) -> Payload: + if len(payload.external_payloads) == 0: + return payload + + reference = self._claim_converter.from_payload(payload, _StorageReference) + if not isinstance(reference, _StorageReference): + return payload + + driver = self._get_driver_by_name(reference.driver_name) + context = StorageDriverRetrieveContext() + + stored_payloads = await driver.retrieve(context, [reference.driver_claim]) + + self._validate_payload_length(stored_payloads, expected=1, driver=driver) + + if self.payload_codec: + stored_payloads = await self.payload_codec.decode(stored_payloads) + + return stored_payloads[0] + + async def _retrieve_payloads(self, payloads: Payloads): + stored_payloads = await self._retrieve_payload_sequence(payloads.payloads) + for i, payload in enumerate(stored_payloads): + payloads.payloads[i].CopyFrom(payload) + + async def _retrieve_payload_sequence( + self, + payloads: Sequence[Payload], + ) -> list[Payload]: + results = list(payloads) + + if len(payloads) == 1: + return [await self._retrieve_payload(payloads[0])] + + driver_claims: dict[StorageDriver, list[tuple[int, StorageDriverClaim]]] = {} + for index, payload in enumerate(payloads): + if len(payload.external_payloads) == 0: + continue + + reference = self._claim_converter.from_payload(payload, _StorageReference) + if not isinstance(reference, _StorageReference): + continue + + driver = self._get_driver_by_name(reference.driver_name) + driver_claims.setdefault(driver, []).append((index, reference.driver_claim)) + + if not driver_claims: + return results + + context = StorageDriverRetrieveContext() + stored_by_index: dict[int, Payload] = {} + + driver_claim_list = list(driver_claims.items()) + + all_stored = await _gather_cancel_on_error( + [ + driver.retrieve(context, [claim for _, claim in indexed_claims]) + for driver, indexed_claims in driver_claim_list + ] + ) + + for (driver, indexed_claims), stored_payloads in zip( + driver_claim_list, all_stored + ): + indices = [idx for idx, _ in indexed_claims] + + self._validate_payload_length( + stored_payloads, + expected=len(indexed_claims), + driver=driver, + ) + + for idx, stored_payload in zip(indices, stored_payloads): + stored_by_index[idx] = stored_payload + + retrieve_indices = sorted(stored_by_index.keys()) + stored_list = [stored_by_index[idx] for idx in retrieve_indices] + + decoded_payloads = stored_list + if self.payload_codec: + decoded_payloads = await self.payload_codec.decode(stored_list) + + for i, retrieved_payload in enumerate(decoded_payloads): + results[retrieve_indices[i]] = retrieved_payload + + return results + + def _validate_claim_length( + self, claims: Sequence[StorageDriverClaim], expected: int, driver: StorageDriver + ) -> None: + if len(claims) != expected: + raise ValueError( + f"Driver '{driver.name()}' returned {len(claims)} claims, expected {expected}", + ) + + def _validate_payload_length( + self, payloads: Sequence[Payload], expected: int, driver: StorageDriver + ) -> None: + if len(payloads) != expected: + raise ValueError( + f"Driver '{driver.name()}' returned {len(payloads)} payloads, expected {expected}", + ) diff --git a/temporalio/worker/_activity.py b/temporalio/worker/_activity.py index 4e6e06282..a895f54d2 100644 --- a/temporalio/worker/_activity.py +++ b/temporalio/worker/_activity.py @@ -625,7 +625,7 @@ async def _execute_activity( else None, ) - if self._encode_headers and data_converter._decode_payload_has_effect: + if self._encode_headers: for payload in start.header_fields.values(): payload.CopyFrom(await data_converter._decode_payload(payload)) diff --git a/temporalio/worker/_workflow.py b/temporalio/worker/_workflow.py index 18f5599ba..6fff2e8b5 100644 --- a/temporalio/worker/_workflow.py +++ b/temporalio/worker/_workflow.py @@ -340,21 +340,44 @@ async def _handle_activation( "Failed handling activation on workflow with run ID %s", act.run_id ) - completion.failed.failure.SetInParent() - try: - data_converter.failure_converter.to_failure( - err, - data_converter.payload_converter, - completion.failed.failure, - ) - except Exception as inner_err: - logger.exception( - "Failed converting activation exception on workflow with run ID %s", - act.run_id, - ) - completion.failed.failure.message = ( - f"Failed converting activation exception: {inner_err}" - ) + if ( + isinstance(err, temporalio.exceptions.ApplicationError) + and err.non_retryable + ): + # Fail the workflow execution terminally rather than failing the task + command = completion.successful.commands.add() + failure = command.fail_workflow_execution.failure + failure.SetInParent() + try: + data_converter.failure_converter.to_failure( + err, + data_converter.payload_converter, + failure, + ) + except Exception as inner_err: + logger.exception( + "Failed converting activation exception on workflow with run ID %s", + act.run_id, + ) + failure.message = ( + f"Failed converting activation exception: {inner_err}" + ) + else: + completion.failed.failure.SetInParent() + try: + data_converter.failure_converter.to_failure( + err, + data_converter.payload_converter, + completion.failed.failure, + ) + except Exception as inner_err: + logger.exception( + "Failed converting activation exception on workflow with run ID %s", + act.run_id, + ) + completion.failed.failure.message = ( + f"Failed converting activation exception: {inner_err}" + ) completion.run_id = act.run_id diff --git a/temporalio/worker/_workflow_instance.py b/temporalio/worker/_workflow_instance.py index 125f2a373..c2739417f 100644 --- a/temporalio/worker/_workflow_instance.py +++ b/temporalio/worker/_workflow_instance.py @@ -1784,9 +1784,9 @@ def workflow_set_current_details(self, details: str): self._current_details = details def workflow_is_failure_exception(self, err: BaseException) -> bool: - # An exception is a failure instead of a task fail if it's already a - # failure error or if it is a timeout error or if it is an instance of - # any of the failure types in the worker or workflow-level setting + # An exception causes the workflow to fail (rather than the task) if it + # is already a failure error, a timeout error, or an instance of any of the + # failure exception types configured at the worker or workflow level. wf_failure_exception_types = self._defn.failure_exception_types if self._dynamic_failure_exception_types is not None: wf_failure_exception_types = self._dynamic_failure_exception_types diff --git a/tests/test_extstore.py b/tests/test_extstore.py new file mode 100644 index 000000000..38fa05018 --- /dev/null +++ b/tests/test_extstore.py @@ -0,0 +1,723 @@ +"""Tests for external storage functionality.""" + +import asyncio +from collections.abc import Sequence + +import pytest +from typing_extensions import Self + +from temporalio.api.common.v1 import Payload +from temporalio.converter import ( + ActivitySerializationContext, + DataConverter, + ExternalStorage, + JSONPlainPayloadConverter, + PayloadCodec, + SerializationContext, + StorageDriver, + StorageDriverClaim, + StorageDriverRetrieveContext, + StorageDriverStoreContext, + WithSerializationContext, + WorkflowSerializationContext, +) +from temporalio.converter._extstore import _StorageReference +from temporalio.exceptions import ApplicationError + + +class InMemoryTestDriver(StorageDriver): + """In-memory storage driver for testing.""" + + def __init__( + self, + driver_name: str = "test-driver", + ): + self._driver_name = driver_name + self._storage: dict[str, bytes] = {} + self._store_calls = 0 + self._retrieve_calls = 0 + + def name(self) -> str: + return self._driver_name + + async def store( + self, + context: StorageDriverStoreContext, + payloads: Sequence[Payload], + ) -> list[StorageDriverClaim]: + self._store_calls += 1 + start_index = len(self._storage) + + entries = [ + (f"payload-{start_index + i}", payload.SerializeToString()) + for i, payload in enumerate(payloads) + ] + self._storage.update(entries) + + return [StorageDriverClaim(data={"key": key}) for key, _ in entries] + + async def retrieve( + self, + context: StorageDriverRetrieveContext, + claims: Sequence[StorageDriverClaim], + ) -> list[Payload]: + self._retrieve_calls += 1 + + def parse_claim( + claim: StorageDriverClaim, + ) -> Payload: + key = claim.data["key"] + if key not in self._storage: + raise ApplicationError( + f"Payload not found for key '{key}'", non_retryable=True + ) + payload = Payload() + payload.ParseFromString(self._storage[key]) + return payload + + return [parse_claim(claim) for claim in claims] + + +class WorkflowIdFeatureFlagDriverSelector(WithSerializationContext): + """Example selector that conditionally stores based on workflow ID feature flag. + + This example shows how a callable can implement WithSerializationContext if it + needs to precompute data from the serialization context instead of doing it on + every payload selection call. + + The feature flag in this example is a simple check on the workflow ID length, but in + a real implementation this could be a call to a feature flag service or a lookup in a + configuration store. + """ + + def __init__(self, driver: StorageDriver, enabled: bool = False): + self._driver = driver + self._enabled = enabled + + def __call__( + self, _context: StorageDriverStoreContext, _payload: Payload + ) -> StorageDriver | None: + return self._driver if self._enabled else None + + def with_context(self, context: SerializationContext) -> Self: + workflow_id = None + if isinstance(context, ActivitySerializationContext) and context.workflow_id: + workflow_id = context.workflow_id + if isinstance(context, WorkflowSerializationContext) and context.workflow_id: + workflow_id = context.workflow_id + + # Create new instance with updated enabled flag and propagate context to inner driver + driver = self._driver + if isinstance(driver, WithSerializationContext): + driver = driver.with_context(context) + + return type(self)( + driver, WorkflowIdFeatureFlagDriverSelector.feature_flag_is_on(workflow_id) + ) + + @staticmethod + def feature_flag_is_on(workflow_id: str | None) -> bool: + """Mock implementation of a feature flag based on a workflow ID.""" + return workflow_id is not None and len(workflow_id) % 2 == 0 + + +class TestDataConverterExternalStorage: + """Tests for DataConverter with external storage.""" + + async def test_extstore_encode_decode(self): + """Test that large payloads are stored externally.""" + driver = InMemoryTestDriver() + + # Configure with 100-byte threshold + converter = DataConverter( + external_storage=ExternalStorage( + drivers=[driver], + payload_size_threshold=100, + ) + ) + + # Small value should not be externalized + small_value = "small" + encoded_small = await converter.encode([small_value]) + assert len(encoded_small) == 1 + assert not encoded_small[0].external_payloads # Not externalized + assert driver._store_calls == 0 + + # Large value should be externalized + large_value = "x" * 200 + encoded_large = await converter.encode([large_value]) + assert len(encoded_large) == 1 + assert len(encoded_large[0].external_payloads) > 0 # Externalized + assert driver._store_calls == 1 + + # Decode large value + decoded = await converter.decode(encoded_large, [str]) + assert len(decoded) == 1 + assert decoded[0] == large_value + assert driver._retrieve_calls == 1 + + async def test_extstore_reference_structure(self): + """Test that external storage creates proper reference structure.""" + converter = DataConverter( + external_storage=ExternalStorage( + drivers=[InMemoryTestDriver("test-driver")], + payload_size_threshold=50, + ) + ) + + # Create large payload + large_value = "x" * 100 + encoded = await converter.encode([large_value]) + + # Verify reference structure + reference_payload = encoded[0] + assert len(reference_payload.external_payloads) > 0 + + # The payload should contain a serialized _ExternalStorageReference + # Deserialize it to verify structure using the same encoding + claim_converter = JSONPlainPayloadConverter( + encoding="json/external-storage-reference" + ) + reference = claim_converter.from_payload(reference_payload, _StorageReference) + + assert isinstance(reference, _StorageReference) + assert "test-driver" == reference.driver_name + assert isinstance(reference.driver_claim, StorageDriverClaim) + assert "key" in reference.driver_claim.data + + async def test_extstore_composite_conditional(self): + """Test using multiple drivers based on size.""" + hot_driver = InMemoryTestDriver("hot-storage") + cold_driver = InMemoryTestDriver("cold-storage") + + options = ExternalStorage( + drivers=[hot_driver, cold_driver], + driver_selector=lambda context, payload: "hot-storage" + if payload.ByteSize() < 500 + else "cold-storage", + payload_size_threshold=100, + ) + converter = DataConverter(external_storage=options) + + # Small payload (not externalized) + small = "x" * 50 + encoded_small = await converter.encode([small]) + assert not encoded_small[0].external_payloads + assert hot_driver._store_calls == 0 + assert cold_driver._store_calls == 0 + + # Medium payload (hot storage) + medium = "x" * 200 + encoded_medium = await converter.encode([medium]) + assert len(encoded_medium[0].external_payloads) > 0 + assert hot_driver._store_calls == 1 + assert cold_driver._store_calls == 0 + + # Large payload (cold storage) + large = "x" * 2000 + encoded_large = await converter.encode([large]) + assert len(encoded_large[0].external_payloads) > 0 + assert hot_driver._store_calls == 1 # Unchanged + assert cold_driver._store_calls == 1 + + # Verify retrieval from correct drivers + decoded_medium = await converter.decode(encoded_medium, [str]) + assert decoded_medium[0] == medium + assert hot_driver._retrieve_calls == 1 + + decoded_large = await converter.decode(encoded_large, [str]) + assert decoded_large[0] == large + assert cold_driver._retrieve_calls == 1 + + +class NotFoundDriver(StorageDriver): + """Driver that stores normally but raises non-retryable ApplicationError on retrieve.""" + + def __init__(self, driver_name: str = "not-found-driver"): + self._driver_name = driver_name + self._storage: dict[str, bytes] = {} + + def name(self) -> str: + return self._driver_name + + async def store( + self, + context: StorageDriverStoreContext, + payloads: Sequence[Payload], + ) -> list[StorageDriverClaim]: + entries = [ + (f"payload-{i}", payload.SerializeToString()) + for i, payload in enumerate(payloads) + ] + self._storage.update(entries) + return [StorageDriverClaim(data={"key": key}) for key, _ in entries] + + async def retrieve( + self, + context: StorageDriverRetrieveContext, + claims: Sequence[StorageDriverClaim], + ) -> list[Payload]: + assert len(claims) > 0, "NotFoundDriver expected claims to be provided" + raise ApplicationError("Payload not found.", non_retryable=True) + + +class TestDriverError: + """Tests for ValueError raised when a driver violates its contract.""" + + async def test_encode_wrong_claim_count_raises_runtime_error(self): + """store() returning fewer claims than payloads must raise ValueError.""" + + class _NoClaimsDriver(InMemoryTestDriver): + async def store( + self, context: StorageDriverStoreContext, payloads: Sequence[Payload] + ) -> list[StorageDriverClaim]: + return [] + + driver = _NoClaimsDriver() + converter = DataConverter( + external_storage=ExternalStorage( + drivers=[driver], + payload_size_threshold=10, + ) + ) + with pytest.raises( + ValueError, + match=f"Driver '{driver.name()}' returned 0 claims, expected 1", + ): + await converter.encode(["x" * 200]) + + async def test_decode_wrong_payload_count_raises_runtime_error(self): + """retrieve() returning fewer payloads than claims must raise ValueError.""" + good_converter = DataConverter( + external_storage=ExternalStorage( + drivers=[InMemoryTestDriver()], + payload_size_threshold=10, + ) + ) + encoded = await good_converter.encode(["x" * 200]) + + class _NoPayloadsDriver(InMemoryTestDriver): + async def retrieve( + self, + context: StorageDriverRetrieveContext, + claims: Sequence[StorageDriverClaim], + ) -> list[Payload]: + return [] + + driver = _NoPayloadsDriver() + bad_converter = DataConverter( + external_storage=ExternalStorage( + drivers=[driver], + payload_size_threshold=10, + ) + ) + with pytest.raises( + ValueError, + match=f"Driver '{driver.name()}' returned 0 payloads, expected 1", + ): + await bad_converter.decode(encoded, [str]) + + async def test_store_cancels_in_flight_driver_on_error(self): + """When one driver raises during concurrent store, other in-flight drivers are cancelled.""" + store_cancelled = asyncio.Event() + + class _SleepingStoreDriver(InMemoryTestDriver): + def __init__(self): + super().__init__("sleeping") + + async def store( + self, + context: StorageDriverStoreContext, + payloads: Sequence[Payload], + ) -> list[StorageDriverClaim]: + try: + await asyncio.sleep(float("inf")) + except asyncio.CancelledError: + store_cancelled.set() + raise + return [] # unreachable + + class _FailingStoreDriver(InMemoryTestDriver): + def __init__(self): + super().__init__("failing") + + async def store( + self, + context: StorageDriverStoreContext, + payloads: Sequence[Payload], + ) -> list[StorageDriverClaim]: + raise RuntimeError("driver store failure") + + drivers = [_SleepingStoreDriver(), _FailingStoreDriver()] + names = iter(["sleeping", "failing"]) + converter = DataConverter( + external_storage=ExternalStorage( + drivers=drivers, + driver_selector=lambda ctx, p: next(names), + payload_size_threshold=None, + ) + ) + + with pytest.raises(RuntimeError, match="driver store failure"): + await converter.encode(["payload_a", "payload_b"]) + + assert store_cancelled.is_set() + + async def test_retrieve_cancels_in_flight_driver_on_error(self): + """When one driver raises during concurrent retrieve, other in-flight drivers are cancelled.""" + retrieve_cancelled = asyncio.Event() + + class _SleepingRetrieveDriver(InMemoryTestDriver): + def __init__(self): + super().__init__("sleeping") + + async def retrieve( + self, + context: StorageDriverRetrieveContext, + claims: Sequence[StorageDriverClaim], + ) -> list[Payload]: + try: + await asyncio.sleep(float("inf")) + except asyncio.CancelledError: + retrieve_cancelled.set() + raise + return [] # unreachable + + class _FailingRetrieveDriver(InMemoryTestDriver): + def __init__(self): + super().__init__("failing") + + async def retrieve( + self, + context: StorageDriverRetrieveContext, + claims: Sequence[StorageDriverClaim], + ) -> list[Payload]: + raise RuntimeError("driver retrieve failure") + + drivers: list[StorageDriver] = [ + _SleepingRetrieveDriver(), + _FailingRetrieveDriver(), + ] + names = iter(["sleeping", "failing"]) + converter = DataConverter( + external_storage=ExternalStorage( + drivers=drivers, + driver_selector=lambda ctx, p: next(names), + payload_size_threshold=None, + ) + ) + encoded = await converter.encode(["payload_a", "payload_b"]) + + with pytest.raises(RuntimeError, match="driver retrieve failure"): + await converter.decode(encoded, [str, str]) + + assert retrieve_cancelled.is_set() + + +class RecordingPayloadCodec(PayloadCodec): + """Codec that wraps each payload under a recognisable ``encoding`` label. + + Encode sets ``metadata["encoding"]`` to ``encoding_label`` and stores the + serialised inner payload as ``data``. Decode reverses that. The call + counters let tests assert exactly how many payloads each codec processed. + """ + + def __init__(self, encoding_label: str) -> None: + self._encoding_label = encoding_label.encode() + self.encoded_count = 0 + self.decoded_count = 0 + + async def encode(self, payloads: Sequence[Payload]) -> list[Payload]: + self.encoded_count += len(payloads) + results = [] + for p in payloads: + wrapped = Payload() + wrapped.metadata["encoding"] = self._encoding_label + wrapped.data = p.SerializeToString() + results.append(wrapped) + return results + + async def decode(self, payloads: Sequence[Payload]) -> list[Payload]: + self.decoded_count += len(payloads) + results = [] + for p in payloads: + inner = Payload() + inner.ParseFromString(p.data) + results.append(inner) + return results + + +class TestPayloadCodecWithExternalStorage: + """Tests for interaction between DataConverter.payload_codec and external storage.""" + + async def test_dc_payload_codec_encodes_reference_payload(self): + """DataConverter.payload_codec encodes the reference payload in workflow + history but does NOT encode the bytes handed to the driver for storage.""" + driver = InMemoryTestDriver() + dc_codec = RecordingPayloadCodec("binary/dc-encoded") + + converter = DataConverter( + payload_codec=dc_codec, + external_storage=ExternalStorage( + drivers=[driver], + payload_size_threshold=50, + ), + ) + + large_value = "x" * 200 + encoded = await converter.encode([large_value]) + assert len(encoded) == 1 + assert driver._store_calls == 1 + + # The reference payload written to history must carry the dc_codec label. + assert dc_codec.encoded_count == 1 + assert encoded[0].metadata.get("encoding") == b"binary/dc-encoded" + + # The bytes given to the driver must NOT carry the dc_codec label. + stored_payload = Payload() + stored_payload.ParseFromString(next(iter(driver._storage.values()))) + assert stored_payload.metadata.get("encoding") != b"binary/dc-encoded" + assert stored_payload.metadata.get("encoding") == b"json/plain" + + # Round-trip must recover the original value. + decoded = await converter.decode(encoded, [str]) + assert decoded[0] == large_value + assert dc_codec.decoded_count == 1 + assert driver._retrieve_calls == 1 + + async def test_external_converter_without_codec_does_not_encode_stored_bytes(self): + """When DataConverter.payload_codec is set but ExternalStorage.payload_codec + is None, stored bytes are NOT encoded – even though + DataConverter.payload_codec is active for the reference payload in history.""" + driver = InMemoryTestDriver() + dc_codec = RecordingPayloadCodec("binary/dc-encoded") + + converter = DataConverter( + payload_codec=dc_codec, + external_storage=ExternalStorage( + drivers=[driver], + payload_size_threshold=50, + ), + ) + + large_value = "x" * 200 + encoded = await converter.encode([large_value]) + assert len(encoded) == 1 + assert driver._store_calls == 1 + + # Reference payload in history is still encoded by DataConverter.payload_codec. + assert dc_codec.encoded_count == 1 + assert encoded[0].metadata.get("encoding") == b"binary/dc-encoded" + + # Stored bytes are NOT encoded. + stored_payload = Payload() + stored_payload.ParseFromString(next(iter(driver._storage.values()))) + assert stored_payload.metadata.get("encoding") != b"binary/dc-encoded" + assert stored_payload.metadata.get("encoding") == b"json/plain" + + # Round-trip. + decoded = await converter.decode(encoded, [str]) + assert decoded[0] == large_value + assert dc_codec.decoded_count == 1 + assert driver._retrieve_calls == 1 + + async def test_external_converter_codec_independent_from_dc_codec(self): + """When both DataConverter.payload_codec and ExternalStorage.payload_codec + are set, the reference payload in history uses DataConverter.payload_codec + and the bytes stored by the driver use ExternalStorage.payload_codec – + independently.""" + driver = InMemoryTestDriver() + dc_codec = RecordingPayloadCodec("binary/dc-encoded") + ext_codec = RecordingPayloadCodec("binary/ext-encoded") + + converter = DataConverter( + payload_codec=dc_codec, + external_storage=ExternalStorage( + drivers=[driver], + payload_size_threshold=50, + payload_codec=ext_codec, + ), + ) + + large_value = "x" * 200 + encoded = await converter.encode([large_value]) + assert len(encoded) == 1 + assert driver._store_calls == 1 + + # Each codec was applied exactly once during encode. + assert dc_codec.encoded_count == 1 + assert ext_codec.encoded_count == 1 + + # Reference payload carries dc_codec's label. + assert encoded[0].metadata.get("encoding") == b"binary/dc-encoded" + + # Stored bytes carry ext_codec's label – different from the reference. + stored_payload = Payload() + stored_payload.ParseFromString(next(iter(driver._storage.values()))) + assert stored_payload.metadata.get("encoding") == b"binary/ext-encoded" + assert stored_payload.metadata.get("encoding") != encoded[0].metadata.get( + "encoding" + ) + + # Round-trip must recover the original value using both codecs. + decoded = await converter.decode(encoded, [str]) + assert decoded[0] == large_value + assert dc_codec.decoded_count == 1 + assert ext_codec.decoded_count == 1 + assert driver._retrieve_calls == 1 + + +class TestMultiDriver: + """Tests for ExternalStorage with multiple drivers.""" + + async def test_no_selector_uses_first_driver_for_store(self): + """Without a driver_selector the first driver in the list handles all + store operations. Additional drivers are never called for store.""" + first = InMemoryTestDriver("driver-first") + second = InMemoryTestDriver("driver-second") + + converter = DataConverter( + external_storage=ExternalStorage( + drivers=[first, second], + payload_size_threshold=50, + ) + ) + + large = "x" * 200 + encoded = await converter.encode([large]) + + assert first._store_calls == 1 + assert second._store_calls == 0 + + # The reference in history names the first driver. + ref = JSONPlainPayloadConverter( + encoding="json/external-storage-reference" + ).from_payload(encoded[0], _StorageReference) + assert ref.driver_name == "driver-first" + + # Retrieval also goes to the first driver. + decoded = await converter.decode(encoded, [str]) + assert decoded[0] == large + assert first._retrieve_calls == 1 + assert second._retrieve_calls == 0 + + async def test_no_selector_second_driver_is_retrieve_only(self): + """A driver that is second in the list acts as a retrieve-only driver. + References are resolved by name, not by position, so a payload stored + by driver-b is retrieved correctly even when driver-a is listed first.""" + driver_a = InMemoryTestDriver("driver-a") + driver_b = InMemoryTestDriver("driver-b") + + # Store with driver-b as the sole driver. + store_converter = DataConverter( + external_storage=ExternalStorage( + drivers=[driver_b], + payload_size_threshold=50, + ) + ) + large = "y" * 200 + encoded = await store_converter.encode([large]) + + # Retrieve with driver-a listed first, driver-b second. + # The "driver-b" name in the reference must route to driver-b. + retrieve_converter = DataConverter( + external_storage=ExternalStorage( + drivers=[driver_a, driver_b], + payload_size_threshold=50, + ) + ) + decoded = await retrieve_converter.decode(encoded, [str]) + assert decoded[0] == large + assert driver_a._retrieve_calls == 0 # never consulted + assert driver_b._retrieve_calls == 1 + + async def test_selector_routes_payloads_to_different_drivers_in_single_batch(self): + """When a selector routes different payloads to different drivers, a + single encode([v1, v2, ...]) call batches payloads per driver so each + driver receives exactly one store() call regardless of how many + payloads are routed to it.""" + driver_a = InMemoryTestDriver("driver-a") + driver_b = InMemoryTestDriver("driver-b") + + # Route payloads that serialise to < 500 bytes to driver_a, larger ones + # to driver_b. + def selector(_ctx: object, payload: Payload) -> str: + return "driver-a" if payload.ByteSize() < 500 else "driver-b" + + converter = DataConverter( + external_storage=ExternalStorage( + drivers=[driver_a, driver_b], + driver_selector=selector, + payload_size_threshold=50, + ) + ) + + small_ext = "a" * 100 # above threshold, serialises well below 500 B + large_ext = "b" * 1000 # serialises above 500 B + + # Encode both values in a single call — they should be batched per driver. + encoded = await converter.encode([small_ext, large_ext]) + assert driver_a._store_calls == 1 # one batched call, not two individual ones + assert driver_b._store_calls == 1 + + # Full round-trip. + decoded = await converter.decode(encoded, [str, str]) + assert decoded == [small_ext, large_ext] + assert driver_a._retrieve_calls == 1 + assert driver_b._retrieve_calls == 1 + + async def test_selector_returning_none_keeps_payload_inline(self): + """A selector that returns None for a payload leaves it stored inline + in workflow history rather than offloading it to any driver, even when + the payload exceeds the size threshold.""" + driver = InMemoryTestDriver("driver-a") + + converter = DataConverter( + external_storage=ExternalStorage( + drivers=[driver], + driver_selector=lambda _ctx, _payload: None, + payload_size_threshold=50, + ) + ) + + large = "x" * 200 + encoded = await converter.encode([large]) + + assert driver._store_calls == 0 + assert len(encoded[0].external_payloads) == 0 # payload is inline + + decoded = await converter.decode(encoded, [str]) + assert decoded[0] == large + assert driver._retrieve_calls == 0 + + async def test_selector_returns_unregistered_driver_raises(self): + """A selector that returns a Driver whose name is not present in + ExternalStorage.drivers raises ValueError during encode.""" + registered = InMemoryTestDriver("registered") + + converter = DataConverter( + external_storage=ExternalStorage( + drivers=[registered], + driver_selector=lambda _ctx, _payload: "not-in-list", + payload_size_threshold=50, + ) + ) + + with pytest.raises(ValueError): + await converter.encode(["x" * 200]) + + def test_duplicate_driver_names_raises(self): + """Registering two drivers with identical names raises ValueError immediately + when constructing ExternalStorage.""" + first = InMemoryTestDriver("dup-name") + duplicate = InMemoryTestDriver("dup-name") + + with pytest.raises(ValueError, match="dup-name"): + ExternalStorage( + drivers=[first, duplicate], + payload_size_threshold=50, + ) + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/tests/worker/test_extstore.py b/tests/worker/test_extstore.py new file mode 100644 index 000000000..f4e038ed1 --- /dev/null +++ b/tests/worker/test_extstore.py @@ -0,0 +1,428 @@ +import dataclasses +import uuid +from collections.abc import Sequence +from dataclasses import dataclass +from datetime import timedelta + +import pytest + +import temporalio +import temporalio.converter +from temporalio import activity, workflow +from temporalio.api.common.v1 import Payload +from temporalio.client import Client, WorkflowFailureError, WorkflowHandle +from temporalio.common import RetryPolicy +from temporalio.converter import ( + ExternalStorage, + StorageDriverClaim, + StorageDriverRetrieveContext, + StorageDriverStoreContext, + StorageWarning, +) +from temporalio.exceptions import ActivityError, ApplicationError +from temporalio.testing._workflow import WorkflowEnvironment +from temporalio.worker import Replayer +from tests.helpers import assert_task_fail_eventually, new_worker +from tests.test_extstore import InMemoryTestDriver + + +@dataclass(frozen=True) +class ExtStoreActivityInput: + input_data: str + output_size: int + pass + + +@activity.defn +async def ext_store_activity( + input: ExtStoreActivityInput, +) -> str: + return "ao" * int(input.output_size / 2) + + +@dataclass(frozen=True) +class ExtStoreWorkflowInput: + input_data: str + activity_input_size: int + activity_output_size: int + output_size: int + max_activity_attempts: int | None = None + + +@workflow.defn +class ExtStoreWorkflow: + @workflow.run + async def run(self, input: ExtStoreWorkflowInput) -> str: + retry_policy = ( + RetryPolicy(maximum_attempts=input.max_activity_attempts) + if input.max_activity_attempts is not None + else None + ) + await workflow.execute_activity( + ext_store_activity, + ExtStoreActivityInput( + input_data="ai" * int(input.activity_input_size / 2), + output_size=input.activity_output_size, + ), + schedule_to_close_timeout=timedelta(seconds=3), + retry_policy=retry_policy, + ) + return "wo" * int(input.output_size / 2) + + +class BadTestDriver(InMemoryTestDriver): + def __init__( + self, + driver_name: str = "bad-driver", + no_store: bool = False, + no_retrieve: bool = False, + raise_payload_not_found: bool = False, + ): + super().__init__(driver_name) + self._no_store = no_store + self._no_retrieve = no_retrieve + self._raise_payload_not_found = raise_payload_not_found + + async def store( + self, + context: StorageDriverStoreContext, + payloads: Sequence[Payload], + ) -> list[StorageDriverClaim]: + if self._no_store: + return [] + return await super().store(context, payloads) + + async def retrieve( + self, + context: StorageDriverRetrieveContext, + claims: Sequence[StorageDriverClaim], + ) -> list[Payload]: + if self._no_retrieve: + return [] + if self._raise_payload_not_found: + raise ApplicationError( + "Payload not found", + type="PayloadNotFoundError", + non_retryable=True, + ) + return await super().retrieve(context, claims) + + +async def test_extstore_activity_input_no_retrieve( + env: WorkflowEnvironment, +): + """Using a small threshold, validate that activity result size over + the threshold causes driver to be invoked.""" + driver = BadTestDriver(no_retrieve=True) + + client = await Client.connect( + env.client.service_client.config.target_host, + namespace=env.client.namespace, + data_converter=dataclasses.replace( + temporalio.converter.default(), + external_storage=ExternalStorage( + drivers=[driver], + payload_size_threshold=1024, + ), + ), + ) + + async with new_worker( + client, ExtStoreWorkflow, activities=[ext_store_activity] + ) as worker: + handle = await client.start_workflow( + ExtStoreWorkflow.run, + ExtStoreWorkflowInput( + input_data="workflow input", + activity_input_size=1000, + activity_output_size=10, + output_size=10, + max_activity_attempts=1, + ), + id=f"workflow-{uuid.uuid4()}", + task_queue=worker.task_queue, + ) + + with pytest.raises(WorkflowFailureError) as err: + await handle.result() + + assert isinstance(err.value.cause, ActivityError) + + +async def test_extstore_activity_result_no_store( + env: WorkflowEnvironment, +): + """Using a small threshold, validate that activity result size over + the threshold causes driver to be invoked.""" + driver = BadTestDriver(no_store=True) + + client = await Client.connect( + env.client.service_client.config.target_host, + namespace=env.client.namespace, + data_converter=dataclasses.replace( + temporalio.converter.default(), + external_storage=ExternalStorage( + drivers=[driver], + payload_size_threshold=1024, + ), + ), + ) + + async with new_worker( + client, ExtStoreWorkflow, activities=[ext_store_activity] + ) as worker: + handle = await client.start_workflow( + ExtStoreWorkflow.run, + ExtStoreWorkflowInput( + input_data="workflow input", + activity_input_size=10, + activity_output_size=1000, + output_size=10, + max_activity_attempts=1, + ), + id=f"workflow-{uuid.uuid4()}", + task_queue=worker.task_queue, + ) + + with pytest.raises(WorkflowFailureError) as err: + await handle.result() + + assert isinstance(err.value.cause, ActivityError) + + +async def test_extstore_worker_missing_driver( + env: WorkflowEnvironment, +): + """Validate that when a worker is provided a workflow history with + external storage references and the worker is not configured for external + storage, it will cause a workflow task failure. + """ + driver = InMemoryTestDriver() + + far_client = await Client.connect( + env.client.service_client.config.target_host, + namespace=env.client.namespace, + data_converter=dataclasses.replace( + temporalio.converter.default(), + external_storage=ExternalStorage( + drivers=[driver], + payload_size_threshold=1024, + ), + ), + ) + + worker_client = await Client.connect( + env.client.service_client.config.target_host, + namespace=env.client.namespace, + ) + + async with new_worker( + worker_client, ExtStoreWorkflow, activities=[ext_store_activity] + ) as worker: + handle = await far_client.start_workflow( + ExtStoreWorkflow.run, + ExtStoreWorkflowInput( + input_data="wi" * 1024, + activity_input_size=10, + activity_output_size=10, + output_size=10, + ), + id=f"workflow-{uuid.uuid4()}", + task_queue=worker.task_queue, + ) + + await assert_task_fail_eventually(handle) + + +async def test_extstore_payload_not_found_fails_workflow( + env: WorkflowEnvironment, +): + """When a non-retryable ApplicationError is raised while retrieving workflow input, + the workflow must fail terminally (not retry as a task failure). + """ + client = await Client.connect( + env.client.service_client.config.target_host, + namespace=env.client.namespace, + data_converter=dataclasses.replace( + temporalio.converter.default(), + external_storage=ExternalStorage( + drivers=[BadTestDriver(raise_payload_not_found=True)], + payload_size_threshold=1024, + ), + ), + ) + + async with new_worker( + client, ExtStoreWorkflow, activities=[ext_store_activity] + ) as worker: + handle = await client.start_workflow( + ExtStoreWorkflow.run, + ExtStoreWorkflowInput( + input_data="wi" * 512, # exceeds 1024-byte threshold + activity_input_size=10, + activity_output_size=10, + output_size=10, + ), + id=f"workflow-{uuid.uuid4()}", + task_queue=worker.task_queue, + execution_timeout=timedelta(seconds=10), + ) + + with pytest.raises(WorkflowFailureError) as exc_info: + await handle.result() + + assert isinstance(exc_info.value.cause, ApplicationError) + assert exc_info.value.cause.type == "PayloadNotFoundError" + assert exc_info.value.cause.non_retryable is True + + +async def _run_extstore_workflow_and_fetch_history( + env: WorkflowEnvironment, + driver: InMemoryTestDriver, + *, + input_data: str, + activity_output_size: int = 10, +) -> WorkflowHandle: + """Helper: run ExtStoreWorkflow with the given driver and return its history handle.""" + extstore_client = await Client.connect( + env.client.service_client.config.target_host, + namespace=env.client.namespace, + data_converter=dataclasses.replace( + temporalio.converter.default(), + external_storage=ExternalStorage( + drivers=[driver], + payload_size_threshold=512, + ), + ), + ) + async with new_worker( + extstore_client, ExtStoreWorkflow, activities=[ext_store_activity] + ) as worker: + handle = await extstore_client.start_workflow( + ExtStoreWorkflow.run, + ExtStoreWorkflowInput( + input_data=input_data, + activity_input_size=10, + activity_output_size=activity_output_size, + output_size=10, + ), + id=f"workflow-{uuid.uuid4()}", + task_queue=worker.task_queue, + ) + await handle.result() + return handle + + +async def test_replay_extstore_history_fails_without_extstore( + env: WorkflowEnvironment, +) -> None: + """A history with externalized workflow input fails to replay when the + Replayer has no external storage configured.""" + driver = InMemoryTestDriver() + handle = await _run_extstore_workflow_and_fetch_history( + env, + driver, + input_data="wi" * 512, # exceeds 512-byte threshold + ) + history = await handle.fetch_history() + + # Replay without external storage — the reference payload cannot be decoded. + # The middleware emits a StorageWarning when it encounters a reference payload + # with no driver configured. + with pytest.warns(StorageWarning, match="TMPRL1105"): + result = await Replayer(workflows=[ExtStoreWorkflow]).replay_workflow( + history, raise_on_replay_failure=False + ) + # Must be a task-failure RuntimeError, not a NondeterminismError — external + # storage decode failures are distinct from workflow code changes. + assert isinstance(result.replay_failure, RuntimeError) + assert not isinstance(result.replay_failure, workflow.NondeterminismError) + # The message is the full activation-completion failure string; the + # "Failed decoding arguments" text from _convert_payloads is embedded in it. + assert "Failed decoding arguments" in result.replay_failure.args[0] + + +async def test_replay_extstore_history_succeeds_with_correct_extstore( + env: WorkflowEnvironment, +) -> None: + """A history with externalized workflow input replays successfully when the + Replayer is configured with the same storage driver that holds the data.""" + driver = InMemoryTestDriver() + handle = await _run_extstore_workflow_and_fetch_history( + env, driver, input_data="wi" * 512 + ) + history = await handle.fetch_history() + + # Replay with the same populated driver — must succeed. + await Replayer( + workflows=[ExtStoreWorkflow], + data_converter=dataclasses.replace( + temporalio.converter.default(), + external_storage=ExternalStorage( + drivers=[driver], + payload_size_threshold=512, + ), + ), + ).replay_workflow(history) + + +async def test_replay_extstore_history_fails_with_empty_driver( + env: WorkflowEnvironment, +) -> None: + """A history with external storage references fails to replay when the + Replayer has external storage configured but the driver holds no data + (simulates pointing at the wrong backend or a purged store).""" + driver = InMemoryTestDriver() + handle = await _run_extstore_workflow_and_fetch_history( + env, driver, input_data="wi" * 512 + ) + history = await handle.fetch_history() + + # Replay with a fresh empty driver — retrieval will fail. + result = await Replayer( + workflows=[ExtStoreWorkflow], + data_converter=dataclasses.replace( + temporalio.converter.default(), + external_storage=ExternalStorage( + drivers=[InMemoryTestDriver()], + payload_size_threshold=512, + ), + ), + ).replay_workflow(history, raise_on_replay_failure=False) + # InMemoryTestDriver raises ApplicationError for absent keys. + # ApplicationError is re-raised without wrapping, so it propagates + # through decode_activation (before the workflow task runs). The core SDK + # receives an activation failure, issues a FailWorkflow command, but the + # next history event is ActivityTaskScheduled — causing a NondeterminismError. + assert isinstance(result.replay_failure, workflow.NondeterminismError) + + +async def test_replay_extstore_activity_result_fails_without_extstore( + env: WorkflowEnvironment, +) -> None: + """A history where only the activity result was stored externally (the + workflow input is small enough to be inline) also fails to replay without + external storage — verifying that mid-workflow decode failures are caught.""" + driver = InMemoryTestDriver() + handle = await _run_extstore_workflow_and_fetch_history( + env, + driver, + input_data="small", # well under 512 bytes — stays inline + activity_output_size=2048, # 2 KB result — stored externally + ) + history = await handle.fetch_history() + + # Replay without external storage. The workflow input decodes fine, but + # when the ActivityTaskCompleted result is delivered back to the workflow + # coroutine it cannot be decoded. + with pytest.warns(StorageWarning, match="TMPRL1105"): + result = await Replayer(workflows=[ExtStoreWorkflow]).replay_workflow( + history, raise_on_replay_failure=False + ) + # Mid-workflow decode failure is still a task failure (RuntimeError), not + # nondeterminism. + assert isinstance(result.replay_failure, RuntimeError) + assert not isinstance(result.replay_failure, workflow.NondeterminismError) + # The message is the full activation-completion failure string; the + # "Failed decoding arguments" text from _convert_payloads is embedded in it. + assert "Failed decoding arguments" in result.replay_failure.args[0]