diff --git a/sdk/cosmos/azure-cosmos/CHANGELOG.md b/sdk/cosmos/azure-cosmos/CHANGELOG.md index 187cd853cb27..2c2b9817d9ff 100644 --- a/sdk/cosmos/azure-cosmos/CHANGELOG.md +++ b/sdk/cosmos/azure-cosmos/CHANGELOG.md @@ -3,6 +3,7 @@ ### 4.16.0b3 (Unreleased) #### Features Added +* Added the `EmbeddingProvider` Protocol and `EmbeddingResult` dataclass defining the contract the SDK will use to generate vector embeddings for `GenerateEmbeddings(...)` query expressions. See [46902](https://github.com/Azure/azure-sdk-for-python/pull/46902) #### Breaking Changes diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/__init__.py b/sdk/cosmos/azure-cosmos/azure/cosmos/__init__.py index d7501df99558..7caf5c607aba 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/__init__.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/__init__.py @@ -21,6 +21,8 @@ from ._version import VERSION from ._cosmos_responses import CosmosDict, CosmosList +from ._embedding_provider import EmbeddingProvider +from ._embedding_result import EmbeddingResult from ._retry_utility import ConnectionRetryPolicy from .container import ContainerProxy from .cosmos_client import CosmosClient @@ -66,6 +68,8 @@ "ConnectionRetryPolicy", "ThroughputProperties", "CosmosDict", - "CosmosList" + "CosmosList", + "EmbeddingProvider", + "EmbeddingResult" ) __version__ = VERSION diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_embedding_provider.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_embedding_provider.py new file mode 100644 index 000000000000..d55356f90018 --- /dev/null +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_embedding_provider.py @@ -0,0 +1,39 @@ +# The MIT License (MIT) +# Copyright (c) Microsoft Corporation. All rights reserved. + +from typing import Any, Protocol, Sequence, runtime_checkable + +from ._embedding_result import EmbeddingResult + + +@runtime_checkable +class EmbeddingProvider(Protocol): + """Protocol for classes that generate text embeddings for Azure Cosmos DB queries. + + Implementations are invoked by the SDK to embed literal text in queries + that use ``GenerateEmbeddings(...)``. A provider may be attached at the + client level or overridden at the container level. Implementations must be + safe to call concurrently. + """ + + def generate_embeddings( + self, + texts: Sequence[str], + *, + endpoint: str, + deployment_name: str, + dimensions: int, + **kwargs: Any, + ) -> EmbeddingResult: + """Generate one embedding vector per input string. + + :param texts: The input strings to embed. The returned vectors must be + in the same order as the inputs. + :type texts: Sequence[str] + :keyword str endpoint: The embedding service endpoint. + :keyword str deployment_name: The model deployment name. + :keyword int dimensions: The expected vector dimensionality. + :returns: An :class:`EmbeddingResult` containing the generated vectors. + :rtype: ~azure.cosmos.EmbeddingResult + """ + ... diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_embedding_result.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_embedding_result.py new file mode 100644 index 000000000000..4b841053b281 --- /dev/null +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_embedding_result.py @@ -0,0 +1,20 @@ +# The MIT License (MIT) +# Copyright (c) Microsoft Corporation. All rights reserved. + +from dataclasses import dataclass +from typing import List, Optional + + +@dataclass +class EmbeddingResult: + """Represents the result of an embedding generation call. + + :ivar vectors: The generated embedding vectors, one per input string, + in the same order as the inputs. + :vartype vectors: List[List[float]] + :ivar total_tokens: The total number of tokens consumed by the embedding call. + :vartype total_tokens: Optional[int] + """ + + vectors: List[List[float]] + total_tokens: Optional[int] = None diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/__init__.py b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/__init__.py index 7d8ac99702e1..c6aaf1a28e71 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/__init__.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/__init__.py @@ -22,6 +22,7 @@ from ._container import ContainerProxy from ._cosmos_client import CosmosClient from ._database import DatabaseProxy +from ._embedding_provider import EmbeddingProvider from ._user import UserProxy from ._scripts import ScriptsProxy @@ -30,5 +31,6 @@ "DatabaseProxy", "ContainerProxy", "ScriptsProxy", - "UserProxy" + "UserProxy", + "EmbeddingProvider" ) diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_embedding_provider.py b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_embedding_provider.py new file mode 100644 index 000000000000..697d994fc643 --- /dev/null +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_embedding_provider.py @@ -0,0 +1,37 @@ +# The MIT License (MIT) +# Copyright (c) Microsoft Corporation. All rights reserved. + +from typing import Any, Protocol, Sequence, runtime_checkable + +from .._embedding_result import EmbeddingResult + + +@runtime_checkable +class EmbeddingProvider(Protocol): + """Asynchronous Protocol for classes that generate text embeddings for Azure Cosmos DB queries. + + Asynchronous counterpart of :class:`azure.cosmos.EmbeddingProvider` for use + with :class:`azure.cosmos.aio.CosmosClient`. + """ + + async def generate_embeddings( + self, + texts: Sequence[str], + *, + endpoint: str, + deployment_name: str, + dimensions: int, + **kwargs: Any, + ) -> EmbeddingResult: + """Asynchronously generate one embedding vector per input string. + + :param texts: The input strings to embed. The returned vectors must be + in the same order as the inputs. + :type texts: Sequence[str] + :keyword str endpoint: The embedding service endpoint. + :keyword str deployment_name: The model deployment name. + :keyword int dimensions: The expected vector dimensionality. + :returns: An :class:`EmbeddingResult` containing the generated vectors. + :rtype: ~azure.cosmos.EmbeddingResult + """ + ...