ADK changes

PiperOrigin-RevId: 759259620
This commit is contained in:
Google Team Member
2025-05-15 12:46:12 -07:00
committed by Copybara-Service
parent 1804ca39a6
commit 05917cabbd
23 changed files with 264 additions and 268 deletions
+4 -6
View File
@@ -55,7 +55,7 @@ async def run_input_file(
input_file = InputFile.model_validate_json(f.read())
input_file.state['_time'] = datetime.now()
session = session_service.create_session(
session = await session_service.create_session(
app_name=app_name, user_id=user_id, state=input_file.state
)
for query in input_file.queries:
@@ -130,7 +130,7 @@ async def run_cli(
agent_module_path = os.path.join(agent_parent_dir, agent_folder_name)
agent_module = importlib.import_module(agent_folder_name)
user_id = 'test_user'
session = session_service.create_session(
session = await session_service.create_session(
app_name=agent_folder_name, user_id=user_id
)
root_agent = agent_module.agent.root_agent
@@ -145,14 +145,12 @@ async def run_cli(
input_path=input_file,
)
elif saved_session_file:
loaded_session = None
with open(saved_session_file, 'r') as f:
loaded_session = Session.model_validate_json(f.read())
if loaded_session:
for event in loaded_session.events:
session_service.append_event(session, event)
await session_service.append_event(session, event)
content = event.content
if not content or not content.parts or not content.parts[0].text:
continue
@@ -181,7 +179,7 @@ async def run_cli(
session_path = f'{agent_module_path}/{session_id}.session.json'
# Fetch the session again to get all the details.
session = session_service.get_session(
session = await session_service.get_session(
app_name=session.app_name,
user_id=session.user_id,
session_id=session.id,
+22 -19
View File
@@ -333,10 +333,12 @@ def get_fast_api_app(
"/apps/{app_name}/users/{user_id}/sessions/{session_id}",
response_model_exclude_none=True,
)
def get_session(app_name: str, user_id: str, session_id: str) -> Session:
async def get_session(
app_name: str, user_id: str, session_id: str
) -> Session:
# Connect to managed session if agent_engine_id is set.
app_name = agent_engine_id if agent_engine_id else app_name
session = session_service.get_session(
session = await session_service.get_session(
app_name=app_name, user_id=user_id, session_id=session_id
)
if not session:
@@ -347,14 +349,15 @@ def get_fast_api_app(
"/apps/{app_name}/users/{user_id}/sessions",
response_model_exclude_none=True,
)
def list_sessions(app_name: str, user_id: str) -> list[Session]:
async def list_sessions(app_name: str, user_id: str) -> list[Session]:
# Connect to managed session if agent_engine_id is set.
app_name = agent_engine_id if agent_engine_id else app_name
list_sessions_response = await session_service.list_sessions(
app_name=app_name, user_id=user_id
)
return [
session
for session in session_service.list_sessions(
app_name=app_name, user_id=user_id
).sessions
for session in list_sessions_response.sessions
# Remove sessions that were generated as a part of Eval.
if not session.id.startswith(EVAL_SESSION_ID_PREFIX)
]
@@ -363,7 +366,7 @@ def get_fast_api_app(
"/apps/{app_name}/users/{user_id}/sessions/{session_id}",
response_model_exclude_none=True,
)
def create_session_with_id(
async def create_session_with_id(
app_name: str,
user_id: str,
session_id: str,
@@ -372,7 +375,7 @@ def get_fast_api_app(
# Connect to managed session if agent_engine_id is set.
app_name = agent_engine_id if agent_engine_id else app_name
if (
session_service.get_session(
await session_service.get_session(
app_name=app_name, user_id=user_id, session_id=session_id
)
is not None
@@ -382,7 +385,7 @@ def get_fast_api_app(
status_code=400, detail=f"Session already exists: {session_id}"
)
logger.info("New session created: %s", session_id)
return session_service.create_session(
return await session_service.create_session(
app_name=app_name, user_id=user_id, state=state, session_id=session_id
)
@@ -390,7 +393,7 @@ def get_fast_api_app(
"/apps/{app_name}/users/{user_id}/sessions",
response_model_exclude_none=True,
)
def create_session(
async def create_session(
app_name: str,
user_id: str,
state: Optional[dict[str, Any]] = None,
@@ -398,7 +401,7 @@ def get_fast_api_app(
# Connect to managed session if agent_engine_id is set.
app_name = agent_engine_id if agent_engine_id else app_name
logger.info("New session created")
return session_service.create_session(
return await session_service.create_session(
app_name=app_name, user_id=user_id, state=state
)
@@ -442,7 +445,7 @@ def get_fast_api_app(
app_name: str, eval_set_id: str, req: AddSessionToEvalSetRequest
):
# Get the session
session = session_service.get_session(
session = await session_service.get_session(
app_name=app_name, user_id=req.user_id, session_id=req.session_id
)
assert session, "Session not found."
@@ -530,7 +533,7 @@ def get_fast_api_app(
session_id=eval_case_result.session_id,
)
)
eval_case_result.session_details = session_service.get_session(
eval_case_result.session_details = await session_service.get_session(
app_name=app_name,
user_id=eval_case_result.user_id,
session_id=eval_case_result.session_id,
@@ -615,10 +618,10 @@ def get_fast_api_app(
return eval_result_files
@app.delete("/apps/{app_name}/users/{user_id}/sessions/{session_id}")
def delete_session(app_name: str, user_id: str, session_id: str):
async def delete_session(app_name: str, user_id: str, session_id: str):
# Connect to managed session if agent_engine_id is set.
app_name = agent_engine_id if agent_engine_id else app_name
session_service.delete_session(
await session_service.delete_session(
app_name=app_name, user_id=user_id, session_id=session_id
)
@@ -713,7 +716,7 @@ def get_fast_api_app(
async def agent_run(req: AgentRunRequest) -> list[Event]:
# Connect to managed session if agent_engine_id is set.
app_id = agent_engine_id if agent_engine_id else req.app_name
session = session_service.get_session(
session = await session_service.get_session(
app_name=app_id, user_id=req.user_id, session_id=req.session_id
)
if not session:
@@ -735,7 +738,7 @@ def get_fast_api_app(
# Connect to managed session if agent_engine_id is set.
app_id = agent_engine_id if agent_engine_id else req.app_name
# SSE endpoint
session = session_service.get_session(
session = await session_service.get_session(
app_name=app_id, user_id=req.user_id, session_id=req.session_id
)
if not session:
@@ -776,7 +779,7 @@ def get_fast_api_app(
):
# Connect to managed session if agent_engine_id is set.
app_id = agent_engine_id if agent_engine_id else app_name
session = session_service.get_session(
session = await session_service.get_session(
app_name=app_id, user_id=user_id, session_id=session_id
)
session_events = session.events if session else []
@@ -833,7 +836,7 @@ def get_fast_api_app(
# Connect to managed session if agent_engine_id is set.
app_id = agent_engine_id if agent_engine_id else app_name
session = session_service.get_session(
session = await session_service.get_session(
app_name=app_id, user_id=user_id, session_id=session_id
)
if not session: