mirror of
https://github.com/EvolutionAPI/adk-python.git
synced 2025-12-18 11:22:22 -06:00
feat! Update session service interface to be async.
Also keep the sync version in the InMemorySessionService as create_session_sync() as a temporary migration option. PiperOrigin-RevId: 759252188
This commit is contained in:
committed by
Copybara-Service
parent
5b3204c356
commit
1804ca39a6
@@ -55,7 +55,7 @@ async def run_input_file(
|
||||
input_file = InputFile.model_validate_json(f.read())
|
||||
input_file.state['_time'] = datetime.now()
|
||||
|
||||
session = await session_service.create_session(
|
||||
session = session_service.create_session(
|
||||
app_name=app_name, user_id=user_id, state=input_file.state
|
||||
)
|
||||
for query in input_file.queries:
|
||||
@@ -130,7 +130,7 @@ async def run_cli(
|
||||
agent_module_path = os.path.join(agent_parent_dir, agent_folder_name)
|
||||
agent_module = importlib.import_module(agent_folder_name)
|
||||
user_id = 'test_user'
|
||||
session = await session_service.create_session(
|
||||
session = session_service.create_session(
|
||||
app_name=agent_folder_name, user_id=user_id
|
||||
)
|
||||
root_agent = agent_module.agent.root_agent
|
||||
@@ -145,12 +145,14 @@ async def run_cli(
|
||||
input_path=input_file,
|
||||
)
|
||||
elif saved_session_file:
|
||||
|
||||
loaded_session = None
|
||||
with open(saved_session_file, 'r') as f:
|
||||
loaded_session = Session.model_validate_json(f.read())
|
||||
|
||||
if loaded_session:
|
||||
for event in loaded_session.events:
|
||||
await session_service.append_event(session, event)
|
||||
session_service.append_event(session, event)
|
||||
content = event.content
|
||||
if not content or not content.parts or not content.parts[0].text:
|
||||
continue
|
||||
@@ -179,7 +181,7 @@ async def run_cli(
|
||||
session_path = f'{agent_module_path}/{session_id}.session.json'
|
||||
|
||||
# Fetch the session again to get all the details.
|
||||
session = await session_service.get_session(
|
||||
session = session_service.get_session(
|
||||
app_name=session.app_name,
|
||||
user_id=session.user_id,
|
||||
session_id=session.id,
|
||||
|
||||
@@ -333,12 +333,10 @@ def get_fast_api_app(
|
||||
"/apps/{app_name}/users/{user_id}/sessions/{session_id}",
|
||||
response_model_exclude_none=True,
|
||||
)
|
||||
async def get_session(
|
||||
app_name: str, user_id: str, session_id: str
|
||||
) -> Session:
|
||||
def get_session(app_name: str, user_id: str, session_id: str) -> Session:
|
||||
# Connect to managed session if agent_engine_id is set.
|
||||
app_name = agent_engine_id if agent_engine_id else app_name
|
||||
session = await session_service.get_session(
|
||||
session = session_service.get_session(
|
||||
app_name=app_name, user_id=user_id, session_id=session_id
|
||||
)
|
||||
if not session:
|
||||
@@ -349,15 +347,14 @@ def get_fast_api_app(
|
||||
"/apps/{app_name}/users/{user_id}/sessions",
|
||||
response_model_exclude_none=True,
|
||||
)
|
||||
async def list_sessions(app_name: str, user_id: str) -> list[Session]:
|
||||
def list_sessions(app_name: str, user_id: str) -> list[Session]:
|
||||
# Connect to managed session if agent_engine_id is set.
|
||||
app_name = agent_engine_id if agent_engine_id else app_name
|
||||
list_sessions_response = await session_service.list_sessions(
|
||||
app_name=app_name, user_id=user_id
|
||||
)
|
||||
return [
|
||||
session
|
||||
for session in list_sessions_response.sessions
|
||||
for session in session_service.list_sessions(
|
||||
app_name=app_name, user_id=user_id
|
||||
).sessions
|
||||
# Remove sessions that were generated as a part of Eval.
|
||||
if not session.id.startswith(EVAL_SESSION_ID_PREFIX)
|
||||
]
|
||||
@@ -366,7 +363,7 @@ def get_fast_api_app(
|
||||
"/apps/{app_name}/users/{user_id}/sessions/{session_id}",
|
||||
response_model_exclude_none=True,
|
||||
)
|
||||
async def create_session_with_id(
|
||||
def create_session_with_id(
|
||||
app_name: str,
|
||||
user_id: str,
|
||||
session_id: str,
|
||||
@@ -375,7 +372,7 @@ def get_fast_api_app(
|
||||
# Connect to managed session if agent_engine_id is set.
|
||||
app_name = agent_engine_id if agent_engine_id else app_name
|
||||
if (
|
||||
await session_service.get_session(
|
||||
session_service.get_session(
|
||||
app_name=app_name, user_id=user_id, session_id=session_id
|
||||
)
|
||||
is not None
|
||||
@@ -385,7 +382,7 @@ def get_fast_api_app(
|
||||
status_code=400, detail=f"Session already exists: {session_id}"
|
||||
)
|
||||
logger.info("New session created: %s", session_id)
|
||||
return await session_service.create_session(
|
||||
return session_service.create_session(
|
||||
app_name=app_name, user_id=user_id, state=state, session_id=session_id
|
||||
)
|
||||
|
||||
@@ -393,7 +390,7 @@ def get_fast_api_app(
|
||||
"/apps/{app_name}/users/{user_id}/sessions",
|
||||
response_model_exclude_none=True,
|
||||
)
|
||||
async def create_session(
|
||||
def create_session(
|
||||
app_name: str,
|
||||
user_id: str,
|
||||
state: Optional[dict[str, Any]] = None,
|
||||
@@ -401,7 +398,7 @@ def get_fast_api_app(
|
||||
# Connect to managed session if agent_engine_id is set.
|
||||
app_name = agent_engine_id if agent_engine_id else app_name
|
||||
logger.info("New session created")
|
||||
return await session_service.create_session(
|
||||
return session_service.create_session(
|
||||
app_name=app_name, user_id=user_id, state=state
|
||||
)
|
||||
|
||||
@@ -445,7 +442,7 @@ def get_fast_api_app(
|
||||
app_name: str, eval_set_id: str, req: AddSessionToEvalSetRequest
|
||||
):
|
||||
# Get the session
|
||||
session = await session_service.get_session(
|
||||
session = session_service.get_session(
|
||||
app_name=app_name, user_id=req.user_id, session_id=req.session_id
|
||||
)
|
||||
assert session, "Session not found."
|
||||
@@ -533,7 +530,7 @@ def get_fast_api_app(
|
||||
session_id=eval_case_result.session_id,
|
||||
)
|
||||
)
|
||||
eval_case_result.session_details = await session_service.get_session(
|
||||
eval_case_result.session_details = session_service.get_session(
|
||||
app_name=app_name,
|
||||
user_id=eval_case_result.user_id,
|
||||
session_id=eval_case_result.session_id,
|
||||
@@ -618,10 +615,10 @@ def get_fast_api_app(
|
||||
return eval_result_files
|
||||
|
||||
@app.delete("/apps/{app_name}/users/{user_id}/sessions/{session_id}")
|
||||
async def delete_session(app_name: str, user_id: str, session_id: str):
|
||||
def delete_session(app_name: str, user_id: str, session_id: str):
|
||||
# Connect to managed session if agent_engine_id is set.
|
||||
app_name = agent_engine_id if agent_engine_id else app_name
|
||||
await session_service.delete_session(
|
||||
session_service.delete_session(
|
||||
app_name=app_name, user_id=user_id, session_id=session_id
|
||||
)
|
||||
|
||||
@@ -716,7 +713,7 @@ def get_fast_api_app(
|
||||
async def agent_run(req: AgentRunRequest) -> list[Event]:
|
||||
# Connect to managed session if agent_engine_id is set.
|
||||
app_id = agent_engine_id if agent_engine_id else req.app_name
|
||||
session = await session_service.get_session(
|
||||
session = session_service.get_session(
|
||||
app_name=app_id, user_id=req.user_id, session_id=req.session_id
|
||||
)
|
||||
if not session:
|
||||
@@ -738,7 +735,7 @@ def get_fast_api_app(
|
||||
# Connect to managed session if agent_engine_id is set.
|
||||
app_id = agent_engine_id if agent_engine_id else req.app_name
|
||||
# SSE endpoint
|
||||
session = await session_service.get_session(
|
||||
session = session_service.get_session(
|
||||
app_name=app_id, user_id=req.user_id, session_id=req.session_id
|
||||
)
|
||||
if not session:
|
||||
@@ -779,7 +776,7 @@ def get_fast_api_app(
|
||||
):
|
||||
# Connect to managed session if agent_engine_id is set.
|
||||
app_id = agent_engine_id if agent_engine_id else app_name
|
||||
session = await session_service.get_session(
|
||||
session = session_service.get_session(
|
||||
app_name=app_id, user_id=user_id, session_id=session_id
|
||||
)
|
||||
session_events = session.events if session else []
|
||||
@@ -836,7 +833,7 @@ def get_fast_api_app(
|
||||
|
||||
# Connect to managed session if agent_engine_id is set.
|
||||
app_id = agent_engine_id if agent_engine_id else app_name
|
||||
session = await session_service.get_session(
|
||||
session = session_service.get_session(
|
||||
app_name=app_id, user_id=user_id, session_id=session_id
|
||||
)
|
||||
if not session:
|
||||
|
||||
@@ -126,7 +126,7 @@ class EvaluationGenerator:
|
||||
user_id = initial_session.user_id if initial_session else "test_user_id"
|
||||
session_id = session_id if session_id else str(uuid.uuid4())
|
||||
|
||||
_ = await session_service.create_session(
|
||||
_ = session_service.create_session(
|
||||
app_name=app_name,
|
||||
user_id=user_id,
|
||||
state=initial_session.state if initial_session else {},
|
||||
|
||||
@@ -173,7 +173,7 @@ class Runner:
|
||||
The events generated by the agent.
|
||||
"""
|
||||
with tracer.start_as_current_span('invocation'):
|
||||
session = await self.session_service.get_session(
|
||||
session = self.session_service.get_session(
|
||||
app_name=self.app_name, user_id=user_id, session_id=session_id
|
||||
)
|
||||
if not session:
|
||||
@@ -197,7 +197,7 @@ class Runner:
|
||||
invocation_context.agent = self._find_agent_to_run(session, root_agent)
|
||||
async for event in invocation_context.agent.run_async(invocation_context):
|
||||
if not event.partial:
|
||||
await self.session_service.append_event(session=session, event=event)
|
||||
self.session_service.append_event(session=session, event=event)
|
||||
yield event
|
||||
|
||||
async def _append_new_message_to_session(
|
||||
@@ -242,7 +242,7 @@ class Runner:
|
||||
author='user',
|
||||
content=new_message,
|
||||
)
|
||||
await self.session_service.append_event(session=session, event=event)
|
||||
self.session_service.append_event(session=session, event=event)
|
||||
|
||||
async def run_live(
|
||||
self,
|
||||
@@ -324,7 +324,7 @@ class Runner:
|
||||
)
|
||||
|
||||
async for event in invocation_context.agent.run_live(invocation_context):
|
||||
await self.session_service.append_event(session=session, event=event)
|
||||
self.session_service.append_event(session=session, event=event)
|
||||
yield event
|
||||
|
||||
async def close_session(self, session: Session):
|
||||
@@ -335,7 +335,7 @@ class Runner:
|
||||
"""
|
||||
if self.memory_service:
|
||||
await self.memory_service.add_session_to_memory(session)
|
||||
await self.session_service.close_session(session=session)
|
||||
self.session_service.close_session(session=session)
|
||||
|
||||
def _find_agent_to_run(
|
||||
self, session: Session, root_agent: BaseAgent
|
||||
|
||||
@@ -47,7 +47,7 @@ class BaseSessionService(abc.ABC):
|
||||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
async def create_session(
|
||||
def create_session(
|
||||
self,
|
||||
*,
|
||||
app_name: str,
|
||||
@@ -67,9 +67,10 @@ class BaseSessionService(abc.ABC):
|
||||
Returns:
|
||||
session: The newly created session instance.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
async def get_session(
|
||||
def get_session(
|
||||
self,
|
||||
*,
|
||||
app_name: str,
|
||||
@@ -78,24 +79,28 @@ class BaseSessionService(abc.ABC):
|
||||
config: Optional[GetSessionConfig] = None,
|
||||
) -> Optional[Session]:
|
||||
"""Gets a session."""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
async def list_sessions(
|
||||
def list_sessions(
|
||||
self, *, app_name: str, user_id: str
|
||||
) -> ListSessionsResponse:
|
||||
"""Lists all the sessions."""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
async def delete_session(
|
||||
def delete_session(
|
||||
self, *, app_name: str, user_id: str, session_id: str
|
||||
) -> None:
|
||||
"""Deletes a session."""
|
||||
pass
|
||||
|
||||
async def close_session(self, *, session: Session):
|
||||
def close_session(self, *, session: Session):
|
||||
"""Closes a session."""
|
||||
# TODO: determine whether we want to finalize the session here.
|
||||
pass
|
||||
|
||||
async def append_event(self, session: Session, event: Event) -> Event:
|
||||
def append_event(self, session: Session, event: Event) -> Event:
|
||||
"""Appends an event to a session object."""
|
||||
if event.partial:
|
||||
return event
|
||||
|
||||
@@ -283,7 +283,7 @@ class DatabaseSessionService(BaseSessionService):
|
||||
Base.metadata.create_all(self.db_engine)
|
||||
|
||||
@override
|
||||
async def create_session(
|
||||
def create_session(
|
||||
self,
|
||||
*,
|
||||
app_name: str,
|
||||
@@ -357,7 +357,7 @@ class DatabaseSessionService(BaseSessionService):
|
||||
return session
|
||||
|
||||
@override
|
||||
async def get_session(
|
||||
def get_session(
|
||||
self,
|
||||
*,
|
||||
app_name: str,
|
||||
@@ -431,7 +431,7 @@ class DatabaseSessionService(BaseSessionService):
|
||||
return session
|
||||
|
||||
@override
|
||||
async def list_sessions(
|
||||
def list_sessions(
|
||||
self, *, app_name: str, user_id: str
|
||||
) -> ListSessionsResponse:
|
||||
with self.DatabaseSessionFactory() as sessionFactory:
|
||||
@@ -454,7 +454,7 @@ class DatabaseSessionService(BaseSessionService):
|
||||
return ListSessionsResponse(sessions=sessions)
|
||||
|
||||
@override
|
||||
async def delete_session(
|
||||
def delete_session(
|
||||
self, app_name: str, user_id: str, session_id: str
|
||||
) -> None:
|
||||
with self.DatabaseSessionFactory() as sessionFactory:
|
||||
@@ -467,7 +467,7 @@ class DatabaseSessionService(BaseSessionService):
|
||||
sessionFactory.commit()
|
||||
|
||||
@override
|
||||
async def append_event(self, session: Session, event: Event) -> Event:
|
||||
def append_event(self, session: Session, event: Event) -> Event:
|
||||
logger.info(f"Append event: {event} to session {session.id}")
|
||||
|
||||
if event.partial:
|
||||
@@ -552,10 +552,9 @@ class DatabaseSessionService(BaseSessionService):
|
||||
session.last_update_time = storage_session.update_time.timestamp()
|
||||
|
||||
# Also update the in-memory session
|
||||
await super().append_event(session=session, event=event)
|
||||
super().append_event(session=session, event=event)
|
||||
return event
|
||||
|
||||
|
||||
def convert_event(event: StorageEvent) -> Event:
|
||||
"""Converts a storage event to an event."""
|
||||
return Event(
|
||||
|
||||
@@ -44,7 +44,7 @@ class InMemorySessionService(BaseSessionService):
|
||||
self.app_state: dict[str, dict[str, Any]] = {}
|
||||
|
||||
@override
|
||||
async def create_session(
|
||||
def create_session(
|
||||
self,
|
||||
*,
|
||||
app_name: str,
|
||||
@@ -106,7 +106,7 @@ class InMemorySessionService(BaseSessionService):
|
||||
return self._merge_state(app_name, user_id, copied_session)
|
||||
|
||||
@override
|
||||
async def get_session(
|
||||
def get_session(
|
||||
self,
|
||||
*,
|
||||
app_name: str,
|
||||
@@ -193,7 +193,7 @@ class InMemorySessionService(BaseSessionService):
|
||||
return copied_session
|
||||
|
||||
@override
|
||||
async def list_sessions(
|
||||
def list_sessions(
|
||||
self, *, app_name: str, user_id: str
|
||||
) -> ListSessionsResponse:
|
||||
return self._list_sessions_impl(app_name=app_name, user_id=user_id)
|
||||
@@ -221,7 +221,7 @@ class InMemorySessionService(BaseSessionService):
|
||||
sessions_without_events.append(copied_session)
|
||||
return ListSessionsResponse(sessions=sessions_without_events)
|
||||
|
||||
async def delete_session(
|
||||
def delete_session(
|
||||
self, *, app_name: str, user_id: str, session_id: str
|
||||
) -> None:
|
||||
self._delete_session_impl(
|
||||
@@ -250,9 +250,16 @@ class InMemorySessionService(BaseSessionService):
|
||||
self.sessions[app_name][user_id].pop(session_id)
|
||||
|
||||
@override
|
||||
async def append_event(self, session: Session, event: Event) -> Event:
|
||||
def append_event(self, session: Session, event: Event) -> Event:
|
||||
return self._append_event_impl(session=session, event=event)
|
||||
|
||||
def append_event_sync(self, session: Session, event: Event) -> Event:
|
||||
logger.warning('Deprecated. Please migrate to the async method.')
|
||||
return self._append_event_impl(session=session, event=event)
|
||||
|
||||
def _append_event_impl(self, session: Session, event: Event) -> Event:
|
||||
# Update the in-memory session.
|
||||
await super().append_event(session=session, event=event)
|
||||
super().append_event(session=session, event=event)
|
||||
session.last_update_time = event.timestamp
|
||||
|
||||
# Update the storage session
|
||||
@@ -279,7 +286,7 @@ class InMemorySessionService(BaseSessionService):
|
||||
] = event.actions.state_delta[key]
|
||||
|
||||
storage_session = self.sessions[app_name][user_id].get(session_id)
|
||||
await super().append_event(session=storage_session, event=event)
|
||||
super().append_event(session=storage_session, event=event)
|
||||
|
||||
storage_session.last_update_time = event.timestamp
|
||||
|
||||
|
||||
@@ -48,7 +48,7 @@ class VertexAiSessionService(BaseSessionService):
|
||||
self.api_client = client._api_client
|
||||
|
||||
@override
|
||||
async def create_session(
|
||||
def create_session(
|
||||
self,
|
||||
*,
|
||||
app_name: str,
|
||||
@@ -68,7 +68,7 @@ class VertexAiSessionService(BaseSessionService):
|
||||
if state:
|
||||
session_json_dict['session_state'] = state
|
||||
|
||||
api_response = await self.api_client.async_request(
|
||||
api_response = self.api_client.request(
|
||||
http_method='POST',
|
||||
path=f'reasoningEngines/{reasoning_engine_id}/sessions',
|
||||
request_dict=session_json_dict,
|
||||
@@ -80,7 +80,7 @@ class VertexAiSessionService(BaseSessionService):
|
||||
|
||||
max_retry_attempt = 5
|
||||
while max_retry_attempt >= 0:
|
||||
lro_response = await self.api_client.async_request(
|
||||
lro_response = self.api_client.request(
|
||||
http_method='GET',
|
||||
path=f'operations/{operation_id}',
|
||||
request_dict={},
|
||||
@@ -93,7 +93,7 @@ class VertexAiSessionService(BaseSessionService):
|
||||
max_retry_attempt -= 1
|
||||
|
||||
# Get session resource
|
||||
get_session_api_response = await self.api_client.async_request(
|
||||
get_session_api_response = self.api_client.request(
|
||||
http_method='GET',
|
||||
path=f'reasoningEngines/{reasoning_engine_id}/sessions/{session_id}',
|
||||
request_dict={},
|
||||
@@ -112,7 +112,7 @@ class VertexAiSessionService(BaseSessionService):
|
||||
return session
|
||||
|
||||
@override
|
||||
async def get_session(
|
||||
def get_session(
|
||||
self,
|
||||
*,
|
||||
app_name: str,
|
||||
@@ -123,7 +123,7 @@ class VertexAiSessionService(BaseSessionService):
|
||||
reasoning_engine_id = _parse_reasoning_engine_id(app_name)
|
||||
|
||||
# Get session resource
|
||||
get_session_api_response = await self.api_client.async_request(
|
||||
get_session_api_response = self.api_client.request(
|
||||
http_method='GET',
|
||||
path=f'reasoningEngines/{reasoning_engine_id}/sessions/{session_id}',
|
||||
request_dict={},
|
||||
@@ -141,7 +141,7 @@ class VertexAiSessionService(BaseSessionService):
|
||||
last_update_time=update_timestamp,
|
||||
)
|
||||
|
||||
list_events_api_response = await self.api_client.async_request(
|
||||
list_events_api_response = self.api_client.request(
|
||||
http_method='GET',
|
||||
path=f'reasoningEngines/{reasoning_engine_id}/sessions/{session_id}/events',
|
||||
request_dict={},
|
||||
@@ -175,7 +175,7 @@ class VertexAiSessionService(BaseSessionService):
|
||||
return session
|
||||
|
||||
@override
|
||||
async def list_sessions(
|
||||
def list_sessions(
|
||||
self, *, app_name: str, user_id: str
|
||||
) -> ListSessionsResponse:
|
||||
reasoning_engine_id = _parse_reasoning_engine_id(app_name)
|
||||
@@ -202,23 +202,23 @@ class VertexAiSessionService(BaseSessionService):
|
||||
sessions.append(session)
|
||||
return ListSessionsResponse(sessions=sessions)
|
||||
|
||||
async def delete_session(
|
||||
def delete_session(
|
||||
self, *, app_name: str, user_id: str, session_id: str
|
||||
) -> None:
|
||||
reasoning_engine_id = _parse_reasoning_engine_id(app_name)
|
||||
await self.api_client.async_request(
|
||||
self.api_client.request(
|
||||
http_method='DELETE',
|
||||
path=f'reasoningEngines/{reasoning_engine_id}/sessions/{session_id}',
|
||||
request_dict={},
|
||||
)
|
||||
|
||||
@override
|
||||
async def append_event(self, session: Session, event: Event) -> Event:
|
||||
def append_event(self, session: Session, event: Event) -> Event:
|
||||
# Update the in-memory session.
|
||||
await super().append_event(session=session, event=event)
|
||||
super().append_event(session=session, event=event)
|
||||
|
||||
reasoning_engine_id = _parse_reasoning_engine_id(session.app_name)
|
||||
await self.api_client.async_request(
|
||||
self.api_client.request(
|
||||
http_method='POST',
|
||||
path=f'reasoningEngines/{reasoning_engine_id}/sessions/{session.id}:appendEvent',
|
||||
request_dict=_convert_event_to_json(event),
|
||||
|
||||
@@ -129,7 +129,7 @@ class AgentTool(BaseTool):
|
||||
session_service=InMemorySessionService(),
|
||||
memory_service=InMemoryMemoryService(),
|
||||
)
|
||||
session = await runner.session_service.create_session(
|
||||
session = runner.session_service.create_session(
|
||||
app_name=self.agent.name,
|
||||
user_id='tmp_user',
|
||||
state=tool_context.state.to_dict(),
|
||||
|
||||
Reference in New Issue
Block a user