mirror of
https://github.com/EvolutionAPI/adk-python.git
synced 2025-07-14 01:41:25 -06:00
refactor: rename agent_dir to agents_dir and rename app_id to app_name in fast_api.py to make it consistent among every endpoints
PiperOrigin-RevId: 763483339
This commit is contained in:
parent
6b89ceb49a
commit
be0786ea88
@ -355,7 +355,7 @@ def cli_eval(
|
|||||||
|
|
||||||
# Write eval set results.
|
# Write eval set results.
|
||||||
local_eval_set_results_manager = LocalEvalSetResultsManager(
|
local_eval_set_results_manager = LocalEvalSetResultsManager(
|
||||||
agent_dir=os.path.dirname(agent_module_file_path)
|
agents_dir=os.path.dirname(agent_module_file_path)
|
||||||
)
|
)
|
||||||
eval_set_id_to_eval_results = collections.defaultdict(list)
|
eval_set_id_to_eval_results = collections.defaultdict(list)
|
||||||
for eval_case_result in eval_results:
|
for eval_case_result in eval_results:
|
||||||
@ -500,7 +500,7 @@ def cli_web(
|
|||||||
)
|
)
|
||||||
|
|
||||||
app = get_fast_api_app(
|
app = get_fast_api_app(
|
||||||
agent_dir=agents_dir,
|
agents_dir=agents_dir,
|
||||||
session_db_url=session_db_url,
|
session_db_url=session_db_url,
|
||||||
allow_origins=allow_origins,
|
allow_origins=allow_origins,
|
||||||
web=True,
|
web=True,
|
||||||
@ -601,7 +601,7 @@ def cli_api_server(
|
|||||||
|
|
||||||
config = uvicorn.Config(
|
config = uvicorn.Config(
|
||||||
get_fast_api_app(
|
get_fast_api_app(
|
||||||
agent_dir=agents_dir,
|
agents_dir=agents_dir,
|
||||||
session_db_url=session_db_url,
|
session_db_url=session_db_url,
|
||||||
allow_origins=allow_origins,
|
allow_origins=allow_origins,
|
||||||
web=False,
|
web=False,
|
||||||
|
@ -191,7 +191,7 @@ class GetEventGraphResult(common.BaseModel):
|
|||||||
|
|
||||||
def get_fast_api_app(
|
def get_fast_api_app(
|
||||||
*,
|
*,
|
||||||
agent_dir: str,
|
agents_dir: str,
|
||||||
session_db_url: str = "",
|
session_db_url: str = "",
|
||||||
allow_origins: Optional[list[str]] = None,
|
allow_origins: Optional[list[str]] = None,
|
||||||
web: bool,
|
web: bool,
|
||||||
@ -210,7 +210,7 @@ def get_fast_api_app(
|
|||||||
memory_exporter = InMemoryExporter(session_trace_dict)
|
memory_exporter = InMemoryExporter(session_trace_dict)
|
||||||
provider.add_span_processor(export.SimpleSpanProcessor(memory_exporter))
|
provider.add_span_processor(export.SimpleSpanProcessor(memory_exporter))
|
||||||
if trace_to_cloud:
|
if trace_to_cloud:
|
||||||
envs.load_dotenv_for_agent("", agent_dir)
|
envs.load_dotenv_for_agent("", agents_dir)
|
||||||
if project_id := os.environ.get("GOOGLE_CLOUD_PROJECT", None):
|
if project_id := os.environ.get("GOOGLE_CLOUD_PROJECT", None):
|
||||||
processor = export.BatchSpanProcessor(
|
processor = export.BatchSpanProcessor(
|
||||||
CloudTraceSpanExporter(project_id=project_id)
|
CloudTraceSpanExporter(project_id=project_id)
|
||||||
@ -249,8 +249,8 @@ def get_fast_api_app(
|
|||||||
allow_headers=["*"],
|
allow_headers=["*"],
|
||||||
)
|
)
|
||||||
|
|
||||||
if agent_dir not in sys.path:
|
if agents_dir not in sys.path:
|
||||||
sys.path.append(agent_dir)
|
sys.path.append(agents_dir)
|
||||||
|
|
||||||
runner_dict = {}
|
runner_dict = {}
|
||||||
root_agent_dict = {}
|
root_agent_dict = {}
|
||||||
@ -259,8 +259,8 @@ def get_fast_api_app(
|
|||||||
artifact_service = InMemoryArtifactService()
|
artifact_service = InMemoryArtifactService()
|
||||||
memory_service = InMemoryMemoryService()
|
memory_service = InMemoryMemoryService()
|
||||||
|
|
||||||
eval_sets_manager = LocalEvalSetsManager(agent_dir=agent_dir)
|
eval_sets_manager = LocalEvalSetsManager(agents_dir=agents_dir)
|
||||||
eval_set_results_manager = LocalEvalSetResultsManager(agent_dir=agent_dir)
|
eval_set_results_manager = LocalEvalSetResultsManager(agents_dir=agents_dir)
|
||||||
|
|
||||||
# Build the Session service
|
# Build the Session service
|
||||||
agent_engine_id = ""
|
agent_engine_id = ""
|
||||||
@ -270,7 +270,7 @@ def get_fast_api_app(
|
|||||||
agent_engine_id = session_db_url.split("://")[1]
|
agent_engine_id = session_db_url.split("://")[1]
|
||||||
if not agent_engine_id:
|
if not agent_engine_id:
|
||||||
raise click.ClickException("Agent engine id can not be empty.")
|
raise click.ClickException("Agent engine id can not be empty.")
|
||||||
envs.load_dotenv_for_agent("", agent_dir)
|
envs.load_dotenv_for_agent("", agents_dir)
|
||||||
session_service = VertexAiSessionService(
|
session_service = VertexAiSessionService(
|
||||||
os.environ["GOOGLE_CLOUD_PROJECT"],
|
os.environ["GOOGLE_CLOUD_PROJECT"],
|
||||||
os.environ["GOOGLE_CLOUD_LOCATION"],
|
os.environ["GOOGLE_CLOUD_LOCATION"],
|
||||||
@ -282,7 +282,7 @@ def get_fast_api_app(
|
|||||||
|
|
||||||
@app.get("/list-apps")
|
@app.get("/list-apps")
|
||||||
def list_apps() -> list[str]:
|
def list_apps() -> list[str]:
|
||||||
base_path = Path.cwd() / agent_dir
|
base_path = Path.cwd() / agents_dir
|
||||||
if not base_path.exists():
|
if not base_path.exists():
|
||||||
raise HTTPException(status_code=404, detail="Path not found")
|
raise HTTPException(status_code=404, detail="Path not found")
|
||||||
if not base_path.is_dir():
|
if not base_path.is_dir():
|
||||||
@ -398,9 +398,9 @@ def get_fast_api_app(
|
|||||||
app_name=app_name, user_id=user_id, state=state
|
app_name=app_name, user_id=user_id, state=state
|
||||||
)
|
)
|
||||||
|
|
||||||
def _get_eval_set_file_path(app_name, agent_dir, eval_set_id) -> str:
|
def _get_eval_set_file_path(app_name, agents_dir, eval_set_id) -> str:
|
||||||
return os.path.join(
|
return os.path.join(
|
||||||
agent_dir,
|
agents_dir,
|
||||||
app_name,
|
app_name,
|
||||||
eval_set_id + _EVAL_SET_FILE_EXTENSION,
|
eval_set_id + _EVAL_SET_FILE_EXTENSION,
|
||||||
)
|
)
|
||||||
@ -490,7 +490,7 @@ def get_fast_api_app(
|
|||||||
|
|
||||||
# Create a mapping from eval set file to all the evals that needed to be
|
# Create a mapping from eval set file to all the evals that needed to be
|
||||||
# run.
|
# run.
|
||||||
envs.load_dotenv_for_agent(os.path.basename(app_name), agent_dir)
|
envs.load_dotenv_for_agent(os.path.basename(app_name), agents_dir)
|
||||||
|
|
||||||
eval_set = eval_sets_manager.get_eval_set(app_name, eval_set_id)
|
eval_set = eval_sets_manager.get_eval_set(app_name, eval_set_id)
|
||||||
|
|
||||||
@ -663,9 +663,9 @@ def get_fast_api_app(
|
|||||||
@app.post("/run", response_model_exclude_none=True)
|
@app.post("/run", response_model_exclude_none=True)
|
||||||
async def agent_run(req: AgentRunRequest) -> list[Event]:
|
async def agent_run(req: AgentRunRequest) -> list[Event]:
|
||||||
# Connect to managed session if agent_engine_id is set.
|
# Connect to managed session if agent_engine_id is set.
|
||||||
app_id = agent_engine_id if agent_engine_id else req.app_name
|
app_name = agent_engine_id if agent_engine_id else req.app_name
|
||||||
session = await session_service.get_session(
|
session = await session_service.get_session(
|
||||||
app_name=app_id, user_id=req.user_id, session_id=req.session_id
|
app_name=app_name, user_id=req.user_id, session_id=req.session_id
|
||||||
)
|
)
|
||||||
if not session:
|
if not session:
|
||||||
raise HTTPException(status_code=404, detail="Session not found")
|
raise HTTPException(status_code=404, detail="Session not found")
|
||||||
@ -684,10 +684,10 @@ def get_fast_api_app(
|
|||||||
@app.post("/run_sse")
|
@app.post("/run_sse")
|
||||||
async def agent_run_sse(req: AgentRunRequest) -> StreamingResponse:
|
async def agent_run_sse(req: AgentRunRequest) -> StreamingResponse:
|
||||||
# Connect to managed session if agent_engine_id is set.
|
# Connect to managed session if agent_engine_id is set.
|
||||||
app_id = agent_engine_id if agent_engine_id else req.app_name
|
app_name = agent_engine_id if agent_engine_id else req.app_name
|
||||||
# SSE endpoint
|
# SSE endpoint
|
||||||
session = await session_service.get_session(
|
session = await session_service.get_session(
|
||||||
app_name=app_id, user_id=req.user_id, session_id=req.session_id
|
app_name=app_name, user_id=req.user_id, session_id=req.session_id
|
||||||
)
|
)
|
||||||
if not session:
|
if not session:
|
||||||
raise HTTPException(status_code=404, detail="Session not found")
|
raise HTTPException(status_code=404, detail="Session not found")
|
||||||
@ -726,9 +726,9 @@ def get_fast_api_app(
|
|||||||
app_name: str, user_id: str, session_id: str, event_id: str
|
app_name: str, user_id: str, session_id: str, event_id: str
|
||||||
):
|
):
|
||||||
# Connect to managed session if agent_engine_id is set.
|
# Connect to managed session if agent_engine_id is set.
|
||||||
app_id = agent_engine_id if agent_engine_id else app_name
|
app_name = agent_engine_id if agent_engine_id else app_name
|
||||||
session = await session_service.get_session(
|
session = await session_service.get_session(
|
||||||
app_name=app_id, user_id=user_id, session_id=session_id
|
app_name=app_name, user_id=user_id, session_id=session_id
|
||||||
)
|
)
|
||||||
session_events = session.events if session else []
|
session_events = session.events if session else []
|
||||||
event = next((x for x in session_events if x.id == event_id), None)
|
event = next((x for x in session_events if x.id == event_id), None)
|
||||||
@ -783,9 +783,9 @@ def get_fast_api_app(
|
|||||||
await websocket.accept()
|
await websocket.accept()
|
||||||
|
|
||||||
# Connect to managed session if agent_engine_id is set.
|
# Connect to managed session if agent_engine_id is set.
|
||||||
app_id = agent_engine_id if agent_engine_id else app_name
|
app_name = agent_engine_id if agent_engine_id else app_name
|
||||||
session = await session_service.get_session(
|
session = await session_service.get_session(
|
||||||
app_name=app_id, user_id=user_id, session_id=session_id
|
app_name=app_name, user_id=user_id, session_id=session_id
|
||||||
)
|
)
|
||||||
if not session:
|
if not session:
|
||||||
# Accept first so that the client is aware of connection establishment,
|
# Accept first so that the client is aware of connection establishment,
|
||||||
@ -855,7 +855,7 @@ def get_fast_api_app(
|
|||||||
|
|
||||||
async def _get_runner_async(app_name: str) -> Runner:
|
async def _get_runner_async(app_name: str) -> Runner:
|
||||||
"""Returns the runner for the given app."""
|
"""Returns the runner for the given app."""
|
||||||
envs.load_dotenv_for_agent(os.path.basename(app_name), agent_dir)
|
envs.load_dotenv_for_agent(os.path.basename(app_name), agents_dir)
|
||||||
if app_name in runner_dict:
|
if app_name in runner_dict:
|
||||||
return runner_dict[app_name]
|
return runner_dict[app_name]
|
||||||
root_agent = await _get_root_agent_async(app_name)
|
root_agent = await _get_root_agent_async(app_name)
|
||||||
|
@ -36,8 +36,8 @@ def _sanitize_eval_set_result_name(eval_set_result_name: str) -> str:
|
|||||||
class LocalEvalSetResultsManager(EvalSetResultsManager):
|
class LocalEvalSetResultsManager(EvalSetResultsManager):
|
||||||
"""An EvalSetResult manager that stores eval set results locally on disk."""
|
"""An EvalSetResult manager that stores eval set results locally on disk."""
|
||||||
|
|
||||||
def __init__(self, agent_dir: str):
|
def __init__(self, agents_dir: str):
|
||||||
self._agent_dir = agent_dir
|
self._agents_dir = agents_dir
|
||||||
|
|
||||||
@override
|
@override
|
||||||
def save_eval_set_result(
|
def save_eval_set_result(
|
||||||
@ -108,4 +108,4 @@ class LocalEvalSetResultsManager(EvalSetResultsManager):
|
|||||||
return eval_result_files
|
return eval_result_files
|
||||||
|
|
||||||
def _get_eval_history_dir(self, app_name: str) -> str:
|
def _get_eval_history_dir(self, app_name: str) -> str:
|
||||||
return os.path.join(self._agent_dir, app_name, _ADK_EVAL_HISTORY_DIR)
|
return os.path.join(self._agents_dir, app_name, _ADK_EVAL_HISTORY_DIR)
|
||||||
|
@ -182,8 +182,8 @@ def load_eval_set_from_file(
|
|||||||
class LocalEvalSetsManager(EvalSetsManager):
|
class LocalEvalSetsManager(EvalSetsManager):
|
||||||
"""An EvalSets manager that stores eval sets locally on disk."""
|
"""An EvalSets manager that stores eval sets locally on disk."""
|
||||||
|
|
||||||
def __init__(self, agent_dir: str):
|
def __init__(self, agents_dir: str):
|
||||||
self._agent_dir = agent_dir
|
self._agents_dir = agents_dir
|
||||||
|
|
||||||
@override
|
@override
|
||||||
def get_eval_set(self, app_name: str, eval_set_id: str) -> EvalSet:
|
def get_eval_set(self, app_name: str, eval_set_id: str) -> EvalSet:
|
||||||
@ -216,7 +216,7 @@ class LocalEvalSetsManager(EvalSetsManager):
|
|||||||
@override
|
@override
|
||||||
def list_eval_sets(self, app_name: str) -> list[str]:
|
def list_eval_sets(self, app_name: str) -> list[str]:
|
||||||
"""Returns a list of EvalSets that belong to the given app_name."""
|
"""Returns a list of EvalSets that belong to the given app_name."""
|
||||||
eval_set_file_path = os.path.join(self._agent_dir, app_name)
|
eval_set_file_path = os.path.join(self._agents_dir, app_name)
|
||||||
eval_sets = []
|
eval_sets = []
|
||||||
for file in os.listdir(eval_set_file_path):
|
for file in os.listdir(eval_set_file_path):
|
||||||
if file.endswith(_EVAL_SET_FILE_EXTENSION):
|
if file.endswith(_EVAL_SET_FILE_EXTENSION):
|
||||||
@ -247,7 +247,7 @@ class LocalEvalSetsManager(EvalSetsManager):
|
|||||||
|
|
||||||
def _get_eval_set_file_path(self, app_name: str, eval_set_id: str) -> str:
|
def _get_eval_set_file_path(self, app_name: str, eval_set_id: str) -> str:
|
||||||
return os.path.join(
|
return os.path.join(
|
||||||
self._agent_dir,
|
self._agents_dir,
|
||||||
app_name,
|
app_name,
|
||||||
eval_set_id + _EVAL_SET_FILE_EXTENSION,
|
eval_set_id + _EVAL_SET_FILE_EXTENSION,
|
||||||
)
|
)
|
||||||
|
@ -308,7 +308,7 @@ def test_app(mock_session_service, mock_artifact_service, mock_memory_service):
|
|||||||
):
|
):
|
||||||
# Get the FastAPI app, but don't actually run it
|
# Get the FastAPI app, but don't actually run it
|
||||||
app = get_fast_api_app(
|
app = get_fast_api_app(
|
||||||
agent_dir=".", web=True, session_db_url="", allow_origins=["*"]
|
agents_dir=".", web=True, session_db_url="", allow_origins=["*"]
|
||||||
)
|
)
|
||||||
|
|
||||||
# Create a TestClient that doesn't start a real server
|
# Create a TestClient that doesn't start a real server
|
||||||
|
Loading…
Reference in New Issue
Block a user