feat:Make VertexAiSessionService true async.

PiperOrigin-RevId: 762547133
This commit is contained in:
Shangjie Chen 2025-05-23 13:30:57 -07:00 committed by Copybara-Service
parent 79681e3513
commit d212e50c10
2 changed files with 47 additions and 17 deletions

View File

@ -11,9 +11,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import asyncio
import logging import logging
import re import re
import time
from typing import Any from typing import Any
from typing import Optional from typing import Optional
@ -69,7 +69,8 @@ class VertexAiSessionService(BaseSessionService):
if state: if state:
session_json_dict['session_state'] = state session_json_dict['session_state'] = state
api_response = self.api_client.request( api_client = _get_api_client(self.project, self.location)
api_response = await api_client.async_request(
http_method='POST', http_method='POST',
path=f'reasoningEngines/{reasoning_engine_id}/sessions', path=f'reasoningEngines/{reasoning_engine_id}/sessions',
request_dict=session_json_dict, request_dict=session_json_dict,
@ -81,7 +82,7 @@ class VertexAiSessionService(BaseSessionService):
max_retry_attempt = 5 max_retry_attempt = 5
while max_retry_attempt >= 0: while max_retry_attempt >= 0:
lro_response = self.api_client.request( lro_response = await api_client.async_request(
http_method='GET', http_method='GET',
path=f'operations/{operation_id}', path=f'operations/{operation_id}',
request_dict={}, request_dict={},
@ -90,11 +91,11 @@ class VertexAiSessionService(BaseSessionService):
if lro_response.get('done', None): if lro_response.get('done', None):
break break
time.sleep(1) await asyncio.sleep(1)
max_retry_attempt -= 1 max_retry_attempt -= 1
# Get session resource # Get session resource
get_session_api_response = self.api_client.request( get_session_api_response = await api_client.async_request(
http_method='GET', http_method='GET',
path=f'reasoningEngines/{reasoning_engine_id}/sessions/{session_id}', path=f'reasoningEngines/{reasoning_engine_id}/sessions/{session_id}',
request_dict={}, request_dict={},
@ -124,7 +125,8 @@ class VertexAiSessionService(BaseSessionService):
reasoning_engine_id = _parse_reasoning_engine_id(app_name) reasoning_engine_id = _parse_reasoning_engine_id(app_name)
# Get session resource # Get session resource
get_session_api_response = self.api_client.request( api_client = _get_api_client(self.project, self.location)
get_session_api_response = await api_client.async_request(
http_method='GET', http_method='GET',
path=f'reasoningEngines/{reasoning_engine_id}/sessions/{session_id}', path=f'reasoningEngines/{reasoning_engine_id}/sessions/{session_id}',
request_dict={}, request_dict={},
@ -142,7 +144,7 @@ class VertexAiSessionService(BaseSessionService):
last_update_time=update_timestamp, last_update_time=update_timestamp,
) )
list_events_api_response = self.api_client.request( list_events_api_response = await api_client.async_request(
http_method='GET', http_method='GET',
path=f'reasoningEngines/{reasoning_engine_id}/sessions/{session_id}/events', path=f'reasoningEngines/{reasoning_engine_id}/sessions/{session_id}/events',
request_dict={}, request_dict={},
@ -181,7 +183,8 @@ class VertexAiSessionService(BaseSessionService):
) -> ListSessionsResponse: ) -> ListSessionsResponse:
reasoning_engine_id = _parse_reasoning_engine_id(app_name) reasoning_engine_id = _parse_reasoning_engine_id(app_name)
api_response = self.api_client.request( api_client = _get_api_client(self.project, self.location)
api_response = await api_client.async_request(
http_method='GET', http_method='GET',
path=f'reasoningEngines/{reasoning_engine_id}/sessions?filter=user_id={user_id}', path=f'reasoningEngines/{reasoning_engine_id}/sessions?filter=user_id={user_id}',
request_dict={}, request_dict={},
@ -207,7 +210,8 @@ class VertexAiSessionService(BaseSessionService):
self, *, app_name: str, user_id: str, session_id: str self, *, app_name: str, user_id: str, session_id: str
) -> None: ) -> None:
reasoning_engine_id = _parse_reasoning_engine_id(app_name) reasoning_engine_id = _parse_reasoning_engine_id(app_name)
self.api_client.request( api_client = _get_api_client(self.project, self.location)
await api_client.async_request(
http_method='DELETE', http_method='DELETE',
path=f'reasoningEngines/{reasoning_engine_id}/sessions/{session_id}', path=f'reasoningEngines/{reasoning_engine_id}/sessions/{session_id}',
request_dict={}, request_dict={},
@ -219,15 +223,25 @@ class VertexAiSessionService(BaseSessionService):
await super().append_event(session=session, event=event) await super().append_event(session=session, event=event)
reasoning_engine_id = _parse_reasoning_engine_id(session.app_name) reasoning_engine_id = _parse_reasoning_engine_id(session.app_name)
self.api_client.request( api_client = _get_api_client(self.project, self.location)
await api_client.async_request(
http_method='POST', http_method='POST',
path=f'reasoningEngines/{reasoning_engine_id}/sessions/{session.id}:appendEvent', path=f'reasoningEngines/{reasoning_engine_id}/sessions/{session.id}:appendEvent',
request_dict=_convert_event_to_json(event), request_dict=_convert_event_to_json(event),
) )
return event return event
def _get_api_client(project: str, location: str):
"""Instantiates an API client for the given project and location.
It needs to be instantiated inside each request so that the event loop
management.
"""
client = genai.Client(vertexai=True, project=project, location=location)
return client._api_client
def _convert_event_to_json(event: Event): def _convert_event_to_json(event: Event):
metadata_json = { metadata_json = {
'partial': event.partial, 'partial': event.partial,

View File

@ -15,6 +15,7 @@
import re import re
import this import this
from typing import Any from typing import Any
from unittest import mock
from dateutil.parser import isoparse from dateutil.parser import isoparse
from google.adk.events import Event from google.adk.events import Event
@ -123,7 +124,9 @@ class MockApiClient:
this.session_dict: dict[str, Any] = {} this.session_dict: dict[str, Any] = {}
this.event_dict: dict[str, list[Any]] = {} this.event_dict: dict[str, list[Any]] = {}
def request(self, http_method: str, path: str, request_dict: dict[str, Any]): async def async_request(
self, http_method: str, path: str, request_dict: dict[str, Any]
):
"""Mocks the API Client request method.""" """Mocks the API Client request method."""
if http_method == 'GET': if http_method == 'GET':
if re.match(SESSION_REGEX, path): if re.match(SESSION_REGEX, path):
@ -194,22 +197,31 @@ class MockApiClient:
def mock_vertex_ai_session_service(): def mock_vertex_ai_session_service():
"""Creates a mock Vertex AI Session service for testing.""" """Creates a mock Vertex AI Session service for testing."""
service = VertexAiSessionService( return VertexAiSessionService(
project='test-project', location='test-location' project='test-project', location='test-location'
) )
service.api_client = MockApiClient()
service.api_client.session_dict = {
@pytest.fixture
def mock_get_api_client():
api_client = MockApiClient()
api_client.session_dict = {
'1': MOCK_SESSION_JSON_1, '1': MOCK_SESSION_JSON_1,
'2': MOCK_SESSION_JSON_2, '2': MOCK_SESSION_JSON_2,
'3': MOCK_SESSION_JSON_3, '3': MOCK_SESSION_JSON_3,
} }
service.api_client.event_dict = { api_client.event_dict = {
'1': MOCK_EVENT_JSON, '1': MOCK_EVENT_JSON,
} }
return service with mock.patch(
"google.adk.sessions.vertex_ai_session_service._get_api_client",
return_value=api_client,
):
yield
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.usefixtures('mock_get_api_client')
async def test_get_empty_session(): async def test_get_empty_session():
session_service = mock_vertex_ai_session_service() session_service = mock_vertex_ai_session_service()
with pytest.raises(ValueError) as excinfo: with pytest.raises(ValueError) as excinfo:
@ -220,6 +232,7 @@ async def test_get_empty_session():
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.usefixtures('mock_get_api_client')
async def test_get_and_delete_session(): async def test_get_and_delete_session():
session_service = mock_vertex_ai_session_service() session_service = mock_vertex_ai_session_service()
@ -241,6 +254,7 @@ async def test_get_and_delete_session():
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.usefixtures('mock_get_api_client')
async def test_list_sessions(): async def test_list_sessions():
session_service = mock_vertex_ai_session_service() session_service = mock_vertex_ai_session_service()
sessions = await session_service.list_sessions(app_name='123', user_id='user') sessions = await session_service.list_sessions(app_name='123', user_id='user')
@ -250,6 +264,7 @@ async def test_list_sessions():
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.usefixtures('mock_get_api_client')
async def test_create_session(): async def test_create_session():
session_service = mock_vertex_ai_session_service() session_service = mock_vertex_ai_session_service()
@ -269,6 +284,7 @@ async def test_create_session():
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.usefixtures('mock_get_api_client')
async def test_create_session_with_custom_session_id(): async def test_create_session_with_custom_session_id():
session_service = mock_vertex_ai_session_service() session_service = mock_vertex_ai_session_service()