From be0786ea880b48984ae9a601dfe6d40adadcf854 Mon Sep 17 00:00:00 2001 From: "Xiang (Sean) Zhou" Date: Mon, 26 May 2025 11:50:44 -0700 Subject: [PATCH] refactor: rename agent_dir to agents_dir and rename app_id to app_name in fast_api.py to make it consistent among every endpoints PiperOrigin-RevId: 763483339 --- src/google/adk/cli/cli_tools_click.py | 6 +-- src/google/adk/cli/fast_api.py | 40 +++++++++---------- .../local_eval_set_results_manager.py | 6 +-- .../adk/evaluation/local_eval_sets_manager.py | 8 ++-- tests/unittests/fast_api/test_fast_api.py | 2 +- tests/unittests/testing_utils.py | 2 +- 6 files changed, 32 insertions(+), 32 deletions(-) diff --git a/src/google/adk/cli/cli_tools_click.py b/src/google/adk/cli/cli_tools_click.py index e861356..e238bae 100644 --- a/src/google/adk/cli/cli_tools_click.py +++ b/src/google/adk/cli/cli_tools_click.py @@ -355,7 +355,7 @@ def cli_eval( # Write eval set results. local_eval_set_results_manager = LocalEvalSetResultsManager( - agent_dir=os.path.dirname(agent_module_file_path) + agents_dir=os.path.dirname(agent_module_file_path) ) eval_set_id_to_eval_results = collections.defaultdict(list) for eval_case_result in eval_results: @@ -500,7 +500,7 @@ def cli_web( ) app = get_fast_api_app( - agent_dir=agents_dir, + agents_dir=agents_dir, session_db_url=session_db_url, allow_origins=allow_origins, web=True, @@ -601,7 +601,7 @@ def cli_api_server( config = uvicorn.Config( get_fast_api_app( - agent_dir=agents_dir, + agents_dir=agents_dir, session_db_url=session_db_url, allow_origins=allow_origins, web=False, diff --git a/src/google/adk/cli/fast_api.py b/src/google/adk/cli/fast_api.py index b9ead36..a4a4b61 100644 --- a/src/google/adk/cli/fast_api.py +++ b/src/google/adk/cli/fast_api.py @@ -191,7 +191,7 @@ class GetEventGraphResult(common.BaseModel): def get_fast_api_app( *, - agent_dir: str, + agents_dir: str, session_db_url: str = "", allow_origins: Optional[list[str]] = None, web: bool, @@ -210,7 +210,7 @@ def get_fast_api_app( memory_exporter = InMemoryExporter(session_trace_dict) provider.add_span_processor(export.SimpleSpanProcessor(memory_exporter)) if trace_to_cloud: - envs.load_dotenv_for_agent("", agent_dir) + envs.load_dotenv_for_agent("", agents_dir) if project_id := os.environ.get("GOOGLE_CLOUD_PROJECT", None): processor = export.BatchSpanProcessor( CloudTraceSpanExporter(project_id=project_id) @@ -249,8 +249,8 @@ def get_fast_api_app( allow_headers=["*"], ) - if agent_dir not in sys.path: - sys.path.append(agent_dir) + if agents_dir not in sys.path: + sys.path.append(agents_dir) runner_dict = {} root_agent_dict = {} @@ -259,8 +259,8 @@ def get_fast_api_app( artifact_service = InMemoryArtifactService() memory_service = InMemoryMemoryService() - eval_sets_manager = LocalEvalSetsManager(agent_dir=agent_dir) - eval_set_results_manager = LocalEvalSetResultsManager(agent_dir=agent_dir) + eval_sets_manager = LocalEvalSetsManager(agents_dir=agents_dir) + eval_set_results_manager = LocalEvalSetResultsManager(agents_dir=agents_dir) # Build the Session service agent_engine_id = "" @@ -270,7 +270,7 @@ def get_fast_api_app( agent_engine_id = session_db_url.split("://")[1] if not agent_engine_id: raise click.ClickException("Agent engine id can not be empty.") - envs.load_dotenv_for_agent("", agent_dir) + envs.load_dotenv_for_agent("", agents_dir) session_service = VertexAiSessionService( os.environ["GOOGLE_CLOUD_PROJECT"], os.environ["GOOGLE_CLOUD_LOCATION"], @@ -282,7 +282,7 @@ def get_fast_api_app( @app.get("/list-apps") def list_apps() -> list[str]: - base_path = Path.cwd() / agent_dir + base_path = Path.cwd() / agents_dir if not base_path.exists(): raise HTTPException(status_code=404, detail="Path not found") if not base_path.is_dir(): @@ -398,9 +398,9 @@ def get_fast_api_app( app_name=app_name, user_id=user_id, state=state ) - def _get_eval_set_file_path(app_name, agent_dir, eval_set_id) -> str: + def _get_eval_set_file_path(app_name, agents_dir, eval_set_id) -> str: return os.path.join( - agent_dir, + agents_dir, app_name, eval_set_id + _EVAL_SET_FILE_EXTENSION, ) @@ -490,7 +490,7 @@ def get_fast_api_app( # Create a mapping from eval set file to all the evals that needed to be # run. - envs.load_dotenv_for_agent(os.path.basename(app_name), agent_dir) + envs.load_dotenv_for_agent(os.path.basename(app_name), agents_dir) eval_set = eval_sets_manager.get_eval_set(app_name, eval_set_id) @@ -663,9 +663,9 @@ def get_fast_api_app( @app.post("/run", response_model_exclude_none=True) async def agent_run(req: AgentRunRequest) -> list[Event]: # Connect to managed session if agent_engine_id is set. - app_id = agent_engine_id if agent_engine_id else req.app_name + app_name = agent_engine_id if agent_engine_id else req.app_name session = await session_service.get_session( - app_name=app_id, user_id=req.user_id, session_id=req.session_id + app_name=app_name, user_id=req.user_id, session_id=req.session_id ) if not session: raise HTTPException(status_code=404, detail="Session not found") @@ -684,10 +684,10 @@ def get_fast_api_app( @app.post("/run_sse") async def agent_run_sse(req: AgentRunRequest) -> StreamingResponse: # Connect to managed session if agent_engine_id is set. - app_id = agent_engine_id if agent_engine_id else req.app_name + app_name = agent_engine_id if agent_engine_id else req.app_name # SSE endpoint session = await session_service.get_session( - app_name=app_id, user_id=req.user_id, session_id=req.session_id + app_name=app_name, user_id=req.user_id, session_id=req.session_id ) if not session: raise HTTPException(status_code=404, detail="Session not found") @@ -726,9 +726,9 @@ def get_fast_api_app( app_name: str, user_id: str, session_id: str, event_id: str ): # Connect to managed session if agent_engine_id is set. - app_id = agent_engine_id if agent_engine_id else app_name + app_name = agent_engine_id if agent_engine_id else app_name session = await session_service.get_session( - app_name=app_id, user_id=user_id, session_id=session_id + app_name=app_name, user_id=user_id, session_id=session_id ) session_events = session.events if session else [] event = next((x for x in session_events if x.id == event_id), None) @@ -783,9 +783,9 @@ def get_fast_api_app( await websocket.accept() # Connect to managed session if agent_engine_id is set. - app_id = agent_engine_id if agent_engine_id else app_name + app_name = agent_engine_id if agent_engine_id else app_name session = await session_service.get_session( - app_name=app_id, user_id=user_id, session_id=session_id + app_name=app_name, user_id=user_id, session_id=session_id ) if not session: # Accept first so that the client is aware of connection establishment, @@ -855,7 +855,7 @@ def get_fast_api_app( async def _get_runner_async(app_name: str) -> Runner: """Returns the runner for the given app.""" - envs.load_dotenv_for_agent(os.path.basename(app_name), agent_dir) + envs.load_dotenv_for_agent(os.path.basename(app_name), agents_dir) if app_name in runner_dict: return runner_dict[app_name] root_agent = await _get_root_agent_async(app_name) diff --git a/src/google/adk/evaluation/local_eval_set_results_manager.py b/src/google/adk/evaluation/local_eval_set_results_manager.py index d1496cc..f18e984 100644 --- a/src/google/adk/evaluation/local_eval_set_results_manager.py +++ b/src/google/adk/evaluation/local_eval_set_results_manager.py @@ -36,8 +36,8 @@ def _sanitize_eval_set_result_name(eval_set_result_name: str) -> str: class LocalEvalSetResultsManager(EvalSetResultsManager): """An EvalSetResult manager that stores eval set results locally on disk.""" - def __init__(self, agent_dir: str): - self._agent_dir = agent_dir + def __init__(self, agents_dir: str): + self._agents_dir = agents_dir @override def save_eval_set_result( @@ -108,4 +108,4 @@ class LocalEvalSetResultsManager(EvalSetResultsManager): return eval_result_files def _get_eval_history_dir(self, app_name: str) -> str: - return os.path.join(self._agent_dir, app_name, _ADK_EVAL_HISTORY_DIR) + return os.path.join(self._agents_dir, app_name, _ADK_EVAL_HISTORY_DIR) diff --git a/src/google/adk/evaluation/local_eval_sets_manager.py b/src/google/adk/evaluation/local_eval_sets_manager.py index ad61cf2..7907499 100644 --- a/src/google/adk/evaluation/local_eval_sets_manager.py +++ b/src/google/adk/evaluation/local_eval_sets_manager.py @@ -182,8 +182,8 @@ def load_eval_set_from_file( class LocalEvalSetsManager(EvalSetsManager): """An EvalSets manager that stores eval sets locally on disk.""" - def __init__(self, agent_dir: str): - self._agent_dir = agent_dir + def __init__(self, agents_dir: str): + self._agents_dir = agents_dir @override def get_eval_set(self, app_name: str, eval_set_id: str) -> EvalSet: @@ -216,7 +216,7 @@ class LocalEvalSetsManager(EvalSetsManager): @override def list_eval_sets(self, app_name: str) -> list[str]: """Returns a list of EvalSets that belong to the given app_name.""" - eval_set_file_path = os.path.join(self._agent_dir, app_name) + eval_set_file_path = os.path.join(self._agents_dir, app_name) eval_sets = [] for file in os.listdir(eval_set_file_path): if file.endswith(_EVAL_SET_FILE_EXTENSION): @@ -247,7 +247,7 @@ class LocalEvalSetsManager(EvalSetsManager): def _get_eval_set_file_path(self, app_name: str, eval_set_id: str) -> str: return os.path.join( - self._agent_dir, + self._agents_dir, app_name, eval_set_id + _EVAL_SET_FILE_EXTENSION, ) diff --git a/tests/unittests/fast_api/test_fast_api.py b/tests/unittests/fast_api/test_fast_api.py index c206821..9729098 100644 --- a/tests/unittests/fast_api/test_fast_api.py +++ b/tests/unittests/fast_api/test_fast_api.py @@ -308,7 +308,7 @@ def test_app(mock_session_service, mock_artifact_service, mock_memory_service): ): # Get the FastAPI app, but don't actually run it app = get_fast_api_app( - agent_dir=".", web=True, session_db_url="", allow_origins=["*"] + agents_dir=".", web=True, session_db_url="", allow_origins=["*"] ) # Create a TestClient that doesn't start a real server diff --git a/tests/unittests/testing_utils.py b/tests/unittests/testing_utils.py index 7efc01a..1a8ed52 100644 --- a/tests/unittests/testing_utils.py +++ b/tests/unittests/testing_utils.py @@ -202,7 +202,7 @@ class InMemoryRunner: session_id=self.session.id, new_message=get_user_content(new_message), ): - events.append(event) + events.append(event) return events def run_live(self, live_request_queue: LiveRequestQueue) -> list[Event]: