mirror of
https://github.com/EvolutionAPI/adk-python.git
synced 2025-07-14 01:41:25 -06:00
refactor: rename agent_dir to agents_dir and rename app_id to app_name in fast_api.py to make it consistent among every endpoints
PiperOrigin-RevId: 763483339
This commit is contained in:
parent
6b89ceb49a
commit
be0786ea88
@ -355,7 +355,7 @@ def cli_eval(
|
||||
|
||||
# Write eval set results.
|
||||
local_eval_set_results_manager = LocalEvalSetResultsManager(
|
||||
agent_dir=os.path.dirname(agent_module_file_path)
|
||||
agents_dir=os.path.dirname(agent_module_file_path)
|
||||
)
|
||||
eval_set_id_to_eval_results = collections.defaultdict(list)
|
||||
for eval_case_result in eval_results:
|
||||
@ -500,7 +500,7 @@ def cli_web(
|
||||
)
|
||||
|
||||
app = get_fast_api_app(
|
||||
agent_dir=agents_dir,
|
||||
agents_dir=agents_dir,
|
||||
session_db_url=session_db_url,
|
||||
allow_origins=allow_origins,
|
||||
web=True,
|
||||
@ -601,7 +601,7 @@ def cli_api_server(
|
||||
|
||||
config = uvicorn.Config(
|
||||
get_fast_api_app(
|
||||
agent_dir=agents_dir,
|
||||
agents_dir=agents_dir,
|
||||
session_db_url=session_db_url,
|
||||
allow_origins=allow_origins,
|
||||
web=False,
|
||||
|
@ -191,7 +191,7 @@ class GetEventGraphResult(common.BaseModel):
|
||||
|
||||
def get_fast_api_app(
|
||||
*,
|
||||
agent_dir: str,
|
||||
agents_dir: str,
|
||||
session_db_url: str = "",
|
||||
allow_origins: Optional[list[str]] = None,
|
||||
web: bool,
|
||||
@ -210,7 +210,7 @@ def get_fast_api_app(
|
||||
memory_exporter = InMemoryExporter(session_trace_dict)
|
||||
provider.add_span_processor(export.SimpleSpanProcessor(memory_exporter))
|
||||
if trace_to_cloud:
|
||||
envs.load_dotenv_for_agent("", agent_dir)
|
||||
envs.load_dotenv_for_agent("", agents_dir)
|
||||
if project_id := os.environ.get("GOOGLE_CLOUD_PROJECT", None):
|
||||
processor = export.BatchSpanProcessor(
|
||||
CloudTraceSpanExporter(project_id=project_id)
|
||||
@ -249,8 +249,8 @@ def get_fast_api_app(
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
if agent_dir not in sys.path:
|
||||
sys.path.append(agent_dir)
|
||||
if agents_dir not in sys.path:
|
||||
sys.path.append(agents_dir)
|
||||
|
||||
runner_dict = {}
|
||||
root_agent_dict = {}
|
||||
@ -259,8 +259,8 @@ def get_fast_api_app(
|
||||
artifact_service = InMemoryArtifactService()
|
||||
memory_service = InMemoryMemoryService()
|
||||
|
||||
eval_sets_manager = LocalEvalSetsManager(agent_dir=agent_dir)
|
||||
eval_set_results_manager = LocalEvalSetResultsManager(agent_dir=agent_dir)
|
||||
eval_sets_manager = LocalEvalSetsManager(agents_dir=agents_dir)
|
||||
eval_set_results_manager = LocalEvalSetResultsManager(agents_dir=agents_dir)
|
||||
|
||||
# Build the Session service
|
||||
agent_engine_id = ""
|
||||
@ -270,7 +270,7 @@ def get_fast_api_app(
|
||||
agent_engine_id = session_db_url.split("://")[1]
|
||||
if not agent_engine_id:
|
||||
raise click.ClickException("Agent engine id can not be empty.")
|
||||
envs.load_dotenv_for_agent("", agent_dir)
|
||||
envs.load_dotenv_for_agent("", agents_dir)
|
||||
session_service = VertexAiSessionService(
|
||||
os.environ["GOOGLE_CLOUD_PROJECT"],
|
||||
os.environ["GOOGLE_CLOUD_LOCATION"],
|
||||
@ -282,7 +282,7 @@ def get_fast_api_app(
|
||||
|
||||
@app.get("/list-apps")
|
||||
def list_apps() -> list[str]:
|
||||
base_path = Path.cwd() / agent_dir
|
||||
base_path = Path.cwd() / agents_dir
|
||||
if not base_path.exists():
|
||||
raise HTTPException(status_code=404, detail="Path not found")
|
||||
if not base_path.is_dir():
|
||||
@ -398,9 +398,9 @@ def get_fast_api_app(
|
||||
app_name=app_name, user_id=user_id, state=state
|
||||
)
|
||||
|
||||
def _get_eval_set_file_path(app_name, agent_dir, eval_set_id) -> str:
|
||||
def _get_eval_set_file_path(app_name, agents_dir, eval_set_id) -> str:
|
||||
return os.path.join(
|
||||
agent_dir,
|
||||
agents_dir,
|
||||
app_name,
|
||||
eval_set_id + _EVAL_SET_FILE_EXTENSION,
|
||||
)
|
||||
@ -490,7 +490,7 @@ def get_fast_api_app(
|
||||
|
||||
# Create a mapping from eval set file to all the evals that needed to be
|
||||
# run.
|
||||
envs.load_dotenv_for_agent(os.path.basename(app_name), agent_dir)
|
||||
envs.load_dotenv_for_agent(os.path.basename(app_name), agents_dir)
|
||||
|
||||
eval_set = eval_sets_manager.get_eval_set(app_name, eval_set_id)
|
||||
|
||||
@ -663,9 +663,9 @@ def get_fast_api_app(
|
||||
@app.post("/run", response_model_exclude_none=True)
|
||||
async def agent_run(req: AgentRunRequest) -> list[Event]:
|
||||
# Connect to managed session if agent_engine_id is set.
|
||||
app_id = agent_engine_id if agent_engine_id else req.app_name
|
||||
app_name = agent_engine_id if agent_engine_id else req.app_name
|
||||
session = await session_service.get_session(
|
||||
app_name=app_id, user_id=req.user_id, session_id=req.session_id
|
||||
app_name=app_name, user_id=req.user_id, session_id=req.session_id
|
||||
)
|
||||
if not session:
|
||||
raise HTTPException(status_code=404, detail="Session not found")
|
||||
@ -684,10 +684,10 @@ def get_fast_api_app(
|
||||
@app.post("/run_sse")
|
||||
async def agent_run_sse(req: AgentRunRequest) -> StreamingResponse:
|
||||
# Connect to managed session if agent_engine_id is set.
|
||||
app_id = agent_engine_id if agent_engine_id else req.app_name
|
||||
app_name = agent_engine_id if agent_engine_id else req.app_name
|
||||
# SSE endpoint
|
||||
session = await session_service.get_session(
|
||||
app_name=app_id, user_id=req.user_id, session_id=req.session_id
|
||||
app_name=app_name, user_id=req.user_id, session_id=req.session_id
|
||||
)
|
||||
if not session:
|
||||
raise HTTPException(status_code=404, detail="Session not found")
|
||||
@ -726,9 +726,9 @@ def get_fast_api_app(
|
||||
app_name: str, user_id: str, session_id: str, event_id: str
|
||||
):
|
||||
# Connect to managed session if agent_engine_id is set.
|
||||
app_id = agent_engine_id if agent_engine_id else app_name
|
||||
app_name = agent_engine_id if agent_engine_id else app_name
|
||||
session = await session_service.get_session(
|
||||
app_name=app_id, user_id=user_id, session_id=session_id
|
||||
app_name=app_name, user_id=user_id, session_id=session_id
|
||||
)
|
||||
session_events = session.events if session else []
|
||||
event = next((x for x in session_events if x.id == event_id), None)
|
||||
@ -783,9 +783,9 @@ def get_fast_api_app(
|
||||
await websocket.accept()
|
||||
|
||||
# Connect to managed session if agent_engine_id is set.
|
||||
app_id = agent_engine_id if agent_engine_id else app_name
|
||||
app_name = agent_engine_id if agent_engine_id else app_name
|
||||
session = await session_service.get_session(
|
||||
app_name=app_id, user_id=user_id, session_id=session_id
|
||||
app_name=app_name, user_id=user_id, session_id=session_id
|
||||
)
|
||||
if not session:
|
||||
# Accept first so that the client is aware of connection establishment,
|
||||
@ -855,7 +855,7 @@ def get_fast_api_app(
|
||||
|
||||
async def _get_runner_async(app_name: str) -> Runner:
|
||||
"""Returns the runner for the given app."""
|
||||
envs.load_dotenv_for_agent(os.path.basename(app_name), agent_dir)
|
||||
envs.load_dotenv_for_agent(os.path.basename(app_name), agents_dir)
|
||||
if app_name in runner_dict:
|
||||
return runner_dict[app_name]
|
||||
root_agent = await _get_root_agent_async(app_name)
|
||||
|
@ -36,8 +36,8 @@ def _sanitize_eval_set_result_name(eval_set_result_name: str) -> str:
|
||||
class LocalEvalSetResultsManager(EvalSetResultsManager):
|
||||
"""An EvalSetResult manager that stores eval set results locally on disk."""
|
||||
|
||||
def __init__(self, agent_dir: str):
|
||||
self._agent_dir = agent_dir
|
||||
def __init__(self, agents_dir: str):
|
||||
self._agents_dir = agents_dir
|
||||
|
||||
@override
|
||||
def save_eval_set_result(
|
||||
@ -108,4 +108,4 @@ class LocalEvalSetResultsManager(EvalSetResultsManager):
|
||||
return eval_result_files
|
||||
|
||||
def _get_eval_history_dir(self, app_name: str) -> str:
|
||||
return os.path.join(self._agent_dir, app_name, _ADK_EVAL_HISTORY_DIR)
|
||||
return os.path.join(self._agents_dir, app_name, _ADK_EVAL_HISTORY_DIR)
|
||||
|
@ -182,8 +182,8 @@ def load_eval_set_from_file(
|
||||
class LocalEvalSetsManager(EvalSetsManager):
|
||||
"""An EvalSets manager that stores eval sets locally on disk."""
|
||||
|
||||
def __init__(self, agent_dir: str):
|
||||
self._agent_dir = agent_dir
|
||||
def __init__(self, agents_dir: str):
|
||||
self._agents_dir = agents_dir
|
||||
|
||||
@override
|
||||
def get_eval_set(self, app_name: str, eval_set_id: str) -> EvalSet:
|
||||
@ -216,7 +216,7 @@ class LocalEvalSetsManager(EvalSetsManager):
|
||||
@override
|
||||
def list_eval_sets(self, app_name: str) -> list[str]:
|
||||
"""Returns a list of EvalSets that belong to the given app_name."""
|
||||
eval_set_file_path = os.path.join(self._agent_dir, app_name)
|
||||
eval_set_file_path = os.path.join(self._agents_dir, app_name)
|
||||
eval_sets = []
|
||||
for file in os.listdir(eval_set_file_path):
|
||||
if file.endswith(_EVAL_SET_FILE_EXTENSION):
|
||||
@ -247,7 +247,7 @@ class LocalEvalSetsManager(EvalSetsManager):
|
||||
|
||||
def _get_eval_set_file_path(self, app_name: str, eval_set_id: str) -> str:
|
||||
return os.path.join(
|
||||
self._agent_dir,
|
||||
self._agents_dir,
|
||||
app_name,
|
||||
eval_set_id + _EVAL_SET_FILE_EXTENSION,
|
||||
)
|
||||
|
@ -308,7 +308,7 @@ def test_app(mock_session_service, mock_artifact_service, mock_memory_service):
|
||||
):
|
||||
# Get the FastAPI app, but don't actually run it
|
||||
app = get_fast_api_app(
|
||||
agent_dir=".", web=True, session_db_url="", allow_origins=["*"]
|
||||
agents_dir=".", web=True, session_db_url="", allow_origins=["*"]
|
||||
)
|
||||
|
||||
# Create a TestClient that doesn't start a real server
|
||||
|
Loading…
Reference in New Issue
Block a user