mirror of
https://github.com/EvolutionAPI/adk-python.git
synced 2025-07-16 04:02:55 -06:00
feat:Make VertexAiSessionService true async.
PiperOrigin-RevId: 762547133
This commit is contained in:
parent
79681e3513
commit
d212e50c10
@ -11,9 +11,9 @@
|
|||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
import re
|
import re
|
||||||
import time
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
@ -69,7 +69,8 @@ class VertexAiSessionService(BaseSessionService):
|
|||||||
if state:
|
if state:
|
||||||
session_json_dict['session_state'] = state
|
session_json_dict['session_state'] = state
|
||||||
|
|
||||||
api_response = self.api_client.request(
|
api_client = _get_api_client(self.project, self.location)
|
||||||
|
api_response = await api_client.async_request(
|
||||||
http_method='POST',
|
http_method='POST',
|
||||||
path=f'reasoningEngines/{reasoning_engine_id}/sessions',
|
path=f'reasoningEngines/{reasoning_engine_id}/sessions',
|
||||||
request_dict=session_json_dict,
|
request_dict=session_json_dict,
|
||||||
@ -81,7 +82,7 @@ class VertexAiSessionService(BaseSessionService):
|
|||||||
|
|
||||||
max_retry_attempt = 5
|
max_retry_attempt = 5
|
||||||
while max_retry_attempt >= 0:
|
while max_retry_attempt >= 0:
|
||||||
lro_response = self.api_client.request(
|
lro_response = await api_client.async_request(
|
||||||
http_method='GET',
|
http_method='GET',
|
||||||
path=f'operations/{operation_id}',
|
path=f'operations/{operation_id}',
|
||||||
request_dict={},
|
request_dict={},
|
||||||
@ -90,11 +91,11 @@ class VertexAiSessionService(BaseSessionService):
|
|||||||
if lro_response.get('done', None):
|
if lro_response.get('done', None):
|
||||||
break
|
break
|
||||||
|
|
||||||
time.sleep(1)
|
await asyncio.sleep(1)
|
||||||
max_retry_attempt -= 1
|
max_retry_attempt -= 1
|
||||||
|
|
||||||
# Get session resource
|
# Get session resource
|
||||||
get_session_api_response = self.api_client.request(
|
get_session_api_response = await api_client.async_request(
|
||||||
http_method='GET',
|
http_method='GET',
|
||||||
path=f'reasoningEngines/{reasoning_engine_id}/sessions/{session_id}',
|
path=f'reasoningEngines/{reasoning_engine_id}/sessions/{session_id}',
|
||||||
request_dict={},
|
request_dict={},
|
||||||
@ -124,7 +125,8 @@ class VertexAiSessionService(BaseSessionService):
|
|||||||
reasoning_engine_id = _parse_reasoning_engine_id(app_name)
|
reasoning_engine_id = _parse_reasoning_engine_id(app_name)
|
||||||
|
|
||||||
# Get session resource
|
# Get session resource
|
||||||
get_session_api_response = self.api_client.request(
|
api_client = _get_api_client(self.project, self.location)
|
||||||
|
get_session_api_response = await api_client.async_request(
|
||||||
http_method='GET',
|
http_method='GET',
|
||||||
path=f'reasoningEngines/{reasoning_engine_id}/sessions/{session_id}',
|
path=f'reasoningEngines/{reasoning_engine_id}/sessions/{session_id}',
|
||||||
request_dict={},
|
request_dict={},
|
||||||
@ -142,7 +144,7 @@ class VertexAiSessionService(BaseSessionService):
|
|||||||
last_update_time=update_timestamp,
|
last_update_time=update_timestamp,
|
||||||
)
|
)
|
||||||
|
|
||||||
list_events_api_response = self.api_client.request(
|
list_events_api_response = await api_client.async_request(
|
||||||
http_method='GET',
|
http_method='GET',
|
||||||
path=f'reasoningEngines/{reasoning_engine_id}/sessions/{session_id}/events',
|
path=f'reasoningEngines/{reasoning_engine_id}/sessions/{session_id}/events',
|
||||||
request_dict={},
|
request_dict={},
|
||||||
@ -181,7 +183,8 @@ class VertexAiSessionService(BaseSessionService):
|
|||||||
) -> ListSessionsResponse:
|
) -> ListSessionsResponse:
|
||||||
reasoning_engine_id = _parse_reasoning_engine_id(app_name)
|
reasoning_engine_id = _parse_reasoning_engine_id(app_name)
|
||||||
|
|
||||||
api_response = self.api_client.request(
|
api_client = _get_api_client(self.project, self.location)
|
||||||
|
api_response = await api_client.async_request(
|
||||||
http_method='GET',
|
http_method='GET',
|
||||||
path=f'reasoningEngines/{reasoning_engine_id}/sessions?filter=user_id={user_id}',
|
path=f'reasoningEngines/{reasoning_engine_id}/sessions?filter=user_id={user_id}',
|
||||||
request_dict={},
|
request_dict={},
|
||||||
@ -207,7 +210,8 @@ class VertexAiSessionService(BaseSessionService):
|
|||||||
self, *, app_name: str, user_id: str, session_id: str
|
self, *, app_name: str, user_id: str, session_id: str
|
||||||
) -> None:
|
) -> None:
|
||||||
reasoning_engine_id = _parse_reasoning_engine_id(app_name)
|
reasoning_engine_id = _parse_reasoning_engine_id(app_name)
|
||||||
self.api_client.request(
|
api_client = _get_api_client(self.project, self.location)
|
||||||
|
await api_client.async_request(
|
||||||
http_method='DELETE',
|
http_method='DELETE',
|
||||||
path=f'reasoningEngines/{reasoning_engine_id}/sessions/{session_id}',
|
path=f'reasoningEngines/{reasoning_engine_id}/sessions/{session_id}',
|
||||||
request_dict={},
|
request_dict={},
|
||||||
@ -219,15 +223,25 @@ class VertexAiSessionService(BaseSessionService):
|
|||||||
await super().append_event(session=session, event=event)
|
await super().append_event(session=session, event=event)
|
||||||
|
|
||||||
reasoning_engine_id = _parse_reasoning_engine_id(session.app_name)
|
reasoning_engine_id = _parse_reasoning_engine_id(session.app_name)
|
||||||
self.api_client.request(
|
api_client = _get_api_client(self.project, self.location)
|
||||||
|
await api_client.async_request(
|
||||||
http_method='POST',
|
http_method='POST',
|
||||||
path=f'reasoningEngines/{reasoning_engine_id}/sessions/{session.id}:appendEvent',
|
path=f'reasoningEngines/{reasoning_engine_id}/sessions/{session.id}:appendEvent',
|
||||||
request_dict=_convert_event_to_json(event),
|
request_dict=_convert_event_to_json(event),
|
||||||
)
|
)
|
||||||
|
|
||||||
return event
|
return event
|
||||||
|
|
||||||
|
|
||||||
|
def _get_api_client(project: str, location: str):
|
||||||
|
"""Instantiates an API client for the given project and location.
|
||||||
|
|
||||||
|
It needs to be instantiated inside each request so that the event loop
|
||||||
|
management.
|
||||||
|
"""
|
||||||
|
client = genai.Client(vertexai=True, project=project, location=location)
|
||||||
|
return client._api_client
|
||||||
|
|
||||||
|
|
||||||
def _convert_event_to_json(event: Event):
|
def _convert_event_to_json(event: Event):
|
||||||
metadata_json = {
|
metadata_json = {
|
||||||
'partial': event.partial,
|
'partial': event.partial,
|
||||||
|
@ -15,6 +15,7 @@
|
|||||||
import re
|
import re
|
||||||
import this
|
import this
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
from unittest import mock
|
||||||
|
|
||||||
from dateutil.parser import isoparse
|
from dateutil.parser import isoparse
|
||||||
from google.adk.events import Event
|
from google.adk.events import Event
|
||||||
@ -123,7 +124,9 @@ class MockApiClient:
|
|||||||
this.session_dict: dict[str, Any] = {}
|
this.session_dict: dict[str, Any] = {}
|
||||||
this.event_dict: dict[str, list[Any]] = {}
|
this.event_dict: dict[str, list[Any]] = {}
|
||||||
|
|
||||||
def request(self, http_method: str, path: str, request_dict: dict[str, Any]):
|
async def async_request(
|
||||||
|
self, http_method: str, path: str, request_dict: dict[str, Any]
|
||||||
|
):
|
||||||
"""Mocks the API Client request method."""
|
"""Mocks the API Client request method."""
|
||||||
if http_method == 'GET':
|
if http_method == 'GET':
|
||||||
if re.match(SESSION_REGEX, path):
|
if re.match(SESSION_REGEX, path):
|
||||||
@ -194,22 +197,31 @@ class MockApiClient:
|
|||||||
|
|
||||||
def mock_vertex_ai_session_service():
|
def mock_vertex_ai_session_service():
|
||||||
"""Creates a mock Vertex AI Session service for testing."""
|
"""Creates a mock Vertex AI Session service for testing."""
|
||||||
service = VertexAiSessionService(
|
return VertexAiSessionService(
|
||||||
project='test-project', location='test-location'
|
project='test-project', location='test-location'
|
||||||
)
|
)
|
||||||
service.api_client = MockApiClient()
|
|
||||||
service.api_client.session_dict = {
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_get_api_client():
|
||||||
|
api_client = MockApiClient()
|
||||||
|
api_client.session_dict = {
|
||||||
'1': MOCK_SESSION_JSON_1,
|
'1': MOCK_SESSION_JSON_1,
|
||||||
'2': MOCK_SESSION_JSON_2,
|
'2': MOCK_SESSION_JSON_2,
|
||||||
'3': MOCK_SESSION_JSON_3,
|
'3': MOCK_SESSION_JSON_3,
|
||||||
}
|
}
|
||||||
service.api_client.event_dict = {
|
api_client.event_dict = {
|
||||||
'1': MOCK_EVENT_JSON,
|
'1': MOCK_EVENT_JSON,
|
||||||
}
|
}
|
||||||
return service
|
with mock.patch(
|
||||||
|
"google.adk.sessions.vertex_ai_session_service._get_api_client",
|
||||||
|
return_value=api_client,
|
||||||
|
):
|
||||||
|
yield
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.usefixtures('mock_get_api_client')
|
||||||
async def test_get_empty_session():
|
async def test_get_empty_session():
|
||||||
session_service = mock_vertex_ai_session_service()
|
session_service = mock_vertex_ai_session_service()
|
||||||
with pytest.raises(ValueError) as excinfo:
|
with pytest.raises(ValueError) as excinfo:
|
||||||
@ -220,6 +232,7 @@ async def test_get_empty_session():
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.usefixtures('mock_get_api_client')
|
||||||
async def test_get_and_delete_session():
|
async def test_get_and_delete_session():
|
||||||
session_service = mock_vertex_ai_session_service()
|
session_service = mock_vertex_ai_session_service()
|
||||||
|
|
||||||
@ -241,6 +254,7 @@ async def test_get_and_delete_session():
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.usefixtures('mock_get_api_client')
|
||||||
async def test_list_sessions():
|
async def test_list_sessions():
|
||||||
session_service = mock_vertex_ai_session_service()
|
session_service = mock_vertex_ai_session_service()
|
||||||
sessions = await session_service.list_sessions(app_name='123', user_id='user')
|
sessions = await session_service.list_sessions(app_name='123', user_id='user')
|
||||||
@ -250,6 +264,7 @@ async def test_list_sessions():
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.usefixtures('mock_get_api_client')
|
||||||
async def test_create_session():
|
async def test_create_session():
|
||||||
session_service = mock_vertex_ai_session_service()
|
session_service = mock_vertex_ai_session_service()
|
||||||
|
|
||||||
@ -269,6 +284,7 @@ async def test_create_session():
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.usefixtures('mock_get_api_client')
|
||||||
async def test_create_session_with_custom_session_id():
|
async def test_create_session_with_custom_session_id():
|
||||||
session_service = mock_vertex_ai_session_service()
|
session_service = mock_vertex_ai_session_service()
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user