Files
adk-python/src/google/adk/tools/agent_tool.py
hangfei 9827820143 Agent Development Kit(ADK)
An easy-to-use and powerful framework to build AI agents.
2025-04-08 17:25:47 +00:00

177 lines
5.7 KiB
Python

# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from typing import Any
from typing import TYPE_CHECKING
from google.genai import types
from pydantic import model_validator
from typing_extensions import override
from ..memory.in_memory_memory_service import InMemoryMemoryService
from ..runners import Runner
from ..sessions.in_memory_session_service import InMemorySessionService
from . import _automatic_function_calling_util
from .base_tool import BaseTool
from .tool_context import ToolContext
if TYPE_CHECKING:
from ..agents.base_agent import BaseAgent
from ..agents.llm_agent import LlmAgent
class AgentTool(BaseTool):
"""A tool that wraps an agent.
This tool allows an agent to be called as a tool within a larger application.
The agent's input schema is used to define the tool's input parameters, and
the agent's output is returned as the tool's result.
Attributes:
agent: The agent to wrap.
skip_summarization: Whether to skip summarization of the agent output.
"""
def __init__(self, agent: BaseAgent):
self.agent = agent
self.skip_summarization: bool = False
"""Whether to skip summarization of the agent output."""
super().__init__(name=agent.name, description=agent.description)
@model_validator(mode='before')
@classmethod
def populate_name(cls, data: Any) -> Any:
data['name'] = data['agent'].name
return data
@override
def _get_declaration(self) -> types.FunctionDeclaration:
from ..agents.llm_agent import LlmAgent
if isinstance(self.agent, LlmAgent) and self.agent.input_schema:
result = _automatic_function_calling_util.build_function_declaration(
func=self.agent.input_schema, variant=self._api_variant
)
else:
result = types.FunctionDeclaration(
parameters=types.Schema(
type=types.Type.OBJECT,
properties={
'request': types.Schema(
type=types.Type.STRING,
),
},
required=['request'],
),
description=self.agent.description,
name=self.name,
)
result.name = self.name
return result
@override
async def run_async(
self,
*,
args: dict[str, Any],
tool_context: ToolContext,
) -> Any:
from ..agents.llm_agent import LlmAgent
if self.skip_summarization:
tool_context.actions.skip_summarization = True
if isinstance(self.agent, LlmAgent) and self.agent.input_schema:
input_value = self.agent.input_schema.model_validate(args)
else:
input_value = args['request']
if isinstance(self.agent, LlmAgent) and self.agent.input_schema:
if isinstance(input_value, dict):
input_value = self.agent.input_schema.model_validate(input_value)
if not isinstance(input_value, self.agent.input_schema):
raise ValueError(
f'Input value {input_value} is not of type'
f' `{self.agent.input_schema}`.'
)
content = types.Content(
role='user',
parts=[
types.Part.from_text(
text=input_value.model_dump_json(exclude_none=True)
)
],
)
else:
content = types.Content(
role='user',
parts=[types.Part.from_text(text=input_value)],
)
runner = Runner(
app_name=self.agent.name,
agent=self.agent,
# TODO(kech): Remove the access to the invocation context.
# It seems we don't need re-use artifact_service if we forward below.
artifact_service=tool_context._invocation_context.artifact_service,
session_service=InMemorySessionService(),
memory_service=InMemoryMemoryService(),
)
session = runner.session_service.create_session(
app_name=self.agent.name,
user_id='tmp_user',
state=tool_context.state.to_dict(),
)
last_event = None
async for event in runner.run_async(
user_id=session.user_id, session_id=session.id, new_message=content
):
# Forward state delta to parent session.
if event.actions.state_delta:
tool_context.state.update(event.actions.state_delta)
last_event = event
if runner.artifact_service:
# Forward all artifacts to parent session.
for artifact_name in runner.artifact_service.list_artifact_keys(
app_name=session.app_name,
user_id=session.user_id,
session_id=session.id,
):
if artifact := runner.artifact_service.load_artifact(
app_name=session.app_name,
user_id=session.user_id,
session_id=session.id,
filename=artifact_name,
):
tool_context.save_artifact(filename=artifact_name, artifact=artifact)
if (
not last_event
or not last_event.content
or not last_event.content.parts
or not last_event.content.parts[0].text
):
return ''
if isinstance(self.agent, LlmAgent) and self.agent.output_schema:
tool_result = self.agent.output_schema.model_validate_json(
last_event.content.parts[0].text
).model_dump(exclude_none=True)
else:
tool_result = last_event.content.parts[0].text
return tool_result