refactor: rename agent_dir to agents_dir and rename app_id to app_name in fast_api.py to make it consistent among every endpoints

PiperOrigin-RevId: 763483339
This commit is contained in:
Xiang (Sean) Zhou 2025-05-26 11:50:44 -07:00 committed by Copybara-Service
parent 6b89ceb49a
commit be0786ea88
6 changed files with 32 additions and 32 deletions

View File

@ -355,7 +355,7 @@ def cli_eval(
# Write eval set results.
local_eval_set_results_manager = LocalEvalSetResultsManager(
agent_dir=os.path.dirname(agent_module_file_path)
agents_dir=os.path.dirname(agent_module_file_path)
)
eval_set_id_to_eval_results = collections.defaultdict(list)
for eval_case_result in eval_results:
@ -500,7 +500,7 @@ def cli_web(
)
app = get_fast_api_app(
agent_dir=agents_dir,
agents_dir=agents_dir,
session_db_url=session_db_url,
allow_origins=allow_origins,
web=True,
@ -601,7 +601,7 @@ def cli_api_server(
config = uvicorn.Config(
get_fast_api_app(
agent_dir=agents_dir,
agents_dir=agents_dir,
session_db_url=session_db_url,
allow_origins=allow_origins,
web=False,

View File

@ -191,7 +191,7 @@ class GetEventGraphResult(common.BaseModel):
def get_fast_api_app(
*,
agent_dir: str,
agents_dir: str,
session_db_url: str = "",
allow_origins: Optional[list[str]] = None,
web: bool,
@ -210,7 +210,7 @@ def get_fast_api_app(
memory_exporter = InMemoryExporter(session_trace_dict)
provider.add_span_processor(export.SimpleSpanProcessor(memory_exporter))
if trace_to_cloud:
envs.load_dotenv_for_agent("", agent_dir)
envs.load_dotenv_for_agent("", agents_dir)
if project_id := os.environ.get("GOOGLE_CLOUD_PROJECT", None):
processor = export.BatchSpanProcessor(
CloudTraceSpanExporter(project_id=project_id)
@ -249,8 +249,8 @@ def get_fast_api_app(
allow_headers=["*"],
)
if agent_dir not in sys.path:
sys.path.append(agent_dir)
if agents_dir not in sys.path:
sys.path.append(agents_dir)
runner_dict = {}
root_agent_dict = {}
@ -259,8 +259,8 @@ def get_fast_api_app(
artifact_service = InMemoryArtifactService()
memory_service = InMemoryMemoryService()
eval_sets_manager = LocalEvalSetsManager(agent_dir=agent_dir)
eval_set_results_manager = LocalEvalSetResultsManager(agent_dir=agent_dir)
eval_sets_manager = LocalEvalSetsManager(agents_dir=agents_dir)
eval_set_results_manager = LocalEvalSetResultsManager(agents_dir=agents_dir)
# Build the Session service
agent_engine_id = ""
@ -270,7 +270,7 @@ def get_fast_api_app(
agent_engine_id = session_db_url.split("://")[1]
if not agent_engine_id:
raise click.ClickException("Agent engine id can not be empty.")
envs.load_dotenv_for_agent("", agent_dir)
envs.load_dotenv_for_agent("", agents_dir)
session_service = VertexAiSessionService(
os.environ["GOOGLE_CLOUD_PROJECT"],
os.environ["GOOGLE_CLOUD_LOCATION"],
@ -282,7 +282,7 @@ def get_fast_api_app(
@app.get("/list-apps")
def list_apps() -> list[str]:
base_path = Path.cwd() / agent_dir
base_path = Path.cwd() / agents_dir
if not base_path.exists():
raise HTTPException(status_code=404, detail="Path not found")
if not base_path.is_dir():
@ -398,9 +398,9 @@ def get_fast_api_app(
app_name=app_name, user_id=user_id, state=state
)
def _get_eval_set_file_path(app_name, agent_dir, eval_set_id) -> str:
def _get_eval_set_file_path(app_name, agents_dir, eval_set_id) -> str:
return os.path.join(
agent_dir,
agents_dir,
app_name,
eval_set_id + _EVAL_SET_FILE_EXTENSION,
)
@ -490,7 +490,7 @@ def get_fast_api_app(
# Create a mapping from eval set file to all the evals that needed to be
# run.
envs.load_dotenv_for_agent(os.path.basename(app_name), agent_dir)
envs.load_dotenv_for_agent(os.path.basename(app_name), agents_dir)
eval_set = eval_sets_manager.get_eval_set(app_name, eval_set_id)
@ -663,9 +663,9 @@ def get_fast_api_app(
@app.post("/run", response_model_exclude_none=True)
async def agent_run(req: AgentRunRequest) -> list[Event]:
# Connect to managed session if agent_engine_id is set.
app_id = agent_engine_id if agent_engine_id else req.app_name
app_name = agent_engine_id if agent_engine_id else req.app_name
session = await session_service.get_session(
app_name=app_id, user_id=req.user_id, session_id=req.session_id
app_name=app_name, user_id=req.user_id, session_id=req.session_id
)
if not session:
raise HTTPException(status_code=404, detail="Session not found")
@ -684,10 +684,10 @@ def get_fast_api_app(
@app.post("/run_sse")
async def agent_run_sse(req: AgentRunRequest) -> StreamingResponse:
# Connect to managed session if agent_engine_id is set.
app_id = agent_engine_id if agent_engine_id else req.app_name
app_name = agent_engine_id if agent_engine_id else req.app_name
# SSE endpoint
session = await session_service.get_session(
app_name=app_id, user_id=req.user_id, session_id=req.session_id
app_name=app_name, user_id=req.user_id, session_id=req.session_id
)
if not session:
raise HTTPException(status_code=404, detail="Session not found")
@ -726,9 +726,9 @@ def get_fast_api_app(
app_name: str, user_id: str, session_id: str, event_id: str
):
# Connect to managed session if agent_engine_id is set.
app_id = agent_engine_id if agent_engine_id else app_name
app_name = agent_engine_id if agent_engine_id else app_name
session = await session_service.get_session(
app_name=app_id, user_id=user_id, session_id=session_id
app_name=app_name, user_id=user_id, session_id=session_id
)
session_events = session.events if session else []
event = next((x for x in session_events if x.id == event_id), None)
@ -783,9 +783,9 @@ def get_fast_api_app(
await websocket.accept()
# Connect to managed session if agent_engine_id is set.
app_id = agent_engine_id if agent_engine_id else app_name
app_name = agent_engine_id if agent_engine_id else app_name
session = await session_service.get_session(
app_name=app_id, user_id=user_id, session_id=session_id
app_name=app_name, user_id=user_id, session_id=session_id
)
if not session:
# Accept first so that the client is aware of connection establishment,
@ -855,7 +855,7 @@ def get_fast_api_app(
async def _get_runner_async(app_name: str) -> Runner:
"""Returns the runner for the given app."""
envs.load_dotenv_for_agent(os.path.basename(app_name), agent_dir)
envs.load_dotenv_for_agent(os.path.basename(app_name), agents_dir)
if app_name in runner_dict:
return runner_dict[app_name]
root_agent = await _get_root_agent_async(app_name)

View File

@ -36,8 +36,8 @@ def _sanitize_eval_set_result_name(eval_set_result_name: str) -> str:
class LocalEvalSetResultsManager(EvalSetResultsManager):
"""An EvalSetResult manager that stores eval set results locally on disk."""
def __init__(self, agent_dir: str):
self._agent_dir = agent_dir
def __init__(self, agents_dir: str):
self._agents_dir = agents_dir
@override
def save_eval_set_result(
@ -108,4 +108,4 @@ class LocalEvalSetResultsManager(EvalSetResultsManager):
return eval_result_files
def _get_eval_history_dir(self, app_name: str) -> str:
return os.path.join(self._agent_dir, app_name, _ADK_EVAL_HISTORY_DIR)
return os.path.join(self._agents_dir, app_name, _ADK_EVAL_HISTORY_DIR)

View File

@ -182,8 +182,8 @@ def load_eval_set_from_file(
class LocalEvalSetsManager(EvalSetsManager):
"""An EvalSets manager that stores eval sets locally on disk."""
def __init__(self, agent_dir: str):
self._agent_dir = agent_dir
def __init__(self, agents_dir: str):
self._agents_dir = agents_dir
@override
def get_eval_set(self, app_name: str, eval_set_id: str) -> EvalSet:
@ -216,7 +216,7 @@ class LocalEvalSetsManager(EvalSetsManager):
@override
def list_eval_sets(self, app_name: str) -> list[str]:
"""Returns a list of EvalSets that belong to the given app_name."""
eval_set_file_path = os.path.join(self._agent_dir, app_name)
eval_set_file_path = os.path.join(self._agents_dir, app_name)
eval_sets = []
for file in os.listdir(eval_set_file_path):
if file.endswith(_EVAL_SET_FILE_EXTENSION):
@ -247,7 +247,7 @@ class LocalEvalSetsManager(EvalSetsManager):
def _get_eval_set_file_path(self, app_name: str, eval_set_id: str) -> str:
return os.path.join(
self._agent_dir,
self._agents_dir,
app_name,
eval_set_id + _EVAL_SET_FILE_EXTENSION,
)

View File

@ -308,7 +308,7 @@ def test_app(mock_session_service, mock_artifact_service, mock_memory_service):
):
# Get the FastAPI app, but don't actually run it
app = get_fast_api_app(
agent_dir=".", web=True, session_db_url="", allow_origins=["*"]
agents_dir=".", web=True, session_db_url="", allow_origins=["*"]
)
# Create a TestClient that doesn't start a real server

View File

@ -202,7 +202,7 @@ class InMemoryRunner:
session_id=self.session.id,
new_message=get_user_content(new_message),
):
events.append(event)
events.append(event)
return events
def run_live(self, live_request_queue: LiveRequestQueue) -> list[Event]: