mirror of
https://github.com/EvolutionAPI/adk-python.git
synced 2025-12-18 19:32:21 -06:00
No public description
PiperOrigin-RevId: 748777998
This commit is contained in:
committed by
hangfei
parent
290058eb05
commit
61d4be2d76
13
tests/unittests/cli/__init__.py
Normal file
13
tests/unittests/cli/__init__.py
Normal file
@@ -0,0 +1,13 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
13
tests/unittests/cli/utils/__init__.py
Normal file
13
tests/unittests/cli/utils/__init__.py
Normal file
@@ -0,0 +1,13 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
434
tests/unittests/cli/utils/test_evals.py
Normal file
434
tests/unittests/cli/utils/test_evals.py
Normal file
@@ -0,0 +1,434 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Tests for utilities in eval."""
|
||||
|
||||
|
||||
from google.adk.cli.utils.evals import convert_session_to_eval_format
|
||||
from google.adk.events.event import Event
|
||||
from google.adk.sessions.session import Session
|
||||
from google.genai import types
|
||||
|
||||
|
||||
def build_event(author: str, parts_content: list[dict]) -> Event:
|
||||
"""Builds an Event object with specified parts."""
|
||||
parts = []
|
||||
for p_data in parts_content:
|
||||
part_args = {}
|
||||
if "text" in p_data:
|
||||
part_args["text"] = p_data["text"]
|
||||
if "func_name" in p_data:
|
||||
part_args["function_call"] = types.FunctionCall(
|
||||
name=p_data.get("func_name"), args=p_data.get("func_args")
|
||||
)
|
||||
# Add other part types here if needed for future tests
|
||||
parts.append(types.Part(**part_args))
|
||||
return Event(author=author, content=types.Content(parts=parts))
|
||||
|
||||
|
||||
def test_convert_empty_session():
|
||||
"""Test conversion function with empty events list in Session."""
|
||||
# Pydantic models require mandatory fields for instantiation
|
||||
session_empty_events = Session(
|
||||
id="s1", app_name="app", user_id="u1", events=[]
|
||||
)
|
||||
assert not convert_session_to_eval_format(session_empty_events)
|
||||
|
||||
|
||||
def test_convert_none_session():
|
||||
"""Test conversion function with None Session."""
|
||||
assert not convert_session_to_eval_format(None)
|
||||
|
||||
|
||||
def test_convert_session_skips_initial_non_user_events():
|
||||
"""Test conversion function with only user events."""
|
||||
events = [
|
||||
build_event("model", [{"text": "Hello"}]),
|
||||
build_event("user", [{"text": "How are you?"}]),
|
||||
]
|
||||
session = Session(id="s1", app_name="app", user_id="u1", events=events)
|
||||
expected = [
|
||||
{
|
||||
"query": "How are you?",
|
||||
"expected_tool_use": [],
|
||||
"expected_intermediate_agent_responses": [],
|
||||
"reference": "",
|
||||
},
|
||||
]
|
||||
assert convert_session_to_eval_format(session) == expected
|
||||
|
||||
|
||||
def test_convert_single_turn_text_only():
|
||||
"""Test a single user query followed by a single agent text response."""
|
||||
events = [
|
||||
build_event("user", [{"text": "What is the time?"}]),
|
||||
build_event("root_agent", [{"text": "It is 3 PM."}]),
|
||||
]
|
||||
session = Session(id="s1", app_name="app", user_id="u1", events=events)
|
||||
expected = [{
|
||||
"query": "What is the time?",
|
||||
"expected_tool_use": [],
|
||||
"expected_intermediate_agent_responses": [],
|
||||
"reference": "It is 3 PM.",
|
||||
}]
|
||||
assert convert_session_to_eval_format(session) == expected
|
||||
|
||||
|
||||
def test_convert_single_turn_tool_only():
|
||||
"""Test a single user query followed by a single agent tool call."""
|
||||
events = [
|
||||
build_event("user", [{"text": "Get weather for Seattle"}]),
|
||||
build_event(
|
||||
"root_agent",
|
||||
[{"func_name": "get_weather", "func_args": {"city": "Seattle"}}],
|
||||
),
|
||||
]
|
||||
session = Session(id="s1", app_name="app", user_id="u1", events=events)
|
||||
expected = [{
|
||||
"query": "Get weather for Seattle",
|
||||
"expected_tool_use": [
|
||||
{"tool_name": "get_weather", "tool_input": {"city": "Seattle"}}
|
||||
],
|
||||
"expected_intermediate_agent_responses": [],
|
||||
"reference": "",
|
||||
}]
|
||||
assert convert_session_to_eval_format(session) == expected
|
||||
|
||||
|
||||
def test_convert_single_turn_multiple_tools_and_texts():
|
||||
"""Test a turn with multiple agent responses (tools and text)."""
|
||||
events = [
|
||||
build_event("user", [{"text": "Do task A then task B"}]),
|
||||
build_event(
|
||||
"root_agent", [{"text": "Okay, starting task A."}]
|
||||
), # Intermediate Text 1
|
||||
build_event(
|
||||
"root_agent", [{"func_name": "task_A", "func_args": {"param": 1}}]
|
||||
), # Tool 1
|
||||
build_event(
|
||||
"root_agent", [{"text": "Task A done. Now starting task B."}]
|
||||
), # Intermediate Text 2
|
||||
build_event(
|
||||
"another_agent", [{"func_name": "task_B", "func_args": {}}]
|
||||
), # Tool 2
|
||||
build_event(
|
||||
"root_agent", [{"text": "All tasks completed."}]
|
||||
), # Final Text (Reference)
|
||||
]
|
||||
session = Session(id="s1", app_name="app", user_id="u1", events=events)
|
||||
expected = [{
|
||||
"query": "Do task A then task B",
|
||||
"expected_tool_use": [
|
||||
{"tool_name": "task_A", "tool_input": {"param": 1}},
|
||||
{"tool_name": "task_B", "tool_input": {}},
|
||||
],
|
||||
"expected_intermediate_agent_responses": [
|
||||
{"author": "root_agent", "text": "Okay, starting task A."},
|
||||
{
|
||||
"author": "root_agent",
|
||||
"text": "Task A done. Now starting task B.",
|
||||
},
|
||||
],
|
||||
"reference": "All tasks completed.",
|
||||
}]
|
||||
assert convert_session_to_eval_format(session) == expected
|
||||
|
||||
|
||||
def test_convert_multi_turn_session():
|
||||
"""Test a session with multiple user/agent turns."""
|
||||
events = [
|
||||
build_event("user", [{"text": "Query 1"}]),
|
||||
build_event("agent", [{"text": "Response 1"}]),
|
||||
build_event("user", [{"text": "Query 2"}]),
|
||||
build_event("agent", [{"func_name": "tool_X", "func_args": {}}]),
|
||||
build_event("agent", [{"text": "Response 2"}]),
|
||||
]
|
||||
session = Session(id="s1", app_name="app", user_id="u1", events=events)
|
||||
expected = [
|
||||
{ # Turn 1
|
||||
"query": "Query 1",
|
||||
"expected_tool_use": [],
|
||||
"expected_intermediate_agent_responses": [],
|
||||
"reference": "Response 1",
|
||||
},
|
||||
{ # Turn 2
|
||||
"query": "Query 2",
|
||||
"expected_tool_use": [{"tool_name": "tool_X", "tool_input": {}}],
|
||||
"expected_intermediate_agent_responses": [],
|
||||
"reference": "Response 2",
|
||||
},
|
||||
]
|
||||
assert convert_session_to_eval_format(session) == expected
|
||||
|
||||
|
||||
def test_convert_agent_event_multiple_parts():
|
||||
"""Test an agent event with both text and tool call parts."""
|
||||
events = [
|
||||
build_event("user", [{"text": "Do something complex"}]),
|
||||
# Build event with multiple dicts in parts_content list
|
||||
build_event(
|
||||
"agent",
|
||||
[
|
||||
{"text": "Okay, doing it."},
|
||||
{"func_name": "complex_tool", "func_args": {"value": True}},
|
||||
],
|
||||
),
|
||||
build_event("agent", [{"text": "Finished."}]),
|
||||
]
|
||||
session = Session(id="s1", app_name="app", user_id="u1", events=events)
|
||||
expected = [{
|
||||
"query": "Do something complex",
|
||||
"expected_tool_use": [
|
||||
{"tool_name": "complex_tool", "tool_input": {"value": True}}
|
||||
],
|
||||
"expected_intermediate_agent_responses": [{
|
||||
"author": "agent",
|
||||
"text": "Okay, doing it.",
|
||||
}], # Text from first part of agent event
|
||||
"reference": "Finished.", # Text from second agent event
|
||||
}]
|
||||
assert convert_session_to_eval_format(session) == expected
|
||||
|
||||
|
||||
def test_convert_handles_missing_content_or_parts():
|
||||
"""Test that events missing content or parts are skipped gracefully."""
|
||||
events = [
|
||||
build_event("user", [{"text": "Query 1"}]),
|
||||
Event(author="agent", content=None), # Agent event missing content
|
||||
build_event("agent", [{"text": "Response 1"}]),
|
||||
Event(author="user", content=None), # User event missing content
|
||||
build_event("user", [{"text": "Query 2"}]),
|
||||
Event(
|
||||
author="agent", content=types.Content(parts=[])
|
||||
), # Agent event with empty parts list
|
||||
build_event("agent", [{"text": "Response 2"}]),
|
||||
# User event with content but no parts (or None parts)
|
||||
Event(author="user", content=types.Content(parts=None)),
|
||||
build_event("user", [{"text": "Query 3"}]),
|
||||
build_event("agent", [{"text": "Response 3"}]),
|
||||
]
|
||||
session = Session(id="s1", app_name="app", user_id="u1", events=events)
|
||||
expected = [
|
||||
{ # Turn 1 (from Query 1)
|
||||
"query": "Query 1",
|
||||
"expected_tool_use": [],
|
||||
"expected_intermediate_agent_responses": [],
|
||||
"reference": "Response 1",
|
||||
},
|
||||
{ # Turn 2 (from Query 2 - user event with None content was skipped)
|
||||
"query": "Query 2",
|
||||
"expected_tool_use": [],
|
||||
"expected_intermediate_agent_responses": [],
|
||||
"reference": "Response 2",
|
||||
},
|
||||
{ # Turn 3 (from Query 3 - user event with None parts was skipped)
|
||||
"query": "Query 3",
|
||||
"expected_tool_use": [],
|
||||
"expected_intermediate_agent_responses": [],
|
||||
"reference": "Response 3",
|
||||
},
|
||||
]
|
||||
assert convert_session_to_eval_format(session) == expected
|
||||
|
||||
|
||||
def test_convert_handles_missing_tool_name_or_args():
|
||||
"""Test tool calls with missing name or args."""
|
||||
events = [
|
||||
build_event("user", [{"text": "Call tools"}]),
|
||||
# Event where FunctionCall has name=None
|
||||
Event(
|
||||
author="agent",
|
||||
content=types.Content(
|
||||
parts=[
|
||||
types.Part(
|
||||
function_call=types.FunctionCall(name=None, args={"a": 1})
|
||||
)
|
||||
]
|
||||
),
|
||||
),
|
||||
# Event where FunctionCall has args=None
|
||||
Event(
|
||||
author="agent",
|
||||
content=types.Content(
|
||||
parts=[
|
||||
types.Part(
|
||||
function_call=types.FunctionCall(name="tool_B", args=None)
|
||||
)
|
||||
]
|
||||
),
|
||||
),
|
||||
# Event where FunctionCall part exists but FunctionCall object is None
|
||||
# (should skip)
|
||||
Event(
|
||||
author="agent",
|
||||
content=types.Content(
|
||||
parts=[types.Part(function_call=None, text="some text")]
|
||||
),
|
||||
),
|
||||
build_event("agent", [{"text": "Done"}]),
|
||||
]
|
||||
session = Session(id="s1", app_name="app", user_id="u1", events=events)
|
||||
expected = [{
|
||||
"query": "Call tools",
|
||||
"expected_tool_use": [
|
||||
{"tool_name": "", "tool_input": {"a": 1}}, # Defaults name to ""
|
||||
{"tool_name": "tool_B", "tool_input": {}}, # Defaults args to {}
|
||||
],
|
||||
"expected_intermediate_agent_responses": [{
|
||||
"author": "agent",
|
||||
"text": "some text",
|
||||
}], # Text part from the event where function_call was None
|
||||
"reference": "Done",
|
||||
}]
|
||||
assert convert_session_to_eval_format(session) == expected
|
||||
|
||||
|
||||
def test_convert_handles_missing_user_query_text():
|
||||
"""Test user event where the first part has no text."""
|
||||
events = [
|
||||
# Event where user part has text=None
|
||||
Event(
|
||||
author="user", content=types.Content(parts=[types.Part(text=None)])
|
||||
),
|
||||
build_event("agent", [{"text": "Response 1"}]),
|
||||
# Event where user part has text=""
|
||||
build_event("user", [{"text": ""}]),
|
||||
build_event("agent", [{"text": "Response 2"}]),
|
||||
]
|
||||
session = Session(id="s1", app_name="app", user_id="u1", events=events)
|
||||
expected = [
|
||||
{
|
||||
"query": "", # Defaults to "" if text is None
|
||||
"expected_tool_use": [],
|
||||
"expected_intermediate_agent_responses": [],
|
||||
"reference": "Response 1",
|
||||
},
|
||||
{
|
||||
"query": "", # Defaults to "" if text is ""
|
||||
"expected_tool_use": [],
|
||||
"expected_intermediate_agent_responses": [],
|
||||
"reference": "Response 2",
|
||||
},
|
||||
]
|
||||
assert convert_session_to_eval_format(session) == expected
|
||||
|
||||
|
||||
def test_convert_handles_empty_agent_text():
|
||||
"""Test agent responses with empty string text."""
|
||||
events = [
|
||||
build_event("user", [{"text": "Query"}]),
|
||||
build_event("agent", [{"text": "Okay"}]),
|
||||
build_event("agent", [{"text": ""}]), # Empty text
|
||||
build_event("agent", [{"text": "Done"}]),
|
||||
]
|
||||
session = Session(id="s1", app_name="app", user_id="u1", events=events)
|
||||
expected = [{
|
||||
"query": "Query",
|
||||
"expected_tool_use": [],
|
||||
"expected_intermediate_agent_responses": [
|
||||
{"author": "agent", "text": "Okay"},
|
||||
],
|
||||
"reference": "Done",
|
||||
}]
|
||||
assert convert_session_to_eval_format(session) == expected
|
||||
|
||||
|
||||
def test_convert_complex_sample_session():
|
||||
"""Test using the complex sample session provided earlier."""
|
||||
events = [
|
||||
build_event("user", [{"text": "What can you do?"}]),
|
||||
build_event(
|
||||
"root_agent",
|
||||
[{"text": "I can roll dice and check if numbers are prime. \n"}],
|
||||
),
|
||||
build_event(
|
||||
"user",
|
||||
[{
|
||||
"text": (
|
||||
"Roll a 8 sided dice and then check if 90 is a prime number"
|
||||
" or not."
|
||||
)
|
||||
}],
|
||||
),
|
||||
build_event(
|
||||
"root_agent",
|
||||
[{
|
||||
"func_name": "transfer_to_agent",
|
||||
"func_args": {"agent_name": "roll_agent"},
|
||||
}],
|
||||
),
|
||||
# Skipping FunctionResponse events as they don't have text/functionCall
|
||||
# parts used by converter
|
||||
build_event(
|
||||
"roll_agent", [{"func_name": "roll_die", "func_args": {"sides": 8}}]
|
||||
),
|
||||
# Skipping FunctionResponse
|
||||
build_event(
|
||||
"roll_agent",
|
||||
[
|
||||
{"text": "I rolled a 2. Now, I'll check if 90 is prime. \n\n"},
|
||||
{
|
||||
"func_name": "transfer_to_agent",
|
||||
"func_args": {"agent_name": "prime_agent"},
|
||||
},
|
||||
],
|
||||
),
|
||||
# Skipping FunctionResponse
|
||||
build_event(
|
||||
"prime_agent",
|
||||
[{"func_name": "check_prime", "func_args": {"nums": [90]}}],
|
||||
),
|
||||
# Skipping FunctionResponse
|
||||
build_event("prime_agent", [{"text": "90 is not a prime number. \n"}]),
|
||||
]
|
||||
session = Session(
|
||||
id="some_id",
|
||||
app_name="hello_world_ma",
|
||||
user_id="user",
|
||||
events=events,
|
||||
)
|
||||
expected = [
|
||||
{
|
||||
"query": "What can you do?",
|
||||
"expected_tool_use": [],
|
||||
"expected_intermediate_agent_responses": [],
|
||||
"reference": "I can roll dice and check if numbers are prime. \n",
|
||||
},
|
||||
{
|
||||
"query": (
|
||||
"Roll a 8 sided dice and then check if 90 is a prime number or"
|
||||
" not."
|
||||
),
|
||||
"expected_tool_use": [
|
||||
{
|
||||
"tool_name": "transfer_to_agent",
|
||||
"tool_input": {"agent_name": "roll_agent"},
|
||||
},
|
||||
{"tool_name": "roll_die", "tool_input": {"sides": 8}},
|
||||
{
|
||||
"tool_name": "transfer_to_agent",
|
||||
"tool_input": {"agent_name": "prime_agent"},
|
||||
}, # From combined event
|
||||
{"tool_name": "check_prime", "tool_input": {"nums": [90]}},
|
||||
],
|
||||
"expected_intermediate_agent_responses": [{
|
||||
"author": "roll_agent",
|
||||
"text": "I rolled a 2. Now, I'll check if 90 is prime. \n\n",
|
||||
}], # Text from combined event
|
||||
"reference": "90 is not a prime number. \n",
|
||||
},
|
||||
]
|
||||
|
||||
actual = convert_session_to_eval_format(session)
|
||||
assert actual == expected
|
||||
13
tests/unittests/evaluation/__init__.py
Normal file
13
tests/unittests/evaluation/__init__.py
Normal file
@@ -0,0 +1,13 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
259
tests/unittests/evaluation/test_response_evaluator.py
Normal file
259
tests/unittests/evaluation/test_response_evaluator.py
Normal file
@@ -0,0 +1,259 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Tests for the Response Evaluator."""
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
|
||||
from google.adk.evaluation.response_evaluator import ResponseEvaluator
|
||||
import pandas as pd
|
||||
import pytest
|
||||
from vertexai.preview.evaluation import MetricPromptTemplateExamples
|
||||
|
||||
# Mock object for the result normally returned by _perform_eval
|
||||
MOCK_EVAL_RESULT = MagicMock()
|
||||
MOCK_EVAL_RESULT.summary_metrics = {"mock_metric": 0.75, "another_mock": 3.5}
|
||||
# Add a metrics_table for testing _print_results interaction
|
||||
MOCK_EVAL_RESULT.metrics_table = pd.DataFrame({
|
||||
"prompt": ["mock_query1"],
|
||||
"response": ["mock_resp1"],
|
||||
"mock_metric": [0.75],
|
||||
})
|
||||
|
||||
SAMPLE_TURN_1_ALL_KEYS = {
|
||||
"query": "query1",
|
||||
"response": "response1",
|
||||
"actual_tool_use": [{"tool_name": "tool_a", "tool_input": {}}],
|
||||
"expected_tool_use": [{"tool_name": "tool_a", "tool_input": {}}],
|
||||
"reference": "reference1",
|
||||
}
|
||||
SAMPLE_TURN_2_MISSING_REF = {
|
||||
"query": "query2",
|
||||
"response": "response2",
|
||||
"actual_tool_use": [],
|
||||
"expected_tool_use": [],
|
||||
# "reference": "reference2" # Missing
|
||||
}
|
||||
SAMPLE_TURN_3_MISSING_EXP_TOOLS = {
|
||||
"query": "query3",
|
||||
"response": "response3",
|
||||
"actual_tool_use": [{"tool_name": "tool_b", "tool_input": {}}],
|
||||
# "expected_tool_use": [], # Missing
|
||||
"reference": "reference3",
|
||||
}
|
||||
SAMPLE_TURN_4_MINIMAL = {
|
||||
"query": "query4",
|
||||
"response": "response4",
|
||||
# Minimal keys, others missing
|
||||
}
|
||||
|
||||
|
||||
@patch(
|
||||
"google.adk.evaluation.response_evaluator.ResponseEvaluator._perform_eval"
|
||||
)
|
||||
class TestResponseEvaluator:
|
||||
"""A class to help organize "patch" that are applicabple to all tests."""
|
||||
|
||||
def test_evaluate_none_dataset_raises_value_error(self, mock_perform_eval):
|
||||
"""Test evaluate function raises ValueError for an empty list."""
|
||||
with pytest.raises(ValueError, match="The evaluation dataset is empty."):
|
||||
ResponseEvaluator.evaluate(None, ["response_evaluation_score"])
|
||||
mock_perform_eval.assert_not_called() # Ensure _perform_eval was not called
|
||||
|
||||
def test_evaluate_empty_dataset_raises_value_error(self, mock_perform_eval):
|
||||
"""Test evaluate function raises ValueError for an empty list."""
|
||||
with pytest.raises(ValueError, match="The evaluation dataset is empty."):
|
||||
ResponseEvaluator.evaluate([], ["response_evaluation_score"])
|
||||
mock_perform_eval.assert_not_called() # Ensure _perform_eval was not called
|
||||
|
||||
def test_evaluate_determines_metrics_correctly_for_perform_eval(
|
||||
self, mock_perform_eval
|
||||
):
|
||||
"""Test that the correct metrics list is passed to _perform_eval based on criteria/keys."""
|
||||
mock_perform_eval.return_value = MOCK_EVAL_RESULT
|
||||
|
||||
# Test case 1: Only Coherence
|
||||
raw_data_1 = [[SAMPLE_TURN_1_ALL_KEYS]]
|
||||
criteria_1 = ["response_evaluation_score"]
|
||||
ResponseEvaluator.evaluate(raw_data_1, criteria_1)
|
||||
_, kwargs = mock_perform_eval.call_args
|
||||
assert kwargs["metrics"] == [
|
||||
MetricPromptTemplateExamples.Pointwise.COHERENCE
|
||||
]
|
||||
mock_perform_eval.reset_mock() # Reset mock for next call
|
||||
|
||||
# Test case 2: Only Rouge
|
||||
raw_data_2 = [[SAMPLE_TURN_1_ALL_KEYS]]
|
||||
criteria_2 = ["response_match_score"]
|
||||
ResponseEvaluator.evaluate(raw_data_2, criteria_2)
|
||||
_, kwargs = mock_perform_eval.call_args
|
||||
assert kwargs["metrics"] == ["rouge_1"]
|
||||
mock_perform_eval.reset_mock()
|
||||
|
||||
# Test case 3: No metrics if keys missing in first turn
|
||||
raw_data_3 = [[SAMPLE_TURN_4_MINIMAL, SAMPLE_TURN_1_ALL_KEYS]]
|
||||
criteria_3 = ["response_evaluation_score", "response_match_score"]
|
||||
ResponseEvaluator.evaluate(raw_data_3, criteria_3)
|
||||
_, kwargs = mock_perform_eval.call_args
|
||||
assert kwargs["metrics"] == []
|
||||
mock_perform_eval.reset_mock()
|
||||
|
||||
# Test case 4: No metrics if criteria empty
|
||||
raw_data_4 = [[SAMPLE_TURN_1_ALL_KEYS]]
|
||||
criteria_4 = []
|
||||
ResponseEvaluator.evaluate(raw_data_4, criteria_4)
|
||||
_, kwargs = mock_perform_eval.call_args
|
||||
assert kwargs["metrics"] == []
|
||||
mock_perform_eval.reset_mock()
|
||||
|
||||
def test_evaluate_calls_perform_eval_correctly_all_metrics(
|
||||
self, mock_perform_eval
|
||||
):
|
||||
"""Test evaluate function calls _perform_eval with expected args when all criteria/keys are present."""
|
||||
# Arrange
|
||||
mock_perform_eval.return_value = (
|
||||
MOCK_EVAL_RESULT # Configure the mock return value
|
||||
)
|
||||
|
||||
raw_data = [[SAMPLE_TURN_1_ALL_KEYS]]
|
||||
criteria = ["response_evaluation_score", "response_match_score"]
|
||||
|
||||
# Act
|
||||
summary = ResponseEvaluator.evaluate(raw_data, criteria)
|
||||
|
||||
# Assert
|
||||
# 1. Check metrics determined by _get_metrics (passed to _perform_eval)
|
||||
expected_metrics_list = [
|
||||
MetricPromptTemplateExamples.Pointwise.COHERENCE,
|
||||
"rouge_1",
|
||||
]
|
||||
# 2. Check DataFrame prepared (passed to _perform_eval)
|
||||
expected_df_data = [{
|
||||
"prompt": "query1",
|
||||
"response": "response1",
|
||||
"actual_tool_use": [{"tool_name": "tool_a", "tool_input": {}}],
|
||||
"reference_trajectory": [{"tool_name": "tool_a", "tool_input": {}}],
|
||||
"reference": "reference1",
|
||||
}]
|
||||
expected_df = pd.DataFrame(expected_df_data)
|
||||
|
||||
# Assert _perform_eval was called once
|
||||
mock_perform_eval.assert_called_once()
|
||||
# Get the arguments passed to the mocked _perform_eval
|
||||
_, kwargs = mock_perform_eval.call_args
|
||||
# Check the 'dataset' keyword argument
|
||||
pd.testing.assert_frame_equal(kwargs["dataset"], expected_df)
|
||||
# Check the 'metrics' keyword argument
|
||||
assert kwargs["metrics"] == expected_metrics_list
|
||||
|
||||
# 3. Check the correct summary metrics are returned
|
||||
# (from mock_perform_eval's return value)
|
||||
assert summary == MOCK_EVAL_RESULT.summary_metrics
|
||||
|
||||
def test_evaluate_prepares_dataframe_correctly_for_perform_eval(
|
||||
self, mock_perform_eval
|
||||
):
|
||||
"""Test that the DataFrame is correctly flattened and renamed before passing to _perform_eval."""
|
||||
mock_perform_eval.return_value = MOCK_EVAL_RESULT
|
||||
|
||||
raw_data = [
|
||||
[SAMPLE_TURN_1_ALL_KEYS], # Conversation 1
|
||||
[
|
||||
SAMPLE_TURN_2_MISSING_REF,
|
||||
SAMPLE_TURN_3_MISSING_EXP_TOOLS,
|
||||
], # Conversation 2
|
||||
]
|
||||
criteria = [
|
||||
"response_match_score"
|
||||
] # Doesn't affect the DataFrame structure
|
||||
|
||||
ResponseEvaluator.evaluate(raw_data, criteria)
|
||||
|
||||
# Expected flattened and renamed data
|
||||
expected_df_data = [
|
||||
# Turn 1 (from SAMPLE_TURN_1_ALL_KEYS)
|
||||
{
|
||||
"prompt": "query1",
|
||||
"response": "response1",
|
||||
"actual_tool_use": [{"tool_name": "tool_a", "tool_input": {}}],
|
||||
"reference_trajectory": [{"tool_name": "tool_a", "tool_input": {}}],
|
||||
"reference": "reference1",
|
||||
},
|
||||
# Turn 2 (from SAMPLE_TURN_2_MISSING_REF)
|
||||
{
|
||||
"prompt": "query2",
|
||||
"response": "response2",
|
||||
"actual_tool_use": [],
|
||||
"reference_trajectory": [],
|
||||
# "reference": None # Missing key results in NaN in DataFrame
|
||||
# usually
|
||||
},
|
||||
# Turn 3 (from SAMPLE_TURN_3_MISSING_EXP_TOOLS)
|
||||
{
|
||||
"prompt": "query3",
|
||||
"response": "response3",
|
||||
"actual_tool_use": [{"tool_name": "tool_b", "tool_input": {}}],
|
||||
# "reference_trajectory": None, # Missing key results in NaN
|
||||
"reference": "reference3",
|
||||
},
|
||||
]
|
||||
# Need to be careful with missing keys -> NaN when creating DataFrame
|
||||
# Pandas handles this automatically when creating from list of dicts
|
||||
expected_df = pd.DataFrame(expected_df_data)
|
||||
|
||||
mock_perform_eval.assert_called_once()
|
||||
_, kwargs = mock_perform_eval.call_args
|
||||
# Compare the DataFrame passed to the mock
|
||||
pd.testing.assert_frame_equal(kwargs["dataset"], expected_df)
|
||||
|
||||
@patch(
|
||||
"google.adk.evaluation.response_evaluator.ResponseEvaluator._print_results"
|
||||
) # Mock the private print method
|
||||
def test_evaluate_print_detailed_results(
|
||||
self, mock_print_results, mock_perform_eval
|
||||
):
|
||||
"""Test _print_results function is called when print_detailed_results=True."""
|
||||
mock_perform_eval.return_value = (
|
||||
MOCK_EVAL_RESULT # Ensure _perform_eval returns our mock result
|
||||
)
|
||||
|
||||
raw_data = [[SAMPLE_TURN_1_ALL_KEYS]]
|
||||
criteria = ["response_match_score"]
|
||||
|
||||
ResponseEvaluator.evaluate(raw_data, criteria, print_detailed_results=True)
|
||||
|
||||
# Assert _perform_eval was called
|
||||
mock_perform_eval.assert_called_once()
|
||||
# Assert _print_results was called once with the result object
|
||||
# from _perform_eval
|
||||
mock_print_results.assert_called_once_with(MOCK_EVAL_RESULT)
|
||||
|
||||
@patch(
|
||||
"google.adk.evaluation.response_evaluator.ResponseEvaluator._print_results"
|
||||
)
|
||||
def test_evaluate_no_print_detailed_results(
|
||||
self, mock_print_results, mock_perform_eval
|
||||
):
|
||||
"""Test _print_results function is NOT called when print_detailed_results=False (default)."""
|
||||
mock_perform_eval.return_value = MOCK_EVAL_RESULT
|
||||
|
||||
raw_data = [[SAMPLE_TURN_1_ALL_KEYS]]
|
||||
criteria = ["response_match_score"]
|
||||
|
||||
ResponseEvaluator.evaluate(raw_data, criteria, print_detailed_results=False)
|
||||
|
||||
# Assert _perform_eval was called
|
||||
mock_perform_eval.assert_called_once()
|
||||
# Assert _print_results was NOT called
|
||||
mock_print_results.assert_not_called()
|
||||
271
tests/unittests/evaluation/test_trajectory_evaluator.py
Normal file
271
tests/unittests/evaluation/test_trajectory_evaluator.py
Normal file
@@ -0,0 +1,271 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Testings for the Trajectory Evaluator."""
|
||||
|
||||
import math
|
||||
from google.adk.evaluation.trajectory_evaluator import TrajectoryEvaluator
|
||||
import pytest
|
||||
|
||||
# Define reusable tool call structures
|
||||
TOOL_ROLL_DICE_16 = {"tool_name": "roll_die", "tool_input": {"sides": 16}}
|
||||
TOOL_ROLL_DICE_6 = {"tool_name": "roll_die", "tool_input": {"sides": 6}}
|
||||
TOOL_GET_WEATHER = {
|
||||
"tool_name": "get_weather",
|
||||
"tool_input": {"location": "Paris"},
|
||||
}
|
||||
TOOL_GET_WEATHER_SF = {
|
||||
"tool_name": "get_weather",
|
||||
"tool_input": {"location": "SF"},
|
||||
}
|
||||
|
||||
# Sample data for turns
|
||||
TURN_MATCH = {
|
||||
"query": "Q1",
|
||||
"response": "R1",
|
||||
"actual_tool_use": [TOOL_ROLL_DICE_16],
|
||||
"expected_tool_use": [TOOL_ROLL_DICE_16],
|
||||
}
|
||||
TURN_MISMATCH_INPUT = {
|
||||
"query": "Q2",
|
||||
"response": "R2",
|
||||
"actual_tool_use": [TOOL_ROLL_DICE_6],
|
||||
"expected_tool_use": [TOOL_ROLL_DICE_16],
|
||||
}
|
||||
TURN_MISMATCH_NAME = {
|
||||
"query": "Q3",
|
||||
"response": "R3",
|
||||
"actual_tool_use": [TOOL_GET_WEATHER],
|
||||
"expected_tool_use": [TOOL_ROLL_DICE_16],
|
||||
}
|
||||
TURN_MATCH_MULTIPLE = {
|
||||
"query": "Q4",
|
||||
"response": "R4",
|
||||
"actual_tool_use": [TOOL_GET_WEATHER, TOOL_ROLL_DICE_6],
|
||||
"expected_tool_use": [TOOL_GET_WEATHER, TOOL_ROLL_DICE_6],
|
||||
}
|
||||
TURN_MISMATCH_ORDER = {
|
||||
"query": "Q5",
|
||||
"response": "R5",
|
||||
"actual_tool_use": [TOOL_ROLL_DICE_6, TOOL_GET_WEATHER],
|
||||
"expected_tool_use": [TOOL_GET_WEATHER, TOOL_ROLL_DICE_6],
|
||||
}
|
||||
TURN_MISMATCH_LENGTH_ACTUAL_LONGER = {
|
||||
"query": "Q6",
|
||||
"response": "R6",
|
||||
"actual_tool_use": [TOOL_GET_WEATHER, TOOL_ROLL_DICE_6],
|
||||
"expected_tool_use": [TOOL_GET_WEATHER],
|
||||
}
|
||||
TURN_MISMATCH_LENGTH_EXPECTED_LONGER = {
|
||||
"query": "Q7",
|
||||
"response": "R7",
|
||||
"actual_tool_use": [TOOL_GET_WEATHER],
|
||||
"expected_tool_use": [TOOL_GET_WEATHER, TOOL_ROLL_DICE_6],
|
||||
}
|
||||
TURN_MATCH_WITH_MOCK_OUTPUT = {
|
||||
"query": "Q8",
|
||||
"response": "R8",
|
||||
"actual_tool_use": [TOOL_GET_WEATHER_SF],
|
||||
"expected_tool_use": [
|
||||
{**TOOL_GET_WEATHER_SF, "mock_tool_output": "Sunny"}
|
||||
], # Add mock output to expected
|
||||
}
|
||||
TURN_MATCH_EMPTY_TOOLS = {
|
||||
"query": "Q9",
|
||||
"response": "R9",
|
||||
"actual_tool_use": [],
|
||||
"expected_tool_use": [],
|
||||
}
|
||||
TURN_MISMATCH_EMPTY_VS_NONEMPTY = {
|
||||
"query": "Q10",
|
||||
"response": "R10",
|
||||
"actual_tool_use": [],
|
||||
"expected_tool_use": [TOOL_GET_WEATHER],
|
||||
}
|
||||
|
||||
|
||||
def test_evaluate_none_dataset_raises_value_error():
|
||||
"""Tests evaluate function raises ValueError for an empty list."""
|
||||
with pytest.raises(ValueError, match="The evaluation dataset is empty."):
|
||||
TrajectoryEvaluator.evaluate(None)
|
||||
|
||||
|
||||
def test_evaluate_empty_dataset_raises_value_error():
|
||||
"""Tests evaluate function raises ValueError for an empty list."""
|
||||
with pytest.raises(ValueError, match="The evaluation dataset is empty."):
|
||||
TrajectoryEvaluator.evaluate([])
|
||||
|
||||
|
||||
def test_evaluate_single_turn_match():
|
||||
"""Tests evaluate function with one conversation, one turn, perfect match."""
|
||||
eval_dataset = [[TURN_MATCH]]
|
||||
assert TrajectoryEvaluator.evaluate(eval_dataset) == 1.0
|
||||
|
||||
|
||||
def test_evaluate_single_turn_mismatch():
|
||||
"""Tests evaluate function with one conversation, one turn, mismatch."""
|
||||
eval_dataset = [[TURN_MISMATCH_INPUT]]
|
||||
assert TrajectoryEvaluator.evaluate(eval_dataset) == 0.0
|
||||
|
||||
|
||||
def test_evaluate_multiple_turns_all_match():
|
||||
"""Tests evaluate function with one conversation, multiple turns, all match."""
|
||||
eval_dataset = [[TURN_MATCH, TURN_MATCH_MULTIPLE, TURN_MATCH_EMPTY_TOOLS]]
|
||||
assert TrajectoryEvaluator.evaluate(eval_dataset) == 1.0
|
||||
|
||||
|
||||
def test_evaluate_multiple_turns_mixed():
|
||||
"""Tests evaluate function with one conversation, mixed match/mismatch turns."""
|
||||
eval_dataset = [
|
||||
[TURN_MATCH, TURN_MISMATCH_NAME, TURN_MATCH_MULTIPLE, TURN_MISMATCH_ORDER]
|
||||
]
|
||||
# Expected: (1.0 + 0.0 + 1.0 + 0.0) / 4 = 0.5
|
||||
assert TrajectoryEvaluator.evaluate(eval_dataset) == 0.5
|
||||
|
||||
|
||||
def test_evaluate_multiple_conversations_mixed():
|
||||
"""Tests evaluate function with multiple conversations, mixed turns."""
|
||||
eval_dataset = [
|
||||
[TURN_MATCH, TURN_MISMATCH_INPUT], # Conv 1: 1.0, 0.0 -> Avg 0.5
|
||||
[TURN_MATCH_MULTIPLE], # Conv 2: 1.0 -> Avg 1.0
|
||||
[
|
||||
TURN_MISMATCH_ORDER,
|
||||
TURN_MISMATCH_LENGTH_ACTUAL_LONGER,
|
||||
TURN_MATCH,
|
||||
], # Conv 3: 0.0, 0.0, 1.0 -> Avg 1/3
|
||||
]
|
||||
# Expected: (1.0 + 0.0 + 1.0 + 0.0 + 0.0 + 1.0) / 6 = 3.0 / 6 = 0.5
|
||||
assert TrajectoryEvaluator.evaluate(eval_dataset) == 0.5
|
||||
|
||||
|
||||
def test_evaluate_ignores_mock_tool_output_in_expected():
|
||||
"""Tests evaluate function correctly compares even if expected has mock_tool_output."""
|
||||
eval_dataset = [[TURN_MATCH_WITH_MOCK_OUTPUT]]
|
||||
assert TrajectoryEvaluator.evaluate(eval_dataset) == 1.0
|
||||
|
||||
|
||||
def test_evaluate_match_empty_tool_lists():
|
||||
"""Tests evaluate function correctly matches empty tool lists."""
|
||||
eval_dataset = [[TURN_MATCH_EMPTY_TOOLS]]
|
||||
assert TrajectoryEvaluator.evaluate(eval_dataset) == 1.0
|
||||
|
||||
|
||||
def test_evaluate_mismatch_empty_vs_nonempty():
|
||||
"""Tests evaluate function correctly mismatches empty vs non-empty tool lists."""
|
||||
eval_dataset = [[TURN_MISMATCH_EMPTY_VS_NONEMPTY]]
|
||||
assert TrajectoryEvaluator.evaluate(eval_dataset) == 0.0
|
||||
eval_dataset_rev = [[{
|
||||
**TURN_MISMATCH_EMPTY_VS_NONEMPTY, # Swap actual/expected
|
||||
"actual_tool_use": [TOOL_GET_WEATHER],
|
||||
"expected_tool_use": [],
|
||||
}]]
|
||||
assert TrajectoryEvaluator.evaluate(eval_dataset_rev) == 0.0
|
||||
|
||||
|
||||
def test_evaluate_dataset_with_empty_conversation():
|
||||
"""Tests evaluate function handles dataset containing an empty conversation list."""
|
||||
eval_dataset = [[TURN_MATCH], []] # One valid conversation, one empty
|
||||
# Should only evaluate the first conversation -> 1.0 / 1 turn = 1.0
|
||||
assert TrajectoryEvaluator.evaluate(eval_dataset) == 1.0
|
||||
|
||||
|
||||
def test_evaluate_dataset_only_empty_conversation():
|
||||
"""Tests evaluate function handles dataset with only an empty conversation."""
|
||||
eval_dataset = [[]]
|
||||
# No rows evaluated, mean of empty series is NaN
|
||||
# Depending on desired behavior, this could be 0.0 or NaN. The code returns
|
||||
# NaN.
|
||||
assert math.isnan(TrajectoryEvaluator.evaluate(eval_dataset))
|
||||
|
||||
|
||||
def test_evaluate_print_detailed_results(capsys):
|
||||
"""Tests evaluate function runs with print_detailed_results=True and prints something."""
|
||||
eval_dataset = [[TURN_MATCH, TURN_MISMATCH_INPUT]]
|
||||
TrajectoryEvaluator.evaluate(eval_dataset, print_detailed_results=True)
|
||||
captured = capsys.readouterr()
|
||||
assert "query" in captured.out # Check if the results table header is printed
|
||||
assert "R1" in captured.out # Check if some data is printed
|
||||
assert "Failures:" in captured.out # Check if failures header is printed
|
||||
assert "Q2" in captured.out # Check if the failing query is printed
|
||||
|
||||
|
||||
def test_evaluate_no_failures_print(capsys):
|
||||
"""Tests evaluate function does not print Failures section when all turns match."""
|
||||
eval_dataset = [[TURN_MATCH]]
|
||||
TrajectoryEvaluator.evaluate(eval_dataset, print_detailed_results=True)
|
||||
captured = capsys.readouterr()
|
||||
assert "query" in captured.out # Results table should still print
|
||||
assert "Failures:" not in captured.out # Failures section should NOT print
|
||||
|
||||
|
||||
def test_are_tools_equal_identical():
|
||||
"""Tests are_tools_equal function with identical lists."""
|
||||
list_a = [TOOL_GET_WEATHER, TOOL_ROLL_DICE_6]
|
||||
list_b = [TOOL_GET_WEATHER, TOOL_ROLL_DICE_6]
|
||||
assert TrajectoryEvaluator.are_tools_equal(list_a, list_b)
|
||||
|
||||
|
||||
def test_are_tools_equal_empty():
|
||||
"""Tests are_tools_equal function with empty lists."""
|
||||
assert TrajectoryEvaluator.are_tools_equal([], [])
|
||||
|
||||
|
||||
def test_are_tools_equal_different_order():
|
||||
"""Tests are_tools_equal function with same tools, different order."""
|
||||
list_a = [TOOL_ROLL_DICE_6, TOOL_GET_WEATHER]
|
||||
list_b = [TOOL_GET_WEATHER, TOOL_ROLL_DICE_6]
|
||||
assert not TrajectoryEvaluator.are_tools_equal(list_a, list_b)
|
||||
|
||||
|
||||
def test_are_tools_equal_different_length():
|
||||
"""Tests are_tools_equal function with lists of different lengths."""
|
||||
list_a = [TOOL_GET_WEATHER, TOOL_ROLL_DICE_6]
|
||||
list_b = [TOOL_GET_WEATHER]
|
||||
assert not TrajectoryEvaluator.are_tools_equal(list_a, list_b)
|
||||
|
||||
|
||||
def test_are_tools_equal_different_input_values():
|
||||
"""Tests are_tools_equal function with different input values."""
|
||||
list_a = [TOOL_ROLL_DICE_16]
|
||||
list_b = [TOOL_ROLL_DICE_6]
|
||||
assert not TrajectoryEvaluator.are_tools_equal(list_a, list_b)
|
||||
|
||||
|
||||
def test_are_tools_equal_different_tool_names():
|
||||
"""Tests are_tools_equal function with different tool names."""
|
||||
list_a = [TOOL_ROLL_DICE_16]
|
||||
list_b = [TOOL_GET_WEATHER]
|
||||
assert not TrajectoryEvaluator.are_tools_equal(list_a, list_b)
|
||||
|
||||
|
||||
def test_are_tools_equal_ignores_extra_keys():
|
||||
"""Tests are_tools_equal function ignores keys other than tool_name/tool_input."""
|
||||
list_a = [{
|
||||
"tool_name": "get_weather",
|
||||
"tool_input": {"location": "Paris"},
|
||||
"extra_key": "abc",
|
||||
}]
|
||||
list_b = [{
|
||||
"tool_name": "get_weather",
|
||||
"tool_input": {"location": "Paris"},
|
||||
"other_key": 123,
|
||||
}]
|
||||
assert TrajectoryEvaluator.are_tools_equal(list_a, list_b)
|
||||
|
||||
|
||||
def test_are_tools_equal_one_empty_one_not():
|
||||
"""Tests are_tools_equal function with one empty list and one non-empty list."""
|
||||
list_a = []
|
||||
list_b = [TOOL_GET_WEATHER]
|
||||
assert not TrajectoryEvaluator.are_tools_equal(list_a, list_b)
|
||||
@@ -225,3 +225,76 @@ def test_create_new_session_will_merge_states(service_type):
|
||||
assert session_2.state.get('user:key1') == 'value1'
|
||||
assert not session_2.state.get('key1')
|
||||
assert not session_2.state.get('temp:key')
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
'service_type', [SessionServiceType.IN_MEMORY, SessionServiceType.DATABASE]
|
||||
)
|
||||
def test_append_event_bytes(service_type):
|
||||
session_service = get_session_service(service_type)
|
||||
app_name = 'my_app'
|
||||
user_id = 'user'
|
||||
|
||||
session = session_service.create_session(app_name=app_name, user_id=user_id)
|
||||
event = Event(
|
||||
invocation_id='invocation',
|
||||
author='user',
|
||||
content=types.Content(
|
||||
role='user',
|
||||
parts=[
|
||||
types.Part.from_bytes(
|
||||
data=b'test_image_data', mime_type='image/png'
|
||||
),
|
||||
],
|
||||
),
|
||||
)
|
||||
session_service.append_event(session=session, event=event)
|
||||
|
||||
assert session.events[0].content.parts[0] == types.Part.from_bytes(
|
||||
data=b'test_image_data', mime_type='image/png'
|
||||
)
|
||||
|
||||
events = session_service.get_session(
|
||||
app_name=app_name, user_id=user_id, session_id=session.id
|
||||
).events
|
||||
assert len(events) == 1
|
||||
assert events[0].content.parts[0] == types.Part.from_bytes(
|
||||
data=b'test_image_data', mime_type='image/png'
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
'service_type', [SessionServiceType.IN_MEMORY, SessionServiceType.DATABASE]
|
||||
)
|
||||
def test_append_event_complete(service_type):
|
||||
session_service = get_session_service(service_type)
|
||||
app_name = 'my_app'
|
||||
user_id = 'user'
|
||||
|
||||
session = session_service.create_session(app_name=app_name, user_id=user_id)
|
||||
event = Event(
|
||||
invocation_id='invocation',
|
||||
author='user',
|
||||
content=types.Content(role='user', parts=[types.Part(text='test_text')]),
|
||||
turn_complete=True,
|
||||
partial=False,
|
||||
actions=EventActions(
|
||||
artifact_delta={
|
||||
'file': 0,
|
||||
},
|
||||
transfer_to_agent='agent',
|
||||
escalate=True,
|
||||
),
|
||||
long_running_tool_ids={'tool1'},
|
||||
error_code='error_code',
|
||||
error_message='error_message',
|
||||
interrupted=True,
|
||||
)
|
||||
session_service.append_event(session=session, event=event)
|
||||
|
||||
assert (
|
||||
session_service.get_session(
|
||||
app_name=app_name, user_id=user_id, session_id=session.id
|
||||
)
|
||||
== session
|
||||
)
|
||||
|
||||
@@ -57,7 +57,7 @@ MOCK_EVENT_JSON = [
|
||||
{
|
||||
'name': (
|
||||
'projects/test-project/locations/test-location/'
|
||||
'reasoningEngines/test_engine/sessions/1/events/123'
|
||||
'reasoningEngines/123/sessions/1/events/123'
|
||||
),
|
||||
'invocationId': '123',
|
||||
'author': 'user',
|
||||
@@ -111,7 +111,7 @@ MOCK_SESSION = Session(
|
||||
|
||||
|
||||
SESSION_REGEX = r'^reasoningEngines/([^/]+)/sessions/([^/]+)$'
|
||||
SESSIONS_REGEX = r'^reasoningEngines/([^/]+)/sessions$'
|
||||
SESSIONS_REGEX = r'^reasoningEngines/([^/]+)/sessions\?filter=user_id=([^/]+)$'
|
||||
EVENTS_REGEX = r'^reasoningEngines/([^/]+)/sessions/([^/]+)/events$'
|
||||
LRO_REGEX = r'^operations/([^/]+)$'
|
||||
|
||||
@@ -136,39 +136,52 @@ class MockApiClient:
|
||||
else:
|
||||
raise ValueError(f'Session not found: {session_id}')
|
||||
elif re.match(SESSIONS_REGEX, path):
|
||||
match = re.match(SESSIONS_REGEX, path)
|
||||
return {
|
||||
'sessions': self.session_dict.values(),
|
||||
'sessions': [
|
||||
session
|
||||
for session in self.session_dict.values()
|
||||
if session['userId'] == match.group(2)
|
||||
],
|
||||
}
|
||||
elif re.match(EVENTS_REGEX, path):
|
||||
match = re.match(EVENTS_REGEX, path)
|
||||
if match:
|
||||
return {'sessionEvents': self.event_dict[match.group(2)]}
|
||||
return {
|
||||
'sessionEvents': (
|
||||
self.event_dict[match.group(2)]
|
||||
if match.group(2) in self.event_dict
|
||||
else []
|
||||
)
|
||||
}
|
||||
elif re.match(LRO_REGEX, path):
|
||||
return {
|
||||
'name': (
|
||||
'projects/test-project/locations/test-location/'
|
||||
'reasoningEngines/123/sessions/123'
|
||||
'reasoningEngines/123/sessions/4'
|
||||
),
|
||||
'done': True,
|
||||
}
|
||||
else:
|
||||
raise ValueError(f'Unsupported path: {path}')
|
||||
elif http_method == 'POST':
|
||||
id = str(uuid.uuid4())
|
||||
self.session_dict[id] = {
|
||||
new_session_id = '4'
|
||||
self.session_dict[new_session_id] = {
|
||||
'name': (
|
||||
'projects/test-project/locations/test-location/'
|
||||
'reasoningEngines/123/sessions/'
|
||||
+ id
|
||||
+ new_session_id
|
||||
),
|
||||
'userId': request_dict['user_id'],
|
||||
'sessionState': request_dict.get('sessionState', {}),
|
||||
'sessionState': request_dict.get('session_state', {}),
|
||||
'updateTime': '2024-12-12T12:12:12.123456Z',
|
||||
}
|
||||
return {
|
||||
'name': (
|
||||
'projects/test_project/locations/test_location/'
|
||||
'reasoningEngines/test_engine/sessions/123'
|
||||
'reasoningEngines/123/sessions/'
|
||||
+ new_session_id
|
||||
+ '/operations/111'
|
||||
),
|
||||
'done': False,
|
||||
}
|
||||
@@ -223,24 +236,28 @@ def test_get_and_delete_session():
|
||||
)
|
||||
assert str(excinfo.value) == 'Session not found: 1'
|
||||
|
||||
def test_list_sessions():
|
||||
session_service = mock_vertex_ai_session_service()
|
||||
sessions = session_service.list_sessions(app_name='123', user_id='user')
|
||||
assert len(sessions.sessions) == 2
|
||||
assert sessions.sessions[0].id == '1'
|
||||
assert sessions.sessions[1].id == '2'
|
||||
|
||||
def test_create_session():
|
||||
session_service = mock_vertex_ai_session_service()
|
||||
session = session_service.create_session(
|
||||
app_name='123', user_id='user', state={'key': 'value'}
|
||||
)
|
||||
assert session.state == {'key': 'value'}
|
||||
assert session.app_name == '123'
|
||||
assert session.user_id == 'user'
|
||||
assert session.last_update_time is not None
|
||||
def test_list_sessions():
|
||||
session_service = mock_vertex_ai_session_service()
|
||||
sessions = session_service.list_sessions(app_name='123', user_id='user')
|
||||
assert len(sessions.sessions) == 2
|
||||
assert sessions.sessions[0].id == '1'
|
||||
assert sessions.sessions[1].id == '2'
|
||||
|
||||
session_id = session.id
|
||||
assert session == session_service.get_session(
|
||||
app_name='123', user_id='user', session_id=session_id
|
||||
)
|
||||
|
||||
def test_create_session():
|
||||
session_service = mock_vertex_ai_session_service()
|
||||
|
||||
state = {'key': 'value'}
|
||||
session = session_service.create_session(
|
||||
app_name='123', user_id='user', state=state
|
||||
)
|
||||
assert session.state == state
|
||||
assert session.app_name == '123'
|
||||
assert session.user_id == 'user'
|
||||
assert session.last_update_time is not None
|
||||
|
||||
session_id = session.id
|
||||
assert session == session_service.get_session(
|
||||
app_name='123', user_id='user', session_id=session_id
|
||||
)
|
||||
|
||||
@@ -119,7 +119,7 @@ def calendar_api_spec():
|
||||
"methods": {
|
||||
"get": {
|
||||
"id": "calendar.calendars.get",
|
||||
"path": "calendars/{calendarId}",
|
||||
"flatPath": "calendars/{calendarId}",
|
||||
"httpMethod": "GET",
|
||||
"description": "Returns metadata for a calendar.",
|
||||
"parameters": {
|
||||
@@ -151,7 +151,7 @@ def calendar_api_spec():
|
||||
"methods": {
|
||||
"list": {
|
||||
"id": "calendar.events.list",
|
||||
"path": "calendars/{calendarId}/events",
|
||||
"flatPath": "calendars/{calendarId}/events",
|
||||
"httpMethod": "GET",
|
||||
"description": (
|
||||
"Returns events on the specified calendar."
|
||||
|
||||
Reference in New Issue
Block a user