Skip to content

Commit f5909b2

Browse files
vertex-sdk-botcopybara-github
authored andcommitted
feat: Add RetrieveSkills semantic search method in Vertex AI Skill Registry SDK
PiperOrigin-RevId: 911650461
1 parent d947295 commit f5909b2

5 files changed

Lines changed: 419 additions & 4 deletions

File tree

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
"""Tests the skills.retrieve() method against the autopush endpoint."""
2+
3+
from tests.unit.vertexai.genai.replays import pytest_helper
4+
from vertexai._genai import types
5+
6+
pytestmark = pytest_helper.setup(
7+
file=__file__,
8+
globals_for_file=globals(),
9+
)
10+
11+
12+
def test_retrieve_skills(client):
13+
# Target the prod endpoint for the Skill Registry API
14+
client._api_client._http_options.base_url = (
15+
"https://us-central1-aiplatform.googleapis.com"
16+
)
17+
18+
response = client.skills.retrieve(query="stubby", config={"top_k": 2})
19+
20+
assert isinstance(response, types.RetrieveSkillsResponse)
21+
assert response.retrieved_skills is not None
22+
23+
for retrieved in response.retrieved_skills:
24+
assert isinstance(retrieved, types.RetrievedSkill)
25+
assert retrieved.skill_name is not None
26+
assert retrieved.description is not None
Lines changed: 106 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# //third_party/py/google/cloud/aiplatform/tests/unit/vertexai/genai/test_genai_skills.py
22
import json
33
from unittest import mock
4+
import google.auth.credentials
45
from vertexai import _genai as genai
56
from vertexai._genai import client as vertexai_client
67
from google.genai import types as genai_types
@@ -9,31 +10,42 @@
910

1011
@pytest.fixture
1112
def skills_client():
12-
creds = mock.MagicMock()
13+
creds = mock.create_autospec(google.auth.credentials.Credentials, instance=True)
1314
creds.token = "test_token"
1415
client = vertexai_client.Client(
1516
project="test-project", location="test-location", credentials=creds
1617
)
1718
return client.skills
1819

1920

21+
@pytest.fixture
22+
def async_skills_client():
23+
creds = mock.create_autospec(google.auth.credentials.Credentials, instance=True)
24+
creds.token = "test_token"
25+
client = vertexai_client.Client(
26+
project="test-project", location="test-location", credentials=creds
27+
)
28+
return client.aio.skills
29+
30+
2031
class TestGenaiSkills:
2132
mock_get_skill_response = {
2233
"name": "projects/test-project/locations/test-location/skills/test-skill",
2334
"displayName": "My Test Skill",
2435
}
2536

2637
def test_get_skill(self, skills_client):
27-
"""Tests the get_skill method."""
28-
with mock.patch.object(skills_client._api_client, "request") as request_mock:
38+
with mock.patch.object(
39+
skills_client._api_client, "request", autospec=True
40+
) as request_mock:
2941
request_mock.return_value = genai_types.HttpResponse(
3042
body=json.dumps(self.mock_get_skill_response)
3143
)
3244
skill_name = (
3345
"projects/test-project/locations/test-location/skills/test-skill"
3446
)
3547
skill = skills_client.get(name=skill_name)
36-
request_mock.assert_called_with(
48+
request_mock.assert_called_once_with(
3749
"get",
3850
skill_name,
3951
{"_url": {"name": skill_name}},
@@ -42,3 +54,93 @@ def test_get_skill(self, skills_client):
4254
assert isinstance(skill, genai.types.Skill)
4355
assert skill.name == skill_name
4456
assert skill.display_name == "My Test Skill"
57+
58+
def test_retrieve_skills_response(self, skills_client):
59+
mock_retrieve_response = {
60+
"retrievedSkills": [
61+
{
62+
"skillName": (
63+
"projects/test-project/locations/test-location/skills/skill-1"
64+
),
65+
"description": "Skill 1 Description",
66+
},
67+
{
68+
"skillName": (
69+
"projects/test-project/locations/test-location/skills/skill-2"
70+
),
71+
"description": "Skill 2 Description",
72+
},
73+
]
74+
}
75+
76+
with mock.patch.object(
77+
skills_client._api_client, "request", autospec=True
78+
) as request_mock:
79+
request_mock.return_value = genai_types.HttpResponse(
80+
body=json.dumps(mock_retrieve_response)
81+
)
82+
83+
response = skills_client.retrieve(query="test query", config={"top_k": 5})
84+
85+
assert isinstance(response, genai.types.RetrieveSkillsResponse)
86+
assert len(response.retrieved_skills) == 2
87+
assert response.retrieved_skills[0].skill_name == (
88+
"projects/test-project/locations/test-location/skills/skill-1"
89+
)
90+
assert response.retrieved_skills[0].description == "Skill 1 Description"
91+
92+
def test_retrieve_skills_request_params(self, skills_client):
93+
mock_retrieve_response = {"retrievedSkills": []}
94+
95+
with mock.patch.object(
96+
skills_client._api_client, "request", autospec=True
97+
) as request_mock:
98+
request_mock.return_value = genai_types.HttpResponse(
99+
body=json.dumps(mock_retrieve_response)
100+
)
101+
102+
skills_client.retrieve(query="test query", config={"top_k": 5})
103+
104+
request_mock.assert_called_once_with(
105+
"get",
106+
"skills:retrieve?query=test+query&topK=5",
107+
{"_query": {"query": "test query", "topK": 5}},
108+
None,
109+
)
110+
111+
@pytest.mark.asyncio
112+
async def test_retrieve_skills_async(self, async_skills_client):
113+
mock_retrieve_response = {
114+
"retrievedSkills": [
115+
{
116+
"skillName": (
117+
"projects/test-project/locations/test-location/skills/skill-1"
118+
),
119+
"description": "Skill 1 Description",
120+
}
121+
]
122+
}
123+
124+
with mock.patch.object(
125+
async_skills_client._api_client, "async_request", autospec=True
126+
) as request_mock:
127+
request_mock.return_value = genai_types.HttpResponse(
128+
body=json.dumps(mock_retrieve_response)
129+
)
130+
131+
response = await async_skills_client.retrieve(
132+
query="test query", config={"top_k": 1}
133+
)
134+
135+
assert isinstance(response, genai.types.RetrieveSkillsResponse)
136+
assert len(response.retrieved_skills) == 1
137+
assert response.retrieved_skills[0].skill_name == (
138+
"projects/test-project/locations/test-location/skills/skill-1"
139+
)
140+
141+
request_mock.assert_called_once_with(
142+
"get",
143+
"skills:retrieve?query=test+query&topK=1",
144+
{"_query": {"query": "test query", "topK": 1}},
145+
None,
146+
)

vertexai/_genai/skills.py

Lines changed: 170 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,36 @@ def _GetSkillRequestParameters_to_vertex(
4444
return to_object
4545

4646

47+
def _RetrieveSkillsConfig_to_vertex(
48+
from_object: Union[dict[str, Any], object],
49+
parent_object: Optional[dict[str, Any]] = None,
50+
) -> dict[str, Any]:
51+
to_object: dict[str, Any] = {}
52+
53+
if getv(from_object, ["top_k"]) is not None:
54+
setv(parent_object, ["_query", "topK"], getv(from_object, ["top_k"]))
55+
56+
return to_object
57+
58+
59+
def _RetrieveSkillsRequestParameters_to_vertex(
60+
from_object: Union[dict[str, Any], object],
61+
parent_object: Optional[dict[str, Any]] = None,
62+
) -> dict[str, Any]:
63+
to_object: dict[str, Any] = {}
64+
if getv(from_object, ["query"]) is not None:
65+
setv(to_object, ["_query", "query"], getv(from_object, ["query"]))
66+
67+
if getv(from_object, ["config"]) is not None:
68+
setv(
69+
to_object,
70+
["config"],
71+
_RetrieveSkillsConfig_to_vertex(getv(from_object, ["config"]), to_object),
72+
)
73+
74+
return to_object
75+
76+
4777
class Skills(_api_module.BaseModule):
4878
"""Class for managing Skills in the Skill Registry."""
4979

@@ -116,6 +146,75 @@ def get(
116146
self._api_client._verify_response(return_value)
117147
return return_value
118148

149+
def retrieve(
150+
self, *, query: str, config: Optional[types.RetrieveSkillsConfigOrDict] = None
151+
) -> types.RetrieveSkillsResponse:
152+
"""
153+
Retrieves skills semantically matched to a query.
154+
"""
155+
156+
parameter_model = types._RetrieveSkillsRequestParameters(
157+
query=query,
158+
config=config,
159+
)
160+
161+
request_url_dict: Optional[dict[str, str]]
162+
if not self._api_client.vertexai:
163+
raise ValueError(
164+
"This method is only supported in the Gemini Enterprise Agent Platform (previously known as Vertex AI) client."
165+
)
166+
else:
167+
request_dict = _RetrieveSkillsRequestParameters_to_vertex(parameter_model)
168+
request_url_dict = request_dict.get("_url")
169+
if request_url_dict:
170+
path = "skills:retrieve".format_map(request_url_dict)
171+
else:
172+
path = "skills:retrieve"
173+
174+
query_params = request_dict.get("_query")
175+
if query_params:
176+
path = f"{path}?{urlencode(query_params)}"
177+
# TODO: remove the hack that pops config.
178+
request_dict.pop("config", None)
179+
180+
http_options: Optional[types.HttpOptions] = None
181+
if (
182+
parameter_model.config is not None
183+
and parameter_model.config.http_options is not None
184+
):
185+
http_options = parameter_model.config.http_options
186+
187+
request_dict = _common.convert_to_dict(request_dict)
188+
request_dict = _common.encode_unserializable_types(request_dict)
189+
190+
response = self._api_client.request("get", path, request_dict, http_options)
191+
192+
response_dict = {} if not response.body else json.loads(response.body)
193+
194+
return_value = types.RetrieveSkillsResponse._from_response(
195+
response=response_dict,
196+
kwargs=(
197+
{
198+
"config": {
199+
"response_schema": getattr(
200+
parameter_model.config, "response_schema", None
201+
),
202+
"response_json_schema": getattr(
203+
parameter_model.config, "response_json_schema", None
204+
),
205+
"include_all_fields": getattr(
206+
parameter_model.config, "include_all_fields", None
207+
),
208+
}
209+
}
210+
if getattr(parameter_model, "config", None)
211+
else {}
212+
),
213+
)
214+
215+
self._api_client._verify_response(return_value)
216+
return return_value
217+
119218

120219
class AsyncSkills(_api_module.BaseModule):
121220
"""Class for managing Skills in the Skill Registry."""
@@ -190,3 +289,74 @@ async def get(
190289

191290
self._api_client._verify_response(return_value)
192291
return return_value
292+
293+
async def retrieve(
294+
self, *, query: str, config: Optional[types.RetrieveSkillsConfigOrDict] = None
295+
) -> types.RetrieveSkillsResponse:
296+
"""
297+
Retrieves skills semantically matched to a query.
298+
"""
299+
300+
parameter_model = types._RetrieveSkillsRequestParameters(
301+
query=query,
302+
config=config,
303+
)
304+
305+
request_url_dict: Optional[dict[str, str]]
306+
if not self._api_client.vertexai:
307+
raise ValueError(
308+
"This method is only supported in the Gemini Enterprise Agent Platform (previously known as Vertex AI) client."
309+
)
310+
else:
311+
request_dict = _RetrieveSkillsRequestParameters_to_vertex(parameter_model)
312+
request_url_dict = request_dict.get("_url")
313+
if request_url_dict:
314+
path = "skills:retrieve".format_map(request_url_dict)
315+
else:
316+
path = "skills:retrieve"
317+
318+
query_params = request_dict.get("_query")
319+
if query_params:
320+
path = f"{path}?{urlencode(query_params)}"
321+
# TODO: remove the hack that pops config.
322+
request_dict.pop("config", None)
323+
324+
http_options: Optional[types.HttpOptions] = None
325+
if (
326+
parameter_model.config is not None
327+
and parameter_model.config.http_options is not None
328+
):
329+
http_options = parameter_model.config.http_options
330+
331+
request_dict = _common.convert_to_dict(request_dict)
332+
request_dict = _common.encode_unserializable_types(request_dict)
333+
334+
response = await self._api_client.async_request(
335+
"get", path, request_dict, http_options
336+
)
337+
338+
response_dict = {} if not response.body else json.loads(response.body)
339+
340+
return_value = types.RetrieveSkillsResponse._from_response(
341+
response=response_dict,
342+
kwargs=(
343+
{
344+
"config": {
345+
"response_schema": getattr(
346+
parameter_model.config, "response_schema", None
347+
),
348+
"response_json_schema": getattr(
349+
parameter_model.config, "response_json_schema", None
350+
),
351+
"include_all_fields": getattr(
352+
parameter_model.config, "include_all_fields", None
353+
),
354+
}
355+
}
356+
if getattr(parameter_model, "config", None)
357+
else {}
358+
),
359+
)
360+
361+
self._api_client._verify_response(return_value)
362+
return return_value

0 commit comments

Comments
 (0)