mirror of
https://github.com/EvolutionAPI/adk-python.git
synced 2025-07-13 15:14:50 -06:00
314 lines
9.6 KiB
Python
314 lines
9.6 KiB
Python
# Copyright 2025 Google LLC
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
import asyncio
|
|
import contextlib
|
|
from typing import AsyncGenerator
|
|
from typing import Generator
|
|
from typing import Union
|
|
|
|
from google.adk.agents.invocation_context import InvocationContext
|
|
from google.adk.agents.live_request_queue import LiveRequestQueue
|
|
from google.adk.agents.llm_agent import Agent
|
|
from google.adk.agents.llm_agent import LlmAgent
|
|
from google.adk.agents.run_config import RunConfig
|
|
from google.adk.artifacts import InMemoryArtifactService
|
|
from google.adk.events.event import Event
|
|
from google.adk.memory.in_memory_memory_service import InMemoryMemoryService
|
|
from google.adk.models.base_llm import BaseLlm
|
|
from google.adk.models.base_llm_connection import BaseLlmConnection
|
|
from google.adk.models.llm_request import LlmRequest
|
|
from google.adk.models.llm_response import LlmResponse
|
|
from google.adk.runners import InMemoryRunner as AfInMemoryRunner
|
|
from google.adk.runners import Runner
|
|
from google.adk.sessions.in_memory_session_service import InMemorySessionService
|
|
from google.adk.sessions.session import Session
|
|
from google.genai import types
|
|
from google.genai.types import Part
|
|
from typing_extensions import override
|
|
|
|
|
|
class UserContent(types.Content):
|
|
|
|
def __init__(self, text_or_part: str):
|
|
parts = [
|
|
types.Part.from_text(text=text_or_part)
|
|
if isinstance(text_or_part, str)
|
|
else text_or_part
|
|
]
|
|
super().__init__(role='user', parts=parts)
|
|
|
|
|
|
class ModelContent(types.Content):
|
|
|
|
def __init__(self, parts: list[types.Part]):
|
|
super().__init__(role='model', parts=parts)
|
|
|
|
|
|
async def create_invocation_context(agent: Agent, user_content: str = ''):
|
|
invocation_id = 'test_id'
|
|
artifact_service = InMemoryArtifactService()
|
|
session_service = InMemorySessionService()
|
|
memory_service = InMemoryMemoryService()
|
|
invocation_context = InvocationContext(
|
|
artifact_service=artifact_service,
|
|
session_service=session_service,
|
|
memory_service=memory_service,
|
|
invocation_id=invocation_id,
|
|
agent=agent,
|
|
session=await session_service.create_session(
|
|
app_name='test_app', user_id='test_user'
|
|
),
|
|
user_content=types.Content(
|
|
role='user', parts=[types.Part.from_text(text=user_content)]
|
|
),
|
|
run_config=RunConfig(),
|
|
)
|
|
if user_content:
|
|
append_user_content(
|
|
invocation_context, [types.Part.from_text(text=user_content)]
|
|
)
|
|
return invocation_context
|
|
|
|
|
|
def append_user_content(
|
|
invocation_context: InvocationContext, parts: list[types.Part]
|
|
) -> Event:
|
|
session = invocation_context.session
|
|
event = Event(
|
|
invocation_id=invocation_context.invocation_id,
|
|
author='user',
|
|
content=types.Content(role='user', parts=parts),
|
|
)
|
|
session.events.append(event)
|
|
return event
|
|
|
|
|
|
# Extracts the contents from the events and transform them into a list of
|
|
# (author, simplified_content) tuples.
|
|
def simplify_events(events: list[Event]) -> list[(str, types.Part)]:
|
|
return [(event.author, simplify_content(event.content)) for event in events]
|
|
|
|
|
|
# Simplifies the contents into a list of (author, simplified_content) tuples.
|
|
def simplify_contents(contents: list[types.Content]) -> list[(str, types.Part)]:
|
|
return [(content.role, simplify_content(content)) for content in contents]
|
|
|
|
|
|
# Simplifies the content so it's easier to assert.
|
|
# - If there is only one part, return part
|
|
# - If the only part is pure text, return stripped_text
|
|
# - If there are multiple parts, return parts
|
|
# - remove function_call_id if it exists
|
|
def simplify_content(
|
|
content: types.Content,
|
|
) -> Union[str, types.Part, list[types.Part]]:
|
|
for part in content.parts:
|
|
if part.function_call and part.function_call.id:
|
|
part.function_call.id = None
|
|
if part.function_response and part.function_response.id:
|
|
part.function_response.id = None
|
|
if len(content.parts) == 1:
|
|
if content.parts[0].text:
|
|
return content.parts[0].text.strip()
|
|
else:
|
|
return content.parts[0]
|
|
return content.parts
|
|
|
|
|
|
def get_user_content(message: types.ContentUnion) -> types.Content:
|
|
return message if isinstance(message, types.Content) else UserContent(message)
|
|
|
|
|
|
class TestInMemoryRunner(AfInMemoryRunner):
|
|
"""InMemoryRunner that is tailored for tests, features async run method.
|
|
|
|
app_name is hardcoded as InMemoryRunner in the parent class.
|
|
"""
|
|
|
|
async def run_async_with_new_session(
|
|
self, new_message: types.ContentUnion
|
|
) -> list[Event]:
|
|
|
|
session = await self.session_service.create_session(
|
|
app_name='InMemoryRunner', user_id='test_user'
|
|
)
|
|
collected_events = []
|
|
|
|
async for event in self.run_async(
|
|
user_id=session.user_id,
|
|
session_id=session.id,
|
|
new_message=get_user_content(new_message),
|
|
):
|
|
collected_events.append(event)
|
|
|
|
return collected_events
|
|
|
|
|
|
class InMemoryRunner:
|
|
"""InMemoryRunner that is tailored for tests."""
|
|
|
|
def __init__(
|
|
self,
|
|
root_agent: Union[Agent, LlmAgent],
|
|
response_modalities: list[str] = None,
|
|
):
|
|
self.root_agent = root_agent
|
|
self.runner = Runner(
|
|
app_name='test_app',
|
|
agent=root_agent,
|
|
artifact_service=InMemoryArtifactService(),
|
|
session_service=InMemorySessionService(),
|
|
memory_service=InMemoryMemoryService(),
|
|
)
|
|
self.session_id = None
|
|
|
|
@property
|
|
def session(self) -> Session:
|
|
if not self.session_id:
|
|
session = asyncio.run(
|
|
self.runner.session_service.create_session(
|
|
app_name='test_app', user_id='test_user'
|
|
)
|
|
)
|
|
self.session_id = session.id
|
|
return session
|
|
return asyncio.run(
|
|
self.runner.session_service.get_session(
|
|
app_name='test_app', user_id='test_user', session_id=self.session_id
|
|
)
|
|
)
|
|
|
|
def run(self, new_message: types.ContentUnion) -> list[Event]:
|
|
return list(
|
|
self.runner.run(
|
|
user_id=self.session.user_id,
|
|
session_id=self.session.id,
|
|
new_message=get_user_content(new_message),
|
|
)
|
|
)
|
|
|
|
def run_live(self, live_request_queue: LiveRequestQueue) -> list[Event]:
|
|
collected_responses = []
|
|
|
|
async def consume_responses(session: Session):
|
|
run_res = self.runner.run_live(
|
|
session=session,
|
|
live_request_queue=live_request_queue,
|
|
)
|
|
|
|
async for response in run_res:
|
|
collected_responses.append(response)
|
|
# When we have enough response, we should return
|
|
if len(collected_responses) >= 1:
|
|
return
|
|
|
|
try:
|
|
session = self.session
|
|
asyncio.run(consume_responses(session))
|
|
except asyncio.TimeoutError:
|
|
print('Returning any partial results collected so far.')
|
|
|
|
return collected_responses
|
|
|
|
|
|
class MockModel(BaseLlm):
|
|
model: str = 'mock'
|
|
|
|
requests: list[LlmRequest] = []
|
|
responses: list[LlmResponse]
|
|
response_index: int = -1
|
|
|
|
@classmethod
|
|
def create(
|
|
cls,
|
|
responses: Union[
|
|
list[types.Part], list[LlmResponse], list[str], list[list[types.Part]]
|
|
],
|
|
):
|
|
if not responses:
|
|
return cls(responses=[])
|
|
elif isinstance(responses[0], LlmResponse):
|
|
# responses is list[LlmResponse]
|
|
return cls(responses=responses)
|
|
else:
|
|
responses = [
|
|
LlmResponse(content=ModelContent(item))
|
|
if isinstance(item, list) and isinstance(item[0], types.Part)
|
|
# responses is list[list[Part]]
|
|
else LlmResponse(
|
|
content=ModelContent(
|
|
# responses is list[str] or list[Part]
|
|
[Part(text=item) if isinstance(item, str) else item]
|
|
)
|
|
)
|
|
for item in responses
|
|
if item
|
|
]
|
|
|
|
return cls(responses=responses)
|
|
|
|
@staticmethod
|
|
def supported_models() -> list[str]:
|
|
return ['mock']
|
|
|
|
def generate_content(
|
|
self, llm_request: LlmRequest, stream: bool = False
|
|
) -> Generator[LlmResponse, None, None]:
|
|
# Increasement of the index has to happen before the yield.
|
|
self.response_index += 1
|
|
self.requests.append(llm_request)
|
|
# yield LlmResponse(content=self.responses[self.response_index])
|
|
yield self.responses[self.response_index]
|
|
|
|
@override
|
|
async def generate_content_async(
|
|
self, llm_request: LlmRequest, stream: bool = False
|
|
) -> AsyncGenerator[LlmResponse, None]:
|
|
# Increasement of the index has to happen before the yield.
|
|
self.response_index += 1
|
|
self.requests.append(llm_request)
|
|
yield self.responses[self.response_index]
|
|
|
|
@contextlib.asynccontextmanager
|
|
async def connect(self, llm_request: LlmRequest) -> BaseLlmConnection:
|
|
"""Creates a live connection to the LLM."""
|
|
yield MockLlmConnection(self.responses)
|
|
|
|
|
|
class MockLlmConnection(BaseLlmConnection):
|
|
|
|
def __init__(self, llm_responses: list[LlmResponse]):
|
|
self.llm_responses = llm_responses
|
|
|
|
async def send_history(self, history: list[types.Content]):
|
|
pass
|
|
|
|
async def send_content(self, content: types.Content):
|
|
pass
|
|
|
|
async def send(self, data):
|
|
pass
|
|
|
|
async def send_realtime(self, blob: types.Blob):
|
|
pass
|
|
|
|
async def receive(self) -> AsyncGenerator[LlmResponse, None]:
|
|
"""Yield each of the pre-defined LlmResponses."""
|
|
for response in self.llm_responses:
|
|
yield response
|
|
|
|
async def close(self):
|
|
pass
|