refactor: rename agent_dir to agents_dir and rename app_id to app_name in fast_api.py to make it consistent among every endpoints

PiperOrigin-RevId: 763483339
This commit is contained in:
Xiang (Sean) Zhou 2025-05-26 11:50:44 -07:00 committed by Copybara-Service
parent 6b89ceb49a
commit be0786ea88
6 changed files with 32 additions and 32 deletions

View File

@ -355,7 +355,7 @@ def cli_eval(
# Write eval set results. # Write eval set results.
local_eval_set_results_manager = LocalEvalSetResultsManager( local_eval_set_results_manager = LocalEvalSetResultsManager(
agent_dir=os.path.dirname(agent_module_file_path) agents_dir=os.path.dirname(agent_module_file_path)
) )
eval_set_id_to_eval_results = collections.defaultdict(list) eval_set_id_to_eval_results = collections.defaultdict(list)
for eval_case_result in eval_results: for eval_case_result in eval_results:
@ -500,7 +500,7 @@ def cli_web(
) )
app = get_fast_api_app( app = get_fast_api_app(
agent_dir=agents_dir, agents_dir=agents_dir,
session_db_url=session_db_url, session_db_url=session_db_url,
allow_origins=allow_origins, allow_origins=allow_origins,
web=True, web=True,
@ -601,7 +601,7 @@ def cli_api_server(
config = uvicorn.Config( config = uvicorn.Config(
get_fast_api_app( get_fast_api_app(
agent_dir=agents_dir, agents_dir=agents_dir,
session_db_url=session_db_url, session_db_url=session_db_url,
allow_origins=allow_origins, allow_origins=allow_origins,
web=False, web=False,

View File

@ -191,7 +191,7 @@ class GetEventGraphResult(common.BaseModel):
def get_fast_api_app( def get_fast_api_app(
*, *,
agent_dir: str, agents_dir: str,
session_db_url: str = "", session_db_url: str = "",
allow_origins: Optional[list[str]] = None, allow_origins: Optional[list[str]] = None,
web: bool, web: bool,
@ -210,7 +210,7 @@ def get_fast_api_app(
memory_exporter = InMemoryExporter(session_trace_dict) memory_exporter = InMemoryExporter(session_trace_dict)
provider.add_span_processor(export.SimpleSpanProcessor(memory_exporter)) provider.add_span_processor(export.SimpleSpanProcessor(memory_exporter))
if trace_to_cloud: if trace_to_cloud:
envs.load_dotenv_for_agent("", agent_dir) envs.load_dotenv_for_agent("", agents_dir)
if project_id := os.environ.get("GOOGLE_CLOUD_PROJECT", None): if project_id := os.environ.get("GOOGLE_CLOUD_PROJECT", None):
processor = export.BatchSpanProcessor( processor = export.BatchSpanProcessor(
CloudTraceSpanExporter(project_id=project_id) CloudTraceSpanExporter(project_id=project_id)
@ -249,8 +249,8 @@ def get_fast_api_app(
allow_headers=["*"], allow_headers=["*"],
) )
if agent_dir not in sys.path: if agents_dir not in sys.path:
sys.path.append(agent_dir) sys.path.append(agents_dir)
runner_dict = {} runner_dict = {}
root_agent_dict = {} root_agent_dict = {}
@ -259,8 +259,8 @@ def get_fast_api_app(
artifact_service = InMemoryArtifactService() artifact_service = InMemoryArtifactService()
memory_service = InMemoryMemoryService() memory_service = InMemoryMemoryService()
eval_sets_manager = LocalEvalSetsManager(agent_dir=agent_dir) eval_sets_manager = LocalEvalSetsManager(agents_dir=agents_dir)
eval_set_results_manager = LocalEvalSetResultsManager(agent_dir=agent_dir) eval_set_results_manager = LocalEvalSetResultsManager(agents_dir=agents_dir)
# Build the Session service # Build the Session service
agent_engine_id = "" agent_engine_id = ""
@ -270,7 +270,7 @@ def get_fast_api_app(
agent_engine_id = session_db_url.split("://")[1] agent_engine_id = session_db_url.split("://")[1]
if not agent_engine_id: if not agent_engine_id:
raise click.ClickException("Agent engine id can not be empty.") raise click.ClickException("Agent engine id can not be empty.")
envs.load_dotenv_for_agent("", agent_dir) envs.load_dotenv_for_agent("", agents_dir)
session_service = VertexAiSessionService( session_service = VertexAiSessionService(
os.environ["GOOGLE_CLOUD_PROJECT"], os.environ["GOOGLE_CLOUD_PROJECT"],
os.environ["GOOGLE_CLOUD_LOCATION"], os.environ["GOOGLE_CLOUD_LOCATION"],
@ -282,7 +282,7 @@ def get_fast_api_app(
@app.get("/list-apps") @app.get("/list-apps")
def list_apps() -> list[str]: def list_apps() -> list[str]:
base_path = Path.cwd() / agent_dir base_path = Path.cwd() / agents_dir
if not base_path.exists(): if not base_path.exists():
raise HTTPException(status_code=404, detail="Path not found") raise HTTPException(status_code=404, detail="Path not found")
if not base_path.is_dir(): if not base_path.is_dir():
@ -398,9 +398,9 @@ def get_fast_api_app(
app_name=app_name, user_id=user_id, state=state app_name=app_name, user_id=user_id, state=state
) )
def _get_eval_set_file_path(app_name, agent_dir, eval_set_id) -> str: def _get_eval_set_file_path(app_name, agents_dir, eval_set_id) -> str:
return os.path.join( return os.path.join(
agent_dir, agents_dir,
app_name, app_name,
eval_set_id + _EVAL_SET_FILE_EXTENSION, eval_set_id + _EVAL_SET_FILE_EXTENSION,
) )
@ -490,7 +490,7 @@ def get_fast_api_app(
# Create a mapping from eval set file to all the evals that needed to be # Create a mapping from eval set file to all the evals that needed to be
# run. # run.
envs.load_dotenv_for_agent(os.path.basename(app_name), agent_dir) envs.load_dotenv_for_agent(os.path.basename(app_name), agents_dir)
eval_set = eval_sets_manager.get_eval_set(app_name, eval_set_id) eval_set = eval_sets_manager.get_eval_set(app_name, eval_set_id)
@ -663,9 +663,9 @@ def get_fast_api_app(
@app.post("/run", response_model_exclude_none=True) @app.post("/run", response_model_exclude_none=True)
async def agent_run(req: AgentRunRequest) -> list[Event]: async def agent_run(req: AgentRunRequest) -> list[Event]:
# Connect to managed session if agent_engine_id is set. # Connect to managed session if agent_engine_id is set.
app_id = agent_engine_id if agent_engine_id else req.app_name app_name = agent_engine_id if agent_engine_id else req.app_name
session = await session_service.get_session( session = await session_service.get_session(
app_name=app_id, user_id=req.user_id, session_id=req.session_id app_name=app_name, user_id=req.user_id, session_id=req.session_id
) )
if not session: if not session:
raise HTTPException(status_code=404, detail="Session not found") raise HTTPException(status_code=404, detail="Session not found")
@ -684,10 +684,10 @@ def get_fast_api_app(
@app.post("/run_sse") @app.post("/run_sse")
async def agent_run_sse(req: AgentRunRequest) -> StreamingResponse: async def agent_run_sse(req: AgentRunRequest) -> StreamingResponse:
# Connect to managed session if agent_engine_id is set. # Connect to managed session if agent_engine_id is set.
app_id = agent_engine_id if agent_engine_id else req.app_name app_name = agent_engine_id if agent_engine_id else req.app_name
# SSE endpoint # SSE endpoint
session = await session_service.get_session( session = await session_service.get_session(
app_name=app_id, user_id=req.user_id, session_id=req.session_id app_name=app_name, user_id=req.user_id, session_id=req.session_id
) )
if not session: if not session:
raise HTTPException(status_code=404, detail="Session not found") raise HTTPException(status_code=404, detail="Session not found")
@ -726,9 +726,9 @@ def get_fast_api_app(
app_name: str, user_id: str, session_id: str, event_id: str app_name: str, user_id: str, session_id: str, event_id: str
): ):
# Connect to managed session if agent_engine_id is set. # Connect to managed session if agent_engine_id is set.
app_id = agent_engine_id if agent_engine_id else app_name app_name = agent_engine_id if agent_engine_id else app_name
session = await session_service.get_session( session = await session_service.get_session(
app_name=app_id, user_id=user_id, session_id=session_id app_name=app_name, user_id=user_id, session_id=session_id
) )
session_events = session.events if session else [] session_events = session.events if session else []
event = next((x for x in session_events if x.id == event_id), None) event = next((x for x in session_events if x.id == event_id), None)
@ -783,9 +783,9 @@ def get_fast_api_app(
await websocket.accept() await websocket.accept()
# Connect to managed session if agent_engine_id is set. # Connect to managed session if agent_engine_id is set.
app_id = agent_engine_id if agent_engine_id else app_name app_name = agent_engine_id if agent_engine_id else app_name
session = await session_service.get_session( session = await session_service.get_session(
app_name=app_id, user_id=user_id, session_id=session_id app_name=app_name, user_id=user_id, session_id=session_id
) )
if not session: if not session:
# Accept first so that the client is aware of connection establishment, # Accept first so that the client is aware of connection establishment,
@ -855,7 +855,7 @@ def get_fast_api_app(
async def _get_runner_async(app_name: str) -> Runner: async def _get_runner_async(app_name: str) -> Runner:
"""Returns the runner for the given app.""" """Returns the runner for the given app."""
envs.load_dotenv_for_agent(os.path.basename(app_name), agent_dir) envs.load_dotenv_for_agent(os.path.basename(app_name), agents_dir)
if app_name in runner_dict: if app_name in runner_dict:
return runner_dict[app_name] return runner_dict[app_name]
root_agent = await _get_root_agent_async(app_name) root_agent = await _get_root_agent_async(app_name)

View File

@ -36,8 +36,8 @@ def _sanitize_eval_set_result_name(eval_set_result_name: str) -> str:
class LocalEvalSetResultsManager(EvalSetResultsManager): class LocalEvalSetResultsManager(EvalSetResultsManager):
"""An EvalSetResult manager that stores eval set results locally on disk.""" """An EvalSetResult manager that stores eval set results locally on disk."""
def __init__(self, agent_dir: str): def __init__(self, agents_dir: str):
self._agent_dir = agent_dir self._agents_dir = agents_dir
@override @override
def save_eval_set_result( def save_eval_set_result(
@ -108,4 +108,4 @@ class LocalEvalSetResultsManager(EvalSetResultsManager):
return eval_result_files return eval_result_files
def _get_eval_history_dir(self, app_name: str) -> str: def _get_eval_history_dir(self, app_name: str) -> str:
return os.path.join(self._agent_dir, app_name, _ADK_EVAL_HISTORY_DIR) return os.path.join(self._agents_dir, app_name, _ADK_EVAL_HISTORY_DIR)

View File

@ -182,8 +182,8 @@ def load_eval_set_from_file(
class LocalEvalSetsManager(EvalSetsManager): class LocalEvalSetsManager(EvalSetsManager):
"""An EvalSets manager that stores eval sets locally on disk.""" """An EvalSets manager that stores eval sets locally on disk."""
def __init__(self, agent_dir: str): def __init__(self, agents_dir: str):
self._agent_dir = agent_dir self._agents_dir = agents_dir
@override @override
def get_eval_set(self, app_name: str, eval_set_id: str) -> EvalSet: def get_eval_set(self, app_name: str, eval_set_id: str) -> EvalSet:
@ -216,7 +216,7 @@ class LocalEvalSetsManager(EvalSetsManager):
@override @override
def list_eval_sets(self, app_name: str) -> list[str]: def list_eval_sets(self, app_name: str) -> list[str]:
"""Returns a list of EvalSets that belong to the given app_name.""" """Returns a list of EvalSets that belong to the given app_name."""
eval_set_file_path = os.path.join(self._agent_dir, app_name) eval_set_file_path = os.path.join(self._agents_dir, app_name)
eval_sets = [] eval_sets = []
for file in os.listdir(eval_set_file_path): for file in os.listdir(eval_set_file_path):
if file.endswith(_EVAL_SET_FILE_EXTENSION): if file.endswith(_EVAL_SET_FILE_EXTENSION):
@ -247,7 +247,7 @@ class LocalEvalSetsManager(EvalSetsManager):
def _get_eval_set_file_path(self, app_name: str, eval_set_id: str) -> str: def _get_eval_set_file_path(self, app_name: str, eval_set_id: str) -> str:
return os.path.join( return os.path.join(
self._agent_dir, self._agents_dir,
app_name, app_name,
eval_set_id + _EVAL_SET_FILE_EXTENSION, eval_set_id + _EVAL_SET_FILE_EXTENSION,
) )

View File

@ -308,7 +308,7 @@ def test_app(mock_session_service, mock_artifact_service, mock_memory_service):
): ):
# Get the FastAPI app, but don't actually run it # Get the FastAPI app, but don't actually run it
app = get_fast_api_app( app = get_fast_api_app(
agent_dir=".", web=True, session_db_url="", allow_origins=["*"] agents_dir=".", web=True, session_db_url="", allow_origins=["*"]
) )
# Create a TestClient that doesn't start a real server # Create a TestClient that doesn't start a real server

View File

@ -202,7 +202,7 @@ class InMemoryRunner:
session_id=self.session.id, session_id=self.session.id,
new_message=get_user_content(new_message), new_message=get_user_content(new_message),
): ):
events.append(event) events.append(event)
return events return events
def run_live(self, live_request_queue: LiveRequestQueue) -> list[Event]: def run_live(self, live_request_queue: LiveRequestQueue) -> list[Event]: