# Copyright 2025 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import asyncio import contextlib from typing import AsyncGenerator from typing import Generator from typing import Union from google.adk.agents.invocation_context import InvocationContext from google.adk.agents.live_request_queue import LiveRequestQueue from google.adk.agents.llm_agent import Agent from google.adk.agents.llm_agent import LlmAgent from google.adk.agents.run_config import RunConfig from google.adk.artifacts import InMemoryArtifactService from google.adk.events.event import Event from google.adk.memory.in_memory_memory_service import InMemoryMemoryService from google.adk.models.base_llm import BaseLlm from google.adk.models.base_llm_connection import BaseLlmConnection from google.adk.models.llm_request import LlmRequest from google.adk.models.llm_response import LlmResponse from google.adk.runners import InMemoryRunner as AfInMemoryRunner from google.adk.runners import Runner from google.adk.sessions.in_memory_session_service import InMemorySessionService from google.adk.sessions.session import Session from google.genai import types from google.genai.types import Part from typing_extensions import override class UserContent(types.Content): def __init__(self, text_or_part: str): parts = [ types.Part.from_text(text=text_or_part) if isinstance(text_or_part, str) else text_or_part ] super().__init__(role='user', parts=parts) class ModelContent(types.Content): def __init__(self, parts: list[types.Part]): super().__init__(role='model', parts=parts) async def create_invocation_context(agent: Agent, user_content: str = ''): invocation_id = 'test_id' artifact_service = InMemoryArtifactService() session_service = InMemorySessionService() memory_service = InMemoryMemoryService() invocation_context = InvocationContext( artifact_service=artifact_service, session_service=session_service, memory_service=memory_service, invocation_id=invocation_id, agent=agent, session=await session_service.create_session( app_name='test_app', user_id='test_user' ), user_content=types.Content( role='user', parts=[types.Part.from_text(text=user_content)] ), run_config=RunConfig(), ) if user_content: append_user_content( invocation_context, [types.Part.from_text(text=user_content)] ) return invocation_context def append_user_content( invocation_context: InvocationContext, parts: list[types.Part] ) -> Event: session = invocation_context.session event = Event( invocation_id=invocation_context.invocation_id, author='user', content=types.Content(role='user', parts=parts), ) session.events.append(event) return event # Extracts the contents from the events and transform them into a list of # (author, simplified_content) tuples. def simplify_events(events: list[Event]) -> list[(str, types.Part)]: return [(event.author, simplify_content(event.content)) for event in events] # Simplifies the contents into a list of (author, simplified_content) tuples. def simplify_contents(contents: list[types.Content]) -> list[(str, types.Part)]: return [(content.role, simplify_content(content)) for content in contents] # Simplifies the content so it's easier to assert. # - If there is only one part, return part # - If the only part is pure text, return stripped_text # - If there are multiple parts, return parts # - remove function_call_id if it exists def simplify_content( content: types.Content, ) -> Union[str, types.Part, list[types.Part]]: for part in content.parts: if part.function_call and part.function_call.id: part.function_call.id = None if part.function_response and part.function_response.id: part.function_response.id = None if len(content.parts) == 1: if content.parts[0].text: return content.parts[0].text.strip() else: return content.parts[0] return content.parts def get_user_content(message: types.ContentUnion) -> types.Content: return message if isinstance(message, types.Content) else UserContent(message) class TestInMemoryRunner(AfInMemoryRunner): """InMemoryRunner that is tailored for tests, features async run method. app_name is hardcoded as InMemoryRunner in the parent class. """ async def run_async_with_new_session( self, new_message: types.ContentUnion ) -> list[Event]: session = await self.session_service.create_session( app_name='InMemoryRunner', user_id='test_user' ) collected_events = [] async for event in self.run_async( user_id=session.user_id, session_id=session.id, new_message=get_user_content(new_message), ): collected_events.append(event) return collected_events class InMemoryRunner: """InMemoryRunner that is tailored for tests.""" def __init__( self, root_agent: Union[Agent, LlmAgent], response_modalities: list[str] = None, ): self.root_agent = root_agent self.runner = Runner( app_name='test_app', agent=root_agent, artifact_service=InMemoryArtifactService(), session_service=InMemorySessionService(), memory_service=InMemoryMemoryService(), ) self.session_id = None @property def session(self) -> Session: if not self.session_id: session = asyncio.run( self.runner.session_service.create_session( app_name='test_app', user_id='test_user' ) ) self.session_id = session.id return session return asyncio.run( self.runner.session_service.get_session( app_name='test_app', user_id='test_user', session_id=self.session_id ) ) def run(self, new_message: types.ContentUnion) -> list[Event]: return list( self.runner.run( user_id=self.session.user_id, session_id=self.session.id, new_message=get_user_content(new_message), ) ) def run_live(self, live_request_queue: LiveRequestQueue) -> list[Event]: collected_responses = [] async def consume_responses(session: Session): run_res = self.runner.run_live( session=session, live_request_queue=live_request_queue, ) async for response in run_res: collected_responses.append(response) # When we have enough response, we should return if len(collected_responses) >= 1: return try: session = self.session asyncio.run(consume_responses(session)) except asyncio.TimeoutError: print('Returning any partial results collected so far.') return collected_responses class MockModel(BaseLlm): model: str = 'mock' requests: list[LlmRequest] = [] responses: list[LlmResponse] response_index: int = -1 @classmethod def create( cls, responses: Union[ list[types.Part], list[LlmResponse], list[str], list[list[types.Part]] ], ): if not responses: return cls(responses=[]) elif isinstance(responses[0], LlmResponse): # responses is list[LlmResponse] return cls(responses=responses) else: responses = [ LlmResponse(content=ModelContent(item)) if isinstance(item, list) and isinstance(item[0], types.Part) # responses is list[list[Part]] else LlmResponse( content=ModelContent( # responses is list[str] or list[Part] [Part(text=item) if isinstance(item, str) else item] ) ) for item in responses if item ] return cls(responses=responses) @staticmethod def supported_models() -> list[str]: return ['mock'] def generate_content( self, llm_request: LlmRequest, stream: bool = False ) -> Generator[LlmResponse, None, None]: # Increasement of the index has to happen before the yield. self.response_index += 1 self.requests.append(llm_request) # yield LlmResponse(content=self.responses[self.response_index]) yield self.responses[self.response_index] @override async def generate_content_async( self, llm_request: LlmRequest, stream: bool = False ) -> AsyncGenerator[LlmResponse, None]: # Increasement of the index has to happen before the yield. self.response_index += 1 self.requests.append(llm_request) yield self.responses[self.response_index] @contextlib.asynccontextmanager async def connect(self, llm_request: LlmRequest) -> BaseLlmConnection: """Creates a live connection to the LLM.""" yield MockLlmConnection(self.responses) class MockLlmConnection(BaseLlmConnection): def __init__(self, llm_responses: list[LlmResponse]): self.llm_responses = llm_responses async def send_history(self, history: list[types.Content]): pass async def send_content(self, content: types.Content): pass async def send(self, data): pass async def send_realtime(self, blob: types.Blob): pass async def receive(self) -> AsyncGenerator[LlmResponse, None]: """Yield each of the pre-defined LlmResponses.""" for response in self.llm_responses: yield response async def close(self): pass