adk-python/tests/unittests/testing_utils.py
Selcuk Gun 41b33d4a0a Move public_utils to utils in tests
Renamed conflicting utils.py as testing_utils.py

PiperOrigin-RevId: 761715808
2025-05-21 16:35:11 -07:00

314 lines
9.6 KiB
Python

# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import contextlib
from typing import AsyncGenerator
from typing import Generator
from typing import Union
from google.adk.agents.invocation_context import InvocationContext
from google.adk.agents.live_request_queue import LiveRequestQueue
from google.adk.agents.llm_agent import Agent
from google.adk.agents.llm_agent import LlmAgent
from google.adk.agents.run_config import RunConfig
from google.adk.artifacts import InMemoryArtifactService
from google.adk.events.event import Event
from google.adk.memory.in_memory_memory_service import InMemoryMemoryService
from google.adk.models.base_llm import BaseLlm
from google.adk.models.base_llm_connection import BaseLlmConnection
from google.adk.models.llm_request import LlmRequest
from google.adk.models.llm_response import LlmResponse
from google.adk.runners import InMemoryRunner as AfInMemoryRunner
from google.adk.runners import Runner
from google.adk.sessions.in_memory_session_service import InMemorySessionService
from google.adk.sessions.session import Session
from google.genai import types
from google.genai.types import Part
from typing_extensions import override
class UserContent(types.Content):
def __init__(self, text_or_part: str):
parts = [
types.Part.from_text(text=text_or_part)
if isinstance(text_or_part, str)
else text_or_part
]
super().__init__(role='user', parts=parts)
class ModelContent(types.Content):
def __init__(self, parts: list[types.Part]):
super().__init__(role='model', parts=parts)
async def create_invocation_context(agent: Agent, user_content: str = ''):
invocation_id = 'test_id'
artifact_service = InMemoryArtifactService()
session_service = InMemorySessionService()
memory_service = InMemoryMemoryService()
invocation_context = InvocationContext(
artifact_service=artifact_service,
session_service=session_service,
memory_service=memory_service,
invocation_id=invocation_id,
agent=agent,
session=await session_service.create_session(
app_name='test_app', user_id='test_user'
),
user_content=types.Content(
role='user', parts=[types.Part.from_text(text=user_content)]
),
run_config=RunConfig(),
)
if user_content:
append_user_content(
invocation_context, [types.Part.from_text(text=user_content)]
)
return invocation_context
def append_user_content(
invocation_context: InvocationContext, parts: list[types.Part]
) -> Event:
session = invocation_context.session
event = Event(
invocation_id=invocation_context.invocation_id,
author='user',
content=types.Content(role='user', parts=parts),
)
session.events.append(event)
return event
# Extracts the contents from the events and transform them into a list of
# (author, simplified_content) tuples.
def simplify_events(events: list[Event]) -> list[(str, types.Part)]:
return [(event.author, simplify_content(event.content)) for event in events]
# Simplifies the contents into a list of (author, simplified_content) tuples.
def simplify_contents(contents: list[types.Content]) -> list[(str, types.Part)]:
return [(content.role, simplify_content(content)) for content in contents]
# Simplifies the content so it's easier to assert.
# - If there is only one part, return part
# - If the only part is pure text, return stripped_text
# - If there are multiple parts, return parts
# - remove function_call_id if it exists
def simplify_content(
content: types.Content,
) -> Union[str, types.Part, list[types.Part]]:
for part in content.parts:
if part.function_call and part.function_call.id:
part.function_call.id = None
if part.function_response and part.function_response.id:
part.function_response.id = None
if len(content.parts) == 1:
if content.parts[0].text:
return content.parts[0].text.strip()
else:
return content.parts[0]
return content.parts
def get_user_content(message: types.ContentUnion) -> types.Content:
return message if isinstance(message, types.Content) else UserContent(message)
class TestInMemoryRunner(AfInMemoryRunner):
"""InMemoryRunner that is tailored for tests, features async run method.
app_name is hardcoded as InMemoryRunner in the parent class.
"""
async def run_async_with_new_session(
self, new_message: types.ContentUnion
) -> list[Event]:
session = await self.session_service.create_session(
app_name='InMemoryRunner', user_id='test_user'
)
collected_events = []
async for event in self.run_async(
user_id=session.user_id,
session_id=session.id,
new_message=get_user_content(new_message),
):
collected_events.append(event)
return collected_events
class InMemoryRunner:
"""InMemoryRunner that is tailored for tests."""
def __init__(
self,
root_agent: Union[Agent, LlmAgent],
response_modalities: list[str] = None,
):
self.root_agent = root_agent
self.runner = Runner(
app_name='test_app',
agent=root_agent,
artifact_service=InMemoryArtifactService(),
session_service=InMemorySessionService(),
memory_service=InMemoryMemoryService(),
)
self.session_id = None
@property
def session(self) -> Session:
if not self.session_id:
session = asyncio.run(
self.runner.session_service.create_session(
app_name='test_app', user_id='test_user'
)
)
self.session_id = session.id
return session
return asyncio.run(
self.runner.session_service.get_session(
app_name='test_app', user_id='test_user', session_id=self.session_id
)
)
def run(self, new_message: types.ContentUnion) -> list[Event]:
return list(
self.runner.run(
user_id=self.session.user_id,
session_id=self.session.id,
new_message=get_user_content(new_message),
)
)
def run_live(self, live_request_queue: LiveRequestQueue) -> list[Event]:
collected_responses = []
async def consume_responses(session: Session):
run_res = self.runner.run_live(
session=session,
live_request_queue=live_request_queue,
)
async for response in run_res:
collected_responses.append(response)
# When we have enough response, we should return
if len(collected_responses) >= 1:
return
try:
session = self.session
asyncio.run(consume_responses(session))
except asyncio.TimeoutError:
print('Returning any partial results collected so far.')
return collected_responses
class MockModel(BaseLlm):
model: str = 'mock'
requests: list[LlmRequest] = []
responses: list[LlmResponse]
response_index: int = -1
@classmethod
def create(
cls,
responses: Union[
list[types.Part], list[LlmResponse], list[str], list[list[types.Part]]
],
):
if not responses:
return cls(responses=[])
elif isinstance(responses[0], LlmResponse):
# responses is list[LlmResponse]
return cls(responses=responses)
else:
responses = [
LlmResponse(content=ModelContent(item))
if isinstance(item, list) and isinstance(item[0], types.Part)
# responses is list[list[Part]]
else LlmResponse(
content=ModelContent(
# responses is list[str] or list[Part]
[Part(text=item) if isinstance(item, str) else item]
)
)
for item in responses
if item
]
return cls(responses=responses)
@staticmethod
def supported_models() -> list[str]:
return ['mock']
def generate_content(
self, llm_request: LlmRequest, stream: bool = False
) -> Generator[LlmResponse, None, None]:
# Increasement of the index has to happen before the yield.
self.response_index += 1
self.requests.append(llm_request)
# yield LlmResponse(content=self.responses[self.response_index])
yield self.responses[self.response_index]
@override
async def generate_content_async(
self, llm_request: LlmRequest, stream: bool = False
) -> AsyncGenerator[LlmResponse, None]:
# Increasement of the index has to happen before the yield.
self.response_index += 1
self.requests.append(llm_request)
yield self.responses[self.response_index]
@contextlib.asynccontextmanager
async def connect(self, llm_request: LlmRequest) -> BaseLlmConnection:
"""Creates a live connection to the LLM."""
yield MockLlmConnection(self.responses)
class MockLlmConnection(BaseLlmConnection):
def __init__(self, llm_responses: list[LlmResponse]):
self.llm_responses = llm_responses
async def send_history(self, history: list[types.Content]):
pass
async def send_content(self, content: types.Content):
pass
async def send(self, data):
pass
async def send_realtime(self, blob: types.Blob):
pass
async def receive(self) -> AsyncGenerator[LlmResponse, None]:
"""Yield each of the pre-defined LlmResponses."""
for response in self.llm_responses:
yield response
async def close(self):
pass