Files
evo-ai/.venv/lib/python3.10/site-packages/google/adk/agents/base_agent.py
2025-04-25 15:30:54 -03:00

346 lines
10 KiB
Python

# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from typing import Any
from typing import AsyncGenerator
from typing import Callable
from typing import final
from typing import Optional
from typing import TYPE_CHECKING
from google.genai import types
from opentelemetry import trace
from pydantic import BaseModel
from pydantic import ConfigDict
from pydantic import Field
from pydantic import field_validator
from typing_extensions import override
from ..events.event import Event
from .callback_context import CallbackContext
if TYPE_CHECKING:
from .invocation_context import InvocationContext
tracer = trace.get_tracer('gcp.vertex.agent')
BeforeAgentCallback = Callable[[CallbackContext], Optional[types.Content]]
"""Callback signature that is invoked before the agent run.
Args:
callback_context: MUST be named 'callback_context' (enforced).
Returns:
The content to return to the user. When set, the agent run will skipped and
the provided content will be returned to user.
"""
AfterAgentCallback = Callable[[CallbackContext], Optional[types.Content]]
"""Callback signature that is invoked after the agent run.
Args:
callback_context: MUST be named 'callback_context' (enforced).
Returns:
The content to return to the user. When set, the agent run will skipped and
the provided content will be appended to event history as agent response.
"""
class BaseAgent(BaseModel):
"""Base class for all agents in Agent Development Kit."""
model_config = ConfigDict(
arbitrary_types_allowed=True,
extra='forbid',
)
name: str
"""The agent's name.
Agent name must be a Python identifier and unique within the agent tree.
Agent name cannot be "user", since it's reserved for end-user's input.
"""
description: str = ''
"""Description about the agent's capability.
The model uses this to determine whether to delegate control to the agent.
One-line description is enough and preferred.
"""
parent_agent: Optional[BaseAgent] = Field(default=None, init=False)
"""The parent agent of this agent.
Note that an agent can ONLY be added as sub-agent once.
If you want to add one agent twice as sub-agent, consider to create two agent
instances with identical config, but with different name and add them to the
agent tree.
"""
sub_agents: list[BaseAgent] = Field(default_factory=list)
"""The sub-agents of this agent."""
before_agent_callback: Optional[BeforeAgentCallback] = None
"""Callback signature that is invoked before the agent run.
Args:
callback_context: MUST be named 'callback_context' (enforced).
Returns:
The content to return to the user. When set, the agent run will skipped and
the provided content will be returned to user.
"""
after_agent_callback: Optional[AfterAgentCallback] = None
"""Callback signature that is invoked after the agent run.
Args:
callback_context: MUST be named 'callback_context' (enforced).
Returns:
The content to return to the user. When set, the agent run will skipped and
the provided content will be appended to event history as agent response.
"""
@final
async def run_async(
self,
parent_context: InvocationContext,
) -> AsyncGenerator[Event, None]:
"""Entry method to run an agent via text-based conversation.
Args:
parent_context: InvocationContext, the invocation context of the parent
agent.
Yields:
Event: the events generated by the agent.
"""
with tracer.start_as_current_span(f'agent_run [{self.name}]'):
ctx = self._create_invocation_context(parent_context)
if event := self.__handle_before_agent_callback(ctx):
yield event
if ctx.end_invocation:
return
async for event in self._run_async_impl(ctx):
yield event
if ctx.end_invocation:
return
if event := self.__handle_after_agent_callback(ctx):
yield event
@final
async def run_live(
self,
parent_context: InvocationContext,
) -> AsyncGenerator[Event, None]:
"""Entry method to run an agent via video/audio-based conversation.
Args:
parent_context: InvocationContext, the invocation context of the parent
agent.
Yields:
Event: the events generated by the agent.
"""
with tracer.start_as_current_span(f'agent_run [{self.name}]'):
ctx = self._create_invocation_context(parent_context)
# TODO(hangfei): support before/after_agent_callback
async for event in self._run_live_impl(ctx):
yield event
async def _run_async_impl(
self, ctx: InvocationContext
) -> AsyncGenerator[Event, None]:
"""Core logic to run this agent via text-based conversation.
Args:
ctx: InvocationContext, the invocation context for this agent.
Yields:
Event: the events generated by the agent.
"""
raise NotImplementedError(
f'_run_async_impl for {type(self)} is not implemented.'
)
yield # AsyncGenerator requires having at least one yield statement
async def _run_live_impl(
self, ctx: InvocationContext
) -> AsyncGenerator[Event, None]:
"""Core logic to run this agent via video/audio-based conversation.
Args:
ctx: InvocationContext, the invocation context for this agent.
Yields:
Event: the events generated by the agent.
"""
raise NotImplementedError(
f'_run_live_impl for {type(self)} is not implemented.'
)
yield # AsyncGenerator requires having at least one yield statement
@property
def root_agent(self) -> BaseAgent:
"""Gets the root agent of this agent."""
root_agent = self
while root_agent.parent_agent is not None:
root_agent = root_agent.parent_agent
return root_agent
def find_agent(self, name: str) -> Optional[BaseAgent]:
"""Finds the agent with the given name in this agent and its descendants.
Args:
name: The name of the agent to find.
Returns:
The agent with the matching name, or None if no such agent is found.
"""
if self.name == name:
return self
return self.find_sub_agent(name)
def find_sub_agent(self, name: str) -> Optional[BaseAgent]:
"""Finds the agent with the given name in this agent's descendants.
Args:
name: The name of the agent to find.
Returns:
The agent with the matching name, or None if no such agent is found.
"""
for sub_agent in self.sub_agents:
if result := sub_agent.find_agent(name):
return result
return None
def _create_invocation_context(
self, parent_context: InvocationContext
) -> InvocationContext:
"""Creates a new invocation context for this agent."""
invocation_context = parent_context.model_copy(update={'agent': self})
if parent_context.branch:
invocation_context.branch = f'{parent_context.branch}.{self.name}'
return invocation_context
def __handle_before_agent_callback(
self, ctx: InvocationContext
) -> Optional[Event]:
"""Runs the before_agent_callback if it exists.
Returns:
Optional[Event]: an event if callback provides content or changed state.
"""
ret_event = None
if not isinstance(self.before_agent_callback, Callable):
return ret_event
callback_context = CallbackContext(ctx)
before_agent_callback_content = self.before_agent_callback(
callback_context=callback_context
)
if before_agent_callback_content:
ret_event = Event(
invocation_id=ctx.invocation_id,
author=self.name,
branch=ctx.branch,
content=before_agent_callback_content,
actions=callback_context._event_actions,
)
ctx.end_invocation = True
return ret_event
if callback_context.state.has_delta():
ret_event = Event(
invocation_id=ctx.invocation_id,
author=self.name,
branch=ctx.branch,
actions=callback_context._event_actions,
)
return ret_event
def __handle_after_agent_callback(
self, invocation_context: InvocationContext
) -> Optional[Event]:
"""Runs the after_agent_callback if it exists.
Returns:
Optional[Event]: an event if callback provides content or changed state.
"""
ret_event = None
if not isinstance(self.after_agent_callback, Callable):
return ret_event
callback_context = CallbackContext(invocation_context)
after_agent_callback_content = self.after_agent_callback(
callback_context=callback_context
)
if after_agent_callback_content or callback_context.state.has_delta():
ret_event = Event(
invocation_id=invocation_context.invocation_id,
author=self.name,
branch=invocation_context.branch,
content=after_agent_callback_content,
actions=callback_context._event_actions,
)
return ret_event
@override
def model_post_init(self, __context: Any) -> None:
self.__set_parent_agent_for_sub_agents()
@field_validator('name', mode='after')
@classmethod
def __validate_name(cls, value: str):
if not value.isidentifier():
raise ValueError(
f'Found invalid agent name: `{value}`.'
' Agent name must be a valid identifier. It should start with a'
' letter (a-z, A-Z) or an underscore (_), and can only contain'
' letters, digits (0-9), and underscores.'
)
if value == 'user':
raise ValueError(
"Agent name cannot be `user`. `user` is reserved for end-user's"
' input.'
)
return value
def __set_parent_agent_for_sub_agents(self) -> BaseAgent:
for sub_agent in self.sub_agents:
if sub_agent.parent_agent is not None:
raise ValueError(
f'Agent `{sub_agent.name}` already has a parent agent, current'
f' parent: `{sub_agent.parent_agent.name}`, trying to add:'
f' `{self.name}`'
)
sub_agent.parent_agent = self
return self