add toolset base class and allow llm agent to accept toolset as tools

PiperOrigin-RevId: 756605470
This commit is contained in:
Xiang (Sean) Zhou
2025-05-08 22:27:13 -07:00
committed by Copybara-Service
parent 8963300518
commit 4d7298e4f2
10 changed files with 129 additions and 39 deletions
@@ -60,11 +60,6 @@ class CallbackContext(ReadonlyContext):
"""
return self._state
@property
def user_content(self) -> Optional[types.Content]:
"""The user content that started this invocation. READONLY field."""
return self._invocation_context.user_content
async def load_artifact(
self, filename: str, version: Optional[int] = None
) -> Optional[types.Part]:
+28 -14
View File
@@ -15,7 +15,15 @@
from __future__ import annotations
import logging
from typing import Any, AsyncGenerator, Awaitable, Callable, Literal, Optional, Union
from typing import (
Any,
AsyncGenerator,
Awaitable,
Callable,
Literal,
Optional,
Union,
)
from google.genai import types
from pydantic import BaseModel
@@ -38,6 +46,7 @@ from ..models.llm_response import LlmResponse
from ..models.registry import LLMRegistry
from ..planners.base_planner import BasePlanner
from ..tools.base_tool import BaseTool
from ..tools.base_toolset import BaseToolset
from ..tools.function_tool import FunctionTool
from ..tools.tool_context import ToolContext
from .base_agent import BaseAgent
@@ -89,18 +98,19 @@ AfterToolCallback: TypeAlias = Union[
InstructionProvider: TypeAlias = Callable[[ReadonlyContext], str]
ToolUnion: TypeAlias = Union[Callable, BaseTool]
ToolUnion: TypeAlias = Union[Callable, BaseTool, BaseToolset]
ExamplesUnion = Union[list[Example], BaseExampleProvider]
def _convert_tool_union_to_tool(
tool_union: ToolUnion,
) -> BaseTool:
return (
tool_union
if isinstance(tool_union, BaseTool)
else FunctionTool(tool_union)
)
async def _convert_tool_union_to_tools(
tool_union: ToolUnion, ctx: ReadonlyContext
) -> list[BaseTool]:
if isinstance(tool_union, BaseTool):
return [tool_union]
if isinstance(tool_union, Callable):
return [FunctionTool(func=tool_union)]
return await tool_union.get_tools(ctx)
class LlmAgent(BaseAgent):
@@ -312,13 +322,17 @@ class LlmAgent(BaseAgent):
else:
return self.global_instruction(ctx)
@property
def canonical_tools(self) -> list[BaseTool]:
"""The resolved self.tools field as a list of BaseTool.
async def canonical_tools(
self, ctx: ReadonlyContext = None
) -> list[BaseTool]:
"""The resolved self.tools field as a list of BaseTool based on the context.
This method is only for use by Agent Development Kit.
"""
return [_convert_tool_union_to_tool(tool) for tool in self.tools]
resolved_tools = []
for tool_union in self.tools:
resolved_tools.extend(await _convert_tool_union_to_tools(tool_union, ctx))
return resolved_tools
@property
def canonical_before_model_callbacks(
+7 -1
View File
@@ -15,10 +15,11 @@
from __future__ import annotations
from types import MappingProxyType
from typing import Any
from typing import Any, Optional
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from google.genai import types
from .invocation_context import InvocationContext
@@ -30,6 +31,11 @@ class ReadonlyContext:
) -> None:
self._invocation_context = invocation_context
@property
def user_content(self) -> Optional[types.Content]:
"""The user content that started this invocation. READONLY field."""
return self._invocation_context.user_content
@property
def invocation_id(self) -> str:
"""The current invocation id."""