fix eval in api server and cli for below issue:

https://github.com/google/adk-python/issues/651

PiperOrigin-RevId: 756906937
This commit is contained in:
Xiang (Sean) Zhou 2025-05-09 14:25:53 -07:00 committed by Copybara-Service
parent c4ecef54bc
commit 166df01236
2 changed files with 29 additions and 21 deletions

View File

@ -18,6 +18,8 @@ from datetime import datetime
import logging import logging
import os import os
import tempfile import tempfile
from typing import AsyncGenerator
from typing import Coroutine
from typing import Optional from typing import Optional
import click import click
@ -267,9 +269,16 @@ def cli_eval(
eval_set_to_evals = parse_and_get_evals_to_run(eval_set_file_path) eval_set_to_evals = parse_and_get_evals_to_run(eval_set_file_path)
async def _collect_async_gen(
async_gen_coroutine: Coroutine[
AsyncGenerator[EvalResult, None], None, None
],
) -> list[EvalResult]:
return [result async for result in async_gen_coroutine]
try: try:
eval_results = list( eval_results = asyncio.run(
asyncio.run( _collect_async_gen(
run_evals( run_evals(
eval_set_to_evals, eval_set_to_evals,
root_agent, root_agent,

View File

@ -24,7 +24,11 @@ import re
import sys import sys
import traceback import traceback
import typing import typing
from typing import Any, List, Literal, Optional, Union from typing import Any
from typing import List
from typing import Literal
from typing import Optional
from typing import Union
import click import click
from fastapi import FastAPI from fastapi import FastAPI
@ -52,7 +56,8 @@ from ..agents import RunConfig
from ..agents.base_agent import BaseAgent from ..agents.base_agent import BaseAgent
from ..agents.live_request_queue import LiveRequest from ..agents.live_request_queue import LiveRequest
from ..agents.live_request_queue import LiveRequestQueue from ..agents.live_request_queue import LiveRequestQueue
from ..agents.llm_agent import Agent, LlmAgent from ..agents.llm_agent import Agent
from ..agents.llm_agent import LlmAgent
from ..agents.run_config import StreamingMode from ..agents.run_config import StreamingMode
from ..artifacts import InMemoryArtifactService from ..artifacts import InMemoryArtifactService
from ..events.event import Event from ..events.event import Event
@ -467,8 +472,16 @@ def get_fast_api_app(
"Eval ids to run list is empty. We will all evals in the eval set." "Eval ids to run list is empty. We will all evals in the eval set."
) )
root_agent = await _get_root_agent_async(app_name) root_agent = await _get_root_agent_async(app_name)
eval_results = list( return [
await run_evals( RunEvalResult(
app_name=app_name,
eval_set_id=eval_set_id,
eval_id=eval_result.eval_id,
final_eval_status=eval_result.final_eval_status,
eval_metric_results=eval_result.eval_metric_results,
session_id=eval_result.session_id,
)
async for eval_result in run_evals(
eval_set_to_evals, eval_set_to_evals,
root_agent, root_agent,
getattr(root_agent, "reset_data", None), getattr(root_agent, "reset_data", None),
@ -476,21 +489,7 @@ def get_fast_api_app(
session_service=session_service, session_service=session_service,
artifact_service=artifact_service, artifact_service=artifact_service,
) )
) ]
run_eval_results = []
for eval_result in eval_results:
run_eval_results.append(
RunEvalResult(
app_name=app_name,
eval_set_id=eval_set_id,
eval_id=eval_result.eval_id,
final_eval_status=eval_result.final_eval_status,
eval_metric_results=eval_result.eval_metric_results,
session_id=eval_result.session_id,
)
)
return run_eval_results
@app.delete("/apps/{app_name}/users/{user_id}/sessions/{session_id}") @app.delete("/apps/{app_name}/users/{user_id}/sessions/{session_id}")
def delete_session(app_name: str, user_id: str, session_id: str): def delete_session(app_name: str, user_id: str, session_id: str):