mirror of
https://github.com/EvolutionAPI/adk-python.git
synced 2025-12-18 11:22:22 -06:00
No public description
PiperOrigin-RevId: 748777998
This commit is contained in:
committed by
hangfei
parent
290058eb05
commit
61d4be2d76
14
tests/__init__.py
Normal file
14
tests/__init__.py
Normal file
@@ -0,0 +1,14 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
10
tests/integration/.env.example
Normal file
10
tests/integration/.env.example
Normal file
@@ -0,0 +1,10 @@
|
||||
# Copy as .env file and fill your values below to run integration tests.
|
||||
|
||||
# Choose Backend: GOOGLE_AI_ONLY | VERTEX_ONLY | BOTH (default)
|
||||
TEST_BACKEND=BOTH
|
||||
|
||||
# ML Dev backend config
|
||||
GOOGLE_API_KEY=YOUR_VALUE_HERE
|
||||
# Vertex backend config
|
||||
GOOGLE_CLOUD_PROJECT=YOUR_VALUE_HERE
|
||||
GOOGLE_CLOUD_LOCATION=YOUR_VALUE_HERE
|
||||
18
tests/integration/__init__.py
Normal file
18
tests/integration/__init__.py
Normal file
@@ -0,0 +1,18 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import pytest
|
||||
|
||||
# This allows pytest to show the values of the asserts.
|
||||
pytest.register_assert_rewrite('tests.integration.utils')
|
||||
119
tests/integration/conftest.py
Normal file
119
tests/integration/conftest.py
Normal file
@@ -0,0 +1,119 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
import os
|
||||
from typing import Literal
|
||||
import warnings
|
||||
|
||||
from dotenv import load_dotenv
|
||||
from google.adk import Agent
|
||||
from pytest import fixture
|
||||
from pytest import FixtureRequest
|
||||
from pytest import hookimpl
|
||||
from pytest import Metafunc
|
||||
|
||||
from .utils import TestRunner
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def load_env_for_tests():
|
||||
dotenv_path = os.path.join(os.path.dirname(__file__), '.env')
|
||||
if not os.path.exists(dotenv_path):
|
||||
warnings.warn(
|
||||
f'Missing .env file at {dotenv_path}. See dotenv.sample for an example.'
|
||||
)
|
||||
else:
|
||||
load_dotenv(dotenv_path, override=True, verbose=True)
|
||||
if 'GOOGLE_API_KEY' not in os.environ:
|
||||
warnings.warn(
|
||||
'Missing GOOGLE_API_KEY in the environment variables. GOOGLE_AI backend'
|
||||
' integration tests will fail.'
|
||||
)
|
||||
for env_var in [
|
||||
'GOOGLE_CLOUD_PROJECT',
|
||||
'GOOGLE_CLOUD_LOCATION',
|
||||
]:
|
||||
if env_var not in os.environ:
|
||||
warnings.warn(
|
||||
f'Missing {env_var} in the environment variables. Vertex backend'
|
||||
' integration tests will fail.'
|
||||
)
|
||||
|
||||
|
||||
load_env_for_tests()
|
||||
|
||||
BackendType = Literal['GOOGLE_AI', 'VERTEX']
|
||||
|
||||
|
||||
@fixture
|
||||
def agent_runner(request: FixtureRequest) -> TestRunner:
|
||||
assert isinstance(request.param, dict)
|
||||
|
||||
if 'agent' in request.param:
|
||||
assert isinstance(request.param['agent'], Agent)
|
||||
return TestRunner(request.param['agent'])
|
||||
elif 'agent_name' in request.param:
|
||||
assert isinstance(request.param['agent_name'], str)
|
||||
return TestRunner.from_agent_name(request.param['agent_name'])
|
||||
|
||||
raise NotImplementedError('Must provide agent or agent_name.')
|
||||
|
||||
|
||||
@fixture(autouse=True)
|
||||
def llm_backend(request: FixtureRequest):
|
||||
# Set backend environment value.
|
||||
original_val = os.environ.get('GOOGLE_GENAI_USE_VERTEXAI')
|
||||
backend_type = request.param
|
||||
if backend_type == 'GOOGLE_AI':
|
||||
os.environ['GOOGLE_GENAI_USE_VERTEXAI'] = '0'
|
||||
else:
|
||||
os.environ['GOOGLE_GENAI_USE_VERTEXAI'] = '1'
|
||||
|
||||
yield # Run the test
|
||||
|
||||
# Restore the environment
|
||||
if original_val is None:
|
||||
os.environ.pop('GOOGLE_GENAI_USE_VERTEXAI', None)
|
||||
else:
|
||||
os.environ['GOOGLE_GENAI_USE_VERTEXAI'] = original_val
|
||||
|
||||
|
||||
@hookimpl(tryfirst=True)
|
||||
def pytest_generate_tests(metafunc: Metafunc):
|
||||
if llm_backend.__name__ in metafunc.fixturenames:
|
||||
if not _is_explicitly_marked(llm_backend.__name__, metafunc):
|
||||
test_backend = os.environ.get('TEST_BACKEND', 'BOTH')
|
||||
if test_backend == 'GOOGLE_AI_ONLY':
|
||||
metafunc.parametrize(llm_backend.__name__, ['GOOGLE_AI'], indirect=True)
|
||||
elif test_backend == 'VERTEX_ONLY':
|
||||
metafunc.parametrize(llm_backend.__name__, ['VERTEX'], indirect=True)
|
||||
elif test_backend == 'BOTH':
|
||||
metafunc.parametrize(
|
||||
llm_backend.__name__, ['GOOGLE_AI', 'VERTEX'], indirect=True
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f'Invalid TEST_BACKEND value: {test_backend}, should be one of'
|
||||
' [GOOGLE_AI_ONLY, VERTEX_ONLY, BOTH]'
|
||||
)
|
||||
|
||||
|
||||
def _is_explicitly_marked(mark_name: str, metafunc: Metafunc) -> bool:
|
||||
if hasattr(metafunc.function, 'pytestmark'):
|
||||
for mark in metafunc.function.pytestmark:
|
||||
if mark.name == 'parametrize' and mark.args[0] == mark_name:
|
||||
return True
|
||||
return False
|
||||
14
tests/integration/fixture/__init__.py
Normal file
14
tests/integration/fixture/__init__.py
Normal file
@@ -0,0 +1,14 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
15
tests/integration/fixture/agent_with_config/__init__.py
Normal file
15
tests/integration/fixture/agent_with_config/__init__.py
Normal file
@@ -0,0 +1,15 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from . import agent
|
||||
88
tests/integration/fixture/agent_with_config/agent.py
Normal file
88
tests/integration/fixture/agent_with_config/agent.py
Normal file
@@ -0,0 +1,88 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from google.adk import Agent
|
||||
from google.genai import types
|
||||
|
||||
new_message = types.Content(
|
||||
role="user",
|
||||
parts=[types.Part.from_text(text="Count a number")],
|
||||
)
|
||||
|
||||
google_agent_1 = Agent(
|
||||
model="gemini-1.5-flash",
|
||||
name="agent_1",
|
||||
description="The first agent in the team.",
|
||||
instruction="Just say 1",
|
||||
generate_content_config=types.GenerateContentConfig(
|
||||
temperature=0.1,
|
||||
),
|
||||
)
|
||||
|
||||
google_agent_2 = Agent(
|
||||
model="gemini-1.5-flash",
|
||||
name="agent_2",
|
||||
description="The second agent in the team.",
|
||||
instruction="Just say 2",
|
||||
generate_content_config=types.GenerateContentConfig(
|
||||
temperature=0.2,
|
||||
safety_settings=[{
|
||||
"category": "HARM_CATEGORY_HATE_SPEECH",
|
||||
"threshold": "BLOCK_ONLY_HIGH",
|
||||
}],
|
||||
),
|
||||
)
|
||||
|
||||
google_agent_3 = Agent(
|
||||
model="gemini-1.5-flash",
|
||||
name="agent_3",
|
||||
description="The third agent in the team.",
|
||||
instruction="Just say 3",
|
||||
generate_content_config=types.GenerateContentConfig(
|
||||
temperature=0.5,
|
||||
safety_settings=[{
|
||||
"category": "HARM_CATEGORY_DANGEROUS_CONTENT",
|
||||
"threshold": "BLOCK_NONE",
|
||||
}],
|
||||
),
|
||||
)
|
||||
|
||||
google_agent_with_instruction_in_config = Agent(
|
||||
model="gemini-1.5-flash",
|
||||
name="agent",
|
||||
generate_content_config=types.GenerateContentConfig(
|
||||
temperature=0.5, system_instruction="Count 1"
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def function():
|
||||
pass
|
||||
|
||||
|
||||
google_agent_with_tools_in_config = Agent(
|
||||
model="gemini-1.5-flash",
|
||||
name="agent",
|
||||
generate_content_config=types.GenerateContentConfig(
|
||||
temperature=0.5, tools=[function]
|
||||
),
|
||||
)
|
||||
|
||||
google_agent_with_response_schema_in_config = Agent(
|
||||
model="gemini-1.5-flash",
|
||||
name="agent",
|
||||
generate_content_config=types.GenerateContentConfig(
|
||||
temperature=0.5, response_schema={"key": "value"}
|
||||
),
|
||||
)
|
||||
15
tests/integration/fixture/callback_agent/__init__.py
Normal file
15
tests/integration/fixture/callback_agent/__init__.py
Normal file
@@ -0,0 +1,15 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from . import agent
|
||||
105
tests/integration/fixture/callback_agent/agent.py
Normal file
105
tests/integration/fixture/callback_agent/agent.py
Normal file
@@ -0,0 +1,105 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from google.adk import Agent
|
||||
from google.adk.agents.callback_context import CallbackContext
|
||||
from google.adk.agents.invocation_context import InvocationContext
|
||||
from google.adk.models import LlmRequest
|
||||
from google.adk.models import LlmResponse
|
||||
from google.genai import types
|
||||
|
||||
|
||||
def before_agent_call_end_invocation(
|
||||
callback_context: CallbackContext,
|
||||
) -> types.Content:
|
||||
return types.Content(
|
||||
role='model',
|
||||
parts=[types.Part(text='End invocation event before agent call.')],
|
||||
)
|
||||
|
||||
|
||||
def before_agent_call(
|
||||
invocation_context: InvocationContext,
|
||||
) -> types.Content:
|
||||
return types.Content(
|
||||
role='model',
|
||||
parts=[types.Part.from_text(text='Plain text event before agent call.')],
|
||||
)
|
||||
|
||||
|
||||
def before_model_call_end_invocation(
|
||||
callback_context: CallbackContext, llm_request: LlmRequest
|
||||
) -> LlmResponse:
|
||||
return LlmResponse(
|
||||
content=types.Content(
|
||||
role='model',
|
||||
parts=[
|
||||
types.Part.from_text(
|
||||
text='End invocation event before model call.'
|
||||
)
|
||||
],
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def before_model_call(
|
||||
invocation_context: InvocationContext, request: LlmRequest
|
||||
) -> LlmResponse:
|
||||
request.config.system_instruction = 'Just return 999 as response.'
|
||||
return LlmResponse(
|
||||
content=types.Content(
|
||||
role='model',
|
||||
parts=[
|
||||
types.Part.from_text(
|
||||
text='Update request event before model call.'
|
||||
)
|
||||
],
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def after_model_call(
|
||||
callback_context: CallbackContext,
|
||||
llm_response: LlmResponse,
|
||||
) -> Optional[LlmResponse]:
|
||||
content = llm_response.content
|
||||
if not content or not content.parts or not content.parts[0].text:
|
||||
return
|
||||
|
||||
content.parts[0].text += 'Update response event after model call.'
|
||||
return llm_response
|
||||
|
||||
|
||||
before_agent_callback_agent = Agent(
|
||||
model='gemini-1.5-flash',
|
||||
name='before_agent_callback_agent',
|
||||
instruction='echo 1',
|
||||
before_agent_callback=before_agent_call_end_invocation,
|
||||
)
|
||||
|
||||
before_model_callback_agent = Agent(
|
||||
model='gemini-1.5-flash',
|
||||
name='before_model_callback_agent',
|
||||
instruction='echo 2',
|
||||
before_model_callback=before_model_call_end_invocation,
|
||||
)
|
||||
|
||||
after_model_callback_agent = Agent(
|
||||
model='gemini-1.5-flash',
|
||||
name='after_model_callback_agent',
|
||||
instruction='Say hello',
|
||||
after_model_callback=after_model_call,
|
||||
)
|
||||
1
tests/integration/fixture/context_update_test/OWNERS
Normal file
1
tests/integration/fixture/context_update_test/OWNERS
Normal file
@@ -0,0 +1 @@
|
||||
gkcng
|
||||
15
tests/integration/fixture/context_update_test/__init__.py
Normal file
15
tests/integration/fixture/context_update_test/__init__.py
Normal file
@@ -0,0 +1,15 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from . import agent
|
||||
43
tests/integration/fixture/context_update_test/agent.py
Normal file
43
tests/integration/fixture/context_update_test/agent.py
Normal file
@@ -0,0 +1,43 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import List
|
||||
from typing import Union
|
||||
|
||||
from google.adk import Agent
|
||||
from google.adk.tools import ToolContext
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
def update_fc(
|
||||
data_one: str,
|
||||
data_two: Union[int, float, str],
|
||||
data_three: list[str],
|
||||
data_four: List[Union[int, float, str]],
|
||||
tool_context: ToolContext,
|
||||
):
|
||||
"""Simply ask to update these variables in the context"""
|
||||
tool_context.actions.update_state("data_one", data_one)
|
||||
tool_context.actions.update_state("data_two", data_two)
|
||||
tool_context.actions.update_state("data_three", data_three)
|
||||
tool_context.actions.update_state("data_four", data_four)
|
||||
|
||||
|
||||
root_agent = Agent(
|
||||
model="gemini-1.5-flash",
|
||||
name="root_agent",
|
||||
instruction="Call tools",
|
||||
flow="auto",
|
||||
tools=[update_fc],
|
||||
)
|
||||
@@ -0,0 +1,582 @@
|
||||
{
|
||||
"id": "ead43200-b575-4241-9248-233b4be4f29a",
|
||||
"context": {
|
||||
"_time": "2024-12-01 09:02:43.531503",
|
||||
"data_one": "RRRR",
|
||||
"data_two": "3.141529",
|
||||
"data_three": [
|
||||
"apple",
|
||||
"banana"
|
||||
],
|
||||
"data_four": [
|
||||
"1",
|
||||
"hello",
|
||||
"3.14"
|
||||
]
|
||||
},
|
||||
"events": [
|
||||
{
|
||||
"invocation_id": "6BGrtKJu",
|
||||
"author": "user",
|
||||
"content": {
|
||||
"parts": [
|
||||
{
|
||||
"text": "hi"
|
||||
}
|
||||
],
|
||||
"role": "user"
|
||||
},
|
||||
"options": {},
|
||||
"id": "ltzQTqR4",
|
||||
"timestamp": 1733043686.8428597
|
||||
},
|
||||
{
|
||||
"invocation_id": "6BGrtKJu",
|
||||
"author": "root_agent",
|
||||
"content": {
|
||||
"parts": [
|
||||
{
|
||||
"text": "Hello! 👋 How can I help you today? \n"
|
||||
}
|
||||
],
|
||||
"role": "model"
|
||||
},
|
||||
"options": {
|
||||
"partial": false
|
||||
},
|
||||
"id": "ClSROx8b",
|
||||
"timestamp": 1733043688.1030986
|
||||
},
|
||||
{
|
||||
"invocation_id": "M3dUcVa8",
|
||||
"author": "user",
|
||||
"content": {
|
||||
"parts": [
|
||||
{
|
||||
"text": "update data_one to be RRRR, data_two to be 3.141529, data_three to be apple and banana, data_four to be 1, hello, and 3.14"
|
||||
}
|
||||
],
|
||||
"role": "user"
|
||||
},
|
||||
"options": {},
|
||||
"id": "yxigGwIZ",
|
||||
"timestamp": 1733043745.9900541
|
||||
},
|
||||
{
|
||||
"invocation_id": "M3dUcVa8",
|
||||
"author": "root_agent",
|
||||
"content": {
|
||||
"parts": [
|
||||
{
|
||||
"function_call": {
|
||||
"args": {
|
||||
"data_four": [
|
||||
"1",
|
||||
"hello",
|
||||
"3.14"
|
||||
],
|
||||
"data_two": "3.141529",
|
||||
"data_three": [
|
||||
"apple",
|
||||
"banana"
|
||||
],
|
||||
"data_one": "RRRR"
|
||||
},
|
||||
"name": "update_fc"
|
||||
}
|
||||
}
|
||||
],
|
||||
"role": "model"
|
||||
},
|
||||
"options": {
|
||||
"partial": false
|
||||
},
|
||||
"id": "8V6de8th",
|
||||
"timestamp": 1733043747.4545543
|
||||
},
|
||||
{
|
||||
"invocation_id": "M3dUcVa8",
|
||||
"author": "root_agent",
|
||||
"content": {
|
||||
"parts": [
|
||||
{
|
||||
"function_response": {
|
||||
"name": "update_fc",
|
||||
"response": {}
|
||||
}
|
||||
}
|
||||
],
|
||||
"role": "user"
|
||||
},
|
||||
"options": {
|
||||
"update_context": {
|
||||
"data_one": "RRRR",
|
||||
"data_two": "3.141529",
|
||||
"data_three": [
|
||||
"apple",
|
||||
"banana"
|
||||
],
|
||||
"data_four": [
|
||||
"1",
|
||||
"hello",
|
||||
"3.14"
|
||||
]
|
||||
},
|
||||
"function_call_event_id": "8V6de8th"
|
||||
},
|
||||
"id": "dkTj5v8B",
|
||||
"timestamp": 1733043747.457031
|
||||
},
|
||||
{
|
||||
"invocation_id": "M3dUcVa8",
|
||||
"author": "root_agent",
|
||||
"content": {
|
||||
"parts": [
|
||||
{
|
||||
"text": "OK. I've updated the data. Anything else? \n"
|
||||
}
|
||||
],
|
||||
"role": "model"
|
||||
},
|
||||
"options": {
|
||||
"partial": false
|
||||
},
|
||||
"id": "OZ77XR41",
|
||||
"timestamp": 1733043748.7901294
|
||||
}
|
||||
],
|
||||
"past_events": [],
|
||||
"pending_events": {},
|
||||
"artifacts": {},
|
||||
"event_logs": [
|
||||
{
|
||||
"invocation_id": "6BGrtKJu",
|
||||
"event_id": "ClSROx8b",
|
||||
"model_request": {
|
||||
"model": "gemini-1.5-flash",
|
||||
"contents": [
|
||||
{
|
||||
"parts": [
|
||||
{
|
||||
"text": "hi"
|
||||
}
|
||||
],
|
||||
"role": "user"
|
||||
}
|
||||
],
|
||||
"config": {
|
||||
"system_instruction": "You are an agent. Your name is root_agent.\nCall tools",
|
||||
"tools": [
|
||||
{
|
||||
"function_declarations": [
|
||||
{
|
||||
"description": "Hello",
|
||||
"name": "update_fc",
|
||||
"parameters": {
|
||||
"type": "OBJECT",
|
||||
"properties": {
|
||||
"data_one": {
|
||||
"type": "STRING"
|
||||
},
|
||||
"data_two": {
|
||||
"type": "STRING"
|
||||
},
|
||||
"data_three": {
|
||||
"type": "ARRAY",
|
||||
"items": {
|
||||
"type": "STRING"
|
||||
}
|
||||
},
|
||||
"data_four": {
|
||||
"type": "ARRAY",
|
||||
"items": {
|
||||
"any_of": [
|
||||
{
|
||||
"type": "INTEGER"
|
||||
},
|
||||
{
|
||||
"type": "NUMBER"
|
||||
},
|
||||
{
|
||||
"type": "STRING"
|
||||
}
|
||||
],
|
||||
"type": "STRING"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
},
|
||||
"model_response": {
|
||||
"candidates": [
|
||||
{
|
||||
"content": {
|
||||
"parts": [
|
||||
{
|
||||
"text": "Hello! 👋 How can I help you today? \n"
|
||||
}
|
||||
],
|
||||
"role": "model"
|
||||
},
|
||||
"avg_logprobs": -0.15831730915949896,
|
||||
"finish_reason": "STOP",
|
||||
"safety_ratings": [
|
||||
{
|
||||
"category": "HARM_CATEGORY_HATE_SPEECH",
|
||||
"probability": "NEGLIGIBLE",
|
||||
"probability_score": 0.071777344,
|
||||
"severity": "HARM_SEVERITY_NEGLIGIBLE",
|
||||
"severity_score": 0.07080078
|
||||
},
|
||||
{
|
||||
"category": "HARM_CATEGORY_DANGEROUS_CONTENT",
|
||||
"probability": "NEGLIGIBLE",
|
||||
"probability_score": 0.16308594,
|
||||
"severity": "HARM_SEVERITY_NEGLIGIBLE",
|
||||
"severity_score": 0.14160156
|
||||
},
|
||||
{
|
||||
"category": "HARM_CATEGORY_HARASSMENT",
|
||||
"probability": "NEGLIGIBLE",
|
||||
"probability_score": 0.09423828,
|
||||
"severity": "HARM_SEVERITY_NEGLIGIBLE",
|
||||
"severity_score": 0.037841797
|
||||
},
|
||||
{
|
||||
"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
|
||||
"probability": "NEGLIGIBLE",
|
||||
"probability_score": 0.059326172,
|
||||
"severity": "HARM_SEVERITY_NEGLIGIBLE",
|
||||
"severity_score": 0.02368164
|
||||
}
|
||||
]
|
||||
}
|
||||
],
|
||||
"model_version": "gemini-1.5-flash-001",
|
||||
"usage_metadata": {
|
||||
"candidates_token_count": 13,
|
||||
"prompt_token_count": 32,
|
||||
"total_token_count": 45
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"invocation_id": "M3dUcVa8",
|
||||
"event_id": "8V6de8th",
|
||||
"model_request": {
|
||||
"model": "gemini-1.5-flash",
|
||||
"contents": [
|
||||
{
|
||||
"parts": [
|
||||
{
|
||||
"text": "hi"
|
||||
}
|
||||
],
|
||||
"role": "user"
|
||||
},
|
||||
{
|
||||
"parts": [
|
||||
{
|
||||
"text": "Hello! 👋 How can I help you today? \n"
|
||||
}
|
||||
],
|
||||
"role": "model"
|
||||
},
|
||||
{
|
||||
"parts": [
|
||||
{
|
||||
"text": "update data_one to be RRRR, data_two to be 3.141529, data_three to be apple and banana, data_four to be 1, hello, and 3.14"
|
||||
}
|
||||
],
|
||||
"role": "user"
|
||||
}
|
||||
],
|
||||
"config": {
|
||||
"system_instruction": "You are an agent. Your name is root_agent.\nCall tools",
|
||||
"tools": [
|
||||
{
|
||||
"function_declarations": [
|
||||
{
|
||||
"description": "Hello",
|
||||
"name": "update_fc",
|
||||
"parameters": {
|
||||
"type": "OBJECT",
|
||||
"properties": {
|
||||
"data_one": {
|
||||
"type": "STRING"
|
||||
},
|
||||
"data_two": {
|
||||
"type": "STRING"
|
||||
},
|
||||
"data_three": {
|
||||
"type": "ARRAY",
|
||||
"items": {
|
||||
"type": "STRING"
|
||||
}
|
||||
},
|
||||
"data_four": {
|
||||
"type": "ARRAY",
|
||||
"items": {
|
||||
"any_of": [
|
||||
{
|
||||
"type": "INTEGER"
|
||||
},
|
||||
{
|
||||
"type": "NUMBER"
|
||||
},
|
||||
{
|
||||
"type": "STRING"
|
||||
}
|
||||
],
|
||||
"type": "STRING"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
},
|
||||
"model_response": {
|
||||
"candidates": [
|
||||
{
|
||||
"content": {
|
||||
"parts": [
|
||||
{
|
||||
"function_call": {
|
||||
"args": {
|
||||
"data_four": [
|
||||
"1",
|
||||
"hello",
|
||||
"3.14"
|
||||
],
|
||||
"data_two": "3.141529",
|
||||
"data_three": [
|
||||
"apple",
|
||||
"banana"
|
||||
],
|
||||
"data_one": "RRRR"
|
||||
},
|
||||
"name": "update_fc"
|
||||
}
|
||||
}
|
||||
],
|
||||
"role": "model"
|
||||
},
|
||||
"avg_logprobs": -2.100960955431219e-6,
|
||||
"finish_reason": "STOP",
|
||||
"safety_ratings": [
|
||||
{
|
||||
"category": "HARM_CATEGORY_HATE_SPEECH",
|
||||
"probability": "NEGLIGIBLE",
|
||||
"probability_score": 0.12158203,
|
||||
"severity": "HARM_SEVERITY_NEGLIGIBLE",
|
||||
"severity_score": 0.13671875
|
||||
},
|
||||
{
|
||||
"category": "HARM_CATEGORY_DANGEROUS_CONTENT",
|
||||
"probability": "NEGLIGIBLE",
|
||||
"probability_score": 0.421875,
|
||||
"severity": "HARM_SEVERITY_LOW",
|
||||
"severity_score": 0.24511719
|
||||
},
|
||||
{
|
||||
"category": "HARM_CATEGORY_HARASSMENT",
|
||||
"probability": "NEGLIGIBLE",
|
||||
"probability_score": 0.15722656,
|
||||
"severity": "HARM_SEVERITY_NEGLIGIBLE",
|
||||
"severity_score": 0.072753906
|
||||
},
|
||||
{
|
||||
"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
|
||||
"probability": "NEGLIGIBLE",
|
||||
"probability_score": 0.083984375,
|
||||
"severity": "HARM_SEVERITY_NEGLIGIBLE",
|
||||
"severity_score": 0.03564453
|
||||
}
|
||||
]
|
||||
}
|
||||
],
|
||||
"model_version": "gemini-1.5-flash-001",
|
||||
"usage_metadata": {
|
||||
"candidates_token_count": 32,
|
||||
"prompt_token_count": 94,
|
||||
"total_token_count": 126
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"invocation_id": "M3dUcVa8",
|
||||
"event_id": "OZ77XR41",
|
||||
"model_request": {
|
||||
"model": "gemini-1.5-flash",
|
||||
"contents": [
|
||||
{
|
||||
"parts": [
|
||||
{
|
||||
"text": "hi"
|
||||
}
|
||||
],
|
||||
"role": "user"
|
||||
},
|
||||
{
|
||||
"parts": [
|
||||
{
|
||||
"text": "Hello! 👋 How can I help you today? \n"
|
||||
}
|
||||
],
|
||||
"role": "model"
|
||||
},
|
||||
{
|
||||
"parts": [
|
||||
{
|
||||
"text": "update data_one to be RRRR, data_two to be 3.141529, data_three to be apple and banana, data_four to be 1, hello, and 3.14"
|
||||
}
|
||||
],
|
||||
"role": "user"
|
||||
},
|
||||
{
|
||||
"parts": [
|
||||
{
|
||||
"function_call": {
|
||||
"args": {
|
||||
"data_four": [
|
||||
"1",
|
||||
"hello",
|
||||
"3.14"
|
||||
],
|
||||
"data_two": "3.141529",
|
||||
"data_three": [
|
||||
"apple",
|
||||
"banana"
|
||||
],
|
||||
"data_one": "RRRR"
|
||||
},
|
||||
"name": "update_fc"
|
||||
}
|
||||
}
|
||||
],
|
||||
"role": "model"
|
||||
},
|
||||
{
|
||||
"parts": [
|
||||
{
|
||||
"function_response": {
|
||||
"name": "update_fc",
|
||||
"response": {}
|
||||
}
|
||||
}
|
||||
],
|
||||
"role": "user"
|
||||
}
|
||||
],
|
||||
"config": {
|
||||
"system_instruction": "You are an agent. Your name is root_agent.\nCall tools",
|
||||
"tools": [
|
||||
{
|
||||
"function_declarations": [
|
||||
{
|
||||
"description": "Hello",
|
||||
"name": "update_fc",
|
||||
"parameters": {
|
||||
"type": "OBJECT",
|
||||
"properties": {
|
||||
"data_one": {
|
||||
"type": "STRING"
|
||||
},
|
||||
"data_two": {
|
||||
"type": "STRING"
|
||||
},
|
||||
"data_three": {
|
||||
"type": "ARRAY",
|
||||
"items": {
|
||||
"type": "STRING"
|
||||
}
|
||||
},
|
||||
"data_four": {
|
||||
"type": "ARRAY",
|
||||
"items": {
|
||||
"any_of": [
|
||||
{
|
||||
"type": "INTEGER"
|
||||
},
|
||||
{
|
||||
"type": "NUMBER"
|
||||
},
|
||||
{
|
||||
"type": "STRING"
|
||||
}
|
||||
],
|
||||
"type": "STRING"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
},
|
||||
"model_response": {
|
||||
"candidates": [
|
||||
{
|
||||
"content": {
|
||||
"parts": [
|
||||
{
|
||||
"text": "OK. I've updated the data. Anything else? \n"
|
||||
}
|
||||
],
|
||||
"role": "model"
|
||||
},
|
||||
"avg_logprobs": -0.22089435373033797,
|
||||
"finish_reason": "STOP",
|
||||
"safety_ratings": [
|
||||
{
|
||||
"category": "HARM_CATEGORY_HATE_SPEECH",
|
||||
"probability": "NEGLIGIBLE",
|
||||
"probability_score": 0.04663086,
|
||||
"severity": "HARM_SEVERITY_NEGLIGIBLE",
|
||||
"severity_score": 0.09423828
|
||||
},
|
||||
{
|
||||
"category": "HARM_CATEGORY_DANGEROUS_CONTENT",
|
||||
"probability": "NEGLIGIBLE",
|
||||
"probability_score": 0.18554688,
|
||||
"severity": "HARM_SEVERITY_NEGLIGIBLE",
|
||||
"severity_score": 0.111328125
|
||||
},
|
||||
{
|
||||
"category": "HARM_CATEGORY_HARASSMENT",
|
||||
"probability": "NEGLIGIBLE",
|
||||
"probability_score": 0.071777344,
|
||||
"severity": "HARM_SEVERITY_NEGLIGIBLE",
|
||||
"severity_score": 0.03112793
|
||||
},
|
||||
{
|
||||
"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
|
||||
"probability": "NEGLIGIBLE",
|
||||
"probability_score": 0.043945313,
|
||||
"severity": "HARM_SEVERITY_NEGLIGIBLE",
|
||||
"severity_score": 0.057373047
|
||||
}
|
||||
]
|
||||
}
|
||||
],
|
||||
"model_version": "gemini-1.5-flash-001",
|
||||
"usage_metadata": {
|
||||
"candidates_token_count": 14,
|
||||
"prompt_token_count": 129,
|
||||
"total_token_count": 143
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
15
tests/integration/fixture/context_variable_agent/__init__.py
Normal file
15
tests/integration/fixture/context_variable_agent/__init__.py
Normal file
@@ -0,0 +1,15 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from . import agent
|
||||
115
tests/integration/fixture/context_variable_agent/agent.py
Normal file
115
tests/integration/fixture/context_variable_agent/agent.py
Normal file
@@ -0,0 +1,115 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import List
|
||||
from typing import Union
|
||||
|
||||
from google.adk import Agent
|
||||
from google.adk.agents.invocation_context import InvocationContext
|
||||
from google.adk.planners import PlanReActPlanner
|
||||
from google.adk.tools import ToolContext
|
||||
|
||||
|
||||
def update_fc(
|
||||
data_one: str,
|
||||
data_two: Union[int, float, str],
|
||||
data_three: list[str],
|
||||
data_four: List[Union[int, float, str]],
|
||||
tool_context: ToolContext,
|
||||
) -> str:
|
||||
"""Simply ask to update these variables in the context"""
|
||||
tool_context.actions.update_state('data_one', data_one)
|
||||
tool_context.actions.update_state('data_two', data_two)
|
||||
tool_context.actions.update_state('data_three', data_three)
|
||||
tool_context.actions.update_state('data_four', data_four)
|
||||
return 'The function `update_fc` executed successfully'
|
||||
|
||||
|
||||
def echo_info(customer_id: str) -> str:
|
||||
"""Echo the context variable"""
|
||||
return customer_id
|
||||
|
||||
|
||||
def build_global_instruction(invocation_context: InvocationContext) -> str:
|
||||
return (
|
||||
'This is the gloabl agent instruction for invocation:'
|
||||
f' {invocation_context.invocation_id}.'
|
||||
)
|
||||
|
||||
|
||||
def build_sub_agent_instruction(invocation_context: InvocationContext) -> str:
|
||||
return 'This is the plain text sub agent instruction.'
|
||||
|
||||
|
||||
context_variable_echo_agent = Agent(
|
||||
model='gemini-1.5-flash',
|
||||
name='context_variable_echo_agent',
|
||||
instruction=(
|
||||
'Use the echo_info tool to echo {customerId}, {customerInt},'
|
||||
' {customerFloat}, and {customerJson}. Ask for it if you need to.'
|
||||
),
|
||||
flow='auto',
|
||||
tools=[echo_info],
|
||||
)
|
||||
|
||||
context_variable_with_complicated_format_agent = Agent(
|
||||
model='gemini-1.5-flash',
|
||||
name='context_variable_echo_agent',
|
||||
instruction=(
|
||||
'Use the echo_info tool to echo { customerId }, {{customer_int }, { '
|
||||
" non-identifier-float}}, {artifact.fileName}, {'key1': 'value1'} and"
|
||||
" {{'key2': 'value2'}}. Ask for it if you need to."
|
||||
),
|
||||
flow='auto',
|
||||
tools=[echo_info],
|
||||
)
|
||||
|
||||
context_variable_with_nl_planner_agent = Agent(
|
||||
model='gemini-1.5-flash',
|
||||
name='context_variable_with_nl_planner_agent',
|
||||
instruction=(
|
||||
'Use the echo_info tool to echo {customerId}. Ask for it if you'
|
||||
' need to.'
|
||||
),
|
||||
flow='auto',
|
||||
planner=PlanReActPlanner(),
|
||||
tools=[echo_info],
|
||||
)
|
||||
|
||||
context_variable_with_function_instruction_agent = Agent(
|
||||
model='gemini-1.5-flash',
|
||||
name='context_variable_with_function_instruction_agent',
|
||||
instruction=build_sub_agent_instruction,
|
||||
flow='auto',
|
||||
)
|
||||
|
||||
context_variable_update_agent = Agent(
|
||||
model='gemini-1.5-flash',
|
||||
name='context_variable_update_agent',
|
||||
instruction='Call tools',
|
||||
flow='auto',
|
||||
tools=[update_fc],
|
||||
)
|
||||
|
||||
root_agent = Agent(
|
||||
model='gemini-1.5-flash',
|
||||
name='root_agent',
|
||||
description='The root agent.',
|
||||
flow='auto',
|
||||
global_instruction=build_global_instruction,
|
||||
sub_agents=[
|
||||
context_variable_with_nl_planner_agent,
|
||||
context_variable_update_agent,
|
||||
],
|
||||
)
|
||||
15
tests/integration/fixture/customer_support_ma/__init__.py
Normal file
15
tests/integration/fixture/customer_support_ma/__init__.py
Normal file
@@ -0,0 +1,15 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from . import agent
|
||||
172
tests/integration/fixture/customer_support_ma/agent.py
Normal file
172
tests/integration/fixture/customer_support_ma/agent.py
Normal file
@@ -0,0 +1,172 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
import sys
|
||||
|
||||
from google.adk import Agent
|
||||
from google.adk.agents import RemoteAgent
|
||||
from google.adk.examples import Example
|
||||
from google.adk.sessions import Session
|
||||
from google.genai import types
|
||||
|
||||
|
||||
def reset_data():
|
||||
pass
|
||||
|
||||
|
||||
def fetch_user_flight_information(customer_email: str) -> str:
|
||||
"""Fetch user flight information."""
|
||||
return """
|
||||
[{"ticket_no": "7240005432906569", "book_ref": "C46E9F", "flight_id": 19250, "flight_no": "LX0112", "departure_airport": "CDG", "arrival_airport": "BSL", "scheduled_departure": "2024-12-30 12:09:03.561731-04:00", "scheduled_arrival": "2024-12-30 13:39:03.561731-04:00", "seat_no": "18E", "fare_conditions": "Economy"}]
|
||||
"""
|
||||
|
||||
|
||||
def list_customer_flights(customer_email: str) -> str:
|
||||
return "{'flights': [{'book_ref': 'C46E9F'}]}"
|
||||
|
||||
|
||||
def update_ticket_to_new_flight(ticket_no: str, new_flight_id: str) -> str:
|
||||
return 'OK, your ticket has been updated.'
|
||||
|
||||
|
||||
def lookup_company_policy(topic: str) -> str:
|
||||
"""Lookup policies for flight cancelation and rebooking."""
|
||||
return """
|
||||
1. How can I change my booking?
|
||||
* The ticket number must start with 724 (SWISS ticket no./plate).
|
||||
* The ticket was not paid for by barter or voucher (there are exceptions to voucher payments; if the ticket was paid for in full by voucher, then it may be possible to rebook online under certain circumstances. If it is not possible to rebook online because of the payment method, then you will be informed accordingly during the rebooking process).
|
||||
* There must be an active flight booking for your ticket. It is not possible to rebook open tickets or tickets without the corresponding flight segments online at the moment.
|
||||
* It is currently only possible to rebook outbound (one-way) tickets or return tickets with single flight routes (point-to-point).
|
||||
"""
|
||||
|
||||
|
||||
def search_flights(
|
||||
departure_airport: str = None,
|
||||
arrival_airport: str = None,
|
||||
start_time: str = None,
|
||||
end_time: str = None,
|
||||
) -> list[dict]:
|
||||
return """
|
||||
[{"flight_id": 19238, "flight_no": "LX0112", "scheduled_departure": "2024-05-08 12:09:03.561731-04:00", "scheduled_arrival": "2024-05-08 13:39:03.561731-04:00", "departure_airport": "CDG", "arrival_airport": "BSL", "status": "Scheduled", "aircraft_code": "SU9", "actual_departure": null, "actual_arrival": null}, {"flight_id": 19242, "flight_no": "LX0112", "scheduled_departure": "2024-05-09 12:09:03.561731-04:00", "scheduled_arrival": "2024-05-09 13:39:03.561731-04:00", "departure_airport": "CDG", "arrival_airport": "BSL", "status": "Scheduled", "aircraft_code": "SU9", "actual_departure": null, "actual_arrival": null}]"""
|
||||
|
||||
|
||||
def search_hotels(
|
||||
location: str = None,
|
||||
price_tier: str = None,
|
||||
checkin_date: str = None,
|
||||
checkout_date: str = None,
|
||||
) -> list[dict]:
|
||||
return """
|
||||
[{"id": 1, "name": "Hilton Basel", "location": "Basel", "price_tier": "Luxury"}, {"id": 3, "name": "Hyatt Regency Basel", "location": "Basel", "price_tier": "Upper Upscale"}, {"id": 8, "name": "Holiday Inn Basel", "location": "Basel", "price_tier": "Upper Midscale"}]
|
||||
"""
|
||||
|
||||
|
||||
def book_hotel(hotel_name: str) -> str:
|
||||
return 'OK, your hotel has been booked.'
|
||||
|
||||
|
||||
def before_model_call(agent: Agent, session: Session, user_message):
|
||||
if 'expedia' in user_message.lower():
|
||||
response = types.Content(
|
||||
role='model',
|
||||
parts=[types.Part(text="Sorry, I can't answer this question.")],
|
||||
)
|
||||
return response
|
||||
return None
|
||||
|
||||
|
||||
def after_model_call(
|
||||
agent: Agent, session: Session, content: types.Content
|
||||
) -> bool:
|
||||
model_message = content.parts[0].text
|
||||
if 'expedia' in model_message.lower():
|
||||
response = types.Content(
|
||||
role='model',
|
||||
parts=[types.Part(text="Sorry, I can't answer this question.")],
|
||||
)
|
||||
return response
|
||||
return None
|
||||
|
||||
|
||||
flight_agent = Agent(
|
||||
model='gemini-1.5-pro',
|
||||
name='flight_agent',
|
||||
description='Handles flight information, policy and updates',
|
||||
instruction="""
|
||||
You are a specialized assistant for handling flight updates.
|
||||
The primary assistant delegates work to you whenever the user needs help updating their bookings.
|
||||
Confirm the updated flight details with the customer and inform them of any additional fees.
|
||||
When searching, be persistent. Expand your query bounds if the first search returns no results.
|
||||
Remember that a booking isn't completed until after the relevant tool has successfully been used.
|
||||
Do not waste the user's time. Do not make up invalid tools or functions.
|
||||
""",
|
||||
tools=[
|
||||
list_customer_flights,
|
||||
lookup_company_policy,
|
||||
fetch_user_flight_information,
|
||||
search_flights,
|
||||
update_ticket_to_new_flight,
|
||||
],
|
||||
)
|
||||
|
||||
hotel_agent = Agent(
|
||||
model='gemini-1.5-pro',
|
||||
name='hotel_agent',
|
||||
description='Handles hotel information and booking',
|
||||
instruction="""
|
||||
You are a specialized assistant for handling hotel bookings.
|
||||
The primary assistant delegates work to you whenever the user needs help booking a hotel.
|
||||
Search for available hotels based on the user's preferences and confirm the booking details with the customer.
|
||||
When searching, be persistent. Expand your query bounds if the first search returns no results.
|
||||
""",
|
||||
tools=[search_hotels, book_hotel],
|
||||
)
|
||||
|
||||
|
||||
idea_agent = RemoteAgent(
|
||||
model='gemini-1.5-pro',
|
||||
name='idea_agent',
|
||||
description='Provide travel ideas base on the destination.',
|
||||
url='http://localhost:8000/agent/run',
|
||||
)
|
||||
|
||||
|
||||
root_agent = Agent(
|
||||
model='gemini-1.5-pro',
|
||||
name='root_agent',
|
||||
instruction="""
|
||||
You are a helpful customer support assistant for Swiss Airlines.
|
||||
""",
|
||||
sub_agents=[flight_agent, hotel_agent, idea_agent],
|
||||
flow='auto',
|
||||
examples=[
|
||||
Example(
|
||||
input=types.Content(
|
||||
role='user',
|
||||
parts=[types.Part(text='How were you built?')],
|
||||
),
|
||||
output=[
|
||||
types.Content(
|
||||
role='model',
|
||||
parts=[
|
||||
types.Part(
|
||||
text='I was built with the best agent framework.'
|
||||
)
|
||||
],
|
||||
)
|
||||
],
|
||||
),
|
||||
],
|
||||
)
|
||||
@@ -0,0 +1,15 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from . import agent
|
||||
@@ -0,0 +1,338 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from google.adk import Agent
|
||||
|
||||
# A lightweight in-memory mock database
|
||||
ORDER_DB = {
|
||||
"1": "FINISHED",
|
||||
"2": "CANCELED",
|
||||
"3": "PENDING",
|
||||
"4": "PENDING",
|
||||
} # Order id to status mapping. Available states: 'FINISHED', 'PENDING', and 'CANCELED'
|
||||
USER_TO_ORDER_DB = {
|
||||
"user_a": ["1", "4"],
|
||||
"user_b": ["2"],
|
||||
"user_c": ["3"],
|
||||
} # User id to Order id mapping
|
||||
TICKET_DB = [{
|
||||
"ticket_id": "1",
|
||||
"user_id": "user_a",
|
||||
"issue_type": "LOGIN_ISSUE",
|
||||
"status": "OPEN",
|
||||
}] # Available states: 'OPEN', 'CLOSED', 'ESCALATED'
|
||||
USER_INFO_DB = {
|
||||
"user_a": {"name": "Alice", "email": "alice@example.com"},
|
||||
"user_b": {"name": "Bob", "email": "bob@example.com"},
|
||||
}
|
||||
|
||||
|
||||
def reset_data():
|
||||
global ORDER_DB
|
||||
global USER_TO_ORDER_DB
|
||||
global TICKET_DB
|
||||
global USER_INFO_DB
|
||||
ORDER_DB = {
|
||||
"1": "FINISHED",
|
||||
"2": "CANCELED",
|
||||
"3": "PENDING",
|
||||
"4": "PENDING",
|
||||
}
|
||||
USER_TO_ORDER_DB = {
|
||||
"user_a": ["1", "4"],
|
||||
"user_b": ["2"],
|
||||
"user_c": ["3"],
|
||||
}
|
||||
TICKET_DB = [{
|
||||
"ticket_id": "1",
|
||||
"user_id": "user_a",
|
||||
"issue_type": "LOGIN_ISSUE",
|
||||
"status": "OPEN",
|
||||
}]
|
||||
USER_INFO_DB = {
|
||||
"user_a": {"name": "Alice", "email": "alice@example.com"},
|
||||
"user_b": {"name": "Bob", "email": "bob@example.com"},
|
||||
}
|
||||
|
||||
|
||||
def get_order_status(order_id: str) -> str:
|
||||
"""Get the status of an order.
|
||||
|
||||
Args:
|
||||
order_id (str): The unique identifier of the order.
|
||||
|
||||
Returns:
|
||||
str: The status of the order (e.g., 'FINISHED', 'CANCELED', 'PENDING'),
|
||||
or 'Order not found' if the order_id does not exist.
|
||||
"""
|
||||
return ORDER_DB.get(order_id, "Order not found")
|
||||
|
||||
|
||||
def get_order_ids_for_user(user_id: str) -> list:
|
||||
"""Get the list of order IDs assigned to a specific transaction associated with a user.
|
||||
|
||||
Args:
|
||||
user_id (str): The unique identifier of the user.
|
||||
|
||||
Returns:
|
||||
List[str]: A list of order IDs associated with the user, or an empty list
|
||||
if no orders are found.
|
||||
"""
|
||||
return USER_TO_ORDER_DB.get(user_id, [])
|
||||
|
||||
|
||||
def cancel_order(order_id: str) -> str:
|
||||
"""Cancel an order if it is in a 'PENDING' state.
|
||||
|
||||
You should call "get_order_status" to check the status first, before calling
|
||||
this tool.
|
||||
|
||||
Args:
|
||||
order_id (str): The unique identifier of the order to be canceled.
|
||||
|
||||
Returns:
|
||||
str: A message indicating whether the order was successfully canceled or
|
||||
not.
|
||||
"""
|
||||
if order_id in ORDER_DB and ORDER_DB[order_id] == "PENDING":
|
||||
ORDER_DB[order_id] = "CANCELED"
|
||||
return f"Order {order_id} has been canceled."
|
||||
return f"Order {order_id} cannot be canceled."
|
||||
|
||||
|
||||
def refund_order(order_id: str) -> str:
|
||||
"""Process a refund for an order if it is in a 'CANCELED' state.
|
||||
|
||||
You should call "get_order_status" to check if status first, before calling
|
||||
this tool.
|
||||
|
||||
Args:
|
||||
order_id (str): The unique identifier of the order to be refunded.
|
||||
|
||||
Returns:
|
||||
str: A message indicating whether the order was successfully refunded or
|
||||
not.
|
||||
"""
|
||||
if order_id in ORDER_DB and ORDER_DB[order_id] == "CANCELED":
|
||||
return f"Order {order_id} has been refunded."
|
||||
return f"Order {order_id} cannot be refunded."
|
||||
|
||||
|
||||
def create_ticket(user_id: str, issue_type: str) -> str:
|
||||
"""Create a new support ticket for a user.
|
||||
|
||||
Args:
|
||||
user_id (str): The unique identifier of the user creating the ticket.
|
||||
issue_type (str): An issue type the user is facing. Available types:
|
||||
'LOGIN_ISSUE', 'ORDER_ISSUE', 'OTHER'.
|
||||
|
||||
Returns:
|
||||
str: A message indicating that the ticket was created successfully,
|
||||
including the ticket ID.
|
||||
"""
|
||||
ticket_id = str(len(TICKET_DB) + 1)
|
||||
TICKET_DB.append({
|
||||
"ticket_id": ticket_id,
|
||||
"user_id": user_id,
|
||||
"issue_type": issue_type,
|
||||
"status": "OPEN",
|
||||
})
|
||||
return f"Ticket {ticket_id} created successfully."
|
||||
|
||||
|
||||
def get_ticket_info(ticket_id: str) -> str:
|
||||
"""Retrieve the information of a support ticket.
|
||||
|
||||
current status of a support ticket.
|
||||
|
||||
Args:
|
||||
ticket_id (str): The unique identifier of the ticket.
|
||||
|
||||
Returns:
|
||||
A dictionary contains the following fields, or 'Ticket not found' if the
|
||||
ticket_id does not exist:
|
||||
- "ticket_id": str, the current ticket id
|
||||
- "user_id": str, the associated user id
|
||||
- "issue": str, the issue type
|
||||
- "status": The current status of the ticket (e.g., 'OPEN', 'CLOSED',
|
||||
'ESCALATED')
|
||||
|
||||
Example: {"ticket_id": "1", "user_id": "user_a", "issue": "Login issue",
|
||||
"status": "OPEN"}
|
||||
"""
|
||||
for ticket in TICKET_DB:
|
||||
if ticket["ticket_id"] == ticket_id:
|
||||
return ticket
|
||||
return "Ticket not found"
|
||||
|
||||
|
||||
def get_tickets_for_user(user_id: str) -> list:
|
||||
"""Get all the ticket IDs associated with a user.
|
||||
|
||||
Args:
|
||||
user_id (str): The unique identifier of the user.
|
||||
|
||||
Returns:
|
||||
List[str]: A list of ticket IDs associated with the user.
|
||||
If no tickets are found, returns an empty list.
|
||||
"""
|
||||
return [
|
||||
ticket["ticket_id"]
|
||||
for ticket in TICKET_DB
|
||||
if ticket["user_id"] == user_id
|
||||
]
|
||||
|
||||
|
||||
def update_ticket_status(ticket_id: str, status: str) -> str:
|
||||
"""Update the status of a support ticket.
|
||||
|
||||
Args:
|
||||
ticket_id (str): The unique identifier of the ticket.
|
||||
status (str): The new status to assign to the ticket (e.g., 'OPEN',
|
||||
'CLOSED', 'ESCALATED').
|
||||
|
||||
Returns:
|
||||
str: A message indicating whether the ticket status was successfully
|
||||
updated.
|
||||
"""
|
||||
for ticket in TICKET_DB:
|
||||
if ticket["ticket_id"] == ticket_id:
|
||||
ticket["status"] = status
|
||||
return f"Ticket {ticket_id} status updated to {status}."
|
||||
return "Ticket not found"
|
||||
|
||||
|
||||
def get_user_info(user_id: str) -> dict:
|
||||
"""Retrieve information (name, email) about a user.
|
||||
|
||||
Args:
|
||||
user_id (str): The unique identifier of the user.
|
||||
|
||||
Returns:
|
||||
dict or str: A dictionary containing user information of the following
|
||||
fields, or 'User not found' if the user_id does not exist:
|
||||
|
||||
- name: The name of the user
|
||||
- email: The email address of the user
|
||||
|
||||
For example, {"name": "Chelsea", "email": "123@example.com"}
|
||||
"""
|
||||
return USER_INFO_DB.get(user_id, "User not found")
|
||||
|
||||
|
||||
def send_email(user_id: str, email: str) -> list:
|
||||
"""Send email to user for notification.
|
||||
|
||||
Args:
|
||||
user_id (str): The unique identifier of the user.
|
||||
email (str): The email address of the user.
|
||||
|
||||
Returns:
|
||||
str: A message indicating whether the email was successfully sent.
|
||||
"""
|
||||
if user_id in USER_INFO_DB:
|
||||
return f"Email sent to {email} for user id {user_id}"
|
||||
return "Cannot find this user"
|
||||
|
||||
|
||||
# def update_user_info(user_id: str, new_info: dict[str, str]) -> str:
|
||||
def update_user_info(user_id: str, email: str, name: str) -> str:
|
||||
"""Update a user's information.
|
||||
|
||||
Args:
|
||||
user_id (str): The unique identifier of the user.
|
||||
new_info (dict): A dictionary containing the fields to be updated (e.g.,
|
||||
{'email': 'new_email@example.com'}). Available field keys: 'email' and
|
||||
'name'.
|
||||
|
||||
Returns:
|
||||
str: A message indicating whether the user's information was successfully
|
||||
updated or not.
|
||||
"""
|
||||
if user_id in USER_INFO_DB:
|
||||
# USER_INFO_DB[user_id].update(new_info)
|
||||
if email and name:
|
||||
USER_INFO_DB[user_id].update({"email": email, "name": name})
|
||||
elif email:
|
||||
USER_INFO_DB[user_id].update({"email": email})
|
||||
elif name:
|
||||
USER_INFO_DB[user_id].update({"name": name})
|
||||
else:
|
||||
raise ValueError("this should not happen.")
|
||||
return f"User {user_id} information updated."
|
||||
return "User not found"
|
||||
|
||||
|
||||
def get_user_id_from_cookie() -> str:
|
||||
"""Get user ID(username) from the cookie.
|
||||
|
||||
Only use this function when you do not know user ID(username).
|
||||
|
||||
Args: None
|
||||
|
||||
Returns:
|
||||
str: The user ID.
|
||||
"""
|
||||
return "user_a"
|
||||
|
||||
|
||||
root_agent = Agent(
|
||||
model="gemini-2.0-flash-001",
|
||||
name="Ecommerce_Customer_Service",
|
||||
instruction="""
|
||||
You are an intelligent customer service assistant for an e-commerce platform. Your goal is to accurately understand user queries and use the appropriate tools to fulfill requests. Follow these guidelines:
|
||||
|
||||
1. **Understand the Query**:
|
||||
- Identify actions and conditions (e.g., create a ticket only for pending orders).
|
||||
- Extract necessary details (e.g., user ID, order ID) from the query or infer them from the context.
|
||||
|
||||
2. **Plan Multi-Step Workflows**:
|
||||
- Break down complex queries into sequential steps. For example
|
||||
- typical workflow:
|
||||
- Retrieve IDs or references first (e.g., orders for a user).
|
||||
- Evaluate conditions (e.g., check order status).
|
||||
- Perform actions (e.g., create a ticket) only when conditions are met.
|
||||
- another typical workflows - order cancellation and refund:
|
||||
- Retrieve all orders for the user (`get_order_ids_for_user`).
|
||||
- Cancel pending orders (`cancel_order`).
|
||||
- Refund canceled orders (`refund_order`).
|
||||
- Notify the user (`send_email`).
|
||||
- another typical workflows - send user report:
|
||||
- Get user id.
|
||||
- Get user info(like emails)
|
||||
- Send email to user.
|
||||
|
||||
3. **Avoid Skipping Steps**:
|
||||
- Ensure each intermediate step is completed before moving to the next.
|
||||
- Do not create tickets or take other actions without verifying the conditions specified in the query.
|
||||
|
||||
4. **Provide Clear Responses**:
|
||||
- Confirm the actions performed, including details like ticket ID or pending orders.
|
||||
- Ensure the response aligns with the steps taken and query intent.
|
||||
""",
|
||||
tools=[
|
||||
get_order_status,
|
||||
cancel_order,
|
||||
get_order_ids_for_user,
|
||||
refund_order,
|
||||
create_ticket,
|
||||
update_ticket_status,
|
||||
get_tickets_for_user,
|
||||
get_ticket_info,
|
||||
get_user_info,
|
||||
send_email,
|
||||
update_user_info,
|
||||
get_user_id_from_cookie,
|
||||
],
|
||||
)
|
||||
@@ -0,0 +1,69 @@
|
||||
[
|
||||
{
|
||||
"query": "Send an email to user user_a whose email address is alice@example.com",
|
||||
"expected_tool_use": [
|
||||
{
|
||||
"tool_name": "send_email",
|
||||
"tool_input": {
|
||||
"email": "alice@example.com",
|
||||
"user_id": "user_a"
|
||||
}
|
||||
}
|
||||
],
|
||||
"reference": "Email sent to alice@example.com for user id user_a."
|
||||
},
|
||||
{
|
||||
"query": "Can you tell me the status of my order with ID 1?",
|
||||
"expected_tool_use": [
|
||||
{
|
||||
"tool_name": "get_order_status",
|
||||
"tool_input": {
|
||||
"order_id": "1"
|
||||
}
|
||||
}
|
||||
],
|
||||
"reference": "Your order with ID 1 is FINISHED."
|
||||
},
|
||||
{
|
||||
"query": "Cancel all pending order for the user with user id user_a",
|
||||
"expected_tool_use": [
|
||||
{
|
||||
"tool_name": "get_order_ids_for_user",
|
||||
"tool_input": {
|
||||
"user_id": "user_a"
|
||||
}
|
||||
},
|
||||
{
|
||||
"tool_name": "get_order_status",
|
||||
"tool_input": {
|
||||
"order_id": "1"
|
||||
}
|
||||
},
|
||||
{
|
||||
"tool_name": "get_order_status",
|
||||
"tool_input": {
|
||||
"order_id": "4"
|
||||
}
|
||||
},
|
||||
{
|
||||
"tool_name": "cancel_order",
|
||||
"tool_input": {
|
||||
"order_id": "4"
|
||||
}
|
||||
}
|
||||
],
|
||||
"reference": "I have checked your orders and order 4 was in pending status, so I have cancelled it. Order 1 was already finished and couldn't be cancelled.\n"
|
||||
},
|
||||
{
|
||||
"query": "What orders have I placed under the username user_b?",
|
||||
"expected_tool_use": [
|
||||
{
|
||||
"tool_name": "get_order_ids_for_user",
|
||||
"tool_input": {
|
||||
"user_id": "user_b"
|
||||
}
|
||||
}
|
||||
],
|
||||
"reference": "User user_b has placed one order with order ID 2.\n"
|
||||
}
|
||||
]
|
||||
@@ -0,0 +1,6 @@
|
||||
{
|
||||
"criteria": {
|
||||
"tool_trajectory_avg_score": 0.7,
|
||||
"response_match_score": 0.5
|
||||
}
|
||||
}
|
||||
15
tests/integration/fixture/flow_complex_spark/__init__.py
Normal file
15
tests/integration/fixture/flow_complex_spark/__init__.py
Normal file
@@ -0,0 +1,15 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from . import agent
|
||||
182
tests/integration/fixture/flow_complex_spark/agent.py
Normal file
182
tests/integration/fixture/flow_complex_spark/agent.py
Normal file
@@ -0,0 +1,182 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from google.adk import Agent
|
||||
from google.genai import types
|
||||
|
||||
research_plan_agent = Agent(
|
||||
model="gemini-1.5-flash",
|
||||
name="research_plan_agent",
|
||||
description="I can help generate research plan.",
|
||||
instruction="""\
|
||||
Your task is to create a research plan according to the user's query.
|
||||
|
||||
# Here are the instructions for creating the research plan:
|
||||
|
||||
+ Focus on finding specific things, e.g. products, data, etc.
|
||||
+ Have the personality of a work colleague that is very helpful and explains things very nicely.
|
||||
+ Don't mention your name unless you are asked.
|
||||
+ Think about the most common things that you would need to research.
|
||||
+ Think about possible answers when creating the plan.
|
||||
+ Your task is to create the sections that should be researched. You will output high level headers, preceded by ##
|
||||
+ Underneath each header, write a short sentence on what we want to find there.
|
||||
+ The headers will follow the logical analysis pattern, as well as logical exploration pattern.
|
||||
+ The headers should be a statement, not be in the form of questions.
|
||||
+ The header will not include roman numerals or anything of the sort, e.g. ":", etc
|
||||
+ Do not include things that you cannot possibly know about from using Google Search: e.g. sales forecasting, competitors, profitability analysis, etc.
|
||||
+ Do not have an executive summary
|
||||
+ In each section describe specifically what will be researched.
|
||||
+ Never use "we will", but rather "I will".
|
||||
+ Don't ask for clarifications from the user.
|
||||
+ Do not ask the user for clarifications or if they have any other questions.
|
||||
+ All headers should be bolded.
|
||||
+ If you have steps in the plan that depend on other information, make sure they are 2 diferent sections in the plan.
|
||||
+ At the end mention that you will start researching.
|
||||
|
||||
# Instruction on replying format
|
||||
|
||||
+ Start with your name as "[research_plan_agent]: ".
|
||||
+ Output the content you want to say.
|
||||
|
||||
Output summary:
|
||||
""",
|
||||
flow="single",
|
||||
sub_agents=[],
|
||||
generate_content_config=types.GenerateContentConfig(
|
||||
temperature=0.1,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
question_generation_agent = Agent(
|
||||
model="gemini-1.5-flash",
|
||||
name="question_generation_agent",
|
||||
description="I can help generate questions related to user's question.",
|
||||
instruction="""\
|
||||
Generate questions related to the research plan generated by research_plan_agent.
|
||||
|
||||
# Instruction on replying format
|
||||
|
||||
Your reply should be a numbered lsit.
|
||||
|
||||
For each question, reply in the following format: "[question_generation_agent]: [generated questions]"
|
||||
|
||||
Here is an example of the generated question list:
|
||||
|
||||
1. [question_generation_agent]: which state is San Jose in?
|
||||
2. [question_generation_agent]: how google website is designed?
|
||||
""",
|
||||
flow="single",
|
||||
sub_agents=[],
|
||||
generate_content_config=types.GenerateContentConfig(
|
||||
temperature=0.1,
|
||||
),
|
||||
)
|
||||
|
||||
information_retrieval_agent = Agent(
|
||||
model="gemini-1.5-flash",
|
||||
name="information_retrieval_agent",
|
||||
description=(
|
||||
"I can help retrieve information related to question_generation_agent's"
|
||||
" question."
|
||||
),
|
||||
instruction="""\
|
||||
Inspect all the questions after "[question_generation_agent]: " and asnwer them.
|
||||
|
||||
# Instruction on replying format
|
||||
|
||||
Always start with "[information_retrieval_agent]: "
|
||||
|
||||
For the answer of one question:
|
||||
|
||||
- Start with a title with one line summary of the reply.
|
||||
- The title line should be bolded and starts with No.x of the corresponding question.
|
||||
- Have a paragraph of detailed explain.
|
||||
|
||||
# Instruction on exiting the loop
|
||||
|
||||
- If you see there are less than 20 questions by "question_generation_agent", do not say "[exit]".
|
||||
- If you see there are already great or equal to 20 questions asked by "question_generation_agent", say "[exit]" at last to exit the loop.
|
||||
""",
|
||||
flow="single",
|
||||
sub_agents=[],
|
||||
generate_content_config=types.GenerateContentConfig(
|
||||
temperature=0.1,
|
||||
),
|
||||
)
|
||||
|
||||
question_sources_generation_agent = Agent(
|
||||
model="gemini-1.5-flash",
|
||||
name="question_sources_generation_agent",
|
||||
description=(
|
||||
"I can help generate questions and retrieve related information."
|
||||
),
|
||||
instruction="Generate questions and retrieve information.",
|
||||
flow="loop",
|
||||
sub_agents=[
|
||||
question_generation_agent,
|
||||
information_retrieval_agent,
|
||||
],
|
||||
generate_content_config=types.GenerateContentConfig(
|
||||
temperature=0.1,
|
||||
),
|
||||
)
|
||||
|
||||
summary_agent = Agent(
|
||||
model="gemini-1.5-flash",
|
||||
name="summary_agent",
|
||||
description="I can help summarize information of previous content.",
|
||||
instruction="""\
|
||||
Summarize information in all historical messages that were replied by "question_generation_agent" and "information_retrieval_agent".
|
||||
|
||||
# Instruction on replying format
|
||||
|
||||
- The output should be like an essay that has a title, an abstract, multiple paragraphs for each topic and a conclusion.
|
||||
- Each paragraph should maps to one or more question in historical content.
|
||||
""",
|
||||
flow="single",
|
||||
generate_content_config=types.GenerateContentConfig(
|
||||
temperature=0.8,
|
||||
),
|
||||
)
|
||||
|
||||
research_assistant = Agent(
|
||||
model="gemini-1.5-flash",
|
||||
name="research_assistant",
|
||||
description="I can help with research question.",
|
||||
instruction="Help customers with their need.",
|
||||
flow="sequential",
|
||||
sub_agents=[
|
||||
research_plan_agent,
|
||||
question_sources_generation_agent,
|
||||
summary_agent,
|
||||
],
|
||||
generate_content_config=types.GenerateContentConfig(
|
||||
temperature=0.1,
|
||||
),
|
||||
)
|
||||
|
||||
spark_agent = Agent(
|
||||
model="gemini-1.5-flash",
|
||||
name="spark_assistant",
|
||||
description="I can help with non-research question.",
|
||||
instruction="Help customers with their need.",
|
||||
flow="auto",
|
||||
sub_agents=[research_assistant],
|
||||
generate_content_config=types.GenerateContentConfig(
|
||||
temperature=0.1,
|
||||
),
|
||||
)
|
||||
|
||||
root_agent = spark_agent
|
||||
190
tests/integration/fixture/flow_complex_spark/sample.session.json
Normal file
190
tests/integration/fixture/flow_complex_spark/sample.session.json
Normal file
File diff suppressed because one or more lines are too long
15
tests/integration/fixture/hello_world_agent/__init__.py
Normal file
15
tests/integration/fixture/hello_world_agent/__init__.py
Normal file
@@ -0,0 +1,15 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from . import agent
|
||||
95
tests/integration/fixture/hello_world_agent/agent.py
Normal file
95
tests/integration/fixture/hello_world_agent/agent.py
Normal file
@@ -0,0 +1,95 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# Hello world agent from agent 1.0 - https://colab.sandbox.google.com/drive/1Zq-nqmgK0nCERCv8jKIaoeTTgbNn6oSo?resourcekey=0-GYaz9pFT4wY8CI8Cvjy5GA#scrollTo=u3X3XwDOaCv9
|
||||
import random
|
||||
|
||||
from google.adk import Agent
|
||||
from google.genai import types
|
||||
|
||||
|
||||
def roll_die(sides: int) -> int:
|
||||
"""Roll a die and return the rolled result.
|
||||
|
||||
Args:
|
||||
sides: The integer number of sides the die has.
|
||||
|
||||
Returns:
|
||||
An integer of the result of rolling the die.
|
||||
"""
|
||||
return random.randint(1, sides)
|
||||
|
||||
|
||||
def check_prime(nums: list[int]) -> list[str]:
|
||||
"""Check if a given list of numbers are prime.
|
||||
|
||||
Args:
|
||||
nums: The list of numbers to check.
|
||||
|
||||
Returns:
|
||||
A str indicating which number is prime.
|
||||
"""
|
||||
primes = set()
|
||||
for number in nums:
|
||||
number = int(number)
|
||||
if number <= 1:
|
||||
continue
|
||||
is_prime = True
|
||||
for i in range(2, int(number**0.5) + 1):
|
||||
if number % i == 0:
|
||||
is_prime = False
|
||||
break
|
||||
if is_prime:
|
||||
primes.add(number)
|
||||
return (
|
||||
'No prime numbers found.'
|
||||
if not primes
|
||||
else f"{', '.join(str(num) for num in primes)} are prime numbers."
|
||||
)
|
||||
|
||||
|
||||
root_agent = Agent(
|
||||
model='gemini-2.0-flash-001',
|
||||
name='data_processing_agent',
|
||||
instruction="""
|
||||
You roll dice and answer questions about the outcome of the dice rolls.
|
||||
You can roll dice of different sizes.
|
||||
You can use multiple tools in parallel by calling functions in parallel(in one request and in one round).
|
||||
The only things you do are roll dice for the user and discuss the outcomes.
|
||||
It is ok to discuss previous dice roles, and comment on the dice rolls.
|
||||
When you are asked to roll a die, you must call the roll_die tool with the number of sides. Be sure to pass in an integer. Do not pass in a string.
|
||||
You should never roll a die on your own.
|
||||
When checking prime numbers, call the check_prime tool with a list of integers. Be sure to pass in a list of integers. You should never pass in a string.
|
||||
You should not check prime numbers before calling the tool.
|
||||
When you are asked to roll a die and check prime numbers, you should always make the following two function calls:
|
||||
1. You should first call the roll_die tool to get a roll. Wait for the function response before calling the check_prime tool.
|
||||
2. After you get the function response from roll_die tool, you should call the check_prime tool with the roll_die result.
|
||||
2.1 If user asks you to check primes based on previous rolls, make sure you include the previous rolls in the list.
|
||||
3. When you respond, you must include the roll_die result from step 1.
|
||||
You should always perform the previous 3 steps when asking for a roll and checking prime numbers.
|
||||
You should not rely on the previous history on prime results.
|
||||
""",
|
||||
tools=[
|
||||
roll_die,
|
||||
check_prime,
|
||||
],
|
||||
generate_content_config=types.GenerateContentConfig(
|
||||
safety_settings=[
|
||||
types.SafetySetting( # avoid false alarm about rolling dice.
|
||||
category=types.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT,
|
||||
threshold=types.HarmBlockThreshold.OFF,
|
||||
),
|
||||
]
|
||||
),
|
||||
)
|
||||
@@ -0,0 +1,24 @@
|
||||
[
|
||||
{
|
||||
"query": "Hi who are you?",
|
||||
"expected_tool_use": [],
|
||||
"reference": "I am a data processing agent. I can roll dice and check if the results are prime numbers. What would you like me to do? \n"
|
||||
},
|
||||
{
|
||||
"query": "What can you do?",
|
||||
"expected_tool_use": [],
|
||||
"reference": "I can roll dice for you of different sizes, and I can check if the results are prime numbers. I can also remember previous rolls if you'd like to check those for primes as well. What would you like me to do? \n"
|
||||
},
|
||||
{
|
||||
"query": "Can you roll a die with 6 sides",
|
||||
"expected_tool_use": [
|
||||
{
|
||||
"tool_name": "roll_die",
|
||||
"tool_input": {
|
||||
"sides": 6
|
||||
}
|
||||
}
|
||||
],
|
||||
"reference": null
|
||||
}
|
||||
]
|
||||
@@ -0,0 +1,6 @@
|
||||
{
|
||||
"criteria": {
|
||||
"tool_trajectory_avg_score": 1.0,
|
||||
"response_match_score": 0.5
|
||||
}
|
||||
}
|
||||
15
tests/integration/fixture/home_automation_agent/__init__.py
Normal file
15
tests/integration/fixture/home_automation_agent/__init__.py
Normal file
@@ -0,0 +1,15 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from . import agent
|
||||
304
tests/integration/fixture/home_automation_agent/agent.py
Normal file
304
tests/integration/fixture/home_automation_agent/agent.py
Normal file
@@ -0,0 +1,304 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
import sys
|
||||
|
||||
from google.adk import Agent
|
||||
|
||||
DEVICE_DB = {
|
||||
"device_1": {"status": "ON", "location": "Living Room"},
|
||||
"device_2": {"status": "OFF", "location": "Bedroom"},
|
||||
"device_3": {"status": "OFF", "location": "Kitchen"},
|
||||
}
|
||||
|
||||
TEMPERATURE_DB = {
|
||||
"Living Room": 22,
|
||||
"Bedroom": 20,
|
||||
"Kitchen": 24,
|
||||
}
|
||||
|
||||
SCHEDULE_DB = {
|
||||
"device_1": {"time": "18:00", "status": "ON"},
|
||||
"device_2": {"time": "22:00", "status": "OFF"},
|
||||
}
|
||||
|
||||
USER_PREFERENCES_DB = {
|
||||
"user_x": {"preferred_temp": 21, "location": "Bedroom"},
|
||||
"user_x": {"preferred_temp": 21, "location": "Living Room"},
|
||||
"user_y": {"preferred_temp": 23, "location": "Living Room"},
|
||||
}
|
||||
|
||||
|
||||
def reset_data():
|
||||
global DEVICE_DB
|
||||
global TEMPERATURE_DB
|
||||
global SCHEDULE_DB
|
||||
global USER_PREFERENCES_DB
|
||||
DEVICE_DB = {
|
||||
"device_1": {"status": "ON", "location": "Living Room"},
|
||||
"device_2": {"status": "OFF", "location": "Bedroom"},
|
||||
"device_3": {"status": "OFF", "location": "Kitchen"},
|
||||
}
|
||||
|
||||
TEMPERATURE_DB = {
|
||||
"Living Room": 22,
|
||||
"Bedroom": 20,
|
||||
"Kitchen": 24,
|
||||
}
|
||||
|
||||
SCHEDULE_DB = {
|
||||
"device_1": {"time": "18:00", "status": "ON"},
|
||||
"device_2": {"time": "22:00", "status": "OFF"},
|
||||
}
|
||||
|
||||
USER_PREFERENCES_DB = {
|
||||
"user_x": {"preferred_temp": 21, "location": "Bedroom"},
|
||||
"user_x": {"preferred_temp": 21, "location": "Living Room"},
|
||||
"user_y": {"preferred_temp": 23, "location": "Living Room"},
|
||||
}
|
||||
|
||||
|
||||
def get_device_info(device_id: str) -> dict:
|
||||
"""Get the current status and location of a AC device.
|
||||
|
||||
Args:
|
||||
device_id (str): The unique identifier of the device.
|
||||
|
||||
Returns:
|
||||
dict: A dictionary containing the following fields, or 'Device not found'
|
||||
if the device_id does not exist:
|
||||
- status: The current status of the device (e.g., 'ON', 'OFF')
|
||||
- location: The location where the device is installed (e.g., 'Living
|
||||
Room', 'Bedroom', ''Kitchen')
|
||||
"""
|
||||
return DEVICE_DB.get(device_id, "Device not found")
|
||||
|
||||
|
||||
# def set_device_info(device_id: str, updates: dict) -> str:
|
||||
# """Update the information of a AC device, specifically its status and/or location.
|
||||
|
||||
# Args:
|
||||
# device_id (str): Required. The unique identifier of the device.
|
||||
# updates (dict): Required. A dictionary containing the fields to be
|
||||
# updated. Supported keys: - "status" (str): The new status to set for the
|
||||
# device. Accepted values: 'ON', 'OFF'. **Only these values are allowed.**
|
||||
# - "location" (str): The new location to set for the device. Accepted
|
||||
# values: 'Living Room', 'Bedroom', 'Kitchen'. **Only these values are
|
||||
# allowed.**
|
||||
|
||||
|
||||
# Returns:
|
||||
# str: A message indicating whether the device information was successfully
|
||||
# updated.
|
||||
# """
|
||||
# if device_id in DEVICE_DB:
|
||||
# if "status" in updates:
|
||||
# DEVICE_DB[device_id]["status"] = updates["status"]
|
||||
# if "location" in updates:
|
||||
# DEVICE_DB[device_id]["location"] = updates["location"]
|
||||
# return f"Device {device_id} information updated: {updates}."
|
||||
# return "Device not found"
|
||||
def set_device_info(
|
||||
device_id: str, status: str = "", location: str = ""
|
||||
) -> str:
|
||||
"""Update the information of a AC device, specifically its status and/or location.
|
||||
|
||||
Args:
|
||||
device_id (str): Required. The unique identifier of the device.
|
||||
status (str): The new status to set for the
|
||||
device. Accepted values: 'ON', 'OFF'. **Only these values are allowed.**
|
||||
location (str): The new location to set for the device. Accepted
|
||||
values: 'Living Room', 'Bedroom', 'Kitchen'. **Only these values are
|
||||
allowed.**
|
||||
|
||||
Returns:
|
||||
str: A message indicating whether the device information was successfully
|
||||
updated.
|
||||
"""
|
||||
if device_id in DEVICE_DB:
|
||||
if status:
|
||||
DEVICE_DB[device_id]["status"] = status
|
||||
return f"Device {device_id} information updated: status -> {status}."
|
||||
if location:
|
||||
DEVICE_DB[device_id]["location"] = location
|
||||
return f"Device {device_id} information updated: location -> {location}."
|
||||
return "Device not found"
|
||||
|
||||
|
||||
def get_temperature(location: str) -> int:
|
||||
"""Get the current temperature in celsius of a location (e.g., 'Living Room', 'Bedroom', 'Kitchen').
|
||||
|
||||
Args:
|
||||
location (str): The location for which to retrieve the temperature (e.g.,
|
||||
'Living Room', 'Bedroom', 'Kitchen').
|
||||
|
||||
Returns:
|
||||
int: The current temperature in celsius in the specified location, or
|
||||
'Location not found' if the location does not exist.
|
||||
"""
|
||||
return TEMPERATURE_DB.get(location, "Location not found")
|
||||
|
||||
|
||||
def set_temperature(location: str, temperature: int) -> str:
|
||||
"""Set the desired temperature in celsius for a location.
|
||||
|
||||
Acceptable range of temperature: 18-30 celsius. If it's out of the range, do
|
||||
not call this tool.
|
||||
|
||||
Args:
|
||||
location (str): The location where the temperature should be set.
|
||||
temperature (int): The desired temperature as integer to set in celsius.
|
||||
Acceptable range: 18-30 celsius.
|
||||
|
||||
Returns:
|
||||
str: A message indicating whether the temperature was successfully set.
|
||||
"""
|
||||
if location in TEMPERATURE_DB:
|
||||
TEMPERATURE_DB[location] = temperature
|
||||
return f"Temperature in {location} set to {temperature}°C."
|
||||
return "Location not found"
|
||||
|
||||
|
||||
def get_user_preferences(user_id: str) -> dict:
|
||||
"""Get the temperature preferences and preferred location of a user_id.
|
||||
|
||||
user_id must be provided.
|
||||
|
||||
Args:
|
||||
user_id (str): The unique identifier of the user.
|
||||
|
||||
Returns:
|
||||
dict: A dictionary containing the following fields, or 'User not found' if
|
||||
the user_id does not exist:
|
||||
- preferred_temp: The user's preferred temperature.
|
||||
- location: The location where the user prefers to be.
|
||||
"""
|
||||
return USER_PREFERENCES_DB.get(user_id, "User not found")
|
||||
|
||||
|
||||
def set_device_schedule(device_id: str, time: str, status: str) -> str:
|
||||
"""Schedule a device to change its status at a specific time.
|
||||
|
||||
Args:
|
||||
device_id (str): The unique identifier of the device.
|
||||
time (str): The time at which the device should change its status (format:
|
||||
'HH:MM').
|
||||
status (str): The status to set for the device at the specified time
|
||||
(e.g., 'ON', 'OFF').
|
||||
|
||||
Returns:
|
||||
str: A message indicating whether the schedule was successfully set.
|
||||
"""
|
||||
if device_id in DEVICE_DB:
|
||||
SCHEDULE_DB[device_id] = {"time": time, "status": status}
|
||||
return f"Device {device_id} scheduled to turn {status} at {time}."
|
||||
return "Device not found"
|
||||
|
||||
|
||||
def get_device_schedule(device_id: str) -> dict:
|
||||
"""Retrieve the schedule of a device.
|
||||
|
||||
Args:
|
||||
device_id (str): The unique identifier of the device.
|
||||
|
||||
Returns:
|
||||
dict: A dictionary containing the following fields, or 'Schedule not
|
||||
found' if the device_id does not exist:
|
||||
- time: The scheduled time for the device to change its status (format:
|
||||
'HH:MM').
|
||||
- status: The status that will be set at the scheduled time (e.g., 'ON',
|
||||
'OFF').
|
||||
"""
|
||||
return SCHEDULE_DB.get(device_id, "Schedule not found")
|
||||
|
||||
|
||||
def celsius_to_fahrenheit(celsius: int) -> float:
|
||||
"""Convert Celsius to Fahrenheit.
|
||||
|
||||
You must call this to do the conversion of temperature, so you can get the
|
||||
precise number in required format.
|
||||
|
||||
Args:
|
||||
celsius (int): Temperature in Celsius.
|
||||
|
||||
Returns:
|
||||
float: Temperature in Fahrenheit.
|
||||
"""
|
||||
return (celsius * 9 / 5) + 32
|
||||
|
||||
|
||||
def fahrenheit_to_celsius(fahrenheit: float) -> int:
|
||||
"""Convert Fahrenheit to Celsius.
|
||||
|
||||
You must call this to do the conversion of temperature, so you can get the
|
||||
precise number in required format.
|
||||
|
||||
Args:
|
||||
fahrenheit (float): Temperature in Fahrenheit.
|
||||
|
||||
Returns:
|
||||
int: Temperature in Celsius.
|
||||
"""
|
||||
return int((fahrenheit - 32) * 5 / 9)
|
||||
|
||||
|
||||
def list_devices(status: str = "", location: str = "") -> list:
|
||||
"""Retrieve a list of AC devices, filtered by status and/or location when provided.
|
||||
|
||||
For cost efficiency, always apply as many filters (status and location) as
|
||||
available in the input arguments.
|
||||
|
||||
Args:
|
||||
status (str, optional): The status to filter devices by (e.g., 'ON',
|
||||
'OFF'). Defaults to None.
|
||||
location (str, optional): The location to filter devices by (e.g., 'Living
|
||||
Room', 'Bedroom', ''Kitchen'). Defaults to None.
|
||||
|
||||
Returns:
|
||||
list: A list of dictionaries, each containing the device ID, status, and
|
||||
location, or an empty list if no devices match the criteria.
|
||||
"""
|
||||
devices = []
|
||||
for device_id, info in DEVICE_DB.items():
|
||||
if ((not status) or info["status"] == status) and (
|
||||
(not location) or info["location"] == location
|
||||
):
|
||||
devices.append({
|
||||
"device_id": device_id,
|
||||
"status": info["status"],
|
||||
"location": info["location"],
|
||||
})
|
||||
return devices if devices else "No devices found matching the criteria."
|
||||
|
||||
|
||||
root_agent = Agent(
|
||||
model="gemini-2.0-flash-001",
|
||||
name="Home_automation_agent",
|
||||
instruction="""
|
||||
You are Home Automation Agent. You are responsible for controlling the devices in the home.
|
||||
""",
|
||||
tools=[
|
||||
get_device_info,
|
||||
set_device_info,
|
||||
get_temperature,
|
||||
set_temperature,
|
||||
get_user_preferences,
|
||||
set_device_schedule,
|
||||
get_device_schedule,
|
||||
celsius_to_fahrenheit,
|
||||
fahrenheit_to_celsius,
|
||||
list_devices,
|
||||
],
|
||||
)
|
||||
@@ -0,0 +1,5 @@
|
||||
[{
|
||||
"query": "Turn off device_2 in the Bedroom.",
|
||||
"expected_tool_use": [{"tool_name": "set_device_info", "tool_input": {"location": "Bedroom", "device_id": "device_2", "status": "OFF"}}],
|
||||
"reference": "I have set the device_2 status to off."
|
||||
}]
|
||||
@@ -0,0 +1,5 @@
|
||||
[{
|
||||
"query": "Turn off device_3 in the Bedroom.",
|
||||
"expected_tool_use": [{"tool_name": "set_device_info", "tool_input": {"location": "Bedroom", "device_id": "device_3", "status": "OFF"}}],
|
||||
"reference": "I have set the device_3 status to off."
|
||||
}]
|
||||
@@ -0,0 +1,5 @@
|
||||
{
|
||||
"criteria": {
|
||||
"tool_trajectory_avg_score": 1.0
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,18 @@
|
||||
[
|
||||
{
|
||||
"query": "Turn off device_2 in the Bedroom.",
|
||||
"expected_tool_use": [{
|
||||
"tool_name": "set_device_info",
|
||||
"tool_input": {"location": "Bedroom", "status": "OFF", "device_id": "device_2"}
|
||||
}],
|
||||
"reference": "I have set the device 2 status to off."
|
||||
},
|
||||
{
|
||||
"query": "What's the status of device_2 in the Bedroom?",
|
||||
"expected_tool_use": [{
|
||||
"tool_name": "get_device_info",
|
||||
"tool_input": {"device_id": "device_2"}
|
||||
}],
|
||||
"reference": "Status of device_2 is off."
|
||||
}
|
||||
]
|
||||
@@ -0,0 +1,17 @@
|
||||
[
|
||||
{
|
||||
"query": "Turn off device_2 in the Bedroom.",
|
||||
"expected_tool_use": [
|
||||
{
|
||||
"tool_name": "set_device_info",
|
||||
"tool_input": {"location": "Bedroom", "device_id": "device_2", "status": "OFF"}
|
||||
}
|
||||
],
|
||||
"reference": "OK. I've turned off device_2 in the Bedroom. Anything else?\n"
|
||||
},
|
||||
{
|
||||
"query": "What's the command I just issued?",
|
||||
"expected_tool_use": [],
|
||||
"reference": "You asked me to turn off device_2 in the Bedroom.\n"
|
||||
}
|
||||
]
|
||||
@@ -0,0 +1,6 @@
|
||||
{
|
||||
"criteria": {
|
||||
"tool_trajectory_avg_score": 1.0,
|
||||
"response_match_score": 0.5
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,18 @@
|
||||
[
|
||||
{
|
||||
"query": "Turn off device_2 in the Bedroom.",
|
||||
"expected_tool_use": [{
|
||||
"tool_name": "set_device_info",
|
||||
"tool_input": {"location": "Bedroom", "device_id": "device_2", "status": "OFF"}
|
||||
}],
|
||||
"reference": "I have set the device 2 status to off."
|
||||
},
|
||||
{
|
||||
"query": "Turn on device_2 in the Bedroom.",
|
||||
"expected_tool_use": [{
|
||||
"tool_name": "set_device_info",
|
||||
"tool_input": {"location": "Bedroom", "status": "ON", "device_id": "device_2"}
|
||||
}],
|
||||
"reference": "I have set the device 2 status to on."
|
||||
}
|
||||
]
|
||||
@@ -0,0 +1,17 @@
|
||||
[
|
||||
{
|
||||
"query": "Turn off device_2 in the Bedroom.",
|
||||
"expected_tool_use": [
|
||||
{
|
||||
"tool_name": "set_device_info",
|
||||
"tool_input": {"location": "Bedroom", "device_id": "device_2", "status": "OFF"}
|
||||
}
|
||||
],
|
||||
"reference": "OK. I've turned off device_2 in the Bedroom. Anything else?\n"
|
||||
},
|
||||
{
|
||||
"query": "What's the command I just issued?",
|
||||
"expected_tool_use": [],
|
||||
"reference": "You asked me to turn off device_2 in the Bedroom.\n"
|
||||
}
|
||||
]
|
||||
@@ -0,0 +1,5 @@
|
||||
[{
|
||||
"query": "Turn off device_3 in the Bedroom.",
|
||||
"expected_tool_use": [{"tool_name": "set_device_info", "tool_input": {"location": "Bedroom", "device_id": "device_3", "status": "OFF"}}],
|
||||
"reference": "I have set the device_3 status to off."
|
||||
}]
|
||||
@@ -0,0 +1,5 @@
|
||||
{
|
||||
"criteria": {
|
||||
"tool_trajectory_avg_score": 1.0
|
||||
}
|
||||
}
|
||||
15
tests/integration/fixture/tool_agent/__init__.py
Normal file
15
tests/integration/fixture/tool_agent/__init__.py
Normal file
@@ -0,0 +1,15 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from . import agent
|
||||
218
tests/integration/fixture/tool_agent/agent.py
Normal file
218
tests/integration/fixture/tool_agent/agent.py
Normal file
@@ -0,0 +1,218 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
from typing import Any
|
||||
|
||||
from crewai_tools import DirectoryReadTool
|
||||
from google.adk import Agent
|
||||
from google.adk.tools.agent_tool import AgentTool
|
||||
from google.adk.tools.crewai_tool import CrewaiTool
|
||||
from google.adk.tools.langchain_tool import LangchainTool
|
||||
from google.adk.tools.retrieval.files_retrieval import FilesRetrieval
|
||||
from google.adk.tools.retrieval.vertex_ai_rag_retrieval import VertexAiRagRetrieval
|
||||
from langchain_community.tools import ShellTool
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class TestCase(BaseModel):
|
||||
case: str
|
||||
|
||||
|
||||
class Test(BaseModel):
|
||||
test_title: list[str]
|
||||
|
||||
|
||||
def simple_function(param: str) -> str:
|
||||
if isinstance(param, str):
|
||||
return "Called simple function successfully"
|
||||
return "Called simple function with wrong param type"
|
||||
|
||||
|
||||
def no_param_function() -> str:
|
||||
return "Called no param function successfully"
|
||||
|
||||
|
||||
def no_output_function(param: str):
|
||||
return
|
||||
|
||||
|
||||
def multiple_param_types_function(
|
||||
param1: str, param2: int, param3: float, param4: bool
|
||||
) -> str:
|
||||
if (
|
||||
isinstance(param1, str)
|
||||
and isinstance(param2, int)
|
||||
and isinstance(param3, float)
|
||||
and isinstance(param4, bool)
|
||||
):
|
||||
return "Called multiple param types function successfully"
|
||||
return "Called multiple param types function with wrong param types"
|
||||
|
||||
|
||||
def throw_error_function(param: str) -> str:
|
||||
raise ValueError("Error thrown by throw_error_function")
|
||||
|
||||
|
||||
def list_str_param_function(param: list[str]) -> str:
|
||||
if isinstance(param, list) and all(isinstance(item, str) for item in param):
|
||||
return "Called list str param function successfully"
|
||||
return "Called list str param function with wrong param type"
|
||||
|
||||
|
||||
def return_list_str_function(param: str) -> list[str]:
|
||||
return ["Called return list str function successfully"]
|
||||
|
||||
|
||||
def complex_function_list_dict(
|
||||
param1: dict[str, Any], param2: list[dict[str, Any]]
|
||||
) -> list[Test]:
|
||||
if (
|
||||
isinstance(param1, dict)
|
||||
and isinstance(param2, list)
|
||||
and all(isinstance(item, dict) for item in param2)
|
||||
):
|
||||
return [
|
||||
Test(test_title=["function test 1", "function test 2"]),
|
||||
Test(test_title=["retrieval test"]),
|
||||
]
|
||||
raise ValueError("Wrong param")
|
||||
|
||||
|
||||
def repetive_call_1(param: str):
|
||||
return f"Call repetive_call_2 tool with param {param + '_repetive'}"
|
||||
|
||||
|
||||
def repetive_call_2(param: str):
|
||||
return param
|
||||
|
||||
|
||||
test_case_retrieval = FilesRetrieval(
|
||||
name="test_case_retrieval",
|
||||
description="General guidence for agent test cases",
|
||||
input_dir=os.path.join(os.path.dirname(__file__), "files"),
|
||||
)
|
||||
|
||||
valid_rag_retrieval = VertexAiRagRetrieval(
|
||||
name="valid_rag_retrieval",
|
||||
rag_corpora=[
|
||||
"projects/1096655024998/locations/us-central1/ragCorpora/4985766262475849728"
|
||||
],
|
||||
description="General guidence for agent test cases",
|
||||
)
|
||||
|
||||
invalid_rag_retrieval = VertexAiRagRetrieval(
|
||||
name="invalid_rag_retrieval",
|
||||
rag_corpora=[
|
||||
"projects/1096655024998/locations/us-central1/InValidRagCorporas/4985766262475849728"
|
||||
],
|
||||
description="Invalid rag retrieval resource name",
|
||||
)
|
||||
|
||||
non_exist_rag_retrieval = VertexAiRagRetrieval(
|
||||
name="non_exist_rag_retrieval",
|
||||
rag_corpora=[
|
||||
"projects/1096655024998/locations/us-central1/RagCorpora/1234567"
|
||||
],
|
||||
description="Non exist rag retrieval resource name",
|
||||
)
|
||||
|
||||
shell_tool = LangchainTool(ShellTool())
|
||||
|
||||
docs_tool = CrewaiTool(
|
||||
name="direcotry_read_tool",
|
||||
description="use this to find files for you.",
|
||||
tool=DirectoryReadTool(directory="."),
|
||||
)
|
||||
|
||||
no_schema_agent = Agent(
|
||||
model="gemini-1.5-flash",
|
||||
name="no_schema_agent",
|
||||
instruction="""Just say 'Hi'
|
||||
""",
|
||||
)
|
||||
|
||||
schema_agent = Agent(
|
||||
model="gemini-1.5-flash",
|
||||
name="schema_agent",
|
||||
instruction="""
|
||||
You will be given a test case.
|
||||
Return a list of the received test case appended with '_success' and '_failure' as test_titles
|
||||
""",
|
||||
input_schema=TestCase,
|
||||
output_schema=Test,
|
||||
)
|
||||
|
||||
no_input_schema_agent = Agent(
|
||||
model="gemini-1.5-flash",
|
||||
name="no_input_schema_agent",
|
||||
instruction="""
|
||||
Just return ['Tools_success, Tools_failure']
|
||||
""",
|
||||
output_schema=Test,
|
||||
)
|
||||
|
||||
no_output_schema_agent = Agent(
|
||||
model="gemini-1.5-flash",
|
||||
name="no_output_schema_agent",
|
||||
instruction="""
|
||||
Just say 'Hi'
|
||||
""",
|
||||
input_schema=TestCase,
|
||||
)
|
||||
|
||||
single_function_agent = Agent(
|
||||
model="gemini-1.5-flash",
|
||||
name="single_function_agent",
|
||||
description="An agent that calls a single function",
|
||||
instruction="When calling tools, just return what the tool returns.",
|
||||
tools=[simple_function],
|
||||
)
|
||||
|
||||
root_agent = Agent(
|
||||
model="gemini-1.5-flash",
|
||||
name="tool_agent",
|
||||
description="An agent that can call other tools",
|
||||
instruction="When calling tools, just return what the tool returns.",
|
||||
tools=[
|
||||
simple_function,
|
||||
no_param_function,
|
||||
no_output_function,
|
||||
multiple_param_types_function,
|
||||
throw_error_function,
|
||||
list_str_param_function,
|
||||
return_list_str_function,
|
||||
# complex_function_list_dict,
|
||||
repetive_call_1,
|
||||
repetive_call_2,
|
||||
test_case_retrieval,
|
||||
valid_rag_retrieval,
|
||||
invalid_rag_retrieval,
|
||||
non_exist_rag_retrieval,
|
||||
shell_tool,
|
||||
docs_tool,
|
||||
AgentTool(
|
||||
agent=no_schema_agent,
|
||||
),
|
||||
AgentTool(
|
||||
agent=schema_agent,
|
||||
),
|
||||
AgentTool(
|
||||
agent=no_input_schema_agent,
|
||||
),
|
||||
AgentTool(
|
||||
agent=no_output_schema_agent,
|
||||
),
|
||||
],
|
||||
)
|
||||
BIN
tests/integration/fixture/tool_agent/files/Agent_test_plan.pdf
Normal file
BIN
tests/integration/fixture/tool_agent/files/Agent_test_plan.pdf
Normal file
Binary file not shown.
15
tests/integration/fixture/trip_planner_agent/__init__.py
Normal file
15
tests/integration/fixture/trip_planner_agent/__init__.py
Normal file
@@ -0,0 +1,15 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from . import agent
|
||||
110
tests/integration/fixture/trip_planner_agent/agent.py
Normal file
110
tests/integration/fixture/trip_planner_agent/agent.py
Normal file
@@ -0,0 +1,110 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# https://github.com/crewAIInc/crewAI-examples/tree/main/trip_planner
|
||||
|
||||
from google.adk import Agent
|
||||
|
||||
# Agent that selects the best city for the trip.
|
||||
identify_agent = Agent(
|
||||
name='identify_agent',
|
||||
description='Select the best city based on weather, season, and prices.',
|
||||
instruction="""
|
||||
Analyze and select the best city for the trip based
|
||||
on specific criteria such as weather patterns, seasonal
|
||||
events, and travel costs. This task involves comparing
|
||||
multiple cities, considering factors like current weather
|
||||
conditions, upcoming cultural or seasonal events, and
|
||||
overall travel expenses.
|
||||
|
||||
Your final answer must be a detailed
|
||||
report on the chosen city, and everything you found out
|
||||
about it, including the actual flight costs, weather
|
||||
forecast and attractions.
|
||||
|
||||
Traveling from: {origin}
|
||||
City Options: {cities}
|
||||
Trip Date: {range}
|
||||
Traveler Interests: {interests}
|
||||
""",
|
||||
)
|
||||
|
||||
# Agent that gathers information about the city.
|
||||
gather_agent = Agent(
|
||||
name='gather_agent',
|
||||
description='Provide the BEST insights about the selected city',
|
||||
instruction="""
|
||||
As a local expert on this city you must compile an
|
||||
in-depth guide for someone traveling there and wanting
|
||||
to have THE BEST trip ever!
|
||||
Gather information about key attractions, local customs,
|
||||
special events, and daily activity recommendations.
|
||||
Find the best spots to go to, the kind of place only a
|
||||
local would know.
|
||||
This guide should provide a thorough overview of what
|
||||
the city has to offer, including hidden gems, cultural
|
||||
hotspots, must-visit landmarks, weather forecasts, and
|
||||
high level costs.
|
||||
|
||||
The final answer must be a comprehensive city guide,
|
||||
rich in cultural insights and practical tips,
|
||||
tailored to enhance the travel experience.
|
||||
|
||||
Trip Date: {range}
|
||||
Traveling from: {origin}
|
||||
Traveler Interests: {interests}
|
||||
""",
|
||||
)
|
||||
|
||||
# Agent that plans the trip.
|
||||
plan_agent = Agent(
|
||||
name='plan_agent',
|
||||
description="""Create the most amazing travel itineraries with budget and
|
||||
packing suggestions for the city""",
|
||||
instruction="""
|
||||
Expand this guide into a full 7-day travel
|
||||
itinerary with detailed per-day plans, including
|
||||
weather forecasts, places to eat, packing suggestions,
|
||||
and a budget breakdown.
|
||||
|
||||
You MUST suggest actual places to visit, actual hotels
|
||||
to stay and actual restaurants to go to.
|
||||
|
||||
This itinerary should cover all aspects of the trip,
|
||||
from arrival to departure, integrating the city guide
|
||||
information with practical travel logistics.
|
||||
|
||||
Your final answer MUST be a complete expanded travel plan,
|
||||
formatted as markdown, encompassing a daily schedule,
|
||||
anticipated weather conditions, recommended clothing and
|
||||
items to pack, and a detailed budget, ensuring THE BEST
|
||||
TRIP EVER. Be specific and give it a reason why you picked
|
||||
each place, what makes them special!
|
||||
|
||||
Trip Date: {range}
|
||||
Traveling from: {origin}
|
||||
Traveler Interests: {interests}
|
||||
""",
|
||||
)
|
||||
|
||||
root_agent = Agent(
|
||||
model='gemini-2.0-flash-001',
|
||||
name='trip_planner',
|
||||
description='Plan the best trip ever',
|
||||
instruction="""
|
||||
Your goal is to plan the best trip according to information listed above.
|
||||
You describe why did you choose the city, list top 3
|
||||
attactions and provide a detailed itinerary for each day.""",
|
||||
sub_agents=[identify_agent, gather_agent, plan_agent],
|
||||
)
|
||||
@@ -0,0 +1,13 @@
|
||||
{
|
||||
"id": "test_id",
|
||||
"app_name": "trip_planner_agent",
|
||||
"user_id": "test_user",
|
||||
"state": {
|
||||
"origin": "San Francisco",
|
||||
"interests": "Food, Shopping, Museums",
|
||||
"range": "1000 miles",
|
||||
"cities": ""
|
||||
},
|
||||
"events": [],
|
||||
"last_update_time": 1741218714.258285
|
||||
}
|
||||
@@ -0,0 +1,5 @@
|
||||
{
|
||||
"criteria": {
|
||||
"response_match_score": 0.5
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,13 @@
|
||||
{
|
||||
"id": "test_id",
|
||||
"app_name": "trip_planner_agent",
|
||||
"user_id": "test_user",
|
||||
"state": {
|
||||
"origin": "San Francisco",
|
||||
"interests": "Food, Shopping, Museums",
|
||||
"range": "1000 miles",
|
||||
"cities": ""
|
||||
},
|
||||
"events": [],
|
||||
"last_update_time": 1741218714.258285
|
||||
}
|
||||
@@ -0,0 +1,5 @@
|
||||
{
|
||||
"criteria": {
|
||||
"response_match_score": 0.5
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,7 @@
|
||||
[
|
||||
{
|
||||
"query": "Based on my interests, where should I go, Yosemite national park or Los Angeles?",
|
||||
"expected_tool_use": [],
|
||||
"reference": "Given your interests in food, shopping, and museums, Los Angeles would be a better choice than Yosemite National Park. Yosemite is primarily focused on outdoor activities and natural landscapes, while Los Angeles offers a diverse range of culinary experiences, shopping districts, and world-class museums. I will now gather information to create an in-depth guide for your trip to Los Angeles.\n"
|
||||
}
|
||||
]
|
||||
@@ -0,0 +1,19 @@
|
||||
[
|
||||
{
|
||||
"query": "Hi, who are you? What can you do?",
|
||||
"expected_tool_use": [],
|
||||
"reference": "I am trip_planner, and my goal is to plan the best trip ever. I can describe why a city was chosen, list its top attractions, and provide a detailed itinerary for each day of the trip.\n"
|
||||
},
|
||||
{
|
||||
"query": "I want to travel from San Francisco to an European country in fall next year. I am considering London and Paris. What is your advice?",
|
||||
"expected_tool_use": [
|
||||
{
|
||||
"tool_name": "transfer_to_agent",
|
||||
"tool_input": {
|
||||
"agent_name": "indentify_agent"
|
||||
}
|
||||
}
|
||||
],
|
||||
"reference": "Okay, I can help you analyze London and Paris to determine which city is better for your trip next fall. I will consider weather patterns, seasonal events, travel costs (including flights from San Francisco), and your interests (food, shopping, and museums). After gathering this information, I'll provide a detailed report on my chosen city.\n"
|
||||
}
|
||||
]
|
||||
14
tests/integration/models/__init__.py
Normal file
14
tests/integration/models/__init__.py
Normal file
@@ -0,0 +1,14 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
65
tests/integration/models/test_google_llm.py
Normal file
65
tests/integration/models/test_google_llm.py
Normal file
@@ -0,0 +1,65 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from google.adk.models import LlmRequest
|
||||
from google.adk.models import LlmResponse
|
||||
from google.adk.models.google_llm import Gemini
|
||||
from google.genai import types
|
||||
from google.genai.types import Content
|
||||
from google.genai.types import Part
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def gemini_llm():
|
||||
return Gemini(model="gemini-1.5-flash")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def llm_request():
|
||||
return LlmRequest(
|
||||
model="gemini-1.5-flash",
|
||||
contents=[Content(role="user", parts=[Part.from_text(text="Hello")])],
|
||||
config=types.GenerateContentConfig(
|
||||
temperature=0.1,
|
||||
response_modalities=[types.Modality.TEXT],
|
||||
system_instruction="You are a helpful assistant",
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_content_async(gemini_llm, llm_request):
|
||||
async for response in gemini_llm.generate_content_async(llm_request):
|
||||
assert isinstance(response, LlmResponse)
|
||||
assert response.content.parts[0].text
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_content_async_stream(gemini_llm, llm_request):
|
||||
responses = [
|
||||
resp
|
||||
async for resp in gemini_llm.generate_content_async(
|
||||
llm_request, stream=True
|
||||
)
|
||||
]
|
||||
text = ""
|
||||
for i in range(len(responses) - 1):
|
||||
assert responses[i].partial is True
|
||||
assert responses[i].content.parts[0].text
|
||||
text += responses[i].content.parts[0].text
|
||||
|
||||
# Last message should be accumulated text
|
||||
assert responses[-1].content.parts[0].text == text
|
||||
assert not responses[-1].partial
|
||||
70
tests/integration/test_callback.py
Normal file
70
tests/integration/test_callback.py
Normal file
@@ -0,0 +1,70 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from pytest import mark
|
||||
|
||||
from ..unittests.utils import simplify_events
|
||||
from .fixture import callback_agent
|
||||
from .utils import assert_agent_says
|
||||
from .utils import TestRunner
|
||||
|
||||
|
||||
@mark.parametrize(
|
||||
"agent_runner",
|
||||
[{"agent": callback_agent.agent.before_agent_callback_agent}],
|
||||
indirect=True,
|
||||
)
|
||||
def test_before_agent_call(agent_runner: TestRunner):
|
||||
agent_runner.run("Hi.")
|
||||
|
||||
# Assert the response content
|
||||
assert_agent_says(
|
||||
"End invocation event before agent call.",
|
||||
agent_name="before_agent_callback_agent",
|
||||
agent_runner=agent_runner,
|
||||
)
|
||||
|
||||
|
||||
@mark.parametrize(
|
||||
"agent_runner",
|
||||
[{"agent": callback_agent.agent.before_model_callback_agent}],
|
||||
indirect=True,
|
||||
)
|
||||
def test_before_model_call(agent_runner: TestRunner):
|
||||
agent_runner.run("Hi.")
|
||||
|
||||
# Assert the response content
|
||||
assert_agent_says(
|
||||
"End invocation event before model call.",
|
||||
agent_name="before_model_callback_agent",
|
||||
agent_runner=agent_runner,
|
||||
)
|
||||
|
||||
|
||||
# TODO: re-enable vertex by removing below line after fixing.
|
||||
@mark.parametrize("llm_backend", ["GOOGLE_AI"], indirect=True)
|
||||
@mark.parametrize(
|
||||
"agent_runner",
|
||||
[{"agent": callback_agent.agent.after_model_callback_agent}],
|
||||
indirect=True,
|
||||
)
|
||||
def test_after_model_call(agent_runner: TestRunner):
|
||||
events = agent_runner.run("Hi.")
|
||||
|
||||
# Assert the response content
|
||||
simplified_events = simplify_events(events)
|
||||
assert simplified_events[0][0] == "after_model_callback_agent"
|
||||
assert simplified_events[0][1].endswith(
|
||||
"Update response event after model call."
|
||||
)
|
||||
67
tests/integration/test_context_variable.py
Normal file
67
tests/integration/test_context_variable.py
Normal file
@@ -0,0 +1,67 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import json
|
||||
|
||||
import pytest
|
||||
|
||||
# Skip until fixed.
|
||||
pytest.skip(allow_module_level=True)
|
||||
|
||||
from .fixture import context_variable_agent
|
||||
from .utils import TestRunner
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"agent_runner",
|
||||
[{"agent": context_variable_agent.agent.state_variable_echo_agent}],
|
||||
indirect=True,
|
||||
)
|
||||
def test_context_variable_missing(agent_runner: TestRunner):
|
||||
with pytest.raises(KeyError) as e_info:
|
||||
agent_runner.run("Hi echo my customer id.")
|
||||
assert "customerId" in str(e_info.value)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"agent_runner",
|
||||
[{"agent": context_variable_agent.agent.state_variable_update_agent}],
|
||||
indirect=True,
|
||||
)
|
||||
def test_context_variable_update(agent_runner: TestRunner):
|
||||
_call_function_and_assert(
|
||||
agent_runner,
|
||||
"update_fc",
|
||||
["RRRR", "3.141529", ["apple", "banana"], [1, 3.14, "hello"]],
|
||||
"successfully",
|
||||
)
|
||||
|
||||
|
||||
def _call_function_and_assert(
|
||||
agent_runner: TestRunner, function_name: str, params, expected
|
||||
):
|
||||
param_section = (
|
||||
" with params"
|
||||
f" {params if isinstance(params, str) else json.dumps(params)}"
|
||||
if params is not None
|
||||
else ""
|
||||
)
|
||||
agent_runner.run(
|
||||
f"Call {function_name}{param_section} and show me the result"
|
||||
)
|
||||
|
||||
model_response_event = agent_runner.get_events()[-1]
|
||||
assert model_response_event.author == "context_variable_update_agent"
|
||||
assert model_response_event.content.role == "model"
|
||||
assert expected in model_response_event.content.parts[0].text.strip()
|
||||
76
tests/integration/test_evalute_agent_in_fixture.py
Normal file
76
tests/integration/test_evalute_agent_in_fixture.py
Normal file
@@ -0,0 +1,76 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Evaluate all agents in fixture folder if evaluation test files exist."""
|
||||
|
||||
import os
|
||||
|
||||
from google.adk.evaluation import AgentEvaluator
|
||||
import pytest
|
||||
|
||||
def agent_eval_artifacts_in_fixture():
|
||||
"""Get all agents from fixture folder."""
|
||||
agent_eval_artifacts = []
|
||||
fixture_dir = os.path.join(os.path.dirname(__file__), 'fixture')
|
||||
for agent_name in os.listdir(fixture_dir):
|
||||
agent_dir = os.path.join(fixture_dir, agent_name)
|
||||
if not os.path.isdir(agent_dir):
|
||||
continue
|
||||
for filename in os.listdir(agent_dir):
|
||||
# Evaluation test files end with test.json
|
||||
if not filename.endswith('test.json'):
|
||||
continue
|
||||
initial_session_file = (
|
||||
f'tests/integration/fixture/{agent_name}/initial.session.json'
|
||||
)
|
||||
agent_eval_artifacts.append((
|
||||
f'tests.integration.fixture.{agent_name}',
|
||||
f'tests/integration/fixture/{agent_name}/{filename}',
|
||||
initial_session_file
|
||||
if os.path.exists(initial_session_file)
|
||||
else None,
|
||||
))
|
||||
|
||||
# This method gets invoked twice, sorting helps ensure that both the
|
||||
# invocations have the same view.
|
||||
agent_eval_artifacts = sorted(
|
||||
agent_eval_artifacts, key=lambda item: f'{item[0]}|{item[1]}'
|
||||
)
|
||||
return agent_eval_artifacts
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
'agent_name, evalfile, initial_session_file',
|
||||
agent_eval_artifacts_in_fixture(),
|
||||
ids=[agent_name for agent_name, _, _ in agent_eval_artifacts_in_fixture()],
|
||||
)
|
||||
def test_evaluate_agents_long_running_4_runs_per_eval_item(
|
||||
agent_name, evalfile, initial_session_file
|
||||
):
|
||||
"""Test agents evaluation in fixture folder.
|
||||
|
||||
After querying the fixture folder, we have 5 eval items. For each eval item
|
||||
we use 4 runs.
|
||||
|
||||
A single eval item is a session that can have multiple queries in it.
|
||||
"""
|
||||
AgentEvaluator.evaluate(
|
||||
agent_module=agent_name,
|
||||
eval_dataset_file_path_or_dir=evalfile,
|
||||
initial_session_file=initial_session_file,
|
||||
# Using a slightly higher value helps us manange the variances that may
|
||||
# happen in each eval.
|
||||
# This, of course, comes at a cost of incrased test run times.
|
||||
num_runs=4,
|
||||
)
|
||||
28
tests/integration/test_multi_agent.py
Normal file
28
tests/integration/test_multi_agent.py
Normal file
@@ -0,0 +1,28 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from google.adk.evaluation import AgentEvaluator
|
||||
|
||||
|
||||
def test_eval_agent():
|
||||
AgentEvaluator.evaluate(
|
||||
agent_module="tests.integration.fixture.trip_planner_agent",
|
||||
eval_dataset_file_path_or_dir=(
|
||||
"tests/integration/fixture/trip_planner_agent/trip_inquiry.test.json"
|
||||
),
|
||||
initial_session_file=(
|
||||
"tests/integration/fixture/trip_planner_agent/initial.session.json"
|
||||
),
|
||||
num_runs=4,
|
||||
)
|
||||
42
tests/integration/test_multi_turn.py
Normal file
42
tests/integration/test_multi_turn.py
Normal file
@@ -0,0 +1,42 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from google.adk.evaluation import AgentEvaluator
|
||||
|
||||
|
||||
def test_simple_multi_turn_conversation():
|
||||
"""Test a simple multi-turn conversation."""
|
||||
AgentEvaluator.evaluate(
|
||||
agent_module="tests.integration.fixture.home_automation_agent",
|
||||
eval_dataset_file_path_or_dir="tests/integration/fixture/home_automation_agent/test_files/simple_multi_turn_conversation.test.json",
|
||||
num_runs=4,
|
||||
)
|
||||
|
||||
|
||||
def test_dependent_tool_calls():
|
||||
"""Test subsequent tool calls that are dependent on previous tool calls."""
|
||||
AgentEvaluator.evaluate(
|
||||
agent_module="tests.integration.fixture.home_automation_agent",
|
||||
eval_dataset_file_path_or_dir="tests/integration/fixture/home_automation_agent/test_files/dependent_tool_calls.test.json",
|
||||
num_runs=4,
|
||||
)
|
||||
|
||||
|
||||
def test_memorizing_past_events():
|
||||
"""Test memorizing past events."""
|
||||
AgentEvaluator.evaluate(
|
||||
agent_module="tests.integration.fixture.home_automation_agent",
|
||||
eval_dataset_file_path_or_dir="tests/integration/fixture/home_automation_agent/test_files/memorizing_past_events/eval_data.test.json",
|
||||
num_runs=4,
|
||||
)
|
||||
23
tests/integration/test_single_agent.py
Normal file
23
tests/integration/test_single_agent.py
Normal file
@@ -0,0 +1,23 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from google.adk.evaluation import AgentEvaluator
|
||||
|
||||
|
||||
def test_eval_agent():
|
||||
AgentEvaluator.evaluate(
|
||||
agent_module="tests.integration.fixture.home_automation_agent",
|
||||
eval_dataset_file_path_or_dir="tests/integration/fixture/home_automation_agent/simple_test.test.json",
|
||||
num_runs=4,
|
||||
)
|
||||
26
tests/integration/test_sub_agent.py
Normal file
26
tests/integration/test_sub_agent.py
Normal file
@@ -0,0 +1,26 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from google.adk.evaluation import AgentEvaluator
|
||||
|
||||
|
||||
def test_eval_agent():
|
||||
"""Test hotel sub agent in a multi-agent system."""
|
||||
AgentEvaluator.evaluate(
|
||||
agent_module="tests.integration.fixture.trip_planner_agent",
|
||||
eval_dataset_file_path_or_dir="tests/integration/fixture/trip_planner_agent/test_files/trip_inquiry_sub_agent.test.json",
|
||||
initial_session_file="tests/integration/fixture/trip_planner_agent/test_files/initial.session.json",
|
||||
agent_name="identify_agent",
|
||||
num_runs=4,
|
||||
)
|
||||
177
tests/integration/test_system_instruction.py
Normal file
177
tests/integration/test_system_instruction.py
Normal file
@@ -0,0 +1,177 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import pytest
|
||||
|
||||
# Skip until fixed.
|
||||
pytest.skip(allow_module_level=True)
|
||||
|
||||
from google.adk.agents import InvocationContext
|
||||
from google.adk.sessions import Session
|
||||
from google.genai import types
|
||||
|
||||
from .fixture import context_variable_agent
|
||||
from .utils import TestRunner
|
||||
|
||||
nl_planner_si = """
|
||||
You are an intelligent tool use agent built upon the Gemini large language model. When answering the question, try to leverage the available tools to gather the information instead of your memorized knowledge.
|
||||
|
||||
Follow this process when answering the question: (1) first come up with a plan in natural language text format; (2) Then use tools to execute the plan and provide reasoning between tool code snippets to make a summary of current state and next step. Tool code snippets and reasoning should be interleaved with each other. (3) In the end, return one final answer.
|
||||
|
||||
Follow this format when answering the question: (1) The planning part should be under /*PLANNING*/. (2) The tool code snippets should be under /*ACTION*/, and the reasoning parts should be under /*REASONING*/. (3) The final answer part should be under /*FINAL_ANSWER*/.
|
||||
|
||||
|
||||
Below are the requirements for the planning:
|
||||
The plan is made to answer the user query if following the plan. The plan is coherent and covers all aspects of information from user query, and only involves the tools that are accessible by the agent. The plan contains the decomposed steps as a numbered list where each step should use one or multiple available tools. By reading the plan, you can intuitively know which tools to trigger or what actions to take.
|
||||
If the initial plan cannot be successfully executed, you should learn from previous execution results and revise your plan. The revised plan should be be under /*REPLANNING*/. Then use tools to follow the new plan.
|
||||
|
||||
Below are the requirements for the reasoning:
|
||||
The reasoning makes a summary of the current trajectory based on the user query and tool outputs. Based on the tool outputs and plan, the reasoning also comes up with instructions to the next steps, making the trajectory closer to the final answer.
|
||||
|
||||
|
||||
|
||||
Below are the requirements for the final answer:
|
||||
The final answer should be precise and follow query formatting requirements. Some queries may not be answerable with the available tools and information. In those cases, inform the user why you cannot process their query and ask for more information.
|
||||
|
||||
|
||||
|
||||
Below are the requirements for the tool code:
|
||||
|
||||
**Custom Tools:** The available tools are described in the context and can be directly used.
|
||||
- Code must be valid self-contained Python snippets with no imports and no references to tools or Python libraries that are not in the context.
|
||||
- You cannot use any parameters or fields that are not explicitly defined in the APIs in the context.
|
||||
- Use "print" to output execution results for the next step or final answer that you need for responding to the user. Never generate ```tool_outputs yourself.
|
||||
- The code snippets should be readable, efficient, and directly relevant to the user query and reasoning steps.
|
||||
- When using the tools, you should use the library name together with the function name, e.g., vertex_search.search().
|
||||
- If Python libraries are not provided in the context, NEVER write your own code other than the function calls using the provided tools.
|
||||
|
||||
|
||||
|
||||
VERY IMPORTANT instruction that you MUST follow in addition to the above instructions:
|
||||
|
||||
You should ask for clarification if you need more information to answer the question.
|
||||
You should prefer using the information available in the context instead of repeated tool use.
|
||||
|
||||
You should ONLY generate code snippets prefixed with "```tool_code" if you need to use the tools to answer the question.
|
||||
|
||||
If you are asked to write code by user specifically,
|
||||
- you should ALWAYS use "```python" to format the code.
|
||||
- you should NEVER put "tool_code" to format the code.
|
||||
- Good example:
|
||||
```python
|
||||
print('hello')
|
||||
```
|
||||
- Bad example:
|
||||
```tool_code
|
||||
print('hello')
|
||||
```
|
||||
"""
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"agent_runner",
|
||||
[{"agent": context_variable_agent.agent.state_variable_echo_agent}],
|
||||
indirect=True,
|
||||
)
|
||||
def test_context_variable(agent_runner: TestRunner):
|
||||
session = Session(
|
||||
context={
|
||||
"customerId": "1234567890",
|
||||
"customerInt": 30,
|
||||
"customerFloat": 12.34,
|
||||
"customerJson": {"name": "John Doe", "age": 30, "count": 11.1},
|
||||
}
|
||||
)
|
||||
si = UnitFlow()._build_system_instruction(
|
||||
InvocationContext(
|
||||
invocation_id="1234567890", agent=agent_runner.agent, session=session
|
||||
)
|
||||
)
|
||||
|
||||
assert (
|
||||
"Use the echo_info tool to echo 1234567890, 30, 12.34, and {'name': 'John"
|
||||
" Doe', 'age': 30, 'count': 11.1}. Ask for it if you need to."
|
||||
in si
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"agent_runner",
|
||||
[{
|
||||
"agent": (
|
||||
context_variable_agent.agent.state_variable_with_complicated_format_agent
|
||||
)
|
||||
}],
|
||||
indirect=True,
|
||||
)
|
||||
def test_context_variable_with_complicated_format(agent_runner: TestRunner):
|
||||
session = Session(
|
||||
context={"customerId": "1234567890", "customer_int": 30},
|
||||
artifacts={"fileName": [types.Part(text="test artifact")]},
|
||||
)
|
||||
si = _context_formatter.populate_context_and_artifact_variable_values(
|
||||
agent_runner.agent.instruction,
|
||||
session.get_state(),
|
||||
session.get_artifact_dict(),
|
||||
)
|
||||
|
||||
assert (
|
||||
si
|
||||
== "Use the echo_info tool to echo 1234567890, 30, { "
|
||||
" non-identifier-float}}, test artifact, {'key1': 'value1'} and"
|
||||
" {{'key2': 'value2'}}. Ask for it if you need to."
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"agent_runner",
|
||||
[{
|
||||
"agent": (
|
||||
context_variable_agent.agent.state_variable_with_nl_planner_agent
|
||||
)
|
||||
}],
|
||||
indirect=True,
|
||||
)
|
||||
def test_nl_planner(agent_runner: TestRunner):
|
||||
session = Session(context={"customerId": "1234567890"})
|
||||
si = UnitFlow()._build_system_instruction(
|
||||
InvocationContext(
|
||||
invocation_id="1234567890",
|
||||
agent=agent_runner.agent,
|
||||
session=session,
|
||||
)
|
||||
)
|
||||
|
||||
for line in nl_planner_si.splitlines():
|
||||
assert line in si
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"agent_runner",
|
||||
[{
|
||||
"agent": (
|
||||
context_variable_agent.agent.state_variable_with_function_instruction_agent
|
||||
)
|
||||
}],
|
||||
indirect=True,
|
||||
)
|
||||
def test_function_instruction(agent_runner: TestRunner):
|
||||
session = Session(context={"customerId": "1234567890"})
|
||||
si = UnitFlow()._build_system_instruction(
|
||||
InvocationContext(
|
||||
invocation_id="1234567890", agent=agent_runner.agent, session=session
|
||||
)
|
||||
)
|
||||
|
||||
assert "This is the plain text sub agent instruction." in si
|
||||
287
tests/integration/test_tools.py
Normal file
287
tests/integration/test_tools.py
Normal file
@@ -0,0 +1,287 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import json
|
||||
|
||||
import pytest
|
||||
|
||||
# Skip until fixed.
|
||||
pytest.skip(allow_module_level=True)
|
||||
|
||||
from .fixture import tool_agent
|
||||
from .utils import TestRunner
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"agent_runner",
|
||||
[{"agent": tool_agent.agent.single_function_agent}],
|
||||
indirect=True,
|
||||
)
|
||||
def test_single_function_calls_success(agent_runner: TestRunner):
|
||||
_call_function_and_assert(
|
||||
agent_runner,
|
||||
"simple_function",
|
||||
"test",
|
||||
"success",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"agent_runner",
|
||||
[{"agent": tool_agent.agent.root_agent}],
|
||||
indirect=True,
|
||||
)
|
||||
def test_multiple_function_calls_success(agent_runner: TestRunner):
|
||||
_call_function_and_assert(
|
||||
agent_runner,
|
||||
"simple_function",
|
||||
"test",
|
||||
"success",
|
||||
)
|
||||
_call_function_and_assert(
|
||||
agent_runner,
|
||||
"no_param_function",
|
||||
None,
|
||||
"Called no param function successfully",
|
||||
)
|
||||
_call_function_and_assert(
|
||||
agent_runner,
|
||||
"no_output_function",
|
||||
"test",
|
||||
"",
|
||||
)
|
||||
_call_function_and_assert(
|
||||
agent_runner,
|
||||
"multiple_param_types_function",
|
||||
["test", 1, 2.34, True],
|
||||
"success",
|
||||
)
|
||||
_call_function_and_assert(
|
||||
agent_runner,
|
||||
"return_list_str_function",
|
||||
"test",
|
||||
"success",
|
||||
)
|
||||
_call_function_and_assert(
|
||||
agent_runner,
|
||||
"list_str_param_function",
|
||||
["test", "test2", "test3", "test4"],
|
||||
"success",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="Currently failing with 400 on MLDev.")
|
||||
@pytest.mark.parametrize(
|
||||
"agent_runner",
|
||||
[{"agent": tool_agent.agent.root_agent}],
|
||||
indirect=True,
|
||||
)
|
||||
def test_complex_function_calls_success(agent_runner: TestRunner):
|
||||
param1 = {"name": "Test", "count": 3}
|
||||
param2 = [
|
||||
{"name": "Function", "count": 2},
|
||||
{"name": "Retrieval", "count": 1},
|
||||
]
|
||||
_call_function_and_assert(
|
||||
agent_runner,
|
||||
"complex_function_list_dict",
|
||||
[param1, param2],
|
||||
"test",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"agent_runner",
|
||||
[{"agent": tool_agent.agent.root_agent}],
|
||||
indirect=True,
|
||||
)
|
||||
def test_repetive_call_success(agent_runner: TestRunner):
|
||||
_call_function_and_assert(
|
||||
agent_runner,
|
||||
"repetive_call_1",
|
||||
"test",
|
||||
"test_repetive",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"agent_runner",
|
||||
[{"agent": tool_agent.agent.root_agent}],
|
||||
indirect=True,
|
||||
)
|
||||
def test_function_calls_fail(agent_runner: TestRunner):
|
||||
_call_function_and_assert(
|
||||
agent_runner,
|
||||
"throw_error_function",
|
||||
"test",
|
||||
None,
|
||||
ValueError,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"agent_runner",
|
||||
[{"agent": tool_agent.agent.root_agent}],
|
||||
indirect=True,
|
||||
)
|
||||
def test_agent_tools_success(agent_runner: TestRunner):
|
||||
_call_function_and_assert(
|
||||
agent_runner,
|
||||
"no_schema_agent",
|
||||
"Hi",
|
||||
"Hi",
|
||||
)
|
||||
_call_function_and_assert(
|
||||
agent_runner,
|
||||
"schema_agent",
|
||||
"Agent_tools",
|
||||
"Agent_tools_success",
|
||||
)
|
||||
_call_function_and_assert(
|
||||
agent_runner, "no_input_schema_agent", "Tools", "Tools_success"
|
||||
)
|
||||
_call_function_and_assert(agent_runner, "no_output_schema_agent", "Hi", "Hi")
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"agent_runner",
|
||||
[{"agent": tool_agent.agent.root_agent}],
|
||||
indirect=True,
|
||||
)
|
||||
def test_files_retrieval_success(agent_runner: TestRunner):
|
||||
_call_function_and_assert(
|
||||
agent_runner,
|
||||
"test_case_retrieval",
|
||||
"What is the testing strategy of agent 2.0?",
|
||||
"test",
|
||||
)
|
||||
# For non relevant query, the agent should still be running fine, just return
|
||||
# response might be different for different calls, so we don't compare the
|
||||
# response here.
|
||||
_call_function_and_assert(
|
||||
agent_runner,
|
||||
"test_case_retrieval",
|
||||
"What is the whether in bay area?",
|
||||
"",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"agent_runner",
|
||||
[{"agent": tool_agent.agent.root_agent}],
|
||||
indirect=True,
|
||||
)
|
||||
def test_rag_retrieval_success(agent_runner: TestRunner):
|
||||
_call_function_and_assert(
|
||||
agent_runner,
|
||||
"valid_rag_retrieval",
|
||||
"What is the testing strategy of agent 2.0?",
|
||||
"test",
|
||||
)
|
||||
_call_function_and_assert(
|
||||
agent_runner,
|
||||
"valid_rag_retrieval",
|
||||
"What is the whether in bay area?",
|
||||
"No",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"agent_runner",
|
||||
[{"agent": tool_agent.agent.root_agent}],
|
||||
indirect=True,
|
||||
)
|
||||
def test_rag_retrieval_fail(agent_runner: TestRunner):
|
||||
_call_function_and_assert(
|
||||
agent_runner,
|
||||
"invalid_rag_retrieval",
|
||||
"What is the testing strategy of agent 2.0?",
|
||||
None,
|
||||
ValueError,
|
||||
)
|
||||
_call_function_and_assert(
|
||||
agent_runner,
|
||||
"non_exist_rag_retrieval",
|
||||
"What is the whether in bay area?",
|
||||
None,
|
||||
ValueError,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"agent_runner",
|
||||
[{"agent": tool_agent.agent.root_agent}],
|
||||
indirect=True,
|
||||
)
|
||||
def test_langchain_tool_success(agent_runner: TestRunner):
|
||||
_call_function_and_assert(
|
||||
agent_runner,
|
||||
"terminal",
|
||||
"Run the following shell command 'echo test!'",
|
||||
"test",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"agent_runner",
|
||||
[{"agent": tool_agent.agent.root_agent}],
|
||||
indirect=True,
|
||||
)
|
||||
def test_crewai_tool_success(agent_runner: TestRunner):
|
||||
_call_function_and_assert(
|
||||
agent_runner,
|
||||
"direcotry_read_tool",
|
||||
"Find all the file paths",
|
||||
"file",
|
||||
)
|
||||
|
||||
|
||||
def _call_function_and_assert(
|
||||
agent_runner: TestRunner,
|
||||
function_name: str,
|
||||
params,
|
||||
expected=None,
|
||||
exception: Exception = None,
|
||||
):
|
||||
param_section = (
|
||||
" with params"
|
||||
f" {params if isinstance(params, str) else json.dumps(params)}"
|
||||
if params is not None
|
||||
else ""
|
||||
)
|
||||
query = f"Call {function_name}{param_section} and show me the result"
|
||||
if exception:
|
||||
_assert_raises(agent_runner, query, exception)
|
||||
return
|
||||
|
||||
_assert_function_output(agent_runner, query, expected)
|
||||
|
||||
|
||||
def _assert_raises(agent_runner: TestRunner, query: str, exception: Exception):
|
||||
with pytest.raises(exception):
|
||||
agent_runner.run(query)
|
||||
|
||||
|
||||
def _assert_function_output(agent_runner: TestRunner, query: str, expected):
|
||||
agent_runner.run(query)
|
||||
|
||||
# Retrieve the latest model response event
|
||||
model_response_event = agent_runner.get_events()[-1]
|
||||
|
||||
# Assert the response content
|
||||
assert model_response_event.content.role == "model"
|
||||
assert (
|
||||
expected.lower()
|
||||
in model_response_event.content.parts[0].text.strip().lower()
|
||||
)
|
||||
34
tests/integration/test_with_test_file.py
Normal file
34
tests/integration/test_with_test_file.py
Normal file
@@ -0,0 +1,34 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from google.adk.evaluation import AgentEvaluator
|
||||
|
||||
|
||||
def test_with_single_test_file():
|
||||
"""Test the agent's basic ability via session file."""
|
||||
AgentEvaluator.evaluate(
|
||||
agent_module="tests.integration.fixture.home_automation_agent",
|
||||
eval_dataset_file_path_or_dir="tests/integration/fixture/home_automation_agent/simple_test.test.json",
|
||||
)
|
||||
|
||||
|
||||
def test_with_folder_of_test_files_long_running():
|
||||
"""Test the agent's basic ability via a folder of session files."""
|
||||
AgentEvaluator.evaluate(
|
||||
agent_module="tests.integration.fixture.home_automation_agent",
|
||||
eval_dataset_file_path_or_dir=(
|
||||
"tests/integration/fixture/home_automation_agent/test_files"
|
||||
),
|
||||
num_runs=4,
|
||||
)
|
||||
14
tests/integration/tools/__init__.py
Normal file
14
tests/integration/tools/__init__.py
Normal file
@@ -0,0 +1,14 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
16
tests/integration/utils/__init__.py
Normal file
16
tests/integration/utils/__init__.py
Normal file
@@ -0,0 +1,16 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .asserts import *
|
||||
from .test_runner import TestRunner
|
||||
75
tests/integration/utils/asserts.py
Normal file
75
tests/integration/utils/asserts.py
Normal file
@@ -0,0 +1,75 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import TypedDict
|
||||
|
||||
from .test_runner import TestRunner
|
||||
|
||||
|
||||
class Message(TypedDict):
|
||||
agent_name: str
|
||||
expected_text: str
|
||||
|
||||
|
||||
def assert_current_agent_is(agent_name: str, *, agent_runner: TestRunner):
|
||||
assert agent_runner.get_current_agent_name() == agent_name
|
||||
|
||||
|
||||
def assert_agent_says(
|
||||
expected_text: str, *, agent_name: str, agent_runner: TestRunner
|
||||
):
|
||||
for event in reversed(agent_runner.get_events()):
|
||||
if event.author == agent_name and event.content.parts[0].text:
|
||||
assert event.content.parts[0].text.strip() == expected_text
|
||||
return
|
||||
|
||||
|
||||
def assert_agent_says_in_order(
|
||||
expected_conversation: list[Message], agent_runner: TestRunner
|
||||
):
|
||||
expected_conversation_idx = len(expected_conversation) - 1
|
||||
for event in reversed(agent_runner.get_events()):
|
||||
if event.content.parts and event.content.parts[0].text:
|
||||
assert (
|
||||
event.author
|
||||
== expected_conversation[expected_conversation_idx]['agent_name']
|
||||
)
|
||||
assert (
|
||||
event.content.parts[0].text.strip()
|
||||
== expected_conversation[expected_conversation_idx]['expected_text']
|
||||
)
|
||||
expected_conversation_idx -= 1
|
||||
if expected_conversation_idx < 0:
|
||||
return
|
||||
|
||||
|
||||
def assert_agent_transfer_path(
|
||||
expected_path: list[str], *, agent_runner: TestRunner
|
||||
):
|
||||
events = agent_runner.get_events()
|
||||
idx_in_expected_path = len(expected_path) - 1
|
||||
# iterate events in reverse order
|
||||
for event in reversed(events):
|
||||
function_calls = event.get_function_calls()
|
||||
if (
|
||||
len(function_calls) == 1
|
||||
and function_calls[0].name == 'transfer_to_agent'
|
||||
):
|
||||
assert (
|
||||
function_calls[0].args['agent_name']
|
||||
== expected_path[idx_in_expected_path]
|
||||
)
|
||||
idx_in_expected_path -= 1
|
||||
if idx_in_expected_path < 0:
|
||||
return
|
||||
97
tests/integration/utils/test_runner.py
Normal file
97
tests/integration/utils/test_runner.py
Normal file
@@ -0,0 +1,97 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import importlib
|
||||
from typing import Optional
|
||||
|
||||
from google.adk import Agent
|
||||
from google.adk import Runner
|
||||
from google.adk.artifacts import BaseArtifactService
|
||||
from google.adk.artifacts import InMemoryArtifactService
|
||||
from google.adk.events import Event
|
||||
from google.adk.sessions import BaseSessionService
|
||||
from google.adk.sessions import InMemorySessionService
|
||||
from google.adk.sessions import Session
|
||||
from google.genai import types
|
||||
|
||||
|
||||
class TestRunner:
|
||||
"""Agents runner for testing."""
|
||||
|
||||
app_name = "test_app"
|
||||
user_id = "test_user"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
agent: Agent,
|
||||
artifact_service: BaseArtifactService = InMemoryArtifactService(),
|
||||
session_service: BaseSessionService = InMemorySessionService(),
|
||||
) -> None:
|
||||
self.agent = agent
|
||||
self.agent_client = Runner(
|
||||
app_name=self.app_name,
|
||||
agent=agent,
|
||||
artifact_service=artifact_service,
|
||||
session_service=session_service,
|
||||
)
|
||||
self.session_service = session_service
|
||||
self.current_session_id = session_service.create_session(
|
||||
app_name=self.app_name, user_id=self.user_id
|
||||
).id
|
||||
|
||||
def new_session(self, session_id: Optional[str] = None) -> None:
|
||||
self.current_session_id = self.session_service.create_session(
|
||||
app_name=self.app_name, user_id=self.user_id, session_id=session_id
|
||||
).id
|
||||
|
||||
def run(self, prompt: str) -> list[Event]:
|
||||
current_session = self.session_service.get_session(
|
||||
app_name=self.app_name,
|
||||
user_id=self.user_id,
|
||||
session_id=self.current_session_id,
|
||||
)
|
||||
assert current_session is not None
|
||||
|
||||
return list(
|
||||
self.agent_client.run(
|
||||
user_id=current_session.user_id,
|
||||
session_id=current_session.id,
|
||||
new_message=types.Content(
|
||||
role="user",
|
||||
parts=[types.Part.from_text(text=prompt)],
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
def get_current_session(self) -> Optional[Session]:
|
||||
return self.session_service.get_session(
|
||||
app_name=self.app_name,
|
||||
user_id=self.user_id,
|
||||
session_id=self.current_session_id,
|
||||
)
|
||||
|
||||
def get_events(self) -> list[Event]:
|
||||
return self.get_current_session().events
|
||||
|
||||
@classmethod
|
||||
def from_agent_name(cls, agent_name: str):
|
||||
agent_module_path = f"tests.integration.fixture.{agent_name}"
|
||||
agent_module = importlib.import_module(agent_module_path)
|
||||
agent: Agent = agent_module.agent.root_agent
|
||||
return cls(agent)
|
||||
|
||||
def get_current_agent_name(self) -> str:
|
||||
return self.agent_client._find_agent_to_run(
|
||||
self.get_current_session(), self.agent
|
||||
).name
|
||||
13
tests/unittests/cli/__init__.py
Normal file
13
tests/unittests/cli/__init__.py
Normal file
@@ -0,0 +1,13 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
13
tests/unittests/cli/utils/__init__.py
Normal file
13
tests/unittests/cli/utils/__init__.py
Normal file
@@ -0,0 +1,13 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
434
tests/unittests/cli/utils/test_evals.py
Normal file
434
tests/unittests/cli/utils/test_evals.py
Normal file
@@ -0,0 +1,434 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Tests for utilities in eval."""
|
||||
|
||||
|
||||
from google.adk.cli.utils.evals import convert_session_to_eval_format
|
||||
from google.adk.events.event import Event
|
||||
from google.adk.sessions.session import Session
|
||||
from google.genai import types
|
||||
|
||||
|
||||
def build_event(author: str, parts_content: list[dict]) -> Event:
|
||||
"""Builds an Event object with specified parts."""
|
||||
parts = []
|
||||
for p_data in parts_content:
|
||||
part_args = {}
|
||||
if "text" in p_data:
|
||||
part_args["text"] = p_data["text"]
|
||||
if "func_name" in p_data:
|
||||
part_args["function_call"] = types.FunctionCall(
|
||||
name=p_data.get("func_name"), args=p_data.get("func_args")
|
||||
)
|
||||
# Add other part types here if needed for future tests
|
||||
parts.append(types.Part(**part_args))
|
||||
return Event(author=author, content=types.Content(parts=parts))
|
||||
|
||||
|
||||
def test_convert_empty_session():
|
||||
"""Test conversion function with empty events list in Session."""
|
||||
# Pydantic models require mandatory fields for instantiation
|
||||
session_empty_events = Session(
|
||||
id="s1", app_name="app", user_id="u1", events=[]
|
||||
)
|
||||
assert not convert_session_to_eval_format(session_empty_events)
|
||||
|
||||
|
||||
def test_convert_none_session():
|
||||
"""Test conversion function with None Session."""
|
||||
assert not convert_session_to_eval_format(None)
|
||||
|
||||
|
||||
def test_convert_session_skips_initial_non_user_events():
|
||||
"""Test conversion function with only user events."""
|
||||
events = [
|
||||
build_event("model", [{"text": "Hello"}]),
|
||||
build_event("user", [{"text": "How are you?"}]),
|
||||
]
|
||||
session = Session(id="s1", app_name="app", user_id="u1", events=events)
|
||||
expected = [
|
||||
{
|
||||
"query": "How are you?",
|
||||
"expected_tool_use": [],
|
||||
"expected_intermediate_agent_responses": [],
|
||||
"reference": "",
|
||||
},
|
||||
]
|
||||
assert convert_session_to_eval_format(session) == expected
|
||||
|
||||
|
||||
def test_convert_single_turn_text_only():
|
||||
"""Test a single user query followed by a single agent text response."""
|
||||
events = [
|
||||
build_event("user", [{"text": "What is the time?"}]),
|
||||
build_event("root_agent", [{"text": "It is 3 PM."}]),
|
||||
]
|
||||
session = Session(id="s1", app_name="app", user_id="u1", events=events)
|
||||
expected = [{
|
||||
"query": "What is the time?",
|
||||
"expected_tool_use": [],
|
||||
"expected_intermediate_agent_responses": [],
|
||||
"reference": "It is 3 PM.",
|
||||
}]
|
||||
assert convert_session_to_eval_format(session) == expected
|
||||
|
||||
|
||||
def test_convert_single_turn_tool_only():
|
||||
"""Test a single user query followed by a single agent tool call."""
|
||||
events = [
|
||||
build_event("user", [{"text": "Get weather for Seattle"}]),
|
||||
build_event(
|
||||
"root_agent",
|
||||
[{"func_name": "get_weather", "func_args": {"city": "Seattle"}}],
|
||||
),
|
||||
]
|
||||
session = Session(id="s1", app_name="app", user_id="u1", events=events)
|
||||
expected = [{
|
||||
"query": "Get weather for Seattle",
|
||||
"expected_tool_use": [
|
||||
{"tool_name": "get_weather", "tool_input": {"city": "Seattle"}}
|
||||
],
|
||||
"expected_intermediate_agent_responses": [],
|
||||
"reference": "",
|
||||
}]
|
||||
assert convert_session_to_eval_format(session) == expected
|
||||
|
||||
|
||||
def test_convert_single_turn_multiple_tools_and_texts():
|
||||
"""Test a turn with multiple agent responses (tools and text)."""
|
||||
events = [
|
||||
build_event("user", [{"text": "Do task A then task B"}]),
|
||||
build_event(
|
||||
"root_agent", [{"text": "Okay, starting task A."}]
|
||||
), # Intermediate Text 1
|
||||
build_event(
|
||||
"root_agent", [{"func_name": "task_A", "func_args": {"param": 1}}]
|
||||
), # Tool 1
|
||||
build_event(
|
||||
"root_agent", [{"text": "Task A done. Now starting task B."}]
|
||||
), # Intermediate Text 2
|
||||
build_event(
|
||||
"another_agent", [{"func_name": "task_B", "func_args": {}}]
|
||||
), # Tool 2
|
||||
build_event(
|
||||
"root_agent", [{"text": "All tasks completed."}]
|
||||
), # Final Text (Reference)
|
||||
]
|
||||
session = Session(id="s1", app_name="app", user_id="u1", events=events)
|
||||
expected = [{
|
||||
"query": "Do task A then task B",
|
||||
"expected_tool_use": [
|
||||
{"tool_name": "task_A", "tool_input": {"param": 1}},
|
||||
{"tool_name": "task_B", "tool_input": {}},
|
||||
],
|
||||
"expected_intermediate_agent_responses": [
|
||||
{"author": "root_agent", "text": "Okay, starting task A."},
|
||||
{
|
||||
"author": "root_agent",
|
||||
"text": "Task A done. Now starting task B.",
|
||||
},
|
||||
],
|
||||
"reference": "All tasks completed.",
|
||||
}]
|
||||
assert convert_session_to_eval_format(session) == expected
|
||||
|
||||
|
||||
def test_convert_multi_turn_session():
|
||||
"""Test a session with multiple user/agent turns."""
|
||||
events = [
|
||||
build_event("user", [{"text": "Query 1"}]),
|
||||
build_event("agent", [{"text": "Response 1"}]),
|
||||
build_event("user", [{"text": "Query 2"}]),
|
||||
build_event("agent", [{"func_name": "tool_X", "func_args": {}}]),
|
||||
build_event("agent", [{"text": "Response 2"}]),
|
||||
]
|
||||
session = Session(id="s1", app_name="app", user_id="u1", events=events)
|
||||
expected = [
|
||||
{ # Turn 1
|
||||
"query": "Query 1",
|
||||
"expected_tool_use": [],
|
||||
"expected_intermediate_agent_responses": [],
|
||||
"reference": "Response 1",
|
||||
},
|
||||
{ # Turn 2
|
||||
"query": "Query 2",
|
||||
"expected_tool_use": [{"tool_name": "tool_X", "tool_input": {}}],
|
||||
"expected_intermediate_agent_responses": [],
|
||||
"reference": "Response 2",
|
||||
},
|
||||
]
|
||||
assert convert_session_to_eval_format(session) == expected
|
||||
|
||||
|
||||
def test_convert_agent_event_multiple_parts():
|
||||
"""Test an agent event with both text and tool call parts."""
|
||||
events = [
|
||||
build_event("user", [{"text": "Do something complex"}]),
|
||||
# Build event with multiple dicts in parts_content list
|
||||
build_event(
|
||||
"agent",
|
||||
[
|
||||
{"text": "Okay, doing it."},
|
||||
{"func_name": "complex_tool", "func_args": {"value": True}},
|
||||
],
|
||||
),
|
||||
build_event("agent", [{"text": "Finished."}]),
|
||||
]
|
||||
session = Session(id="s1", app_name="app", user_id="u1", events=events)
|
||||
expected = [{
|
||||
"query": "Do something complex",
|
||||
"expected_tool_use": [
|
||||
{"tool_name": "complex_tool", "tool_input": {"value": True}}
|
||||
],
|
||||
"expected_intermediate_agent_responses": [{
|
||||
"author": "agent",
|
||||
"text": "Okay, doing it.",
|
||||
}], # Text from first part of agent event
|
||||
"reference": "Finished.", # Text from second agent event
|
||||
}]
|
||||
assert convert_session_to_eval_format(session) == expected
|
||||
|
||||
|
||||
def test_convert_handles_missing_content_or_parts():
|
||||
"""Test that events missing content or parts are skipped gracefully."""
|
||||
events = [
|
||||
build_event("user", [{"text": "Query 1"}]),
|
||||
Event(author="agent", content=None), # Agent event missing content
|
||||
build_event("agent", [{"text": "Response 1"}]),
|
||||
Event(author="user", content=None), # User event missing content
|
||||
build_event("user", [{"text": "Query 2"}]),
|
||||
Event(
|
||||
author="agent", content=types.Content(parts=[])
|
||||
), # Agent event with empty parts list
|
||||
build_event("agent", [{"text": "Response 2"}]),
|
||||
# User event with content but no parts (or None parts)
|
||||
Event(author="user", content=types.Content(parts=None)),
|
||||
build_event("user", [{"text": "Query 3"}]),
|
||||
build_event("agent", [{"text": "Response 3"}]),
|
||||
]
|
||||
session = Session(id="s1", app_name="app", user_id="u1", events=events)
|
||||
expected = [
|
||||
{ # Turn 1 (from Query 1)
|
||||
"query": "Query 1",
|
||||
"expected_tool_use": [],
|
||||
"expected_intermediate_agent_responses": [],
|
||||
"reference": "Response 1",
|
||||
},
|
||||
{ # Turn 2 (from Query 2 - user event with None content was skipped)
|
||||
"query": "Query 2",
|
||||
"expected_tool_use": [],
|
||||
"expected_intermediate_agent_responses": [],
|
||||
"reference": "Response 2",
|
||||
},
|
||||
{ # Turn 3 (from Query 3 - user event with None parts was skipped)
|
||||
"query": "Query 3",
|
||||
"expected_tool_use": [],
|
||||
"expected_intermediate_agent_responses": [],
|
||||
"reference": "Response 3",
|
||||
},
|
||||
]
|
||||
assert convert_session_to_eval_format(session) == expected
|
||||
|
||||
|
||||
def test_convert_handles_missing_tool_name_or_args():
|
||||
"""Test tool calls with missing name or args."""
|
||||
events = [
|
||||
build_event("user", [{"text": "Call tools"}]),
|
||||
# Event where FunctionCall has name=None
|
||||
Event(
|
||||
author="agent",
|
||||
content=types.Content(
|
||||
parts=[
|
||||
types.Part(
|
||||
function_call=types.FunctionCall(name=None, args={"a": 1})
|
||||
)
|
||||
]
|
||||
),
|
||||
),
|
||||
# Event where FunctionCall has args=None
|
||||
Event(
|
||||
author="agent",
|
||||
content=types.Content(
|
||||
parts=[
|
||||
types.Part(
|
||||
function_call=types.FunctionCall(name="tool_B", args=None)
|
||||
)
|
||||
]
|
||||
),
|
||||
),
|
||||
# Event where FunctionCall part exists but FunctionCall object is None
|
||||
# (should skip)
|
||||
Event(
|
||||
author="agent",
|
||||
content=types.Content(
|
||||
parts=[types.Part(function_call=None, text="some text")]
|
||||
),
|
||||
),
|
||||
build_event("agent", [{"text": "Done"}]),
|
||||
]
|
||||
session = Session(id="s1", app_name="app", user_id="u1", events=events)
|
||||
expected = [{
|
||||
"query": "Call tools",
|
||||
"expected_tool_use": [
|
||||
{"tool_name": "", "tool_input": {"a": 1}}, # Defaults name to ""
|
||||
{"tool_name": "tool_B", "tool_input": {}}, # Defaults args to {}
|
||||
],
|
||||
"expected_intermediate_agent_responses": [{
|
||||
"author": "agent",
|
||||
"text": "some text",
|
||||
}], # Text part from the event where function_call was None
|
||||
"reference": "Done",
|
||||
}]
|
||||
assert convert_session_to_eval_format(session) == expected
|
||||
|
||||
|
||||
def test_convert_handles_missing_user_query_text():
|
||||
"""Test user event where the first part has no text."""
|
||||
events = [
|
||||
# Event where user part has text=None
|
||||
Event(
|
||||
author="user", content=types.Content(parts=[types.Part(text=None)])
|
||||
),
|
||||
build_event("agent", [{"text": "Response 1"}]),
|
||||
# Event where user part has text=""
|
||||
build_event("user", [{"text": ""}]),
|
||||
build_event("agent", [{"text": "Response 2"}]),
|
||||
]
|
||||
session = Session(id="s1", app_name="app", user_id="u1", events=events)
|
||||
expected = [
|
||||
{
|
||||
"query": "", # Defaults to "" if text is None
|
||||
"expected_tool_use": [],
|
||||
"expected_intermediate_agent_responses": [],
|
||||
"reference": "Response 1",
|
||||
},
|
||||
{
|
||||
"query": "", # Defaults to "" if text is ""
|
||||
"expected_tool_use": [],
|
||||
"expected_intermediate_agent_responses": [],
|
||||
"reference": "Response 2",
|
||||
},
|
||||
]
|
||||
assert convert_session_to_eval_format(session) == expected
|
||||
|
||||
|
||||
def test_convert_handles_empty_agent_text():
|
||||
"""Test agent responses with empty string text."""
|
||||
events = [
|
||||
build_event("user", [{"text": "Query"}]),
|
||||
build_event("agent", [{"text": "Okay"}]),
|
||||
build_event("agent", [{"text": ""}]), # Empty text
|
||||
build_event("agent", [{"text": "Done"}]),
|
||||
]
|
||||
session = Session(id="s1", app_name="app", user_id="u1", events=events)
|
||||
expected = [{
|
||||
"query": "Query",
|
||||
"expected_tool_use": [],
|
||||
"expected_intermediate_agent_responses": [
|
||||
{"author": "agent", "text": "Okay"},
|
||||
],
|
||||
"reference": "Done",
|
||||
}]
|
||||
assert convert_session_to_eval_format(session) == expected
|
||||
|
||||
|
||||
def test_convert_complex_sample_session():
|
||||
"""Test using the complex sample session provided earlier."""
|
||||
events = [
|
||||
build_event("user", [{"text": "What can you do?"}]),
|
||||
build_event(
|
||||
"root_agent",
|
||||
[{"text": "I can roll dice and check if numbers are prime. \n"}],
|
||||
),
|
||||
build_event(
|
||||
"user",
|
||||
[{
|
||||
"text": (
|
||||
"Roll a 8 sided dice and then check if 90 is a prime number"
|
||||
" or not."
|
||||
)
|
||||
}],
|
||||
),
|
||||
build_event(
|
||||
"root_agent",
|
||||
[{
|
||||
"func_name": "transfer_to_agent",
|
||||
"func_args": {"agent_name": "roll_agent"},
|
||||
}],
|
||||
),
|
||||
# Skipping FunctionResponse events as they don't have text/functionCall
|
||||
# parts used by converter
|
||||
build_event(
|
||||
"roll_agent", [{"func_name": "roll_die", "func_args": {"sides": 8}}]
|
||||
),
|
||||
# Skipping FunctionResponse
|
||||
build_event(
|
||||
"roll_agent",
|
||||
[
|
||||
{"text": "I rolled a 2. Now, I'll check if 90 is prime. \n\n"},
|
||||
{
|
||||
"func_name": "transfer_to_agent",
|
||||
"func_args": {"agent_name": "prime_agent"},
|
||||
},
|
||||
],
|
||||
),
|
||||
# Skipping FunctionResponse
|
||||
build_event(
|
||||
"prime_agent",
|
||||
[{"func_name": "check_prime", "func_args": {"nums": [90]}}],
|
||||
),
|
||||
# Skipping FunctionResponse
|
||||
build_event("prime_agent", [{"text": "90 is not a prime number. \n"}]),
|
||||
]
|
||||
session = Session(
|
||||
id="some_id",
|
||||
app_name="hello_world_ma",
|
||||
user_id="user",
|
||||
events=events,
|
||||
)
|
||||
expected = [
|
||||
{
|
||||
"query": "What can you do?",
|
||||
"expected_tool_use": [],
|
||||
"expected_intermediate_agent_responses": [],
|
||||
"reference": "I can roll dice and check if numbers are prime. \n",
|
||||
},
|
||||
{
|
||||
"query": (
|
||||
"Roll a 8 sided dice and then check if 90 is a prime number or"
|
||||
" not."
|
||||
),
|
||||
"expected_tool_use": [
|
||||
{
|
||||
"tool_name": "transfer_to_agent",
|
||||
"tool_input": {"agent_name": "roll_agent"},
|
||||
},
|
||||
{"tool_name": "roll_die", "tool_input": {"sides": 8}},
|
||||
{
|
||||
"tool_name": "transfer_to_agent",
|
||||
"tool_input": {"agent_name": "prime_agent"},
|
||||
}, # From combined event
|
||||
{"tool_name": "check_prime", "tool_input": {"nums": [90]}},
|
||||
],
|
||||
"expected_intermediate_agent_responses": [{
|
||||
"author": "roll_agent",
|
||||
"text": "I rolled a 2. Now, I'll check if 90 is prime. \n\n",
|
||||
}], # Text from combined event
|
||||
"reference": "90 is not a prime number. \n",
|
||||
},
|
||||
]
|
||||
|
||||
actual = convert_session_to_eval_format(session)
|
||||
assert actual == expected
|
||||
13
tests/unittests/evaluation/__init__.py
Normal file
13
tests/unittests/evaluation/__init__.py
Normal file
@@ -0,0 +1,13 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
259
tests/unittests/evaluation/test_response_evaluator.py
Normal file
259
tests/unittests/evaluation/test_response_evaluator.py
Normal file
@@ -0,0 +1,259 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Tests for the Response Evaluator."""
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
|
||||
from google.adk.evaluation.response_evaluator import ResponseEvaluator
|
||||
import pandas as pd
|
||||
import pytest
|
||||
from vertexai.preview.evaluation import MetricPromptTemplateExamples
|
||||
|
||||
# Mock object for the result normally returned by _perform_eval
|
||||
MOCK_EVAL_RESULT = MagicMock()
|
||||
MOCK_EVAL_RESULT.summary_metrics = {"mock_metric": 0.75, "another_mock": 3.5}
|
||||
# Add a metrics_table for testing _print_results interaction
|
||||
MOCK_EVAL_RESULT.metrics_table = pd.DataFrame({
|
||||
"prompt": ["mock_query1"],
|
||||
"response": ["mock_resp1"],
|
||||
"mock_metric": [0.75],
|
||||
})
|
||||
|
||||
SAMPLE_TURN_1_ALL_KEYS = {
|
||||
"query": "query1",
|
||||
"response": "response1",
|
||||
"actual_tool_use": [{"tool_name": "tool_a", "tool_input": {}}],
|
||||
"expected_tool_use": [{"tool_name": "tool_a", "tool_input": {}}],
|
||||
"reference": "reference1",
|
||||
}
|
||||
SAMPLE_TURN_2_MISSING_REF = {
|
||||
"query": "query2",
|
||||
"response": "response2",
|
||||
"actual_tool_use": [],
|
||||
"expected_tool_use": [],
|
||||
# "reference": "reference2" # Missing
|
||||
}
|
||||
SAMPLE_TURN_3_MISSING_EXP_TOOLS = {
|
||||
"query": "query3",
|
||||
"response": "response3",
|
||||
"actual_tool_use": [{"tool_name": "tool_b", "tool_input": {}}],
|
||||
# "expected_tool_use": [], # Missing
|
||||
"reference": "reference3",
|
||||
}
|
||||
SAMPLE_TURN_4_MINIMAL = {
|
||||
"query": "query4",
|
||||
"response": "response4",
|
||||
# Minimal keys, others missing
|
||||
}
|
||||
|
||||
|
||||
@patch(
|
||||
"google.adk.evaluation.response_evaluator.ResponseEvaluator._perform_eval"
|
||||
)
|
||||
class TestResponseEvaluator:
|
||||
"""A class to help organize "patch" that are applicabple to all tests."""
|
||||
|
||||
def test_evaluate_none_dataset_raises_value_error(self, mock_perform_eval):
|
||||
"""Test evaluate function raises ValueError for an empty list."""
|
||||
with pytest.raises(ValueError, match="The evaluation dataset is empty."):
|
||||
ResponseEvaluator.evaluate(None, ["response_evaluation_score"])
|
||||
mock_perform_eval.assert_not_called() # Ensure _perform_eval was not called
|
||||
|
||||
def test_evaluate_empty_dataset_raises_value_error(self, mock_perform_eval):
|
||||
"""Test evaluate function raises ValueError for an empty list."""
|
||||
with pytest.raises(ValueError, match="The evaluation dataset is empty."):
|
||||
ResponseEvaluator.evaluate([], ["response_evaluation_score"])
|
||||
mock_perform_eval.assert_not_called() # Ensure _perform_eval was not called
|
||||
|
||||
def test_evaluate_determines_metrics_correctly_for_perform_eval(
|
||||
self, mock_perform_eval
|
||||
):
|
||||
"""Test that the correct metrics list is passed to _perform_eval based on criteria/keys."""
|
||||
mock_perform_eval.return_value = MOCK_EVAL_RESULT
|
||||
|
||||
# Test case 1: Only Coherence
|
||||
raw_data_1 = [[SAMPLE_TURN_1_ALL_KEYS]]
|
||||
criteria_1 = ["response_evaluation_score"]
|
||||
ResponseEvaluator.evaluate(raw_data_1, criteria_1)
|
||||
_, kwargs = mock_perform_eval.call_args
|
||||
assert kwargs["metrics"] == [
|
||||
MetricPromptTemplateExamples.Pointwise.COHERENCE
|
||||
]
|
||||
mock_perform_eval.reset_mock() # Reset mock for next call
|
||||
|
||||
# Test case 2: Only Rouge
|
||||
raw_data_2 = [[SAMPLE_TURN_1_ALL_KEYS]]
|
||||
criteria_2 = ["response_match_score"]
|
||||
ResponseEvaluator.evaluate(raw_data_2, criteria_2)
|
||||
_, kwargs = mock_perform_eval.call_args
|
||||
assert kwargs["metrics"] == ["rouge_1"]
|
||||
mock_perform_eval.reset_mock()
|
||||
|
||||
# Test case 3: No metrics if keys missing in first turn
|
||||
raw_data_3 = [[SAMPLE_TURN_4_MINIMAL, SAMPLE_TURN_1_ALL_KEYS]]
|
||||
criteria_3 = ["response_evaluation_score", "response_match_score"]
|
||||
ResponseEvaluator.evaluate(raw_data_3, criteria_3)
|
||||
_, kwargs = mock_perform_eval.call_args
|
||||
assert kwargs["metrics"] == []
|
||||
mock_perform_eval.reset_mock()
|
||||
|
||||
# Test case 4: No metrics if criteria empty
|
||||
raw_data_4 = [[SAMPLE_TURN_1_ALL_KEYS]]
|
||||
criteria_4 = []
|
||||
ResponseEvaluator.evaluate(raw_data_4, criteria_4)
|
||||
_, kwargs = mock_perform_eval.call_args
|
||||
assert kwargs["metrics"] == []
|
||||
mock_perform_eval.reset_mock()
|
||||
|
||||
def test_evaluate_calls_perform_eval_correctly_all_metrics(
|
||||
self, mock_perform_eval
|
||||
):
|
||||
"""Test evaluate function calls _perform_eval with expected args when all criteria/keys are present."""
|
||||
# Arrange
|
||||
mock_perform_eval.return_value = (
|
||||
MOCK_EVAL_RESULT # Configure the mock return value
|
||||
)
|
||||
|
||||
raw_data = [[SAMPLE_TURN_1_ALL_KEYS]]
|
||||
criteria = ["response_evaluation_score", "response_match_score"]
|
||||
|
||||
# Act
|
||||
summary = ResponseEvaluator.evaluate(raw_data, criteria)
|
||||
|
||||
# Assert
|
||||
# 1. Check metrics determined by _get_metrics (passed to _perform_eval)
|
||||
expected_metrics_list = [
|
||||
MetricPromptTemplateExamples.Pointwise.COHERENCE,
|
||||
"rouge_1",
|
||||
]
|
||||
# 2. Check DataFrame prepared (passed to _perform_eval)
|
||||
expected_df_data = [{
|
||||
"prompt": "query1",
|
||||
"response": "response1",
|
||||
"actual_tool_use": [{"tool_name": "tool_a", "tool_input": {}}],
|
||||
"reference_trajectory": [{"tool_name": "tool_a", "tool_input": {}}],
|
||||
"reference": "reference1",
|
||||
}]
|
||||
expected_df = pd.DataFrame(expected_df_data)
|
||||
|
||||
# Assert _perform_eval was called once
|
||||
mock_perform_eval.assert_called_once()
|
||||
# Get the arguments passed to the mocked _perform_eval
|
||||
_, kwargs = mock_perform_eval.call_args
|
||||
# Check the 'dataset' keyword argument
|
||||
pd.testing.assert_frame_equal(kwargs["dataset"], expected_df)
|
||||
# Check the 'metrics' keyword argument
|
||||
assert kwargs["metrics"] == expected_metrics_list
|
||||
|
||||
# 3. Check the correct summary metrics are returned
|
||||
# (from mock_perform_eval's return value)
|
||||
assert summary == MOCK_EVAL_RESULT.summary_metrics
|
||||
|
||||
def test_evaluate_prepares_dataframe_correctly_for_perform_eval(
|
||||
self, mock_perform_eval
|
||||
):
|
||||
"""Test that the DataFrame is correctly flattened and renamed before passing to _perform_eval."""
|
||||
mock_perform_eval.return_value = MOCK_EVAL_RESULT
|
||||
|
||||
raw_data = [
|
||||
[SAMPLE_TURN_1_ALL_KEYS], # Conversation 1
|
||||
[
|
||||
SAMPLE_TURN_2_MISSING_REF,
|
||||
SAMPLE_TURN_3_MISSING_EXP_TOOLS,
|
||||
], # Conversation 2
|
||||
]
|
||||
criteria = [
|
||||
"response_match_score"
|
||||
] # Doesn't affect the DataFrame structure
|
||||
|
||||
ResponseEvaluator.evaluate(raw_data, criteria)
|
||||
|
||||
# Expected flattened and renamed data
|
||||
expected_df_data = [
|
||||
# Turn 1 (from SAMPLE_TURN_1_ALL_KEYS)
|
||||
{
|
||||
"prompt": "query1",
|
||||
"response": "response1",
|
||||
"actual_tool_use": [{"tool_name": "tool_a", "tool_input": {}}],
|
||||
"reference_trajectory": [{"tool_name": "tool_a", "tool_input": {}}],
|
||||
"reference": "reference1",
|
||||
},
|
||||
# Turn 2 (from SAMPLE_TURN_2_MISSING_REF)
|
||||
{
|
||||
"prompt": "query2",
|
||||
"response": "response2",
|
||||
"actual_tool_use": [],
|
||||
"reference_trajectory": [],
|
||||
# "reference": None # Missing key results in NaN in DataFrame
|
||||
# usually
|
||||
},
|
||||
# Turn 3 (from SAMPLE_TURN_3_MISSING_EXP_TOOLS)
|
||||
{
|
||||
"prompt": "query3",
|
||||
"response": "response3",
|
||||
"actual_tool_use": [{"tool_name": "tool_b", "tool_input": {}}],
|
||||
# "reference_trajectory": None, # Missing key results in NaN
|
||||
"reference": "reference3",
|
||||
},
|
||||
]
|
||||
# Need to be careful with missing keys -> NaN when creating DataFrame
|
||||
# Pandas handles this automatically when creating from list of dicts
|
||||
expected_df = pd.DataFrame(expected_df_data)
|
||||
|
||||
mock_perform_eval.assert_called_once()
|
||||
_, kwargs = mock_perform_eval.call_args
|
||||
# Compare the DataFrame passed to the mock
|
||||
pd.testing.assert_frame_equal(kwargs["dataset"], expected_df)
|
||||
|
||||
@patch(
|
||||
"google.adk.evaluation.response_evaluator.ResponseEvaluator._print_results"
|
||||
) # Mock the private print method
|
||||
def test_evaluate_print_detailed_results(
|
||||
self, mock_print_results, mock_perform_eval
|
||||
):
|
||||
"""Test _print_results function is called when print_detailed_results=True."""
|
||||
mock_perform_eval.return_value = (
|
||||
MOCK_EVAL_RESULT # Ensure _perform_eval returns our mock result
|
||||
)
|
||||
|
||||
raw_data = [[SAMPLE_TURN_1_ALL_KEYS]]
|
||||
criteria = ["response_match_score"]
|
||||
|
||||
ResponseEvaluator.evaluate(raw_data, criteria, print_detailed_results=True)
|
||||
|
||||
# Assert _perform_eval was called
|
||||
mock_perform_eval.assert_called_once()
|
||||
# Assert _print_results was called once with the result object
|
||||
# from _perform_eval
|
||||
mock_print_results.assert_called_once_with(MOCK_EVAL_RESULT)
|
||||
|
||||
@patch(
|
||||
"google.adk.evaluation.response_evaluator.ResponseEvaluator._print_results"
|
||||
)
|
||||
def test_evaluate_no_print_detailed_results(
|
||||
self, mock_print_results, mock_perform_eval
|
||||
):
|
||||
"""Test _print_results function is NOT called when print_detailed_results=False (default)."""
|
||||
mock_perform_eval.return_value = MOCK_EVAL_RESULT
|
||||
|
||||
raw_data = [[SAMPLE_TURN_1_ALL_KEYS]]
|
||||
criteria = ["response_match_score"]
|
||||
|
||||
ResponseEvaluator.evaluate(raw_data, criteria, print_detailed_results=False)
|
||||
|
||||
# Assert _perform_eval was called
|
||||
mock_perform_eval.assert_called_once()
|
||||
# Assert _print_results was NOT called
|
||||
mock_print_results.assert_not_called()
|
||||
271
tests/unittests/evaluation/test_trajectory_evaluator.py
Normal file
271
tests/unittests/evaluation/test_trajectory_evaluator.py
Normal file
@@ -0,0 +1,271 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Testings for the Trajectory Evaluator."""
|
||||
|
||||
import math
|
||||
from google.adk.evaluation.trajectory_evaluator import TrajectoryEvaluator
|
||||
import pytest
|
||||
|
||||
# Define reusable tool call structures
|
||||
TOOL_ROLL_DICE_16 = {"tool_name": "roll_die", "tool_input": {"sides": 16}}
|
||||
TOOL_ROLL_DICE_6 = {"tool_name": "roll_die", "tool_input": {"sides": 6}}
|
||||
TOOL_GET_WEATHER = {
|
||||
"tool_name": "get_weather",
|
||||
"tool_input": {"location": "Paris"},
|
||||
}
|
||||
TOOL_GET_WEATHER_SF = {
|
||||
"tool_name": "get_weather",
|
||||
"tool_input": {"location": "SF"},
|
||||
}
|
||||
|
||||
# Sample data for turns
|
||||
TURN_MATCH = {
|
||||
"query": "Q1",
|
||||
"response": "R1",
|
||||
"actual_tool_use": [TOOL_ROLL_DICE_16],
|
||||
"expected_tool_use": [TOOL_ROLL_DICE_16],
|
||||
}
|
||||
TURN_MISMATCH_INPUT = {
|
||||
"query": "Q2",
|
||||
"response": "R2",
|
||||
"actual_tool_use": [TOOL_ROLL_DICE_6],
|
||||
"expected_tool_use": [TOOL_ROLL_DICE_16],
|
||||
}
|
||||
TURN_MISMATCH_NAME = {
|
||||
"query": "Q3",
|
||||
"response": "R3",
|
||||
"actual_tool_use": [TOOL_GET_WEATHER],
|
||||
"expected_tool_use": [TOOL_ROLL_DICE_16],
|
||||
}
|
||||
TURN_MATCH_MULTIPLE = {
|
||||
"query": "Q4",
|
||||
"response": "R4",
|
||||
"actual_tool_use": [TOOL_GET_WEATHER, TOOL_ROLL_DICE_6],
|
||||
"expected_tool_use": [TOOL_GET_WEATHER, TOOL_ROLL_DICE_6],
|
||||
}
|
||||
TURN_MISMATCH_ORDER = {
|
||||
"query": "Q5",
|
||||
"response": "R5",
|
||||
"actual_tool_use": [TOOL_ROLL_DICE_6, TOOL_GET_WEATHER],
|
||||
"expected_tool_use": [TOOL_GET_WEATHER, TOOL_ROLL_DICE_6],
|
||||
}
|
||||
TURN_MISMATCH_LENGTH_ACTUAL_LONGER = {
|
||||
"query": "Q6",
|
||||
"response": "R6",
|
||||
"actual_tool_use": [TOOL_GET_WEATHER, TOOL_ROLL_DICE_6],
|
||||
"expected_tool_use": [TOOL_GET_WEATHER],
|
||||
}
|
||||
TURN_MISMATCH_LENGTH_EXPECTED_LONGER = {
|
||||
"query": "Q7",
|
||||
"response": "R7",
|
||||
"actual_tool_use": [TOOL_GET_WEATHER],
|
||||
"expected_tool_use": [TOOL_GET_WEATHER, TOOL_ROLL_DICE_6],
|
||||
}
|
||||
TURN_MATCH_WITH_MOCK_OUTPUT = {
|
||||
"query": "Q8",
|
||||
"response": "R8",
|
||||
"actual_tool_use": [TOOL_GET_WEATHER_SF],
|
||||
"expected_tool_use": [
|
||||
{**TOOL_GET_WEATHER_SF, "mock_tool_output": "Sunny"}
|
||||
], # Add mock output to expected
|
||||
}
|
||||
TURN_MATCH_EMPTY_TOOLS = {
|
||||
"query": "Q9",
|
||||
"response": "R9",
|
||||
"actual_tool_use": [],
|
||||
"expected_tool_use": [],
|
||||
}
|
||||
TURN_MISMATCH_EMPTY_VS_NONEMPTY = {
|
||||
"query": "Q10",
|
||||
"response": "R10",
|
||||
"actual_tool_use": [],
|
||||
"expected_tool_use": [TOOL_GET_WEATHER],
|
||||
}
|
||||
|
||||
|
||||
def test_evaluate_none_dataset_raises_value_error():
|
||||
"""Tests evaluate function raises ValueError for an empty list."""
|
||||
with pytest.raises(ValueError, match="The evaluation dataset is empty."):
|
||||
TrajectoryEvaluator.evaluate(None)
|
||||
|
||||
|
||||
def test_evaluate_empty_dataset_raises_value_error():
|
||||
"""Tests evaluate function raises ValueError for an empty list."""
|
||||
with pytest.raises(ValueError, match="The evaluation dataset is empty."):
|
||||
TrajectoryEvaluator.evaluate([])
|
||||
|
||||
|
||||
def test_evaluate_single_turn_match():
|
||||
"""Tests evaluate function with one conversation, one turn, perfect match."""
|
||||
eval_dataset = [[TURN_MATCH]]
|
||||
assert TrajectoryEvaluator.evaluate(eval_dataset) == 1.0
|
||||
|
||||
|
||||
def test_evaluate_single_turn_mismatch():
|
||||
"""Tests evaluate function with one conversation, one turn, mismatch."""
|
||||
eval_dataset = [[TURN_MISMATCH_INPUT]]
|
||||
assert TrajectoryEvaluator.evaluate(eval_dataset) == 0.0
|
||||
|
||||
|
||||
def test_evaluate_multiple_turns_all_match():
|
||||
"""Tests evaluate function with one conversation, multiple turns, all match."""
|
||||
eval_dataset = [[TURN_MATCH, TURN_MATCH_MULTIPLE, TURN_MATCH_EMPTY_TOOLS]]
|
||||
assert TrajectoryEvaluator.evaluate(eval_dataset) == 1.0
|
||||
|
||||
|
||||
def test_evaluate_multiple_turns_mixed():
|
||||
"""Tests evaluate function with one conversation, mixed match/mismatch turns."""
|
||||
eval_dataset = [
|
||||
[TURN_MATCH, TURN_MISMATCH_NAME, TURN_MATCH_MULTIPLE, TURN_MISMATCH_ORDER]
|
||||
]
|
||||
# Expected: (1.0 + 0.0 + 1.0 + 0.0) / 4 = 0.5
|
||||
assert TrajectoryEvaluator.evaluate(eval_dataset) == 0.5
|
||||
|
||||
|
||||
def test_evaluate_multiple_conversations_mixed():
|
||||
"""Tests evaluate function with multiple conversations, mixed turns."""
|
||||
eval_dataset = [
|
||||
[TURN_MATCH, TURN_MISMATCH_INPUT], # Conv 1: 1.0, 0.0 -> Avg 0.5
|
||||
[TURN_MATCH_MULTIPLE], # Conv 2: 1.0 -> Avg 1.0
|
||||
[
|
||||
TURN_MISMATCH_ORDER,
|
||||
TURN_MISMATCH_LENGTH_ACTUAL_LONGER,
|
||||
TURN_MATCH,
|
||||
], # Conv 3: 0.0, 0.0, 1.0 -> Avg 1/3
|
||||
]
|
||||
# Expected: (1.0 + 0.0 + 1.0 + 0.0 + 0.0 + 1.0) / 6 = 3.0 / 6 = 0.5
|
||||
assert TrajectoryEvaluator.evaluate(eval_dataset) == 0.5
|
||||
|
||||
|
||||
def test_evaluate_ignores_mock_tool_output_in_expected():
|
||||
"""Tests evaluate function correctly compares even if expected has mock_tool_output."""
|
||||
eval_dataset = [[TURN_MATCH_WITH_MOCK_OUTPUT]]
|
||||
assert TrajectoryEvaluator.evaluate(eval_dataset) == 1.0
|
||||
|
||||
|
||||
def test_evaluate_match_empty_tool_lists():
|
||||
"""Tests evaluate function correctly matches empty tool lists."""
|
||||
eval_dataset = [[TURN_MATCH_EMPTY_TOOLS]]
|
||||
assert TrajectoryEvaluator.evaluate(eval_dataset) == 1.0
|
||||
|
||||
|
||||
def test_evaluate_mismatch_empty_vs_nonempty():
|
||||
"""Tests evaluate function correctly mismatches empty vs non-empty tool lists."""
|
||||
eval_dataset = [[TURN_MISMATCH_EMPTY_VS_NONEMPTY]]
|
||||
assert TrajectoryEvaluator.evaluate(eval_dataset) == 0.0
|
||||
eval_dataset_rev = [[{
|
||||
**TURN_MISMATCH_EMPTY_VS_NONEMPTY, # Swap actual/expected
|
||||
"actual_tool_use": [TOOL_GET_WEATHER],
|
||||
"expected_tool_use": [],
|
||||
}]]
|
||||
assert TrajectoryEvaluator.evaluate(eval_dataset_rev) == 0.0
|
||||
|
||||
|
||||
def test_evaluate_dataset_with_empty_conversation():
|
||||
"""Tests evaluate function handles dataset containing an empty conversation list."""
|
||||
eval_dataset = [[TURN_MATCH], []] # One valid conversation, one empty
|
||||
# Should only evaluate the first conversation -> 1.0 / 1 turn = 1.0
|
||||
assert TrajectoryEvaluator.evaluate(eval_dataset) == 1.0
|
||||
|
||||
|
||||
def test_evaluate_dataset_only_empty_conversation():
|
||||
"""Tests evaluate function handles dataset with only an empty conversation."""
|
||||
eval_dataset = [[]]
|
||||
# No rows evaluated, mean of empty series is NaN
|
||||
# Depending on desired behavior, this could be 0.0 or NaN. The code returns
|
||||
# NaN.
|
||||
assert math.isnan(TrajectoryEvaluator.evaluate(eval_dataset))
|
||||
|
||||
|
||||
def test_evaluate_print_detailed_results(capsys):
|
||||
"""Tests evaluate function runs with print_detailed_results=True and prints something."""
|
||||
eval_dataset = [[TURN_MATCH, TURN_MISMATCH_INPUT]]
|
||||
TrajectoryEvaluator.evaluate(eval_dataset, print_detailed_results=True)
|
||||
captured = capsys.readouterr()
|
||||
assert "query" in captured.out # Check if the results table header is printed
|
||||
assert "R1" in captured.out # Check if some data is printed
|
||||
assert "Failures:" in captured.out # Check if failures header is printed
|
||||
assert "Q2" in captured.out # Check if the failing query is printed
|
||||
|
||||
|
||||
def test_evaluate_no_failures_print(capsys):
|
||||
"""Tests evaluate function does not print Failures section when all turns match."""
|
||||
eval_dataset = [[TURN_MATCH]]
|
||||
TrajectoryEvaluator.evaluate(eval_dataset, print_detailed_results=True)
|
||||
captured = capsys.readouterr()
|
||||
assert "query" in captured.out # Results table should still print
|
||||
assert "Failures:" not in captured.out # Failures section should NOT print
|
||||
|
||||
|
||||
def test_are_tools_equal_identical():
|
||||
"""Tests are_tools_equal function with identical lists."""
|
||||
list_a = [TOOL_GET_WEATHER, TOOL_ROLL_DICE_6]
|
||||
list_b = [TOOL_GET_WEATHER, TOOL_ROLL_DICE_6]
|
||||
assert TrajectoryEvaluator.are_tools_equal(list_a, list_b)
|
||||
|
||||
|
||||
def test_are_tools_equal_empty():
|
||||
"""Tests are_tools_equal function with empty lists."""
|
||||
assert TrajectoryEvaluator.are_tools_equal([], [])
|
||||
|
||||
|
||||
def test_are_tools_equal_different_order():
|
||||
"""Tests are_tools_equal function with same tools, different order."""
|
||||
list_a = [TOOL_ROLL_DICE_6, TOOL_GET_WEATHER]
|
||||
list_b = [TOOL_GET_WEATHER, TOOL_ROLL_DICE_6]
|
||||
assert not TrajectoryEvaluator.are_tools_equal(list_a, list_b)
|
||||
|
||||
|
||||
def test_are_tools_equal_different_length():
|
||||
"""Tests are_tools_equal function with lists of different lengths."""
|
||||
list_a = [TOOL_GET_WEATHER, TOOL_ROLL_DICE_6]
|
||||
list_b = [TOOL_GET_WEATHER]
|
||||
assert not TrajectoryEvaluator.are_tools_equal(list_a, list_b)
|
||||
|
||||
|
||||
def test_are_tools_equal_different_input_values():
|
||||
"""Tests are_tools_equal function with different input values."""
|
||||
list_a = [TOOL_ROLL_DICE_16]
|
||||
list_b = [TOOL_ROLL_DICE_6]
|
||||
assert not TrajectoryEvaluator.are_tools_equal(list_a, list_b)
|
||||
|
||||
|
||||
def test_are_tools_equal_different_tool_names():
|
||||
"""Tests are_tools_equal function with different tool names."""
|
||||
list_a = [TOOL_ROLL_DICE_16]
|
||||
list_b = [TOOL_GET_WEATHER]
|
||||
assert not TrajectoryEvaluator.are_tools_equal(list_a, list_b)
|
||||
|
||||
|
||||
def test_are_tools_equal_ignores_extra_keys():
|
||||
"""Tests are_tools_equal function ignores keys other than tool_name/tool_input."""
|
||||
list_a = [{
|
||||
"tool_name": "get_weather",
|
||||
"tool_input": {"location": "Paris"},
|
||||
"extra_key": "abc",
|
||||
}]
|
||||
list_b = [{
|
||||
"tool_name": "get_weather",
|
||||
"tool_input": {"location": "Paris"},
|
||||
"other_key": 123,
|
||||
}]
|
||||
assert TrajectoryEvaluator.are_tools_equal(list_a, list_b)
|
||||
|
||||
|
||||
def test_are_tools_equal_one_empty_one_not():
|
||||
"""Tests are_tools_equal function with one empty list and one non-empty list."""
|
||||
list_a = []
|
||||
list_b = [TOOL_GET_WEATHER]
|
||||
assert not TrajectoryEvaluator.are_tools_equal(list_a, list_b)
|
||||
@@ -225,3 +225,76 @@ def test_create_new_session_will_merge_states(service_type):
|
||||
assert session_2.state.get('user:key1') == 'value1'
|
||||
assert not session_2.state.get('key1')
|
||||
assert not session_2.state.get('temp:key')
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
'service_type', [SessionServiceType.IN_MEMORY, SessionServiceType.DATABASE]
|
||||
)
|
||||
def test_append_event_bytes(service_type):
|
||||
session_service = get_session_service(service_type)
|
||||
app_name = 'my_app'
|
||||
user_id = 'user'
|
||||
|
||||
session = session_service.create_session(app_name=app_name, user_id=user_id)
|
||||
event = Event(
|
||||
invocation_id='invocation',
|
||||
author='user',
|
||||
content=types.Content(
|
||||
role='user',
|
||||
parts=[
|
||||
types.Part.from_bytes(
|
||||
data=b'test_image_data', mime_type='image/png'
|
||||
),
|
||||
],
|
||||
),
|
||||
)
|
||||
session_service.append_event(session=session, event=event)
|
||||
|
||||
assert session.events[0].content.parts[0] == types.Part.from_bytes(
|
||||
data=b'test_image_data', mime_type='image/png'
|
||||
)
|
||||
|
||||
events = session_service.get_session(
|
||||
app_name=app_name, user_id=user_id, session_id=session.id
|
||||
).events
|
||||
assert len(events) == 1
|
||||
assert events[0].content.parts[0] == types.Part.from_bytes(
|
||||
data=b'test_image_data', mime_type='image/png'
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
'service_type', [SessionServiceType.IN_MEMORY, SessionServiceType.DATABASE]
|
||||
)
|
||||
def test_append_event_complete(service_type):
|
||||
session_service = get_session_service(service_type)
|
||||
app_name = 'my_app'
|
||||
user_id = 'user'
|
||||
|
||||
session = session_service.create_session(app_name=app_name, user_id=user_id)
|
||||
event = Event(
|
||||
invocation_id='invocation',
|
||||
author='user',
|
||||
content=types.Content(role='user', parts=[types.Part(text='test_text')]),
|
||||
turn_complete=True,
|
||||
partial=False,
|
||||
actions=EventActions(
|
||||
artifact_delta={
|
||||
'file': 0,
|
||||
},
|
||||
transfer_to_agent='agent',
|
||||
escalate=True,
|
||||
),
|
||||
long_running_tool_ids={'tool1'},
|
||||
error_code='error_code',
|
||||
error_message='error_message',
|
||||
interrupted=True,
|
||||
)
|
||||
session_service.append_event(session=session, event=event)
|
||||
|
||||
assert (
|
||||
session_service.get_session(
|
||||
app_name=app_name, user_id=user_id, session_id=session.id
|
||||
)
|
||||
== session
|
||||
)
|
||||
|
||||
@@ -57,7 +57,7 @@ MOCK_EVENT_JSON = [
|
||||
{
|
||||
'name': (
|
||||
'projects/test-project/locations/test-location/'
|
||||
'reasoningEngines/test_engine/sessions/1/events/123'
|
||||
'reasoningEngines/123/sessions/1/events/123'
|
||||
),
|
||||
'invocationId': '123',
|
||||
'author': 'user',
|
||||
@@ -111,7 +111,7 @@ MOCK_SESSION = Session(
|
||||
|
||||
|
||||
SESSION_REGEX = r'^reasoningEngines/([^/]+)/sessions/([^/]+)$'
|
||||
SESSIONS_REGEX = r'^reasoningEngines/([^/]+)/sessions$'
|
||||
SESSIONS_REGEX = r'^reasoningEngines/([^/]+)/sessions\?filter=user_id=([^/]+)$'
|
||||
EVENTS_REGEX = r'^reasoningEngines/([^/]+)/sessions/([^/]+)/events$'
|
||||
LRO_REGEX = r'^operations/([^/]+)$'
|
||||
|
||||
@@ -136,39 +136,52 @@ class MockApiClient:
|
||||
else:
|
||||
raise ValueError(f'Session not found: {session_id}')
|
||||
elif re.match(SESSIONS_REGEX, path):
|
||||
match = re.match(SESSIONS_REGEX, path)
|
||||
return {
|
||||
'sessions': self.session_dict.values(),
|
||||
'sessions': [
|
||||
session
|
||||
for session in self.session_dict.values()
|
||||
if session['userId'] == match.group(2)
|
||||
],
|
||||
}
|
||||
elif re.match(EVENTS_REGEX, path):
|
||||
match = re.match(EVENTS_REGEX, path)
|
||||
if match:
|
||||
return {'sessionEvents': self.event_dict[match.group(2)]}
|
||||
return {
|
||||
'sessionEvents': (
|
||||
self.event_dict[match.group(2)]
|
||||
if match.group(2) in self.event_dict
|
||||
else []
|
||||
)
|
||||
}
|
||||
elif re.match(LRO_REGEX, path):
|
||||
return {
|
||||
'name': (
|
||||
'projects/test-project/locations/test-location/'
|
||||
'reasoningEngines/123/sessions/123'
|
||||
'reasoningEngines/123/sessions/4'
|
||||
),
|
||||
'done': True,
|
||||
}
|
||||
else:
|
||||
raise ValueError(f'Unsupported path: {path}')
|
||||
elif http_method == 'POST':
|
||||
id = str(uuid.uuid4())
|
||||
self.session_dict[id] = {
|
||||
new_session_id = '4'
|
||||
self.session_dict[new_session_id] = {
|
||||
'name': (
|
||||
'projects/test-project/locations/test-location/'
|
||||
'reasoningEngines/123/sessions/'
|
||||
+ id
|
||||
+ new_session_id
|
||||
),
|
||||
'userId': request_dict['user_id'],
|
||||
'sessionState': request_dict.get('sessionState', {}),
|
||||
'sessionState': request_dict.get('session_state', {}),
|
||||
'updateTime': '2024-12-12T12:12:12.123456Z',
|
||||
}
|
||||
return {
|
||||
'name': (
|
||||
'projects/test_project/locations/test_location/'
|
||||
'reasoningEngines/test_engine/sessions/123'
|
||||
'reasoningEngines/123/sessions/'
|
||||
+ new_session_id
|
||||
+ '/operations/111'
|
||||
),
|
||||
'done': False,
|
||||
}
|
||||
@@ -223,24 +236,28 @@ def test_get_and_delete_session():
|
||||
)
|
||||
assert str(excinfo.value) == 'Session not found: 1'
|
||||
|
||||
def test_list_sessions():
|
||||
session_service = mock_vertex_ai_session_service()
|
||||
sessions = session_service.list_sessions(app_name='123', user_id='user')
|
||||
assert len(sessions.sessions) == 2
|
||||
assert sessions.sessions[0].id == '1'
|
||||
assert sessions.sessions[1].id == '2'
|
||||
|
||||
def test_create_session():
|
||||
session_service = mock_vertex_ai_session_service()
|
||||
session = session_service.create_session(
|
||||
app_name='123', user_id='user', state={'key': 'value'}
|
||||
)
|
||||
assert session.state == {'key': 'value'}
|
||||
assert session.app_name == '123'
|
||||
assert session.user_id == 'user'
|
||||
assert session.last_update_time is not None
|
||||
def test_list_sessions():
|
||||
session_service = mock_vertex_ai_session_service()
|
||||
sessions = session_service.list_sessions(app_name='123', user_id='user')
|
||||
assert len(sessions.sessions) == 2
|
||||
assert sessions.sessions[0].id == '1'
|
||||
assert sessions.sessions[1].id == '2'
|
||||
|
||||
session_id = session.id
|
||||
assert session == session_service.get_session(
|
||||
app_name='123', user_id='user', session_id=session_id
|
||||
)
|
||||
|
||||
def test_create_session():
|
||||
session_service = mock_vertex_ai_session_service()
|
||||
|
||||
state = {'key': 'value'}
|
||||
session = session_service.create_session(
|
||||
app_name='123', user_id='user', state=state
|
||||
)
|
||||
assert session.state == state
|
||||
assert session.app_name == '123'
|
||||
assert session.user_id == 'user'
|
||||
assert session.last_update_time is not None
|
||||
|
||||
session_id = session.id
|
||||
assert session == session_service.get_session(
|
||||
app_name='123', user_id='user', session_id=session_id
|
||||
)
|
||||
|
||||
@@ -119,7 +119,7 @@ def calendar_api_spec():
|
||||
"methods": {
|
||||
"get": {
|
||||
"id": "calendar.calendars.get",
|
||||
"path": "calendars/{calendarId}",
|
||||
"flatPath": "calendars/{calendarId}",
|
||||
"httpMethod": "GET",
|
||||
"description": "Returns metadata for a calendar.",
|
||||
"parameters": {
|
||||
@@ -151,7 +151,7 @@ def calendar_api_spec():
|
||||
"methods": {
|
||||
"list": {
|
||||
"id": "calendar.events.list",
|
||||
"path": "calendars/{calendarId}/events",
|
||||
"flatPath": "calendars/{calendarId}/events",
|
||||
"httpMethod": "GET",
|
||||
"description": (
|
||||
"Returns events on the specified calendar."
|
||||
|
||||
Reference in New Issue
Block a user