Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,15 @@ def test_get_dataset_from_public_method(client):
assert dataset.display_name == "test-display-name"


def test_get_dataset_by_id(client):
dataset = client.datasets.get_multimodal_dataset(
name="8810841321427173376",
)
assert isinstance(dataset, types.MultimodalDataset)
assert dataset.name == DATASET
assert dataset.display_name == "test-display-name"


pytestmark = pytest_helper.setup(
file=__file__,
globals_for_file=globals(),
Expand All @@ -67,3 +76,13 @@ async def test_get_dataset_from_public_method_async(client):
assert isinstance(dataset, types.MultimodalDataset)
assert dataset.name == DATASET
assert dataset.display_name == "test-display-name"


@pytest.mark.asyncio
async def test_get_dataset_by_id_async(client):
dataset = await client.aio.datasets.get_multimodal_dataset(
name="8810841321427173376",
)
assert isinstance(dataset, types.MultimodalDataset)
assert dataset.name == DATASET
assert dataset.display_name == "test-display-name"
8 changes: 8 additions & 0 deletions vertexai/_genai/_datasets_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,3 +262,11 @@ async def save_dataframe_to_bigquery_async(
)
await asyncio.to_thread(copy_job.result)
await asyncio.to_thread(bq_client.delete_table, temp_table_id)


def resolve_dataset_name(resource_name_or_id: str, project: str, location: str) -> str:
"""Resolves a dataset name or ID to a full resource name."""
resource_prefix = f"projects/{project}/locations/{location}/datasets/"
if not resource_name_or_id.startswith(resource_prefix):
return resource_prefix + resource_name_or_id
return resource_name_or_id
114 changes: 85 additions & 29 deletions vertexai/_genai/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -1130,8 +1130,8 @@ def get_multimodal_dataset(

Args:
name:
Required. name of a multimodal dataset. The name should be in
the format of "projects/{project}/locations/{location}/datasets/{dataset}".
Required. A fully-qualified resource name or ID of the dataset.
Example: "projects/.../locations/.../datasets/123" or "123".
config:
Optional. A configuration for getting the multimodal dataset. If not
provided, the default configuration will be used.
Expand All @@ -1145,6 +1145,10 @@ def get_multimodal_dataset(
elif not config:
config = types.VertexBaseConfig()

name = _datasets_utils.resolve_dataset_name(
name, self._api_client.project, self._api_client.location
)

return self._get_multimodal_dataset(config=config, name=name)

def delete_multimodal_dataset(
Expand All @@ -1157,8 +1161,8 @@ def delete_multimodal_dataset(

Args:
name:
Required. name of a multimodal dataset. The name should be in
the format of "projects/{project}/locations/{location}/datasets/{dataset}".
Required. A fully-qualified resource name or ID of the dataset.
Example: "projects/.../locations/.../datasets/123" or "123".
config:
Optional. A configuration for deleting the multimodal dataset. If not
provided, the default configuration will be used.
Expand All @@ -1172,6 +1176,10 @@ def delete_multimodal_dataset(
elif not config:
config = types.VertexBaseConfig()

name = _datasets_utils.resolve_dataset_name(
name, self._api_client.project, self._api_client.location
)

return self._delete_multimodal_dataset(config=config, name=name)

def assemble(
Expand All @@ -1189,8 +1197,8 @@ def assemble(

Args:
name:
Required. The name of the dataset to assemble. The name should be in
the format of "projects/{project}/locations/{location}/datasets/{dataset}".
Required. A fully-qualified resource name or ID of the dataset.
Example: "projects/.../locations/.../datasets/123" or "123".
gemini_request_read_config:
Optional. The read config to use to assemble the dataset. If
not provided, the read config attached to the dataset will be
Expand All @@ -1207,6 +1215,10 @@ def assemble(
elif not config:
config = types.AssembleDatasetConfig()

name = _datasets_utils.resolve_dataset_name(
name, self._api_client.project, self._api_client.location
)

operation = self._assemble_multimodal_dataset(
name=name,
gemini_request_read_config=gemini_request_read_config,
Expand All @@ -1232,8 +1244,8 @@ def assess_tuning_resources(

Args:
dataset_name:
Required. The name of the dataset to assess the tuning resources
for. The name should be in the format of "projects/{project}/locations/{location}/datasets/{dataset}".
Required. A fully-qualified resource name or ID of the dataset.
Example: "projects/.../locations/.../datasets/123" or "123".
model_name:
Required. The name of the model to assess the tuning resources
for.
Expand All @@ -1255,6 +1267,10 @@ def assess_tuning_resources(
elif not config:
config = types.AssessDatasetConfig()

dataset_name = _datasets_utils.resolve_dataset_name(
dataset_name, self._api_client.project, self._api_client.location
)

operation = self._assess_multimodal_dataset(
name=dataset_name,
tuning_resource_usage_assessment_config=types.TuningResourceUsageAssessmentConfig(
Expand Down Expand Up @@ -1288,8 +1304,8 @@ def assess_tuning_validity(

Args:
dataset_name:
Required. The name of the dataset to assess the tuning validity
for. The name should be in the format of "projects/{project}/locations/{location}/datasets/{dataset}".
Required. A fully-qualified resource name or ID of the dataset.
Example: "projects/.../locations/.../datasets/123" or "123".
model_name:
Required. The name of the model to assess the tuning validity
for.
Expand All @@ -1316,6 +1332,10 @@ def assess_tuning_validity(
elif not config:
config = types.AssessDatasetConfig()

dataset_name = _datasets_utils.resolve_dataset_name(
dataset_name, self._api_client.project, self._api_client.location
)

operation = self._assess_multimodal_dataset(
name=dataset_name,
tuning_validation_assessment_config=types.TuningValidationAssessmentConfig(
Expand Down Expand Up @@ -1348,8 +1368,8 @@ def assess_batch_prediction_resources(

Args:
dataset_name:
Required. The name of the dataset to assess the batch prediction
resources. The name should be in the format of "projects/{project}/locations/{location}/datasets/{dataset}".
Required. A fully-qualified resource name or ID of the dataset.
Example: "projects/.../locations/.../datasets/123" or "123".
model_name:
Required. The name of the model to assess the batch prediction
resources.
Expand All @@ -1376,6 +1396,10 @@ def assess_batch_prediction_resources(
elif not config:
config = types.AssessDatasetConfig()

dataset_name = _datasets_utils.resolve_dataset_name(
dataset_name, self._api_client.project, self._api_client.location
)

operation = self._assess_multimodal_dataset(
name=dataset_name,
batch_prediction_resource_usage_assessment_config=types.BatchPredictionResourceUsageAssessmentConfig(
Expand Down Expand Up @@ -1409,8 +1433,8 @@ def assess_batch_prediction_validity(

Args:
dataset_name:
Required. The name of the dataset to assess the batch prediction
validity for. The name should be in the format of "projects/{project}/locations/{location}/datasets/{dataset}".
Required. A fully-qualified resource name or ID of the dataset.
Example: "projects/.../locations/.../datasets/123" or "123".
model_name:
Required. The name of the model to assess the batch prediction
validity for.
Expand All @@ -1435,6 +1459,10 @@ def assess_batch_prediction_validity(
elif not config:
config = types.AssessDatasetConfig()

dataset_name = _datasets_utils.resolve_dataset_name(
dataset_name, self._api_client.project, self._api_client.location
)

operation = self._assess_multimodal_dataset(
name=dataset_name,
batch_prediction_validation_assessment_config=types.BatchPredictionValidationAssessmentConfig(
Expand Down Expand Up @@ -2352,21 +2380,25 @@ async def get_multimodal_dataset(

Args:
name:
Required. name of a multimodal dataset. The name should be in
the format of "projects/{project}/locations/{location}/datasets/{dataset}".
Required. A fully-qualified resource name or ID of the dataset.
Example: "projects/.../locations/.../datasets/123" or "123".
config:
Optional. A configuration for getting the multimodal dataset. If not
provided, the default configuration will be used.

Returns:
A types.MultimodalDataset object representing the updated multimodal
A types.MultimodalDataset object representing the retrieved multimodal
dataset.
"""
if isinstance(config, dict):
config = types.VertexBaseConfig(**config)
elif not config:
config = types.VertexBaseConfig()

name = _datasets_utils.resolve_dataset_name(
name, self._api_client.project, self._api_client.location
)

return await self._get_multimodal_dataset(config=config, name=name)

async def delete_multimodal_dataset(
Expand All @@ -2379,8 +2411,8 @@ async def delete_multimodal_dataset(

Args:
name:
Required. name of a multimodal dataset. The name should be in
the format of "projects/{project}/locations/{location}/datasets/{dataset}".
Required. A fully-qualified resource name or ID of the dataset.
Example: "projects/.../locations/.../datasets/123" or "123".
config:
Optional. A configuration for deleting the multimodal dataset. If not
provided, the default configuration will be used.
Expand All @@ -2394,6 +2426,10 @@ async def delete_multimodal_dataset(
elif not config:
config = types.VertexBaseConfig()

name = _datasets_utils.resolve_dataset_name(
name, self._api_client.project, self._api_client.location
)

return await self._delete_multimodal_dataset(config=config, name=name)

async def assemble(
Expand All @@ -2411,8 +2447,8 @@ async def assemble(

Args:
name:
Required. The name of the dataset to assemble. The name should be in
the format of "projects/{project}/locations/{location}/datasets/{dataset}".
Required. A fully-qualified resource name or ID of the dataset.
Example: "projects/.../locations/.../datasets/123" or "123".
gemini_request_read_config:
Optional. The read config to use to assemble the dataset. If
not provided, the read config attached to the dataset will be
Expand All @@ -2429,6 +2465,10 @@ async def assemble(
elif not config:
config = types.AssembleDatasetConfig()

name = _datasets_utils.resolve_dataset_name(
name, self._api_client.project, self._api_client.location
)

operation = await self._assemble_multimodal_dataset(
name=name,
gemini_request_read_config=gemini_request_read_config,
Expand All @@ -2454,8 +2494,8 @@ async def assess_tuning_resources(

Args:
dataset_name:
Required. The name of the dataset to assess the tuning resources
for. The name should be in the format of "projects/{project}/locations/{location}/datasets/{dataset}".
Required. A fully-qualified resource name or ID of the dataset.
Example: "projects/.../locations/.../datasets/123" or "123".
model_name:
Required. The name of the model to assess the tuning resources
for.
Expand All @@ -2477,6 +2517,10 @@ async def assess_tuning_resources(
elif not config:
config = types.AssessDatasetConfig()

dataset_name = _datasets_utils.resolve_dataset_name(
dataset_name, self._api_client.project, self._api_client.location
)

operation = await self._assess_multimodal_dataset(
name=dataset_name,
tuning_resource_usage_assessment_config=types.TuningResourceUsageAssessmentConfig(
Expand Down Expand Up @@ -2510,8 +2554,8 @@ async def assess_tuning_validity(

Args:
dataset_name:
Required. The name of the dataset to assess the tuning validity
for. The name should be in the format of "projects/{project}/locations/{location}/datasets/{dataset}".
Required. A fully-qualified resource name or ID of the dataset.
Example: "projects/.../locations/.../datasets/123" or "123".
model_name:
Required. The name of the model to assess the tuning validity
for.
Expand All @@ -2538,6 +2582,10 @@ async def assess_tuning_validity(
elif not config:
config = types.AssessDatasetConfig()

dataset_name = _datasets_utils.resolve_dataset_name(
dataset_name, self._api_client.project, self._api_client.location
)

operation = await self._assess_multimodal_dataset(
name=dataset_name,
tuning_validation_assessment_config=types.TuningValidationAssessmentConfig(
Expand Down Expand Up @@ -2570,8 +2618,8 @@ async def assess_batch_prediction_resources(

Args:
dataset_name:
Required. The name of the dataset to assess the batch prediction
resources. The name should be in the format of "projects/{project}/locations/{location}/datasets/{dataset}".
Required. A fully-qualified resource name or ID of the dataset.
Example: "projects/.../locations/.../datasets/123" or "123".
model_name:
Required. The name of the model to assess the batch prediction
resources.
Expand All @@ -2598,6 +2646,10 @@ async def assess_batch_prediction_resources(
elif not config:
config = types.AssessDatasetConfig()

dataset_name = _datasets_utils.resolve_dataset_name(
dataset_name, self._api_client.project, self._api_client.location
)

operation = await self._assess_multimodal_dataset(
name=dataset_name,
batch_prediction_resource_usage_assessment_config=types.BatchPredictionResourceUsageAssessmentConfig(
Expand Down Expand Up @@ -2631,8 +2683,8 @@ async def assess_batch_prediction_validity(

Args:
dataset_name:
Required. The name of the dataset to assess the batch prediction
validity for. The name should be in the format of "projects/{project}/locations/{location}/datasets/{dataset}".
Required. A fully-qualified resource name or ID of the dataset.
Example: "projects/.../locations/.../datasets/123" or "123".
model_name:
Required. The name of the model to assess the batch prediction
validity for.
Expand All @@ -2657,6 +2709,10 @@ async def assess_batch_prediction_validity(
elif not config:
config = types.AssessDatasetConfig()

dataset_name = _datasets_utils.resolve_dataset_name(
dataset_name, self._api_client.project, self._api_client.location
)

operation = await self._assess_multimodal_dataset(
name=dataset_name,
batch_prediction_validation_assessment_config=types.BatchPredictionValidationAssessmentConfig(
Expand Down
Loading