Skip to content

Commit 90da804

Browse files
vertex-sdk-botcopybara-github
authored andcommitted
feat: GenAI Client(evals) - Add allow_cross_region_model support for create_evaluation_run
PiperOrigin-RevId: 902873150
1 parent 719f874 commit 90da804

4 files changed

Lines changed: 195 additions & 0 deletions

File tree

tests/unit/vertexai/genai/replays/test_create_evaluation_run.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -294,6 +294,31 @@ def test_create_eval_run_with_inference_configs(client):
294294
assert evaluation_run.error is None
295295

296296

297+
def test_create_eval_run_with_allow_cross_region_model(client):
298+
"""Tests that create_evaluation_run() works with allow_cross_region_model in config."""
299+
client._api_client._http_options.api_version = "v1beta1"
300+
inference_config = types.EvaluationRunInferenceConfig(
301+
model=MODEL_NAME,
302+
prompt_template=types.EvaluationRunPromptTemplate(
303+
prompt_template="test prompt template"
304+
),
305+
)
306+
evaluation_run = client.evals.create_evaluation_run(
307+
name="test_inference_config",
308+
display_name="test_inference_config",
309+
dataset=types.EvaluationRunDataSource(evaluation_set=EVAL_SET_NAME),
310+
dest=GCS_DEST,
311+
metrics=[GENERAL_QUALITY_METRIC],
312+
inference_configs={"model_1": inference_config},
313+
labels={"label1": "value1"},
314+
config={"allow_cross_region_model": True},
315+
)
316+
assert isinstance(evaluation_run, types.EvaluationRun)
317+
assert evaluation_run.display_name == "test_inference_config"
318+
assert evaluation_run.state == types.EvaluationRunState.PENDING
319+
assert evaluation_run.error is None
320+
321+
297322
@mock.patch("uuid.uuid4")
298323
def test_create_eval_run_with_metric_resource_name(mock_uuid4, client):
299324
"""Tests create_evaluation_run with metric_resource_name."""

tests/unit/vertexai/genai/test_evals.py

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9136,3 +9136,104 @@ def test_computation_metric_retry_on_resource_exhausted(
91369136
summary_metric = result.summary_metrics[0]
91379137
assert summary_metric.metric_name == "bleu"
91389138
assert summary_metric.mean_score == 0.85
9139+
9140+
9141+
class TestAllowCrossRegionModel:
9142+
"""Tests for allow_cross_region_model flag for create_evaluation_run."""
9143+
9144+
def setup_method(self, method):
9145+
self.mock_api_client = mock.MagicMock()
9146+
self.mock_api_client.vertexai = True
9147+
9148+
self.mock_response = mock.MagicMock()
9149+
self.mock_response.body = json.dumps(
9150+
{
9151+
"name": "projects/123/locations/us-central1/evaluationRuns/456",
9152+
"displayName": "test_run",
9153+
"state": "PENDING",
9154+
}
9155+
)
9156+
self.mock_api_client.request.return_value = self.mock_response
9157+
9158+
def test_create_evaluation_run_config_has_allow_cross_region_model(self):
9159+
"""Verifies allow_cross_region_model field exists on CreateEvaluationRunConfig."""
9160+
config = vertexai_genai_types.CreateEvaluationRunConfig(
9161+
allow_cross_region_model=True,
9162+
)
9163+
assert config.allow_cross_region_model is True
9164+
9165+
def test_create_evaluation_run_config_from_dict(self):
9166+
"""Verifies allow_cross_region_model can be set via dict on CreateEvaluationRunConfig."""
9167+
config = vertexai_genai_types.CreateEvaluationRunConfig.model_validate(
9168+
{"allow_cross_region_model": True}
9169+
)
9170+
assert config.allow_cross_region_model is True
9171+
9172+
def test_create_evaluation_run_config_default_is_none(self):
9173+
"""Verifies the default value of allow_cross_region_model is None."""
9174+
config = vertexai_genai_types.CreateEvaluationRunConfig()
9175+
assert config.allow_cross_region_model is None
9176+
9177+
def test_create_evaluation_run_passes_allow_cross_region_model(self):
9178+
"""Verifies allow_cross_region_model is sent inside evaluationConfig in the API request."""
9179+
evals_module = evals.Evals(api_client_=self.mock_api_client)
9180+
9181+
evals_module.create_evaluation_run(
9182+
dataset=vertexai_genai_types.EvaluationRunDataSource(
9183+
evaluation_set="projects/123/locations/us-central1/evaluationSets/789"
9184+
),
9185+
metrics=[
9186+
vertexai_genai_types.EvaluationRunMetric(
9187+
metric="general_quality_v1",
9188+
metric_config=vertexai_genai_types.UnifiedMetric(
9189+
predefined_metric_spec=genai_types.PredefinedMetricSpec(
9190+
metric_spec_name="general_quality_v1",
9191+
)
9192+
),
9193+
)
9194+
],
9195+
dest="gs://test-bucket/output",
9196+
config={"allow_cross_region_model": True},
9197+
)
9198+
9199+
self.mock_api_client.request.assert_called_once()
9200+
call_args = self.mock_api_client.request.call_args
9201+
request_body = call_args[0][2] # Third positional arg is the request dict
9202+
assert (
9203+
request_body.get("evaluationConfig", {}).get("allowCrossRegionModel")
9204+
is True
9205+
)
9206+
9207+
@pytest.mark.asyncio
9208+
async def test_create_evaluation_run_async_passes_allow_cross_region_model(self):
9209+
"""Verifies allow_cross_region_model is sent inside evaluationConfig in the async API request."""
9210+
self.mock_api_client.async_request = mock.AsyncMock(
9211+
return_value=self.mock_response
9212+
)
9213+
async_evals_module = evals.AsyncEvals(api_client_=self.mock_api_client)
9214+
9215+
await async_evals_module.create_evaluation_run(
9216+
dataset=vertexai_genai_types.EvaluationRunDataSource(
9217+
evaluation_set="projects/123/locations/us-central1/evaluationSets/789"
9218+
),
9219+
metrics=[
9220+
vertexai_genai_types.EvaluationRunMetric(
9221+
metric="general_quality_v1",
9222+
metric_config=vertexai_genai_types.UnifiedMetric(
9223+
predefined_metric_spec=genai_types.PredefinedMetricSpec(
9224+
metric_spec_name="general_quality_v1",
9225+
)
9226+
),
9227+
)
9228+
],
9229+
dest="gs://test-bucket/output",
9230+
config={"allow_cross_region_model": True},
9231+
)
9232+
9233+
self.mock_api_client.async_request.assert_called_once()
9234+
call_args = self.mock_api_client.async_request.call_args
9235+
request_body = call_args[0][2] # Third positional arg is the request dict
9236+
assert (
9237+
request_body.get("evaluationConfig", {}).get("allowCrossRegionModel")
9238+
is True
9239+
)

vertexai/_genai/evals.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -391,6 +391,13 @@ def _EvaluationRunConfig_from_vertex(
391391
[item for item in getv(from_object, ["lossAnalysisConfig"])],
392392
)
393393

394+
if getv(from_object, ["allowCrossRegionModel"]) is not None:
395+
setv(
396+
to_object,
397+
["allow_cross_region_model"],
398+
getv(from_object, ["allowCrossRegionModel"]),
399+
)
400+
394401
return to_object
395402

396403

@@ -425,6 +432,13 @@ def _EvaluationRunConfig_to_vertex(
425432
[item for item in getv(from_object, ["loss_analysis_config"])],
426433
)
427434

435+
if getv(from_object, ["allow_cross_region_model"]) is not None:
436+
setv(
437+
to_object,
438+
["allowCrossRegionModel"],
439+
getv(from_object, ["allow_cross_region_model"]),
440+
)
441+
428442
return to_object
429443

430444

@@ -2653,6 +2667,13 @@ def create_evaluation_run(
26532667
``max_top_cluster_count``. Mutually exclusive with
26542668
``loss_analysis_metrics``.
26552669
config: The configuration for the evaluation run.
2670+
- allow_cross_region_model: Allows the evaluation run to use cross
2671+
region models. When this flag is set, the service may route traffic to
2672+
other regions if a model is unavailable in the current region (e.g.,
2673+
to a `global`endpoint). If a fully-qualified model endpoint resource
2674+
name with a different region than the run location is provided
2675+
elsewhere in the runconfig, this flag must be set to true or the
2676+
request will fail.
26562677
26572678
Returns:
26582679
The created evaluation run.
@@ -2672,6 +2693,11 @@ def create_evaluation_run(
26722693
else (agent_info or evals_types.AgentInfo())
26732694
)
26742695

2696+
if not config:
2697+
config = types.CreateEvaluationRunConfig()
2698+
if isinstance(config, dict):
2699+
config = types.CreateEvaluationRunConfig.model_validate(config)
2700+
26752701
if agent_info and not inference_configs:
26762702
parsed_user_simulator_config = (
26772703
evals_types.UserSimulatorConfig.model_validate(user_simulator_config)
@@ -2712,6 +2738,7 @@ def create_evaluation_run(
27122738
output_config=output_config,
27132739
metrics=resolved_metrics,
27142740
loss_analysis_config=resolved_loss_configs,
2741+
allow_cross_region_model=getattr(config, "allow_cross_region_model", None),
27152742
)
27162743
resolved_inference_configs = _evals_common._resolve_inference_configs(
27172744
self._api_client, resolved_dataset, inference_configs, parsed_agent_info
@@ -4422,6 +4449,8 @@ async def create_evaluation_run(
44224449
``max_top_cluster_count``. Mutually exclusive with
44234450
``loss_analysis_metrics``.
44244451
config: The configuration for the evaluation run.
4452+
- allow_cross_region_model: Opt-in flag to authorize cross-region
4453+
routing for model inference. Applies to both scraping and evaluation.
44254454
44264455
Returns:
44274456
The created evaluation run.
@@ -4441,6 +4470,11 @@ async def create_evaluation_run(
44414470
else (agent_info or evals_types.AgentInfo())
44424471
)
44434472

4473+
if not config:
4474+
config = types.CreateEvaluationRunConfig()
4475+
if isinstance(config, dict):
4476+
config = types.CreateEvaluationRunConfig.model_validate(config)
4477+
44444478
if agent_info and not inference_configs:
44454479
parsed_user_simulator_config = (
44464480
evals_types.UserSimulatorConfig.model_validate(user_simulator_config)
@@ -4481,6 +4515,7 @@ async def create_evaluation_run(
44814515
output_config=output_config,
44824516
metrics=resolved_metrics,
44834517
loss_analysis_config=resolved_loss_configs,
4518+
allow_cross_region_model=getattr(config, "allow_cross_region_model", None),
44844519
)
44854520
resolved_inference_configs = _evals_common._resolve_inference_configs(
44864521
self._api_client, resolved_dataset, inference_configs, parsed_agent_info

vertexai/_genai/types/common.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2401,6 +2401,15 @@ class EvaluationRunConfig(_common.BaseModel):
24012401
default=None,
24022402
description="""Specifications for loss analysis. Each config specifies a metric and candidate to analyze for loss patterns.""",
24032403
)
2404+
allow_cross_region_model: Optional[bool] = Field(
2405+
default=None,
2406+
description="""Allows the evaluation run to use cross region models. When this
2407+
flag is set, the service may route traffic to other regions if a model is
2408+
unavailable in the current region (e.g., to a `global`endpoint). If a
2409+
fully-qualified model endpoint resource name with a different region than
2410+
the run location is provided elsewhere in the run config, this flag must
2411+
be set to true or the request will fail.""",
2412+
)
24042413

24052414

24062415
class EvaluationRunConfigDict(TypedDict, total=False):
@@ -2421,6 +2430,14 @@ class EvaluationRunConfigDict(TypedDict, total=False):
24212430
loss_analysis_config: Optional[list[LossAnalysisConfigDict]]
24222431
"""Specifications for loss analysis. Each config specifies a metric and candidate to analyze for loss patterns."""
24232432

2433+
allow_cross_region_model: Optional[bool]
2434+
"""Allows the evaluation run to use cross region models. When this
2435+
flag is set, the service may route traffic to other regions if a model is
2436+
unavailable in the current region (e.g., to a `global`endpoint). If a
2437+
fully-qualified model endpoint resource name with a different region than
2438+
the run location is provided elsewhere in the run config, this flag must
2439+
be set to true or the request will fail."""
2440+
24242441

24252442
EvaluationRunConfigOrDict = Union[EvaluationRunConfig, EvaluationRunConfigDict]
24262443

@@ -2551,6 +2568,15 @@ class CreateEvaluationRunConfig(_common.BaseModel):
25512568
http_options: Optional[genai_types.HttpOptions] = Field(
25522569
default=None, description="""Used to override HTTP request options."""
25532570
)
2571+
allow_cross_region_model: Optional[bool] = Field(
2572+
default=None,
2573+
description="""Allows the evaluation run to use cross region models. When this
2574+
flag is set, the service may route traffic to other regions if a model is
2575+
unavailable in the current region (e.g., to a `global`endpoint). If a
2576+
fully-qualified model endpoint resource name with a different region than
2577+
the run location is provided elsewhere in the run config, this flag must
2578+
be set to true or the request will fail.""",
2579+
)
25542580

25552581

25562582
class CreateEvaluationRunConfigDict(TypedDict, total=False):
@@ -2559,6 +2585,14 @@ class CreateEvaluationRunConfigDict(TypedDict, total=False):
25592585
http_options: Optional[genai_types.HttpOptionsDict]
25602586
"""Used to override HTTP request options."""
25612587

2588+
allow_cross_region_model: Optional[bool]
2589+
"""Allows the evaluation run to use cross region models. When this
2590+
flag is set, the service may route traffic to other regions if a model is
2591+
unavailable in the current region (e.g., to a `global`endpoint). If a
2592+
fully-qualified model endpoint resource name with a different region than
2593+
the run location is provided elsewhere in the run config, this flag must
2594+
be set to true or the request will fail."""
2595+
25622596

25632597
CreateEvaluationRunConfigOrDict = Union[
25642598
CreateEvaluationRunConfig, CreateEvaluationRunConfigDict

0 commit comments

Comments
 (0)