mirror of
https://github.com/EvolutionAPI/adk-python.git
synced 2025-07-14 09:51:25 -06:00
Refactor Eval Set Management into its own class.
PiperOrigin-RevId: 758378377
This commit is contained in:
parent
303af440ee
commit
cf06cc507a
@ -63,6 +63,7 @@ from ..agents.llm_agent import Agent
|
|||||||
from ..agents.llm_agent import LlmAgent
|
from ..agents.llm_agent import LlmAgent
|
||||||
from ..agents.run_config import StreamingMode
|
from ..agents.run_config import StreamingMode
|
||||||
from ..artifacts import InMemoryArtifactService
|
from ..artifacts import InMemoryArtifactService
|
||||||
|
from ..evaluation.local_eval_sets_manager import LocalEvalSetsManager
|
||||||
from ..events.event import Event
|
from ..events.event import Event
|
||||||
from ..memory.in_memory_memory_service import InMemoryMemoryService
|
from ..memory.in_memory_memory_service import InMemoryMemoryService
|
||||||
from ..runners import Runner
|
from ..runners import Runner
|
||||||
@ -252,6 +253,8 @@ def get_fast_api_app(
|
|||||||
artifact_service = InMemoryArtifactService()
|
artifact_service = InMemoryArtifactService()
|
||||||
memory_service = InMemoryMemoryService()
|
memory_service = InMemoryMemoryService()
|
||||||
|
|
||||||
|
eval_sets_manager = LocalEvalSetsManager(agent_dir=agent_dir)
|
||||||
|
|
||||||
# Build the Session service
|
# Build the Session service
|
||||||
agent_engine_id = ""
|
agent_engine_id = ""
|
||||||
if session_db_url:
|
if session_db_url:
|
||||||
@ -401,28 +404,13 @@ def get_fast_api_app(
|
|||||||
eval_set_id: str,
|
eval_set_id: str,
|
||||||
):
|
):
|
||||||
"""Creates an eval set, given the id."""
|
"""Creates an eval set, given the id."""
|
||||||
pattern = r"^[a-zA-Z0-9_]+$"
|
try:
|
||||||
if not bool(re.fullmatch(pattern, eval_set_id)):
|
eval_sets_manager.create_eval_set(app_name, eval_set_id)
|
||||||
|
except ValueError as ve:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=400,
|
status_code=400,
|
||||||
detail=(
|
detail=str(ve),
|
||||||
f"Invalid eval set id. Eval set id should have the `{pattern}`"
|
) from ve
|
||||||
" format"
|
|
||||||
),
|
|
||||||
)
|
|
||||||
# Define the file path
|
|
||||||
new_eval_set_path = _get_eval_set_file_path(
|
|
||||||
app_name, agent_dir, eval_set_id
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.info("Creating eval set file `%s`", new_eval_set_path)
|
|
||||||
|
|
||||||
if not os.path.exists(new_eval_set_path):
|
|
||||||
# Write the JSON string to the file
|
|
||||||
logger.info("Eval set file doesn't exist, we will create a new one.")
|
|
||||||
with open(new_eval_set_path, "w") as f:
|
|
||||||
empty_content = json.dumps([], indent=2)
|
|
||||||
f.write(empty_content)
|
|
||||||
|
|
||||||
@app.get(
|
@app.get(
|
||||||
"/apps/{app_name}/eval_sets",
|
"/apps/{app_name}/eval_sets",
|
||||||
@ -430,15 +418,7 @@ def get_fast_api_app(
|
|||||||
)
|
)
|
||||||
def list_eval_sets(app_name: str) -> list[str]:
|
def list_eval_sets(app_name: str) -> list[str]:
|
||||||
"""Lists all eval sets for the given app."""
|
"""Lists all eval sets for the given app."""
|
||||||
eval_set_file_path = os.path.join(agent_dir, app_name)
|
return eval_sets_manager.list_eval_sets(app_name)
|
||||||
eval_sets = []
|
|
||||||
for file in os.listdir(eval_set_file_path):
|
|
||||||
if file.endswith(_EVAL_SET_FILE_EXTENSION):
|
|
||||||
eval_sets.append(
|
|
||||||
os.path.basename(file).removesuffix(_EVAL_SET_FILE_EXTENSION)
|
|
||||||
)
|
|
||||||
|
|
||||||
return sorted(eval_sets)
|
|
||||||
|
|
||||||
@app.post(
|
@app.post(
|
||||||
"/apps/{app_name}/eval_sets/{eval_set_id}/add_session",
|
"/apps/{app_name}/eval_sets/{eval_set_id}/add_session",
|
||||||
@ -447,33 +427,11 @@ def get_fast_api_app(
|
|||||||
async def add_session_to_eval_set(
|
async def add_session_to_eval_set(
|
||||||
app_name: str, eval_set_id: str, req: AddSessionToEvalSetRequest
|
app_name: str, eval_set_id: str, req: AddSessionToEvalSetRequest
|
||||||
):
|
):
|
||||||
pattern = r"^[a-zA-Z0-9_]+$"
|
|
||||||
if not bool(re.fullmatch(pattern, req.eval_id)):
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=400,
|
|
||||||
detail=f"Invalid eval id. Eval id should have the `{pattern}` format",
|
|
||||||
)
|
|
||||||
|
|
||||||
# Get the session
|
# Get the session
|
||||||
session = session_service.get_session(
|
session = session_service.get_session(
|
||||||
app_name=app_name, user_id=req.user_id, session_id=req.session_id
|
app_name=app_name, user_id=req.user_id, session_id=req.session_id
|
||||||
)
|
)
|
||||||
assert session, "Session not found."
|
assert session, "Session not found."
|
||||||
# Load the eval set file data
|
|
||||||
eval_set_file_path = _get_eval_set_file_path(
|
|
||||||
app_name, agent_dir, eval_set_id
|
|
||||||
)
|
|
||||||
with open(eval_set_file_path, "r") as file:
|
|
||||||
eval_set_data = json.load(file) # Load JSON into a list
|
|
||||||
|
|
||||||
if [x for x in eval_set_data if x["name"] == req.eval_id]:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=400,
|
|
||||||
detail=(
|
|
||||||
f"Eval id `{req.eval_id}` already exists in `{eval_set_id}`"
|
|
||||||
" eval set."
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
# Convert the session data to evaluation format
|
# Convert the session data to evaluation format
|
||||||
test_data = evals.convert_session_to_eval_format(session)
|
test_data = evals.convert_session_to_eval_format(session)
|
||||||
@ -483,7 +441,7 @@ def get_fast_api_app(
|
|||||||
await _get_root_agent_async(app_name)
|
await _get_root_agent_async(app_name)
|
||||||
)
|
)
|
||||||
|
|
||||||
eval_set_data.append({
|
eval_case = {
|
||||||
"name": req.eval_id,
|
"name": req.eval_id,
|
||||||
"data": test_data,
|
"data": test_data,
|
||||||
"initial_session": {
|
"initial_session": {
|
||||||
@ -491,10 +449,11 @@ def get_fast_api_app(
|
|||||||
"app_name": app_name,
|
"app_name": app_name,
|
||||||
"user_id": req.user_id,
|
"user_id": req.user_id,
|
||||||
},
|
},
|
||||||
})
|
}
|
||||||
# Serialize the test data to JSON and write to the eval set file.
|
try:
|
||||||
with open(eval_set_file_path, "w") as f:
|
eval_sets_manager.add_eval_case(app_name, eval_set_id, eval_case)
|
||||||
f.write(json.dumps(eval_set_data, indent=2))
|
except ValueError as ve:
|
||||||
|
raise HTTPException(status_code=400, detail=str(ve)) from ve
|
||||||
|
|
||||||
@app.get(
|
@app.get(
|
||||||
"/apps/{app_name}/eval_sets/{eval_set_id}/evals",
|
"/apps/{app_name}/eval_sets/{eval_set_id}/evals",
|
||||||
@ -505,12 +464,7 @@ def get_fast_api_app(
|
|||||||
eval_set_id: str,
|
eval_set_id: str,
|
||||||
) -> list[str]:
|
) -> list[str]:
|
||||||
"""Lists all evals in an eval set."""
|
"""Lists all evals in an eval set."""
|
||||||
# Load the eval set file data
|
eval_set_data = eval_sets_manager.get_eval_set(app_name, eval_set_id)
|
||||||
eval_set_file_path = _get_eval_set_file_path(
|
|
||||||
app_name, agent_dir, eval_set_id
|
|
||||||
)
|
|
||||||
with open(eval_set_file_path, "r") as file:
|
|
||||||
eval_set_data = json.load(file) # Load JSON into a list
|
|
||||||
|
|
||||||
return sorted([x["name"] for x in eval_set_data])
|
return sorted([x["name"] for x in eval_set_data])
|
||||||
|
|
||||||
|
40
src/google/adk/evaluation/eval_sets_manager.py
Normal file
40
src/google/adk/evaluation/eval_sets_manager.py
Normal file
@ -0,0 +1,40 @@
|
|||||||
|
# Copyright 2025 Google LLC
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
|
||||||
|
class EvalSetsManager(ABC):
|
||||||
|
"""An interface to manage an Eval Sets."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_eval_set(self, app_name: str, eval_set_id: str) -> Any:
|
||||||
|
"""Returns an EvalSet identified by an app_name and eval_set_id."""
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def create_eval_set(self, app_name: str, eval_set_id: str):
|
||||||
|
"""Creates an empty EvalSet given the app_name and eval_set_id."""
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def list_eval_sets(self, app_name: str) -> list[str]:
|
||||||
|
"""Returns a list of EvalSets that belong to the given app_name."""
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def add_eval_case(self, app_name: str, eval_set_id: str, eval_case: Any):
|
||||||
|
"""Adds the given EvalCase to an existing EvalSet identified by app_name and eval_set_id."""
|
||||||
|
raise NotImplementedError()
|
106
src/google/adk/evaluation/local_eval_sets_manager.py
Normal file
106
src/google/adk/evaluation/local_eval_sets_manager.py
Normal file
@ -0,0 +1,106 @@
|
|||||||
|
# Copyright 2025 Google LLC
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import re
|
||||||
|
from typing import Any
|
||||||
|
from typing_extensions import override
|
||||||
|
from .eval_sets_manager import EvalSetsManager
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
_EVAL_SET_FILE_EXTENSION = ".evalset.json"
|
||||||
|
|
||||||
|
|
||||||
|
class LocalEvalSetsManager(EvalSetsManager):
|
||||||
|
"""An EvalSets manager that stores eval sets locally on disk."""
|
||||||
|
|
||||||
|
def __init__(self, agent_dir: str):
|
||||||
|
self._agent_dir = agent_dir
|
||||||
|
|
||||||
|
@override
|
||||||
|
def get_eval_set(self, app_name: str, eval_set_id: str) -> Any:
|
||||||
|
"""Returns an EvalSet identified by an app_name and eval_set_id."""
|
||||||
|
# Load the eval set file data
|
||||||
|
eval_set_file_path = self._get_eval_set_file_path(app_name, eval_set_id)
|
||||||
|
with open(eval_set_file_path, "r") as file:
|
||||||
|
return json.load(file) # Load JSON into a list
|
||||||
|
|
||||||
|
@override
|
||||||
|
def create_eval_set(self, app_name: str, eval_set_id: str):
|
||||||
|
"""Creates an empty EvalSet given the app_name and eval_set_id."""
|
||||||
|
self._validate_id(id_name="Eval Set Id", id_value=eval_set_id)
|
||||||
|
|
||||||
|
# Define the file path
|
||||||
|
new_eval_set_path = self._get_eval_set_file_path(app_name, eval_set_id)
|
||||||
|
|
||||||
|
logger.info("Creating eval set file `%s`", new_eval_set_path)
|
||||||
|
|
||||||
|
if not os.path.exists(new_eval_set_path):
|
||||||
|
# Write the JSON string to the file
|
||||||
|
logger.info("Eval set file doesn't exist, we will create a new one.")
|
||||||
|
with open(new_eval_set_path, "w") as f:
|
||||||
|
empty_content = json.dumps([], indent=2)
|
||||||
|
f.write(empty_content)
|
||||||
|
|
||||||
|
@override
|
||||||
|
def list_eval_sets(self, app_name: str) -> list[str]:
|
||||||
|
"""Returns a list of EvalSets that belong to the given app_name."""
|
||||||
|
eval_set_file_path = os.path.join(self._agent_dir, app_name)
|
||||||
|
eval_sets = []
|
||||||
|
for file in os.listdir(eval_set_file_path):
|
||||||
|
if file.endswith(_EVAL_SET_FILE_EXTENSION):
|
||||||
|
eval_sets.append(
|
||||||
|
os.path.basename(file).removesuffix(_EVAL_SET_FILE_EXTENSION)
|
||||||
|
)
|
||||||
|
|
||||||
|
return sorted(eval_sets)
|
||||||
|
|
||||||
|
@override
|
||||||
|
def add_eval_case(self, app_name: str, eval_set_id: str, eval_case: Any):
|
||||||
|
"""Adds the given EvalCase to an existing EvalSet identified by app_name and eval_set_id."""
|
||||||
|
eval_case_id = eval_case["name"]
|
||||||
|
self._validate_id(id_name="Eval Case Id", id_value=eval_case_id)
|
||||||
|
|
||||||
|
# Load the eval set file data
|
||||||
|
eval_set_file_path = self._get_eval_set_file_path(app_name, eval_set_id)
|
||||||
|
with open(eval_set_file_path, "r") as file:
|
||||||
|
eval_set_data = json.load(file) # Load JSON into a list
|
||||||
|
|
||||||
|
if [x for x in eval_set_data if x["name"] == eval_case_id]:
|
||||||
|
raise ValueError(
|
||||||
|
f"Eval id `{eval_case_id}` already exists in `{eval_set_id}`"
|
||||||
|
" eval set.",
|
||||||
|
)
|
||||||
|
|
||||||
|
eval_set_data.append(eval_case)
|
||||||
|
# Serialize the test data to JSON and write to the eval set file.
|
||||||
|
with open(eval_set_file_path, "w") as f:
|
||||||
|
f.write(json.dumps(eval_set_data, indent=2))
|
||||||
|
|
||||||
|
def _get_eval_set_file_path(self, app_name: str, eval_set_id: str) -> str:
|
||||||
|
return os.path.join(
|
||||||
|
self._agent_dir,
|
||||||
|
app_name,
|
||||||
|
eval_set_id + _EVAL_SET_FILE_EXTENSION,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _validate_id(self, id_name: str, id_value: str):
|
||||||
|
pattern = r"^[a-zA-Z0-9_]+$"
|
||||||
|
if not bool(re.fullmatch(pattern, id_value)):
|
||||||
|
raise ValueError(
|
||||||
|
f"Invalid {id_name}. {id_name} should have the `{pattern}` format",
|
||||||
|
)
|
Loading…
Reference in New Issue
Block a user