From c0137077677d1a7e447596af7471f7da68493fb1 Mon Sep 17 00:00:00 2001 From: Christian Leopoldseder Date: Tue, 14 Apr 2026 07:21:04 -0700 Subject: [PATCH] feat: GenAI SDK client(multimodal) - Allow passing dataset ID in addition to full resource name in dataset methods. PiperOrigin-RevId: 899573600 --- .../replays/test_get_multimodal_datasets.py | 19 +++ vertexai/_genai/_datasets_utils.py | 8 ++ vertexai/_genai/datasets.py | 114 +++++++++++++----- 3 files changed, 112 insertions(+), 29 deletions(-) diff --git a/tests/unit/vertexai/genai/replays/test_get_multimodal_datasets.py b/tests/unit/vertexai/genai/replays/test_get_multimodal_datasets.py index dbc9da776e..09769040e9 100644 --- a/tests/unit/vertexai/genai/replays/test_get_multimodal_datasets.py +++ b/tests/unit/vertexai/genai/replays/test_get_multimodal_datasets.py @@ -41,6 +41,15 @@ def test_get_dataset_from_public_method(client): assert dataset.display_name == "test-display-name" +def test_get_dataset_by_id(client): + dataset = client.datasets.get_multimodal_dataset( + name="8810841321427173376", + ) + assert isinstance(dataset, types.MultimodalDataset) + assert dataset.name == DATASET + assert dataset.display_name == "test-display-name" + + pytestmark = pytest_helper.setup( file=__file__, globals_for_file=globals(), @@ -67,3 +76,13 @@ async def test_get_dataset_from_public_method_async(client): assert isinstance(dataset, types.MultimodalDataset) assert dataset.name == DATASET assert dataset.display_name == "test-display-name" + + +@pytest.mark.asyncio +async def test_get_dataset_by_id_async(client): + dataset = await client.aio.datasets.get_multimodal_dataset( + name="8810841321427173376", + ) + assert isinstance(dataset, types.MultimodalDataset) + assert dataset.name == DATASET + assert dataset.display_name == "test-display-name" diff --git a/vertexai/_genai/_datasets_utils.py b/vertexai/_genai/_datasets_utils.py index a913523e7e..f853b4b2c1 100644 --- a/vertexai/_genai/_datasets_utils.py +++ b/vertexai/_genai/_datasets_utils.py @@ -262,3 +262,11 @@ async def save_dataframe_to_bigquery_async( ) await asyncio.to_thread(copy_job.result) await asyncio.to_thread(bq_client.delete_table, temp_table_id) + + +def resolve_dataset_name(resource_name_or_id: str, project: str, location: str) -> str: + """Resolves a dataset name or ID to a full resource name.""" + resource_prefix = f"projects/{project}/locations/{location}/datasets/" + if not resource_name_or_id.startswith(resource_prefix): + return resource_prefix + resource_name_or_id + return resource_name_or_id diff --git a/vertexai/_genai/datasets.py b/vertexai/_genai/datasets.py index e9febf01b0..2b10805729 100644 --- a/vertexai/_genai/datasets.py +++ b/vertexai/_genai/datasets.py @@ -1130,8 +1130,8 @@ def get_multimodal_dataset( Args: name: - Required. name of a multimodal dataset. The name should be in - the format of "projects/{project}/locations/{location}/datasets/{dataset}". + Required. A fully-qualified resource name or ID of the dataset. + Example: "projects/.../locations/.../datasets/123" or "123". config: Optional. A configuration for getting the multimodal dataset. If not provided, the default configuration will be used. @@ -1145,6 +1145,10 @@ def get_multimodal_dataset( elif not config: config = types.VertexBaseConfig() + name = _datasets_utils.resolve_dataset_name( + name, self._api_client.project, self._api_client.location + ) + return self._get_multimodal_dataset(config=config, name=name) def delete_multimodal_dataset( @@ -1157,8 +1161,8 @@ def delete_multimodal_dataset( Args: name: - Required. name of a multimodal dataset. The name should be in - the format of "projects/{project}/locations/{location}/datasets/{dataset}". + Required. A fully-qualified resource name or ID of the dataset. + Example: "projects/.../locations/.../datasets/123" or "123". config: Optional. A configuration for deleting the multimodal dataset. If not provided, the default configuration will be used. @@ -1172,6 +1176,10 @@ def delete_multimodal_dataset( elif not config: config = types.VertexBaseConfig() + name = _datasets_utils.resolve_dataset_name( + name, self._api_client.project, self._api_client.location + ) + return self._delete_multimodal_dataset(config=config, name=name) def assemble( @@ -1189,8 +1197,8 @@ def assemble( Args: name: - Required. The name of the dataset to assemble. The name should be in - the format of "projects/{project}/locations/{location}/datasets/{dataset}". + Required. A fully-qualified resource name or ID of the dataset. + Example: "projects/.../locations/.../datasets/123" or "123". gemini_request_read_config: Optional. The read config to use to assemble the dataset. If not provided, the read config attached to the dataset will be @@ -1207,6 +1215,10 @@ def assemble( elif not config: config = types.AssembleDatasetConfig() + name = _datasets_utils.resolve_dataset_name( + name, self._api_client.project, self._api_client.location + ) + operation = self._assemble_multimodal_dataset( name=name, gemini_request_read_config=gemini_request_read_config, @@ -1232,8 +1244,8 @@ def assess_tuning_resources( Args: dataset_name: - Required. The name of the dataset to assess the tuning resources - for. The name should be in the format of "projects/{project}/locations/{location}/datasets/{dataset}". + Required. A fully-qualified resource name or ID of the dataset. + Example: "projects/.../locations/.../datasets/123" or "123". model_name: Required. The name of the model to assess the tuning resources for. @@ -1255,6 +1267,10 @@ def assess_tuning_resources( elif not config: config = types.AssessDatasetConfig() + dataset_name = _datasets_utils.resolve_dataset_name( + dataset_name, self._api_client.project, self._api_client.location + ) + operation = self._assess_multimodal_dataset( name=dataset_name, tuning_resource_usage_assessment_config=types.TuningResourceUsageAssessmentConfig( @@ -1288,8 +1304,8 @@ def assess_tuning_validity( Args: dataset_name: - Required. The name of the dataset to assess the tuning validity - for. The name should be in the format of "projects/{project}/locations/{location}/datasets/{dataset}". + Required. A fully-qualified resource name or ID of the dataset. + Example: "projects/.../locations/.../datasets/123" or "123". model_name: Required. The name of the model to assess the tuning validity for. @@ -1316,6 +1332,10 @@ def assess_tuning_validity( elif not config: config = types.AssessDatasetConfig() + dataset_name = _datasets_utils.resolve_dataset_name( + dataset_name, self._api_client.project, self._api_client.location + ) + operation = self._assess_multimodal_dataset( name=dataset_name, tuning_validation_assessment_config=types.TuningValidationAssessmentConfig( @@ -1348,8 +1368,8 @@ def assess_batch_prediction_resources( Args: dataset_name: - Required. The name of the dataset to assess the batch prediction - resources. The name should be in the format of "projects/{project}/locations/{location}/datasets/{dataset}". + Required. A fully-qualified resource name or ID of the dataset. + Example: "projects/.../locations/.../datasets/123" or "123". model_name: Required. The name of the model to assess the batch prediction resources. @@ -1376,6 +1396,10 @@ def assess_batch_prediction_resources( elif not config: config = types.AssessDatasetConfig() + dataset_name = _datasets_utils.resolve_dataset_name( + dataset_name, self._api_client.project, self._api_client.location + ) + operation = self._assess_multimodal_dataset( name=dataset_name, batch_prediction_resource_usage_assessment_config=types.BatchPredictionResourceUsageAssessmentConfig( @@ -1409,8 +1433,8 @@ def assess_batch_prediction_validity( Args: dataset_name: - Required. The name of the dataset to assess the batch prediction - validity for. The name should be in the format of "projects/{project}/locations/{location}/datasets/{dataset}". + Required. A fully-qualified resource name or ID of the dataset. + Example: "projects/.../locations/.../datasets/123" or "123". model_name: Required. The name of the model to assess the batch prediction validity for. @@ -1435,6 +1459,10 @@ def assess_batch_prediction_validity( elif not config: config = types.AssessDatasetConfig() + dataset_name = _datasets_utils.resolve_dataset_name( + dataset_name, self._api_client.project, self._api_client.location + ) + operation = self._assess_multimodal_dataset( name=dataset_name, batch_prediction_validation_assessment_config=types.BatchPredictionValidationAssessmentConfig( @@ -2352,14 +2380,14 @@ async def get_multimodal_dataset( Args: name: - Required. name of a multimodal dataset. The name should be in - the format of "projects/{project}/locations/{location}/datasets/{dataset}". + Required. A fully-qualified resource name or ID of the dataset. + Example: "projects/.../locations/.../datasets/123" or "123". config: Optional. A configuration for getting the multimodal dataset. If not provided, the default configuration will be used. Returns: - A types.MultimodalDataset object representing the updated multimodal + A types.MultimodalDataset object representing the retrieved multimodal dataset. """ if isinstance(config, dict): @@ -2367,6 +2395,10 @@ async def get_multimodal_dataset( elif not config: config = types.VertexBaseConfig() + name = _datasets_utils.resolve_dataset_name( + name, self._api_client.project, self._api_client.location + ) + return await self._get_multimodal_dataset(config=config, name=name) async def delete_multimodal_dataset( @@ -2379,8 +2411,8 @@ async def delete_multimodal_dataset( Args: name: - Required. name of a multimodal dataset. The name should be in - the format of "projects/{project}/locations/{location}/datasets/{dataset}". + Required. A fully-qualified resource name or ID of the dataset. + Example: "projects/.../locations/.../datasets/123" or "123". config: Optional. A configuration for deleting the multimodal dataset. If not provided, the default configuration will be used. @@ -2394,6 +2426,10 @@ async def delete_multimodal_dataset( elif not config: config = types.VertexBaseConfig() + name = _datasets_utils.resolve_dataset_name( + name, self._api_client.project, self._api_client.location + ) + return await self._delete_multimodal_dataset(config=config, name=name) async def assemble( @@ -2411,8 +2447,8 @@ async def assemble( Args: name: - Required. The name of the dataset to assemble. The name should be in - the format of "projects/{project}/locations/{location}/datasets/{dataset}". + Required. A fully-qualified resource name or ID of the dataset. + Example: "projects/.../locations/.../datasets/123" or "123". gemini_request_read_config: Optional. The read config to use to assemble the dataset. If not provided, the read config attached to the dataset will be @@ -2429,6 +2465,10 @@ async def assemble( elif not config: config = types.AssembleDatasetConfig() + name = _datasets_utils.resolve_dataset_name( + name, self._api_client.project, self._api_client.location + ) + operation = await self._assemble_multimodal_dataset( name=name, gemini_request_read_config=gemini_request_read_config, @@ -2454,8 +2494,8 @@ async def assess_tuning_resources( Args: dataset_name: - Required. The name of the dataset to assess the tuning resources - for. The name should be in the format of "projects/{project}/locations/{location}/datasets/{dataset}". + Required. A fully-qualified resource name or ID of the dataset. + Example: "projects/.../locations/.../datasets/123" or "123". model_name: Required. The name of the model to assess the tuning resources for. @@ -2477,6 +2517,10 @@ async def assess_tuning_resources( elif not config: config = types.AssessDatasetConfig() + dataset_name = _datasets_utils.resolve_dataset_name( + dataset_name, self._api_client.project, self._api_client.location + ) + operation = await self._assess_multimodal_dataset( name=dataset_name, tuning_resource_usage_assessment_config=types.TuningResourceUsageAssessmentConfig( @@ -2510,8 +2554,8 @@ async def assess_tuning_validity( Args: dataset_name: - Required. The name of the dataset to assess the tuning validity - for. The name should be in the format of "projects/{project}/locations/{location}/datasets/{dataset}". + Required. A fully-qualified resource name or ID of the dataset. + Example: "projects/.../locations/.../datasets/123" or "123". model_name: Required. The name of the model to assess the tuning validity for. @@ -2538,6 +2582,10 @@ async def assess_tuning_validity( elif not config: config = types.AssessDatasetConfig() + dataset_name = _datasets_utils.resolve_dataset_name( + dataset_name, self._api_client.project, self._api_client.location + ) + operation = await self._assess_multimodal_dataset( name=dataset_name, tuning_validation_assessment_config=types.TuningValidationAssessmentConfig( @@ -2570,8 +2618,8 @@ async def assess_batch_prediction_resources( Args: dataset_name: - Required. The name of the dataset to assess the batch prediction - resources. The name should be in the format of "projects/{project}/locations/{location}/datasets/{dataset}". + Required. A fully-qualified resource name or ID of the dataset. + Example: "projects/.../locations/.../datasets/123" or "123". model_name: Required. The name of the model to assess the batch prediction resources. @@ -2598,6 +2646,10 @@ async def assess_batch_prediction_resources( elif not config: config = types.AssessDatasetConfig() + dataset_name = _datasets_utils.resolve_dataset_name( + dataset_name, self._api_client.project, self._api_client.location + ) + operation = await self._assess_multimodal_dataset( name=dataset_name, batch_prediction_resource_usage_assessment_config=types.BatchPredictionResourceUsageAssessmentConfig( @@ -2631,8 +2683,8 @@ async def assess_batch_prediction_validity( Args: dataset_name: - Required. The name of the dataset to assess the batch prediction - validity for. The name should be in the format of "projects/{project}/locations/{location}/datasets/{dataset}". + Required. A fully-qualified resource name or ID of the dataset. + Example: "projects/.../locations/.../datasets/123" or "123". model_name: Required. The name of the model to assess the batch prediction validity for. @@ -2657,6 +2709,10 @@ async def assess_batch_prediction_validity( elif not config: config = types.AssessDatasetConfig() + dataset_name = _datasets_utils.resolve_dataset_name( + dataset_name, self._api_client.project, self._api_client.location + ) + operation = await self._assess_multimodal_dataset( name=dataset_name, batch_prediction_validation_assessment_config=types.BatchPredictionValidationAssessmentConfig(