structure saas with tools
This commit is contained in:
20
.venv/lib/python3.10/site-packages/google/adk/__init__.py
Normal file
20
.venv/lib/python3.10/site-packages/google/adk/__init__.py
Normal file
@@ -0,0 +1,20 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from . import version
|
||||
from .agents.llm_agent import Agent
|
||||
from .runners import Runner
|
||||
|
||||
__version__ = version.__version__
|
||||
__all__ = ["Agent", "Runner"]
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,32 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .base_agent import BaseAgent
|
||||
from .live_request_queue import LiveRequest
|
||||
from .live_request_queue import LiveRequestQueue
|
||||
from .llm_agent import Agent
|
||||
from .llm_agent import LlmAgent
|
||||
from .loop_agent import LoopAgent
|
||||
from .parallel_agent import ParallelAgent
|
||||
from .run_config import RunConfig
|
||||
from .sequential_agent import SequentialAgent
|
||||
|
||||
__all__ = [
|
||||
'Agent',
|
||||
'BaseAgent',
|
||||
'LlmAgent',
|
||||
'LoopAgent',
|
||||
'ParallelAgent',
|
||||
'SequentialAgent',
|
||||
]
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,38 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import ConfigDict
|
||||
|
||||
from .live_request_queue import LiveRequestQueue
|
||||
|
||||
|
||||
class ActiveStreamingTool(BaseModel):
|
||||
"""Manages streaming tool related resources during invocation."""
|
||||
|
||||
model_config = ConfigDict(
|
||||
arbitrary_types_allowed=True,
|
||||
extra='forbid',
|
||||
)
|
||||
|
||||
task: Optional[asyncio.Task] = None
|
||||
"""The active task of this streaming tool."""
|
||||
|
||||
stream: Optional[LiveRequestQueue] = None
|
||||
"""The active (input) streams of this streaming tool."""
|
||||
@@ -0,0 +1,345 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
from typing import AsyncGenerator
|
||||
from typing import Callable
|
||||
from typing import final
|
||||
from typing import Optional
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from google.genai import types
|
||||
from opentelemetry import trace
|
||||
from pydantic import BaseModel
|
||||
from pydantic import ConfigDict
|
||||
from pydantic import Field
|
||||
from pydantic import field_validator
|
||||
from typing_extensions import override
|
||||
|
||||
from ..events.event import Event
|
||||
from .callback_context import CallbackContext
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .invocation_context import InvocationContext
|
||||
|
||||
tracer = trace.get_tracer('gcp.vertex.agent')
|
||||
|
||||
BeforeAgentCallback = Callable[[CallbackContext], Optional[types.Content]]
|
||||
"""Callback signature that is invoked before the agent run.
|
||||
|
||||
Args:
|
||||
callback_context: MUST be named 'callback_context' (enforced).
|
||||
|
||||
Returns:
|
||||
The content to return to the user. When set, the agent run will skipped and
|
||||
the provided content will be returned to user.
|
||||
"""
|
||||
|
||||
AfterAgentCallback = Callable[[CallbackContext], Optional[types.Content]]
|
||||
"""Callback signature that is invoked after the agent run.
|
||||
|
||||
Args:
|
||||
callback_context: MUST be named 'callback_context' (enforced).
|
||||
|
||||
Returns:
|
||||
The content to return to the user. When set, the agent run will skipped and
|
||||
the provided content will be appended to event history as agent response.
|
||||
"""
|
||||
|
||||
|
||||
class BaseAgent(BaseModel):
|
||||
"""Base class for all agents in Agent Development Kit."""
|
||||
|
||||
model_config = ConfigDict(
|
||||
arbitrary_types_allowed=True,
|
||||
extra='forbid',
|
||||
)
|
||||
|
||||
name: str
|
||||
"""The agent's name.
|
||||
|
||||
Agent name must be a Python identifier and unique within the agent tree.
|
||||
Agent name cannot be "user", since it's reserved for end-user's input.
|
||||
"""
|
||||
|
||||
description: str = ''
|
||||
"""Description about the agent's capability.
|
||||
|
||||
The model uses this to determine whether to delegate control to the agent.
|
||||
One-line description is enough and preferred.
|
||||
"""
|
||||
|
||||
parent_agent: Optional[BaseAgent] = Field(default=None, init=False)
|
||||
"""The parent agent of this agent.
|
||||
|
||||
Note that an agent can ONLY be added as sub-agent once.
|
||||
|
||||
If you want to add one agent twice as sub-agent, consider to create two agent
|
||||
instances with identical config, but with different name and add them to the
|
||||
agent tree.
|
||||
"""
|
||||
sub_agents: list[BaseAgent] = Field(default_factory=list)
|
||||
"""The sub-agents of this agent."""
|
||||
|
||||
before_agent_callback: Optional[BeforeAgentCallback] = None
|
||||
"""Callback signature that is invoked before the agent run.
|
||||
|
||||
Args:
|
||||
callback_context: MUST be named 'callback_context' (enforced).
|
||||
|
||||
Returns:
|
||||
The content to return to the user. When set, the agent run will skipped and
|
||||
the provided content will be returned to user.
|
||||
"""
|
||||
after_agent_callback: Optional[AfterAgentCallback] = None
|
||||
"""Callback signature that is invoked after the agent run.
|
||||
|
||||
Args:
|
||||
callback_context: MUST be named 'callback_context' (enforced).
|
||||
|
||||
Returns:
|
||||
The content to return to the user. When set, the agent run will skipped and
|
||||
the provided content will be appended to event history as agent response.
|
||||
"""
|
||||
|
||||
@final
|
||||
async def run_async(
|
||||
self,
|
||||
parent_context: InvocationContext,
|
||||
) -> AsyncGenerator[Event, None]:
|
||||
"""Entry method to run an agent via text-based conversation.
|
||||
|
||||
Args:
|
||||
parent_context: InvocationContext, the invocation context of the parent
|
||||
agent.
|
||||
|
||||
Yields:
|
||||
Event: the events generated by the agent.
|
||||
"""
|
||||
|
||||
with tracer.start_as_current_span(f'agent_run [{self.name}]'):
|
||||
ctx = self._create_invocation_context(parent_context)
|
||||
|
||||
if event := self.__handle_before_agent_callback(ctx):
|
||||
yield event
|
||||
if ctx.end_invocation:
|
||||
return
|
||||
|
||||
async for event in self._run_async_impl(ctx):
|
||||
yield event
|
||||
|
||||
if ctx.end_invocation:
|
||||
return
|
||||
|
||||
if event := self.__handle_after_agent_callback(ctx):
|
||||
yield event
|
||||
|
||||
@final
|
||||
async def run_live(
|
||||
self,
|
||||
parent_context: InvocationContext,
|
||||
) -> AsyncGenerator[Event, None]:
|
||||
"""Entry method to run an agent via video/audio-based conversation.
|
||||
|
||||
Args:
|
||||
parent_context: InvocationContext, the invocation context of the parent
|
||||
agent.
|
||||
|
||||
Yields:
|
||||
Event: the events generated by the agent.
|
||||
"""
|
||||
with tracer.start_as_current_span(f'agent_run [{self.name}]'):
|
||||
ctx = self._create_invocation_context(parent_context)
|
||||
# TODO(hangfei): support before/after_agent_callback
|
||||
|
||||
async for event in self._run_live_impl(ctx):
|
||||
yield event
|
||||
|
||||
async def _run_async_impl(
|
||||
self, ctx: InvocationContext
|
||||
) -> AsyncGenerator[Event, None]:
|
||||
"""Core logic to run this agent via text-based conversation.
|
||||
|
||||
Args:
|
||||
ctx: InvocationContext, the invocation context for this agent.
|
||||
|
||||
Yields:
|
||||
Event: the events generated by the agent.
|
||||
"""
|
||||
raise NotImplementedError(
|
||||
f'_run_async_impl for {type(self)} is not implemented.'
|
||||
)
|
||||
yield # AsyncGenerator requires having at least one yield statement
|
||||
|
||||
async def _run_live_impl(
|
||||
self, ctx: InvocationContext
|
||||
) -> AsyncGenerator[Event, None]:
|
||||
"""Core logic to run this agent via video/audio-based conversation.
|
||||
|
||||
Args:
|
||||
ctx: InvocationContext, the invocation context for this agent.
|
||||
|
||||
Yields:
|
||||
Event: the events generated by the agent.
|
||||
"""
|
||||
raise NotImplementedError(
|
||||
f'_run_live_impl for {type(self)} is not implemented.'
|
||||
)
|
||||
yield # AsyncGenerator requires having at least one yield statement
|
||||
|
||||
@property
|
||||
def root_agent(self) -> BaseAgent:
|
||||
"""Gets the root agent of this agent."""
|
||||
root_agent = self
|
||||
while root_agent.parent_agent is not None:
|
||||
root_agent = root_agent.parent_agent
|
||||
return root_agent
|
||||
|
||||
def find_agent(self, name: str) -> Optional[BaseAgent]:
|
||||
"""Finds the agent with the given name in this agent and its descendants.
|
||||
|
||||
Args:
|
||||
name: The name of the agent to find.
|
||||
|
||||
Returns:
|
||||
The agent with the matching name, or None if no such agent is found.
|
||||
"""
|
||||
if self.name == name:
|
||||
return self
|
||||
return self.find_sub_agent(name)
|
||||
|
||||
def find_sub_agent(self, name: str) -> Optional[BaseAgent]:
|
||||
"""Finds the agent with the given name in this agent's descendants.
|
||||
|
||||
Args:
|
||||
name: The name of the agent to find.
|
||||
|
||||
Returns:
|
||||
The agent with the matching name, or None if no such agent is found.
|
||||
"""
|
||||
for sub_agent in self.sub_agents:
|
||||
if result := sub_agent.find_agent(name):
|
||||
return result
|
||||
return None
|
||||
|
||||
def _create_invocation_context(
|
||||
self, parent_context: InvocationContext
|
||||
) -> InvocationContext:
|
||||
"""Creates a new invocation context for this agent."""
|
||||
invocation_context = parent_context.model_copy(update={'agent': self})
|
||||
if parent_context.branch:
|
||||
invocation_context.branch = f'{parent_context.branch}.{self.name}'
|
||||
return invocation_context
|
||||
|
||||
def __handle_before_agent_callback(
|
||||
self, ctx: InvocationContext
|
||||
) -> Optional[Event]:
|
||||
"""Runs the before_agent_callback if it exists.
|
||||
|
||||
Returns:
|
||||
Optional[Event]: an event if callback provides content or changed state.
|
||||
"""
|
||||
ret_event = None
|
||||
|
||||
if not isinstance(self.before_agent_callback, Callable):
|
||||
return ret_event
|
||||
|
||||
callback_context = CallbackContext(ctx)
|
||||
before_agent_callback_content = self.before_agent_callback(
|
||||
callback_context=callback_context
|
||||
)
|
||||
|
||||
if before_agent_callback_content:
|
||||
ret_event = Event(
|
||||
invocation_id=ctx.invocation_id,
|
||||
author=self.name,
|
||||
branch=ctx.branch,
|
||||
content=before_agent_callback_content,
|
||||
actions=callback_context._event_actions,
|
||||
)
|
||||
ctx.end_invocation = True
|
||||
return ret_event
|
||||
|
||||
if callback_context.state.has_delta():
|
||||
ret_event = Event(
|
||||
invocation_id=ctx.invocation_id,
|
||||
author=self.name,
|
||||
branch=ctx.branch,
|
||||
actions=callback_context._event_actions,
|
||||
)
|
||||
|
||||
return ret_event
|
||||
|
||||
def __handle_after_agent_callback(
|
||||
self, invocation_context: InvocationContext
|
||||
) -> Optional[Event]:
|
||||
"""Runs the after_agent_callback if it exists.
|
||||
|
||||
Returns:
|
||||
Optional[Event]: an event if callback provides content or changed state.
|
||||
"""
|
||||
ret_event = None
|
||||
|
||||
if not isinstance(self.after_agent_callback, Callable):
|
||||
return ret_event
|
||||
|
||||
callback_context = CallbackContext(invocation_context)
|
||||
after_agent_callback_content = self.after_agent_callback(
|
||||
callback_context=callback_context
|
||||
)
|
||||
|
||||
if after_agent_callback_content or callback_context.state.has_delta():
|
||||
ret_event = Event(
|
||||
invocation_id=invocation_context.invocation_id,
|
||||
author=self.name,
|
||||
branch=invocation_context.branch,
|
||||
content=after_agent_callback_content,
|
||||
actions=callback_context._event_actions,
|
||||
)
|
||||
|
||||
return ret_event
|
||||
|
||||
@override
|
||||
def model_post_init(self, __context: Any) -> None:
|
||||
self.__set_parent_agent_for_sub_agents()
|
||||
|
||||
@field_validator('name', mode='after')
|
||||
@classmethod
|
||||
def __validate_name(cls, value: str):
|
||||
if not value.isidentifier():
|
||||
raise ValueError(
|
||||
f'Found invalid agent name: `{value}`.'
|
||||
' Agent name must be a valid identifier. It should start with a'
|
||||
' letter (a-z, A-Z) or an underscore (_), and can only contain'
|
||||
' letters, digits (0-9), and underscores.'
|
||||
)
|
||||
if value == 'user':
|
||||
raise ValueError(
|
||||
"Agent name cannot be `user`. `user` is reserved for end-user's"
|
||||
' input.'
|
||||
)
|
||||
return value
|
||||
|
||||
def __set_parent_agent_for_sub_agents(self) -> BaseAgent:
|
||||
for sub_agent in self.sub_agents:
|
||||
if sub_agent.parent_agent is not None:
|
||||
raise ValueError(
|
||||
f'Agent `{sub_agent.name}` already has a parent agent, current'
|
||||
f' parent: `{sub_agent.parent_agent.name}`, trying to add:'
|
||||
f' `{self.name}`'
|
||||
)
|
||||
sub_agent.parent_agent = self
|
||||
return self
|
||||
@@ -0,0 +1,111 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Optional, TYPE_CHECKING
|
||||
|
||||
from typing_extensions import override
|
||||
|
||||
from .readonly_context import ReadonlyContext
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from google.genai import types
|
||||
|
||||
from ..events.event_actions import EventActions
|
||||
from ..sessions.state import State
|
||||
from .invocation_context import InvocationContext
|
||||
|
||||
|
||||
class CallbackContext(ReadonlyContext):
|
||||
"""The context of various callbacks within an agent run."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
invocation_context: InvocationContext,
|
||||
*,
|
||||
event_actions: Optional[EventActions] = None,
|
||||
) -> None:
|
||||
super().__init__(invocation_context)
|
||||
|
||||
from ..events.event_actions import EventActions
|
||||
from ..sessions.state import State
|
||||
|
||||
# TODO(weisun): make this public for Agent Development Kit, but private for
|
||||
# users.
|
||||
self._event_actions = event_actions or EventActions()
|
||||
self._state = State(
|
||||
value=invocation_context.session.state,
|
||||
delta=self._event_actions.state_delta,
|
||||
)
|
||||
|
||||
@property
|
||||
@override
|
||||
def state(self) -> State:
|
||||
"""The delta-aware state of the current session.
|
||||
|
||||
For any state change, you can mutate this object directly,
|
||||
e.g. `ctx.state['foo'] = 'bar'`
|
||||
"""
|
||||
return self._state
|
||||
|
||||
@property
|
||||
def user_content(self) -> Optional[types.Content]:
|
||||
"""The user content that started this invocation. READONLY field."""
|
||||
return self._invocation_context.user_content
|
||||
|
||||
def load_artifact(
|
||||
self, filename: str, version: Optional[int] = None
|
||||
) -> Optional[types.Part]:
|
||||
"""Loads an artifact attached to the current session.
|
||||
|
||||
Args:
|
||||
filename: The filename of the artifact.
|
||||
version: The version of the artifact. If None, the latest version will be
|
||||
returned.
|
||||
|
||||
Returns:
|
||||
The artifact.
|
||||
"""
|
||||
if self._invocation_context.artifact_service is None:
|
||||
raise ValueError("Artifact service is not initialized.")
|
||||
return self._invocation_context.artifact_service.load_artifact(
|
||||
app_name=self._invocation_context.app_name,
|
||||
user_id=self._invocation_context.user_id,
|
||||
session_id=self._invocation_context.session.id,
|
||||
filename=filename,
|
||||
version=version,
|
||||
)
|
||||
|
||||
def save_artifact(self, filename: str, artifact: types.Part) -> int:
|
||||
"""Saves an artifact and records it as delta for the current session.
|
||||
|
||||
Args:
|
||||
filename: The filename of the artifact.
|
||||
artifact: The artifact to save.
|
||||
|
||||
Returns:
|
||||
The version of the artifact.
|
||||
"""
|
||||
if self._invocation_context.artifact_service is None:
|
||||
raise ValueError("Artifact service is not initialized.")
|
||||
version = self._invocation_context.artifact_service.save_artifact(
|
||||
app_name=self._invocation_context.app_name,
|
||||
user_id=self._invocation_context.user_id,
|
||||
session_id=self._invocation_context.session.id,
|
||||
filename=filename,
|
||||
artifact=artifact,
|
||||
)
|
||||
self._event_actions.artifact_delta[filename] = version
|
||||
return version
|
||||
@@ -0,0 +1,181 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Optional
|
||||
import uuid
|
||||
|
||||
from google.genai import types
|
||||
from pydantic import BaseModel
|
||||
from pydantic import ConfigDict
|
||||
|
||||
from ..artifacts.base_artifact_service import BaseArtifactService
|
||||
from ..memory.base_memory_service import BaseMemoryService
|
||||
from ..sessions.base_session_service import BaseSessionService
|
||||
from ..sessions.session import Session
|
||||
from .active_streaming_tool import ActiveStreamingTool
|
||||
from .base_agent import BaseAgent
|
||||
from .live_request_queue import LiveRequestQueue
|
||||
from .run_config import RunConfig
|
||||
from .transcription_entry import TranscriptionEntry
|
||||
|
||||
|
||||
class LlmCallsLimitExceededError(Exception):
|
||||
"""Error thrown when the number of LLM calls exceed the limit."""
|
||||
|
||||
|
||||
class _InvocationCostManager(BaseModel):
|
||||
"""A container to keep track of the cost of invocation.
|
||||
|
||||
While we don't expected the metrics captured here to be a direct
|
||||
representatative of monetary cost incurred in executing the current
|
||||
invocation, but they, in someways have an indirect affect.
|
||||
"""
|
||||
|
||||
_number_of_llm_calls: int = 0
|
||||
"""A counter that keeps track of number of llm calls made."""
|
||||
|
||||
def increment_and_enforce_llm_calls_limit(
|
||||
self, run_config: Optional[RunConfig]
|
||||
):
|
||||
"""Increments _number_of_llm_calls and enforces the limit."""
|
||||
# We first increment the counter and then check the conditions.
|
||||
self._number_of_llm_calls += 1
|
||||
|
||||
if (
|
||||
run_config
|
||||
and run_config.max_llm_calls > 0
|
||||
and self._number_of_llm_calls > run_config.max_llm_calls
|
||||
):
|
||||
# We only enforce the limit if the limit is a positive number.
|
||||
raise LlmCallsLimitExceededError(
|
||||
"Max number of llm calls limit of"
|
||||
f" `{run_config.max_llm_calls}` exceeded"
|
||||
)
|
||||
|
||||
|
||||
class InvocationContext(BaseModel):
|
||||
"""An invocation context represents the data of a single invocation of an agent.
|
||||
|
||||
An invocation:
|
||||
1. Starts with a user message and ends with a final response.
|
||||
2. Can contain one or multiple agent calls.
|
||||
3. Is handled by runner.run_async().
|
||||
|
||||
An invocation runs an agent until it does not request to transfer to another
|
||||
agent.
|
||||
|
||||
An agent call:
|
||||
1. Is handled by agent.run().
|
||||
2. Ends when agent.run() ends.
|
||||
|
||||
An LLM agent call is an agent with a BaseLLMFlow.
|
||||
An LLM agent call can contain one or multiple steps.
|
||||
|
||||
An LLM agent runs steps in a loop until:
|
||||
1. A final response is generated.
|
||||
2. The agent transfers to another agent.
|
||||
3. The end_invocation is set to true by any callbacks or tools.
|
||||
|
||||
A step:
|
||||
1. Calls the LLM only once and yields its response.
|
||||
2. Calls the tools and yields their responses if requested.
|
||||
|
||||
The summarization of the function response is considered another step, since
|
||||
it is another llm call.
|
||||
A step ends when it's done calling llm and tools, or if the end_invocation
|
||||
is set to true at any time.
|
||||
|
||||
```
|
||||
┌─────────────────────── invocation ──────────────────────────┐
|
||||
┌──────────── llm_agent_call_1 ────────────┐ ┌─ agent_call_2 ─┐
|
||||
┌──── step_1 ────────┐ ┌───── step_2 ──────┐
|
||||
[call_llm] [call_tool] [call_llm] [transfer]
|
||||
```
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(
|
||||
arbitrary_types_allowed=True,
|
||||
extra="forbid",
|
||||
)
|
||||
|
||||
artifact_service: Optional[BaseArtifactService] = None
|
||||
session_service: BaseSessionService
|
||||
memory_service: Optional[BaseMemoryService] = None
|
||||
|
||||
invocation_id: str
|
||||
"""The id of this invocation context. Readonly."""
|
||||
branch: Optional[str] = None
|
||||
"""The branch of the invocation context.
|
||||
|
||||
The format is like agent_1.agent_2.agent_3, where agent_1 is the parent of
|
||||
agent_2, and agent_2 is the parent of agent_3.
|
||||
|
||||
Branch is used when multiple sub-agents shouldn't see their peer agents'
|
||||
conversation history.
|
||||
"""
|
||||
agent: BaseAgent
|
||||
"""The current agent of this invocation context. Readonly."""
|
||||
user_content: Optional[types.Content] = None
|
||||
"""The user content that started this invocation. Readonly."""
|
||||
session: Session
|
||||
"""The current session of this invocation context. Readonly."""
|
||||
|
||||
end_invocation: bool = False
|
||||
"""Whether to end this invocation.
|
||||
|
||||
Set to True in callbacks or tools to terminate this invocation."""
|
||||
|
||||
live_request_queue: Optional[LiveRequestQueue] = None
|
||||
"""The queue to receive live requests."""
|
||||
|
||||
active_streaming_tools: Optional[dict[str, ActiveStreamingTool]] = None
|
||||
"""The running streaming tools of this invocation."""
|
||||
|
||||
transcription_cache: Optional[list[TranscriptionEntry]] = None
|
||||
"""Caches necessary, data audio or contents, that are needed by transcription."""
|
||||
|
||||
run_config: Optional[RunConfig] = None
|
||||
"""Configurations for live agents under this invocation."""
|
||||
|
||||
_invocation_cost_manager: _InvocationCostManager = _InvocationCostManager()
|
||||
"""A container to keep track of different kinds of costs incurred as a part
|
||||
of this invocation.
|
||||
"""
|
||||
|
||||
def increment_llm_call_count(
|
||||
self,
|
||||
):
|
||||
"""Tracks number of llm calls made.
|
||||
|
||||
Raises:
|
||||
LlmCallsLimitExceededError: If number of llm calls made exceed the set
|
||||
threshold.
|
||||
"""
|
||||
self._invocation_cost_manager.increment_and_enforce_llm_calls_limit(
|
||||
self.run_config
|
||||
)
|
||||
|
||||
@property
|
||||
def app_name(self) -> str:
|
||||
return self.session.app_name
|
||||
|
||||
@property
|
||||
def user_id(self) -> str:
|
||||
return self.session.user_id
|
||||
|
||||
|
||||
def new_invocation_context_id() -> str:
|
||||
return "e-" + str(uuid.uuid4())
|
||||
@@ -0,0 +1,140 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import AsyncGenerator
|
||||
from typing import Union
|
||||
|
||||
from google.genai import types
|
||||
from langchain_core.messages import AIMessage
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langchain_core.messages import SystemMessage
|
||||
from langchain_core.runnables.config import RunnableConfig
|
||||
from langgraph.graph.graph import CompiledGraph
|
||||
from pydantic import ConfigDict
|
||||
from typing_extensions import override
|
||||
|
||||
from ..events.event import Event
|
||||
from .base_agent import BaseAgent
|
||||
from .invocation_context import InvocationContext
|
||||
|
||||
|
||||
def _get_last_human_messages(events: list[Event]) -> list[HumanMessage]:
|
||||
"""Extracts last human messages from given list of events.
|
||||
|
||||
Args:
|
||||
events: the list of events
|
||||
|
||||
Returns:
|
||||
list of last human messages
|
||||
"""
|
||||
messages = []
|
||||
for event in reversed(events):
|
||||
if messages and event.author != 'user':
|
||||
break
|
||||
if event.author == 'user' and event.content and event.content.parts:
|
||||
messages.append(HumanMessage(content=event.content.parts[0].text))
|
||||
return list(reversed(messages))
|
||||
|
||||
|
||||
class LangGraphAgent(BaseAgent):
|
||||
"""Currently a concept implementation, supports single and multi-turn."""
|
||||
|
||||
model_config = ConfigDict(
|
||||
arbitrary_types_allowed=True,
|
||||
)
|
||||
|
||||
graph: CompiledGraph
|
||||
|
||||
instruction: str = ''
|
||||
|
||||
@override
|
||||
async def _run_async_impl(
|
||||
self,
|
||||
ctx: InvocationContext,
|
||||
) -> AsyncGenerator[Event, None]:
|
||||
|
||||
# Needed for langgraph checkpointer (for subsequent invocations; multi-turn)
|
||||
config: RunnableConfig = {'configurable': {'thread_id': ctx.session.id}}
|
||||
|
||||
# Add instruction as SystemMessage if graph state is empty
|
||||
current_graph_state = self.graph.get_state(config)
|
||||
graph_messages = (
|
||||
current_graph_state.values.get('messages', [])
|
||||
if current_graph_state.values
|
||||
else []
|
||||
)
|
||||
messages = (
|
||||
[SystemMessage(content=self.instruction)]
|
||||
if self.instruction and not graph_messages
|
||||
else []
|
||||
)
|
||||
# Add events to messages (evaluating the memory used; parent agent vs checkpointer)
|
||||
messages += self._get_messages(ctx.session.events)
|
||||
|
||||
# Use the Runnable
|
||||
final_state = self.graph.invoke({'messages': messages}, config)
|
||||
result = final_state['messages'][-1].content
|
||||
|
||||
result_event = Event(
|
||||
invocation_id=ctx.invocation_id,
|
||||
author=self.name,
|
||||
branch=ctx.branch,
|
||||
content=types.Content(
|
||||
role='model',
|
||||
parts=[types.Part.from_text(text=result)],
|
||||
),
|
||||
)
|
||||
yield result_event
|
||||
|
||||
def _get_messages(
|
||||
self, events: list[Event]
|
||||
) -> list[Union[HumanMessage, AIMessage]]:
|
||||
"""Extracts messages from given list of events.
|
||||
|
||||
If the developer provides their own memory within langgraph, we return the
|
||||
last user messages only. Otherwise, we return all messages between the user
|
||||
and the agent.
|
||||
|
||||
Args:
|
||||
events: the list of events
|
||||
|
||||
Returns:
|
||||
list of messages
|
||||
"""
|
||||
if self.graph.checkpointer:
|
||||
return _get_last_human_messages(events)
|
||||
else:
|
||||
return self._get_conversation_with_agent(events)
|
||||
|
||||
def _get_conversation_with_agent(
|
||||
self, events: list[Event]
|
||||
) -> list[Union[HumanMessage, AIMessage]]:
|
||||
"""Extracts messages from given list of events.
|
||||
|
||||
Args:
|
||||
events: the list of events
|
||||
|
||||
Returns:
|
||||
list of messages
|
||||
"""
|
||||
|
||||
messages = []
|
||||
for event in events:
|
||||
if not event.content or not event.content.parts:
|
||||
continue
|
||||
if event.author == 'user':
|
||||
messages.append(HumanMessage(content=event.content.parts[0].text))
|
||||
elif event.author == self.name:
|
||||
messages.append(AIMessage(content=event.content.parts[0].text))
|
||||
return messages
|
||||
@@ -0,0 +1,64 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import asyncio
|
||||
from typing import Optional
|
||||
|
||||
from google.genai import types
|
||||
from pydantic import BaseModel
|
||||
from pydantic import ConfigDict
|
||||
|
||||
|
||||
class LiveRequest(BaseModel):
|
||||
"""Request send to live agents."""
|
||||
|
||||
model_config = ConfigDict(ser_json_bytes='base64', val_json_bytes='base64')
|
||||
|
||||
content: Optional[types.Content] = None
|
||||
"""If set, send the content to the model in turn-by-turn mode."""
|
||||
blob: Optional[types.Blob] = None
|
||||
"""If set, send the blob to the model in realtime mode."""
|
||||
close: bool = False
|
||||
"""If set, close the queue. queue.shutdown() is only supported in Python 3.13+."""
|
||||
|
||||
|
||||
class LiveRequestQueue:
|
||||
"""Queue used to send LiveRequest in a live(bidirectional streaming) way."""
|
||||
|
||||
def __init__(self):
|
||||
# Ensure there's an event loop available in this thread
|
||||
try:
|
||||
asyncio.get_running_loop()
|
||||
except RuntimeError:
|
||||
# No running loop, create one
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
# Now create the queue (it will use the event loop we just ensured exists)
|
||||
self._queue = asyncio.Queue()
|
||||
|
||||
def close(self):
|
||||
self._queue.put_nowait(LiveRequest(close=True))
|
||||
|
||||
def send_content(self, content: types.Content):
|
||||
self._queue.put_nowait(LiveRequest(content=content))
|
||||
|
||||
def send_realtime(self, blob: types.Blob):
|
||||
self._queue.put_nowait(LiveRequest(blob=blob))
|
||||
|
||||
def send(self, req: LiveRequest):
|
||||
self._queue.put_nowait(req)
|
||||
|
||||
async def get(self) -> LiveRequest:
|
||||
return await self._queue.get()
|
||||
@@ -0,0 +1,376 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
from typing import AsyncGenerator
|
||||
from typing import Callable
|
||||
from typing import Literal
|
||||
from typing import Optional
|
||||
from typing import Union
|
||||
|
||||
from google.genai import types
|
||||
from pydantic import BaseModel
|
||||
from pydantic import Field
|
||||
from pydantic import field_validator
|
||||
from pydantic import model_validator
|
||||
from typing_extensions import override
|
||||
from typing_extensions import TypeAlias
|
||||
|
||||
from ..code_executors.base_code_executor import BaseCodeExecutor
|
||||
from ..events.event import Event
|
||||
from ..examples.base_example_provider import BaseExampleProvider
|
||||
from ..examples.example import Example
|
||||
from ..flows.llm_flows.auto_flow import AutoFlow
|
||||
from ..flows.llm_flows.base_llm_flow import BaseLlmFlow
|
||||
from ..flows.llm_flows.single_flow import SingleFlow
|
||||
from ..models.base_llm import BaseLlm
|
||||
from ..models.llm_request import LlmRequest
|
||||
from ..models.llm_response import LlmResponse
|
||||
from ..models.registry import LLMRegistry
|
||||
from ..planners.base_planner import BasePlanner
|
||||
from ..tools.base_tool import BaseTool
|
||||
from ..tools.function_tool import FunctionTool
|
||||
from ..tools.tool_context import ToolContext
|
||||
from .base_agent import BaseAgent
|
||||
from .callback_context import CallbackContext
|
||||
from .invocation_context import InvocationContext
|
||||
from .readonly_context import ReadonlyContext
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
BeforeModelCallback: TypeAlias = Callable[
|
||||
[CallbackContext, LlmRequest], Optional[LlmResponse]
|
||||
]
|
||||
AfterModelCallback: TypeAlias = Callable[
|
||||
[CallbackContext, LlmResponse],
|
||||
Optional[LlmResponse],
|
||||
]
|
||||
BeforeToolCallback: TypeAlias = Callable[
|
||||
[BaseTool, dict[str, Any], ToolContext],
|
||||
Optional[dict],
|
||||
]
|
||||
AfterToolCallback: TypeAlias = Callable[
|
||||
[BaseTool, dict[str, Any], ToolContext, dict],
|
||||
Optional[dict],
|
||||
]
|
||||
|
||||
InstructionProvider: TypeAlias = Callable[[ReadonlyContext], str]
|
||||
|
||||
ToolUnion: TypeAlias = Union[Callable, BaseTool]
|
||||
ExamplesUnion = Union[list[Example], BaseExampleProvider]
|
||||
|
||||
|
||||
def _convert_tool_union_to_tool(
|
||||
tool_union: ToolUnion,
|
||||
) -> BaseTool:
|
||||
return (
|
||||
tool_union
|
||||
if isinstance(tool_union, BaseTool)
|
||||
else FunctionTool(tool_union)
|
||||
)
|
||||
|
||||
|
||||
class LlmAgent(BaseAgent):
|
||||
"""LLM-based Agent."""
|
||||
|
||||
model: Union[str, BaseLlm] = ''
|
||||
"""The model to use for the agent.
|
||||
|
||||
When not set, the agent will inherit the model from its ancestor.
|
||||
"""
|
||||
|
||||
instruction: Union[str, InstructionProvider] = ''
|
||||
"""Instructions for the LLM model, guiding the agent's behavior."""
|
||||
|
||||
global_instruction: Union[str, InstructionProvider] = ''
|
||||
"""Instructions for all the agents in the entire agent tree.
|
||||
|
||||
global_instruction ONLY takes effect in root agent.
|
||||
|
||||
For example: use global_instruction to make all agents have a stable identity
|
||||
or personality.
|
||||
"""
|
||||
|
||||
tools: list[ToolUnion] = Field(default_factory=list)
|
||||
"""Tools available to this agent."""
|
||||
|
||||
generate_content_config: Optional[types.GenerateContentConfig] = None
|
||||
"""The additional content generation configurations.
|
||||
|
||||
NOTE: not all fields are usable, e.g. tools must be configured via `tools`,
|
||||
thinking_config must be configured via `planner` in LlmAgent.
|
||||
|
||||
For example: use this config to adjust model temperature, configure safety
|
||||
settings, etc.
|
||||
"""
|
||||
|
||||
# LLM-based agent transfer configs - Start
|
||||
disallow_transfer_to_parent: bool = False
|
||||
"""Disallows LLM-controlled transferring to the parent agent."""
|
||||
disallow_transfer_to_peers: bool = False
|
||||
"""Disallows LLM-controlled transferring to the peer agents."""
|
||||
# LLM-based agent transfer configs - End
|
||||
|
||||
include_contents: Literal['default', 'none'] = 'default'
|
||||
"""Whether to include contents in the model request.
|
||||
|
||||
When set to 'none', the model request will not include any contents, such as
|
||||
user messages, tool results, etc.
|
||||
"""
|
||||
|
||||
# Controlled input/output configurations - Start
|
||||
input_schema: Optional[type[BaseModel]] = None
|
||||
"""The input schema when agent is used as a tool."""
|
||||
output_schema: Optional[type[BaseModel]] = None
|
||||
"""The output schema when agent replies.
|
||||
|
||||
NOTE: when this is set, agent can ONLY reply and CANNOT use any tools, such as
|
||||
function tools, RAGs, agent transfer, etc.
|
||||
"""
|
||||
output_key: Optional[str] = None
|
||||
"""The key in session state to store the output of the agent.
|
||||
|
||||
Typically use cases:
|
||||
- Extracts agent reply for later use, such as in tools, callbacks, etc.
|
||||
- Connects agents to coordinate with each other.
|
||||
"""
|
||||
# Controlled input/output configurations - End
|
||||
|
||||
# Advance features - Start
|
||||
planner: Optional[BasePlanner] = None
|
||||
"""Instructs the agent to make a plan and execute it step by step.
|
||||
|
||||
NOTE: to use model's built-in thinking features, set the `thinking_config`
|
||||
field in `google.adk.planners.built_in_planner`.
|
||||
|
||||
"""
|
||||
|
||||
code_executor: Optional[BaseCodeExecutor] = None
|
||||
"""Allow agent to execute code blocks from model responses using the provided
|
||||
CodeExecutor.
|
||||
|
||||
Check out available code executions in `google.adk.code_executor` package.
|
||||
|
||||
NOTE: to use model's built-in code executor, don't set this field, add
|
||||
`google.adk.tools.built_in_code_execution` to tools instead.
|
||||
"""
|
||||
# Advance features - End
|
||||
|
||||
# TODO: remove below fields after migration. - Start
|
||||
# These fields are added back for easier migration.
|
||||
examples: Optional[ExamplesUnion] = None
|
||||
# TODO: remove above fields after migration. - End
|
||||
|
||||
# Callbacks - Start
|
||||
before_model_callback: Optional[BeforeModelCallback] = None
|
||||
"""Called before calling the LLM.
|
||||
Args:
|
||||
callback_context: CallbackContext,
|
||||
llm_request: LlmRequest, The raw model request. Callback can mutate the
|
||||
request.
|
||||
|
||||
Returns:
|
||||
The content to return to the user. When present, the model call will be
|
||||
skipped and the provided content will be returned to user.
|
||||
"""
|
||||
after_model_callback: Optional[AfterModelCallback] = None
|
||||
"""Called after calling LLM.
|
||||
|
||||
Args:
|
||||
callback_context: CallbackContext,
|
||||
llm_response: LlmResponse, the actual model response.
|
||||
|
||||
Returns:
|
||||
The content to return to the user. When present, the actual model response
|
||||
will be ignored and the provided content will be returned to user.
|
||||
"""
|
||||
before_tool_callback: Optional[BeforeToolCallback] = None
|
||||
"""Called before the tool is called.
|
||||
|
||||
Args:
|
||||
tool: The tool to be called.
|
||||
args: The arguments to the tool.
|
||||
tool_context: ToolContext,
|
||||
|
||||
Returns:
|
||||
The tool response. When present, the returned tool response will be used and
|
||||
the framework will skip calling the actual tool.
|
||||
"""
|
||||
after_tool_callback: Optional[AfterToolCallback] = None
|
||||
"""Called after the tool is called.
|
||||
|
||||
Args:
|
||||
tool: The tool to be called.
|
||||
args: The arguments to the tool.
|
||||
tool_context: ToolContext,
|
||||
tool_response: The response from the tool.
|
||||
|
||||
Returns:
|
||||
When present, the returned dict will be used as tool result.
|
||||
"""
|
||||
# Callbacks - End
|
||||
|
||||
@override
|
||||
async def _run_async_impl(
|
||||
self, ctx: InvocationContext
|
||||
) -> AsyncGenerator[Event, None]:
|
||||
async for event in self._llm_flow.run_async(ctx):
|
||||
self.__maybe_save_output_to_state(event)
|
||||
yield event
|
||||
|
||||
@override
|
||||
async def _run_live_impl(
|
||||
self, ctx: InvocationContext
|
||||
) -> AsyncGenerator[Event, None]:
|
||||
async for event in self._llm_flow.run_live(ctx):
|
||||
self.__maybe_save_output_to_state(event)
|
||||
yield event
|
||||
if ctx.end_invocation:
|
||||
return
|
||||
|
||||
@property
|
||||
def canonical_model(self) -> BaseLlm:
|
||||
"""The resolved self.model field as BaseLlm.
|
||||
|
||||
This method is only for use by Agent Development Kit.
|
||||
"""
|
||||
if isinstance(self.model, BaseLlm):
|
||||
return self.model
|
||||
elif self.model: # model is non-empty str
|
||||
return LLMRegistry.new_llm(self.model)
|
||||
else: # find model from ancestors.
|
||||
ancestor_agent = self.parent_agent
|
||||
while ancestor_agent is not None:
|
||||
if isinstance(ancestor_agent, LlmAgent):
|
||||
return ancestor_agent.canonical_model
|
||||
ancestor_agent = ancestor_agent.parent_agent
|
||||
raise ValueError(f'No model found for {self.name}.')
|
||||
|
||||
def canonical_instruction(self, ctx: ReadonlyContext) -> str:
|
||||
"""The resolved self.instruction field to construct instruction for this agent.
|
||||
|
||||
This method is only for use by Agent Development Kit.
|
||||
"""
|
||||
if isinstance(self.instruction, str):
|
||||
return self.instruction
|
||||
else:
|
||||
return self.instruction(ctx)
|
||||
|
||||
def canonical_global_instruction(self, ctx: ReadonlyContext) -> str:
|
||||
"""The resolved self.instruction field to construct global instruction.
|
||||
|
||||
This method is only for use by Agent Development Kit.
|
||||
"""
|
||||
if isinstance(self.global_instruction, str):
|
||||
return self.global_instruction
|
||||
else:
|
||||
return self.global_instruction(ctx)
|
||||
|
||||
@property
|
||||
def canonical_tools(self) -> list[BaseTool]:
|
||||
"""The resolved self.tools field as a list of BaseTool.
|
||||
|
||||
This method is only for use by Agent Development Kit.
|
||||
"""
|
||||
return [_convert_tool_union_to_tool(tool) for tool in self.tools]
|
||||
|
||||
@property
|
||||
def _llm_flow(self) -> BaseLlmFlow:
|
||||
if (
|
||||
self.disallow_transfer_to_parent
|
||||
and self.disallow_transfer_to_peers
|
||||
and not self.sub_agents
|
||||
):
|
||||
return SingleFlow()
|
||||
else:
|
||||
return AutoFlow()
|
||||
|
||||
def __maybe_save_output_to_state(self, event: Event):
|
||||
"""Saves the model output to state if needed."""
|
||||
if (
|
||||
self.output_key
|
||||
and event.is_final_response()
|
||||
and event.content
|
||||
and event.content.parts
|
||||
):
|
||||
result = ''.join(
|
||||
[part.text if part.text else '' for part in event.content.parts]
|
||||
)
|
||||
if self.output_schema:
|
||||
result = self.output_schema.model_validate_json(result).model_dump(
|
||||
exclude_none=True
|
||||
)
|
||||
event.actions.state_delta[self.output_key] = result
|
||||
|
||||
@model_validator(mode='after')
|
||||
def __model_validator_after(self) -> LlmAgent:
|
||||
self.__check_output_schema()
|
||||
return self
|
||||
|
||||
def __check_output_schema(self):
|
||||
if not self.output_schema:
|
||||
return
|
||||
|
||||
if (
|
||||
not self.disallow_transfer_to_parent
|
||||
or not self.disallow_transfer_to_peers
|
||||
):
|
||||
logger.warning(
|
||||
'Invalid config for agent %s: output_schema cannot co-exist with'
|
||||
' agent transfer configurations. Setting'
|
||||
' disallow_transfer_to_parent=True, disallow_transfer_to_peers=True',
|
||||
self.name,
|
||||
)
|
||||
self.disallow_transfer_to_parent = True
|
||||
self.disallow_transfer_to_peers = True
|
||||
|
||||
if self.sub_agents:
|
||||
raise ValueError(
|
||||
f'Invalid config for agent {self.name}: if output_schema is set,'
|
||||
' sub_agents must be empty to disable agent transfer.'
|
||||
)
|
||||
|
||||
if self.tools:
|
||||
raise ValueError(
|
||||
f'Invalid config for agent {self.name}: if output_schema is set,'
|
||||
' tools must be empty'
|
||||
)
|
||||
|
||||
@field_validator('generate_content_config', mode='after')
|
||||
@classmethod
|
||||
def __validate_generate_content_config(
|
||||
cls, generate_content_config: Optional[types.GenerateContentConfig]
|
||||
) -> types.GenerateContentConfig:
|
||||
if not generate_content_config:
|
||||
return types.GenerateContentConfig()
|
||||
if generate_content_config.thinking_config:
|
||||
raise ValueError('Thinking config should be set via LlmAgent.planner.')
|
||||
if generate_content_config.tools:
|
||||
raise ValueError('All tools must be set via LlmAgent.tools.')
|
||||
if generate_content_config.system_instruction:
|
||||
raise ValueError(
|
||||
'System instruction must be set via LlmAgent.instruction.'
|
||||
)
|
||||
if generate_content_config.response_schema:
|
||||
raise ValueError(
|
||||
'Response schema must be set via LlmAgent.output_schema.'
|
||||
)
|
||||
return generate_content_config
|
||||
|
||||
|
||||
Agent: TypeAlias = LlmAgent
|
||||
@@ -0,0 +1,62 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Loop agent implementation."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import AsyncGenerator
|
||||
from typing import Optional
|
||||
|
||||
from typing_extensions import override
|
||||
|
||||
from ..agents.invocation_context import InvocationContext
|
||||
from ..events.event import Event
|
||||
from .base_agent import BaseAgent
|
||||
|
||||
|
||||
class LoopAgent(BaseAgent):
|
||||
"""A shell agent that run its sub-agents in a loop.
|
||||
|
||||
When sub-agent generates an event with escalate or max_iterations are
|
||||
reached, the loop agent will stop.
|
||||
"""
|
||||
|
||||
max_iterations: Optional[int] = None
|
||||
"""The maximum number of iterations to run the loop agent.
|
||||
|
||||
If not set, the loop agent will run indefinitely until a sub-agent
|
||||
escalates.
|
||||
"""
|
||||
|
||||
@override
|
||||
async def _run_async_impl(
|
||||
self, ctx: InvocationContext
|
||||
) -> AsyncGenerator[Event, None]:
|
||||
times_looped = 0
|
||||
while not self.max_iterations or times_looped < self.max_iterations:
|
||||
for sub_agent in self.sub_agents:
|
||||
async for event in sub_agent.run_async(ctx):
|
||||
yield event
|
||||
if event.actions.escalate:
|
||||
return
|
||||
times_looped += 1
|
||||
return
|
||||
|
||||
@override
|
||||
async def _run_live_impl(
|
||||
self, ctx: InvocationContext
|
||||
) -> AsyncGenerator[Event, None]:
|
||||
raise NotImplementedError('The behavior for run_live is not defined yet.')
|
||||
yield # AsyncGenerator requires having at least one yield statement
|
||||
@@ -0,0 +1,96 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Parallel agent implementation."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from typing import AsyncGenerator
|
||||
|
||||
from typing_extensions import override
|
||||
|
||||
from ..agents.invocation_context import InvocationContext
|
||||
from ..events.event import Event
|
||||
from .base_agent import BaseAgent
|
||||
|
||||
|
||||
def _set_branch_for_current_agent(
|
||||
current_agent: BaseAgent, invocation_context: InvocationContext
|
||||
):
|
||||
invocation_context.branch = (
|
||||
f"{invocation_context.branch}.{current_agent.name}"
|
||||
if invocation_context.branch
|
||||
else current_agent.name
|
||||
)
|
||||
|
||||
|
||||
async def _merge_agent_run(
|
||||
agent_runs: list[AsyncGenerator[Event, None]],
|
||||
) -> AsyncGenerator[Event, None]:
|
||||
"""Merges the agent run event generator.
|
||||
|
||||
This implementation guarantees for each agent, it won't move on until the
|
||||
generated event is processed by upstream runner.
|
||||
|
||||
Args:
|
||||
agent_runs: A list of async generators that yield events from each agent.
|
||||
|
||||
Yields:
|
||||
Event: The next event from the merged generator.
|
||||
"""
|
||||
tasks = [
|
||||
asyncio.create_task(events_for_one_agent.__anext__())
|
||||
for events_for_one_agent in agent_runs
|
||||
]
|
||||
pending_tasks = set(tasks)
|
||||
|
||||
while pending_tasks:
|
||||
done, pending_tasks = await asyncio.wait(
|
||||
pending_tasks, return_when=asyncio.FIRST_COMPLETED
|
||||
)
|
||||
for task in done:
|
||||
try:
|
||||
yield task.result()
|
||||
|
||||
# Find the generator that produced this event and move it on.
|
||||
for i, original_task in enumerate(tasks):
|
||||
if task == original_task:
|
||||
new_task = asyncio.create_task(agent_runs[i].__anext__())
|
||||
tasks[i] = new_task
|
||||
pending_tasks.add(new_task)
|
||||
break # stop iterating once found
|
||||
|
||||
except StopAsyncIteration:
|
||||
continue
|
||||
|
||||
|
||||
class ParallelAgent(BaseAgent):
|
||||
"""A shell agent that run its sub-agents in parallel in isolated manner.
|
||||
|
||||
This approach is beneficial for scenarios requiring multiple perspectives or
|
||||
attempts on a single task, such as:
|
||||
|
||||
- Running different algorithms simultaneously.
|
||||
- Generating multiple responses for review by a subsequent evaluation agent.
|
||||
"""
|
||||
|
||||
@override
|
||||
async def _run_async_impl(
|
||||
self, ctx: InvocationContext
|
||||
) -> AsyncGenerator[Event, None]:
|
||||
_set_branch_for_current_agent(self, ctx)
|
||||
agent_runs = [agent.run_async(ctx) for agent in self.sub_agents]
|
||||
async for event in _merge_agent_run(agent_runs):
|
||||
yield event
|
||||
@@ -0,0 +1,46 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from types import MappingProxyType
|
||||
from typing import Any
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .invocation_context import InvocationContext
|
||||
|
||||
|
||||
class ReadonlyContext:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
invocation_context: InvocationContext,
|
||||
) -> None:
|
||||
self._invocation_context = invocation_context
|
||||
|
||||
@property
|
||||
def invocation_id(self) -> str:
|
||||
"""The current invocation id."""
|
||||
return self._invocation_context.invocation_id
|
||||
|
||||
@property
|
||||
def agent_name(self) -> str:
|
||||
"""The name of the agent that is currently running."""
|
||||
return self._invocation_context.agent.name
|
||||
|
||||
@property
|
||||
def state(self) -> MappingProxyType[str, Any]:
|
||||
"""The state of the current session. READONLY field."""
|
||||
return MappingProxyType(self._invocation_context.session.state)
|
||||
@@ -0,0 +1,50 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import json
|
||||
from typing import AsyncGenerator
|
||||
|
||||
from pydantic import Field
|
||||
import requests
|
||||
from typing_extensions import override
|
||||
|
||||
from ..events.event import Event
|
||||
from .base_agent import BaseAgent
|
||||
from .invocation_context import InvocationContext
|
||||
|
||||
|
||||
class RemoteAgent(BaseAgent):
|
||||
"""Experimental, do not use."""
|
||||
|
||||
url: str
|
||||
|
||||
sub_agents: list[BaseAgent] = Field(
|
||||
default_factory=list, init=False, frozen=True
|
||||
)
|
||||
"""Sub-agent is disabled in RemoteAgent."""
|
||||
|
||||
@override
|
||||
async def _run_async_impl(
|
||||
self, ctx: InvocationContext
|
||||
) -> AsyncGenerator[Event, None]:
|
||||
data = {
|
||||
'invocation_id': ctx.invocation_id,
|
||||
'session': ctx.session.model_dump(exclude_none=True),
|
||||
}
|
||||
events = requests.post(self.url, data=json.dumps(data), timeout=120)
|
||||
events.raise_for_status()
|
||||
for event in events.json():
|
||||
e = Event.model_validate(event)
|
||||
e.author = self.name
|
||||
yield e
|
||||
@@ -0,0 +1,91 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from enum import Enum
|
||||
import logging
|
||||
import sys
|
||||
from typing import Optional
|
||||
|
||||
from google.genai import types
|
||||
from pydantic import BaseModel
|
||||
from pydantic import ConfigDict
|
||||
from pydantic import field_validator
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class StreamingMode(Enum):
|
||||
NONE = None
|
||||
SSE = 'sse'
|
||||
BIDI = 'bidi'
|
||||
|
||||
|
||||
class RunConfig(BaseModel):
|
||||
"""Configs for runtime behavior of agents."""
|
||||
|
||||
model_config = ConfigDict(
|
||||
extra='forbid',
|
||||
)
|
||||
|
||||
speech_config: Optional[types.SpeechConfig] = None
|
||||
"""Speech configuration for the live agent."""
|
||||
|
||||
response_modalities: Optional[list[str]] = None
|
||||
"""The output modalities. If not set, it's default to AUDIO."""
|
||||
|
||||
save_input_blobs_as_artifacts: bool = False
|
||||
"""Whether or not to save the input blobs as artifacts."""
|
||||
|
||||
support_cfc: bool = False
|
||||
"""
|
||||
Whether to support CFC (Compositional Function Calling). Only applicable for
|
||||
StreamingMode.SSE. If it's true. the LIVE API will be invoked. Since only LIVE
|
||||
API supports CFC
|
||||
|
||||
.. warning::
|
||||
This feature is **experimental** and its API or behavior may change
|
||||
in future releases.
|
||||
"""
|
||||
|
||||
streaming_mode: StreamingMode = StreamingMode.NONE
|
||||
"""Streaming mode, None or StreamingMode.SSE or StreamingMode.BIDI."""
|
||||
|
||||
output_audio_transcription: Optional[types.AudioTranscriptionConfig] = None
|
||||
"""Output transcription for live agents with audio response."""
|
||||
|
||||
max_llm_calls: int = 500
|
||||
"""
|
||||
A limit on the total number of llm calls for a given run.
|
||||
|
||||
Valid Values:
|
||||
- More than 0 and less than sys.maxsize: The bound on the number of llm
|
||||
calls is enforced, if the value is set in this range.
|
||||
- Less than or equal to 0: This allows for unbounded number of llm calls.
|
||||
"""
|
||||
|
||||
@field_validator('max_llm_calls', mode='after')
|
||||
@classmethod
|
||||
def validate_max_llm_calls(cls, value: int) -> int:
|
||||
if value == sys.maxsize:
|
||||
raise ValueError(f'max_llm_calls should be less than {sys.maxsize}.')
|
||||
elif value <= 0:
|
||||
logger.warning(
|
||||
'max_llm_calls is less than or equal to 0. This will result in'
|
||||
' no enforcement on total number of llm calls that will be made for a'
|
||||
' run. This may not be ideal, as this could result in a never'
|
||||
' ending communication between the model and the agent in certain'
|
||||
' cases.',
|
||||
)
|
||||
|
||||
return value
|
||||
@@ -0,0 +1,45 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Sequential agent implementation."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import AsyncGenerator
|
||||
|
||||
from typing_extensions import override
|
||||
|
||||
from ..agents.invocation_context import InvocationContext
|
||||
from ..events.event import Event
|
||||
from .base_agent import BaseAgent
|
||||
|
||||
|
||||
class SequentialAgent(BaseAgent):
|
||||
"""A shell agent that run its sub-agents in sequence."""
|
||||
|
||||
@override
|
||||
async def _run_async_impl(
|
||||
self, ctx: InvocationContext
|
||||
) -> AsyncGenerator[Event, None]:
|
||||
for sub_agent in self.sub_agents:
|
||||
async for event in sub_agent.run_async(ctx):
|
||||
yield event
|
||||
|
||||
@override
|
||||
async def _run_live_impl(
|
||||
self, ctx: InvocationContext
|
||||
) -> AsyncGenerator[Event, None]:
|
||||
for sub_agent in self.sub_agents:
|
||||
async for event in sub_agent.run_live(ctx):
|
||||
yield event
|
||||
@@ -0,0 +1,34 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Union
|
||||
|
||||
from google.genai import types
|
||||
from pydantic import BaseModel
|
||||
from pydantic import ConfigDict
|
||||
|
||||
|
||||
class TranscriptionEntry(BaseModel):
|
||||
"""Store the data that can be used for transcription."""
|
||||
|
||||
model_config = ConfigDict(
|
||||
arbitrary_types_allowed=True,
|
||||
extra='forbid',
|
||||
)
|
||||
|
||||
role: str
|
||||
"""The role that created this data, typically "user" or "model"""
|
||||
|
||||
data: Union[types.Blob, types.Content]
|
||||
"""The data that can be used for transcription"""
|
||||
@@ -0,0 +1,23 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .base_artifact_service import BaseArtifactService
|
||||
from .gcs_artifact_service import GcsArtifactService
|
||||
from .in_memory_artifact_service import InMemoryArtifactService
|
||||
|
||||
__all__ = [
|
||||
'BaseArtifactService',
|
||||
'GcsArtifactService',
|
||||
'InMemoryArtifactService',
|
||||
]
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,128 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Abstract base class for artifact services."""
|
||||
|
||||
from abc import ABC
|
||||
from abc import abstractmethod
|
||||
from typing import Optional
|
||||
|
||||
from google.genai import types
|
||||
|
||||
|
||||
class BaseArtifactService(ABC):
|
||||
"""Abstract base class for artifact services."""
|
||||
|
||||
@abstractmethod
|
||||
def save_artifact(
|
||||
self,
|
||||
*,
|
||||
app_name: str,
|
||||
user_id: str,
|
||||
session_id: str,
|
||||
filename: str,
|
||||
artifact: types.Part,
|
||||
) -> int:
|
||||
"""Saves an artifact to the artifact service storage.
|
||||
|
||||
The artifact is a file identified by the app name, user ID, session ID, and
|
||||
filename. After saving the artifact, a revision ID is returned to identify
|
||||
the artifact version.
|
||||
|
||||
Args:
|
||||
app_name: The app name.
|
||||
user_id: The user ID.
|
||||
session_id: The session ID.
|
||||
filename: The filename of the artifact.
|
||||
artifact: The artifact to save.
|
||||
|
||||
Returns:
|
||||
The revision ID. The first version of the artifact has a revision ID of 0.
|
||||
This is incremented by 1 after each successful save.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def load_artifact(
|
||||
self,
|
||||
*,
|
||||
app_name: str,
|
||||
user_id: str,
|
||||
session_id: str,
|
||||
filename: str,
|
||||
version: Optional[int] = None,
|
||||
) -> Optional[types.Part]:
|
||||
"""Gets an artifact from the artifact service storage.
|
||||
|
||||
The artifact is a file identified by the app name, user ID, session ID, and
|
||||
filename.
|
||||
|
||||
Args:
|
||||
app_name: The app name.
|
||||
user_id: The user ID.
|
||||
session_id: The session ID.
|
||||
filename: The filename of the artifact.
|
||||
version: The version of the artifact. If None, the latest version will be
|
||||
returned.
|
||||
|
||||
Returns:
|
||||
The artifact or None if not found.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def list_artifact_keys(
|
||||
self, *, app_name: str, user_id: str, session_id: str
|
||||
) -> list[str]:
|
||||
"""Lists all the artifact filenames within a session.
|
||||
|
||||
Args:
|
||||
app_name: The name of the application.
|
||||
user_id: The ID of the user.
|
||||
session_id: The ID of the session.
|
||||
|
||||
Returns:
|
||||
A list of all artifact filenames within a session.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def delete_artifact(
|
||||
self, *, app_name: str, user_id: str, session_id: str, filename: str
|
||||
) -> None:
|
||||
"""Deletes an artifact.
|
||||
|
||||
Args:
|
||||
app_name: The name of the application.
|
||||
user_id: The ID of the user.
|
||||
session_id: The ID of the session.
|
||||
filename: The name of the artifact file.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def list_versions(
|
||||
self, *, app_name: str, user_id: str, session_id: str, filename: str
|
||||
) -> list[int]:
|
||||
"""Lists all versions of an artifact.
|
||||
|
||||
Args:
|
||||
app_name: The name of the application.
|
||||
user_id: The ID of the user.
|
||||
session_id: The ID of the session.
|
||||
filename: The name of the artifact file.
|
||||
|
||||
Returns:
|
||||
A list of all available versions of the artifact.
|
||||
"""
|
||||
pass
|
||||
@@ -0,0 +1,195 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""An artifact service implementation using Google Cloud Storage (GCS)."""
|
||||
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
from google.cloud import storage
|
||||
from google.genai import types
|
||||
from typing_extensions import override
|
||||
|
||||
from .base_artifact_service import BaseArtifactService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class GcsArtifactService(BaseArtifactService):
|
||||
"""An artifact service implementation using Google Cloud Storage (GCS)."""
|
||||
|
||||
def __init__(self, bucket_name: str, **kwargs):
|
||||
"""Initializes the GcsArtifactService.
|
||||
|
||||
Args:
|
||||
bucket_name: The name of the bucket to use.
|
||||
**kwargs: Keyword arguments to pass to the Google Cloud Storage client.
|
||||
"""
|
||||
self.bucket_name = bucket_name
|
||||
self.storage_client = storage.Client(**kwargs)
|
||||
self.bucket = self.storage_client.bucket(self.bucket_name)
|
||||
|
||||
def _file_has_user_namespace(self, filename: str) -> bool:
|
||||
"""Checks if the filename has a user namespace.
|
||||
|
||||
Args:
|
||||
filename: The filename to check.
|
||||
|
||||
Returns:
|
||||
True if the filename has a user namespace (starts with "user:"),
|
||||
False otherwise.
|
||||
"""
|
||||
return filename.startswith("user:")
|
||||
|
||||
def _get_blob_name(
|
||||
self,
|
||||
app_name: str,
|
||||
user_id: str,
|
||||
session_id: str,
|
||||
filename: str,
|
||||
version: int,
|
||||
) -> str:
|
||||
"""Constructs the blob name in GCS.
|
||||
|
||||
Args:
|
||||
app_name: The name of the application.
|
||||
user_id: The ID of the user.
|
||||
session_id: The ID of the session.
|
||||
filename: The name of the artifact file.
|
||||
version: The version of the artifact.
|
||||
|
||||
Returns:
|
||||
The constructed blob name in GCS.
|
||||
"""
|
||||
if self._file_has_user_namespace(filename):
|
||||
return f"{app_name}/{user_id}/user/{filename}/{version}"
|
||||
return f"{app_name}/{user_id}/{session_id}/{filename}/{version}"
|
||||
|
||||
@override
|
||||
def save_artifact(
|
||||
self,
|
||||
*,
|
||||
app_name: str,
|
||||
user_id: str,
|
||||
session_id: str,
|
||||
filename: str,
|
||||
artifact: types.Part,
|
||||
) -> int:
|
||||
versions = self.list_versions(
|
||||
app_name=app_name,
|
||||
user_id=user_id,
|
||||
session_id=session_id,
|
||||
filename=filename,
|
||||
)
|
||||
version = 0 if not versions else max(versions) + 1
|
||||
|
||||
blob_name = self._get_blob_name(
|
||||
app_name, user_id, session_id, filename, version
|
||||
)
|
||||
blob = self.bucket.blob(blob_name)
|
||||
|
||||
blob.upload_from_string(
|
||||
data=artifact.inline_data.data,
|
||||
content_type=artifact.inline_data.mime_type,
|
||||
)
|
||||
|
||||
return version
|
||||
|
||||
@override
|
||||
def load_artifact(
|
||||
self,
|
||||
*,
|
||||
app_name: str,
|
||||
user_id: str,
|
||||
session_id: str,
|
||||
filename: str,
|
||||
version: Optional[int] = None,
|
||||
) -> Optional[types.Part]:
|
||||
if version is None:
|
||||
versions = self.list_versions(
|
||||
app_name=app_name,
|
||||
user_id=user_id,
|
||||
session_id=session_id,
|
||||
filename=filename,
|
||||
)
|
||||
if not versions:
|
||||
return None
|
||||
version = max(versions)
|
||||
|
||||
blob_name = self._get_blob_name(
|
||||
app_name, user_id, session_id, filename, version
|
||||
)
|
||||
blob = self.bucket.blob(blob_name)
|
||||
|
||||
artifact_bytes = blob.download_as_bytes()
|
||||
if not artifact_bytes:
|
||||
return None
|
||||
artifact = types.Part.from_bytes(
|
||||
data=artifact_bytes, mime_type=blob.content_type
|
||||
)
|
||||
return artifact
|
||||
|
||||
@override
|
||||
def list_artifact_keys(
|
||||
self, *, app_name: str, user_id: str, session_id: str
|
||||
) -> list[str]:
|
||||
filenames = set()
|
||||
|
||||
session_prefix = f"{app_name}/{user_id}/{session_id}/"
|
||||
session_blobs = self.storage_client.list_blobs(
|
||||
self.bucket, prefix=session_prefix
|
||||
)
|
||||
for blob in session_blobs:
|
||||
_, _, _, filename, _ = blob.name.split("/")
|
||||
filenames.add(filename)
|
||||
|
||||
user_namespace_prefix = f"{app_name}/{user_id}/user/"
|
||||
user_namespace_blobs = self.storage_client.list_blobs(
|
||||
self.bucket, prefix=user_namespace_prefix
|
||||
)
|
||||
for blob in user_namespace_blobs:
|
||||
_, _, _, filename, _ = blob.name.split("/")
|
||||
filenames.add(filename)
|
||||
|
||||
return sorted(list(filenames))
|
||||
|
||||
@override
|
||||
def delete_artifact(
|
||||
self, *, app_name: str, user_id: str, session_id: str, filename: str
|
||||
) -> None:
|
||||
versions = self.list_versions(
|
||||
app_name=app_name,
|
||||
user_id=user_id,
|
||||
session_id=session_id,
|
||||
filename=filename,
|
||||
)
|
||||
for version in versions:
|
||||
blob_name = self._get_blob_name(
|
||||
app_name, user_id, session_id, filename, version
|
||||
)
|
||||
blob = self.bucket.blob(blob_name)
|
||||
blob.delete()
|
||||
return
|
||||
|
||||
@override
|
||||
def list_versions(
|
||||
self, *, app_name: str, user_id: str, session_id: str, filename: str
|
||||
) -> list[int]:
|
||||
prefix = self._get_blob_name(app_name, user_id, session_id, filename, "")
|
||||
blobs = self.storage_client.list_blobs(self.bucket, prefix=prefix)
|
||||
versions = []
|
||||
for blob in blobs:
|
||||
_, _, _, _, version = blob.name.split("/")
|
||||
versions.append(int(version))
|
||||
return versions
|
||||
@@ -0,0 +1,133 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""An in-memory implementation of the artifact service."""
|
||||
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
from google.genai import types
|
||||
from pydantic import BaseModel
|
||||
from pydantic import Field
|
||||
from typing_extensions import override
|
||||
|
||||
from .base_artifact_service import BaseArtifactService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class InMemoryArtifactService(BaseArtifactService, BaseModel):
|
||||
"""An in-memory implementation of the artifact service."""
|
||||
|
||||
artifacts: dict[str, list[types.Part]] = Field(default_factory=dict)
|
||||
|
||||
def _file_has_user_namespace(self, filename: str) -> bool:
|
||||
"""Checks if the filename has a user namespace.
|
||||
|
||||
Args:
|
||||
filename: The filename to check.
|
||||
|
||||
Returns:
|
||||
True if the filename has a user namespace (starts with "user:"),
|
||||
False otherwise.
|
||||
"""
|
||||
return filename.startswith("user:")
|
||||
|
||||
def _artifact_path(
|
||||
self, app_name: str, user_id: str, session_id: str, filename: str
|
||||
) -> str:
|
||||
"""Constructs the artifact path.
|
||||
|
||||
Args:
|
||||
app_name: The name of the application.
|
||||
user_id: The ID of the user.
|
||||
session_id: The ID of the session.
|
||||
filename: The name of the artifact file.
|
||||
|
||||
Returns:
|
||||
The constructed artifact path.
|
||||
"""
|
||||
if self._file_has_user_namespace(filename):
|
||||
return f"{app_name}/{user_id}/user/{filename}"
|
||||
return f"{app_name}/{user_id}/{session_id}/{filename}"
|
||||
|
||||
@override
|
||||
def save_artifact(
|
||||
self,
|
||||
*,
|
||||
app_name: str,
|
||||
user_id: str,
|
||||
session_id: str,
|
||||
filename: str,
|
||||
artifact: types.Part,
|
||||
) -> int:
|
||||
path = self._artifact_path(app_name, user_id, session_id, filename)
|
||||
if path not in self.artifacts:
|
||||
self.artifacts[path] = []
|
||||
version = len(self.artifacts[path])
|
||||
self.artifacts[path].append(artifact)
|
||||
return version
|
||||
|
||||
@override
|
||||
def load_artifact(
|
||||
self,
|
||||
*,
|
||||
app_name: str,
|
||||
user_id: str,
|
||||
session_id: str,
|
||||
filename: str,
|
||||
version: Optional[int] = None,
|
||||
) -> Optional[types.Part]:
|
||||
path = self._artifact_path(app_name, user_id, session_id, filename)
|
||||
versions = self.artifacts.get(path)
|
||||
if not versions:
|
||||
return None
|
||||
if version is None:
|
||||
version = -1
|
||||
return versions[version]
|
||||
|
||||
@override
|
||||
def list_artifact_keys(
|
||||
self, *, app_name: str, user_id: str, session_id: str
|
||||
) -> list[str]:
|
||||
session_prefix = f"{app_name}/{user_id}/{session_id}/"
|
||||
usernamespace_prefix = f"{app_name}/{user_id}/user/"
|
||||
filenames = []
|
||||
for path in self.artifacts:
|
||||
if path.startswith(session_prefix):
|
||||
filename = path.removeprefix(session_prefix)
|
||||
filenames.append(filename)
|
||||
elif path.startswith(usernamespace_prefix):
|
||||
filename = path.removeprefix(usernamespace_prefix)
|
||||
filenames.append(filename)
|
||||
return sorted(filenames)
|
||||
|
||||
@override
|
||||
def delete_artifact(
|
||||
self, *, app_name: str, user_id: str, session_id: str, filename: str
|
||||
) -> None:
|
||||
path = self._artifact_path(app_name, user_id, session_id, filename)
|
||||
if not self.artifacts.get(path):
|
||||
return None
|
||||
self.artifacts.pop(path, None)
|
||||
|
||||
@override
|
||||
def list_versions(
|
||||
self, *, app_name: str, user_id: str, session_id: str, filename: str
|
||||
) -> list[int]:
|
||||
path = self._artifact_path(app_name, user_id, session_id, filename)
|
||||
versions = self.artifacts.get(path)
|
||||
if not versions:
|
||||
return []
|
||||
return list(range(len(versions)))
|
||||
@@ -0,0 +1,22 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .auth_credential import AuthCredential
|
||||
from .auth_credential import AuthCredentialTypes
|
||||
from .auth_credential import OAuth2Auth
|
||||
from .auth_handler import AuthHandler
|
||||
from .auth_schemes import AuthScheme
|
||||
from .auth_schemes import AuthSchemeType
|
||||
from .auth_schemes import OpenIdConnectWithConfig
|
||||
from .auth_tool import AuthConfig
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,221 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
from typing import Dict
|
||||
from typing import List
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import Field
|
||||
|
||||
|
||||
class BaseModelWithConfig(BaseModel):
|
||||
model_config = {"extra": "allow"}
|
||||
|
||||
|
||||
class HttpCredentials(BaseModelWithConfig):
|
||||
"""Represents the secret token value for HTTP authentication, like user name, password, oauth token, etc."""
|
||||
|
||||
username: Optional[str] = None
|
||||
password: Optional[str] = None
|
||||
token: Optional[str] = None
|
||||
|
||||
@classmethod
|
||||
def model_validate(cls, data: Dict[str, Any]) -> "HttpCredentials":
|
||||
return cls(
|
||||
username=data.get("username"),
|
||||
password=data.get("password"),
|
||||
token=data.get("token"),
|
||||
)
|
||||
|
||||
|
||||
class HttpAuth(BaseModelWithConfig):
|
||||
"""The credentials and metadata for HTTP authentication."""
|
||||
|
||||
# The name of the HTTP Authorization scheme to be used in the Authorization
|
||||
# header as defined in RFC7235. The values used SHOULD be registered in the
|
||||
# IANA Authentication Scheme registry.
|
||||
# Examples: 'basic', 'bearer'
|
||||
scheme: str
|
||||
credentials: HttpCredentials
|
||||
|
||||
|
||||
class OAuth2Auth(BaseModelWithConfig):
|
||||
"""Represents credential value and its metadata for a OAuth2 credential."""
|
||||
|
||||
client_id: Optional[str] = None
|
||||
client_secret: Optional[str] = None
|
||||
# tool or adk can generate the auth_uri with the state info thus client
|
||||
# can verify the state
|
||||
auth_uri: Optional[str] = None
|
||||
state: Optional[str] = None
|
||||
# tool or adk can decide the redirect_uri if they don't want client to decide
|
||||
redirect_uri: Optional[str] = None
|
||||
auth_response_uri: Optional[str] = None
|
||||
auth_code: Optional[str] = None
|
||||
access_token: Optional[str] = None
|
||||
refresh_token: Optional[str] = None
|
||||
|
||||
|
||||
class ServiceAccountCredential(BaseModelWithConfig):
|
||||
"""Represents Google Service Account configuration.
|
||||
|
||||
Attributes:
|
||||
type: The type should be "service_account".
|
||||
project_id: The project ID.
|
||||
private_key_id: The ID of the private key.
|
||||
private_key: The private key.
|
||||
client_email: The client email.
|
||||
client_id: The client ID.
|
||||
auth_uri: The authorization URI.
|
||||
token_uri: The token URI.
|
||||
auth_provider_x509_cert_url: URL for auth provider's X.509 cert.
|
||||
client_x509_cert_url: URL for the client's X.509 cert.
|
||||
universe_domain: The universe domain.
|
||||
|
||||
Example:
|
||||
|
||||
config = ServiceAccountCredential(
|
||||
type_="service_account",
|
||||
project_id="your_project_id",
|
||||
private_key_id="your_private_key_id",
|
||||
private_key="-----BEGIN PRIVATE KEY-----...",
|
||||
client_email="...@....iam.gserviceaccount.com",
|
||||
client_id="your_client_id",
|
||||
auth_uri="https://accounts.google.com/o/oauth2/auth",
|
||||
token_uri="https://oauth2.googleapis.com/token",
|
||||
auth_provider_x509_cert_url="https://www.googleapis.com/oauth2/v1/certs",
|
||||
client_x509_cert_url="https://www.googleapis.com/robot/v1/metadata/x509/...",
|
||||
universe_domain="googleapis.com"
|
||||
)
|
||||
|
||||
|
||||
config = ServiceAccountConfig.model_construct(**{
|
||||
...service account config dict
|
||||
})
|
||||
"""
|
||||
|
||||
type_: str = Field("", alias="type")
|
||||
project_id: str
|
||||
private_key_id: str
|
||||
private_key: str
|
||||
client_email: str
|
||||
client_id: str
|
||||
auth_uri: str
|
||||
token_uri: str
|
||||
auth_provider_x509_cert_url: str
|
||||
client_x509_cert_url: str
|
||||
universe_domain: str
|
||||
|
||||
|
||||
class ServiceAccount(BaseModelWithConfig):
|
||||
"""Represents Google Service Account configuration."""
|
||||
|
||||
service_account_credential: Optional[ServiceAccountCredential] = None
|
||||
scopes: List[str]
|
||||
use_default_credential: Optional[bool] = False
|
||||
|
||||
|
||||
class AuthCredentialTypes(str, Enum):
|
||||
"""Represents the type of authentication credential."""
|
||||
|
||||
# API Key credential:
|
||||
# https://swagger.io/docs/specification/v3_0/authentication/api-keys/
|
||||
API_KEY = "apiKey"
|
||||
|
||||
# Credentials for HTTP Auth schemes:
|
||||
# https://www.iana.org/assignments/http-authschemes/http-authschemes.xhtml
|
||||
HTTP = "http"
|
||||
|
||||
# OAuth2 credentials:
|
||||
# https://swagger.io/docs/specification/v3_0/authentication/oauth2/
|
||||
OAUTH2 = "oauth2"
|
||||
|
||||
# OpenID Connect credentials:
|
||||
# https://swagger.io/docs/specification/v3_0/authentication/openid-connect-discovery/
|
||||
OPEN_ID_CONNECT = "openIdConnect"
|
||||
|
||||
# Service Account credentials:
|
||||
# https://cloud.google.com/iam/docs/service-account-creds
|
||||
SERVICE_ACCOUNT = "serviceAccount"
|
||||
|
||||
|
||||
class AuthCredential(BaseModelWithConfig):
|
||||
"""Data class representing an authentication credential.
|
||||
|
||||
To exchange for the actual credential, please use
|
||||
CredentialExchanger.exchange_credential().
|
||||
|
||||
Examples: API Key Auth
|
||||
AuthCredential(
|
||||
auth_type=AuthCredentialTypes.API_KEY,
|
||||
api_key="1234",
|
||||
)
|
||||
|
||||
Example: HTTP Auth
|
||||
AuthCredential(
|
||||
auth_type=AuthCredentialTypes.HTTP,
|
||||
http=HttpAuth(
|
||||
scheme="basic",
|
||||
credentials=HttpCredentials(username="user", password="password"),
|
||||
),
|
||||
)
|
||||
|
||||
Example: OAuth2 Bearer Token in HTTP Header
|
||||
AuthCredential(
|
||||
auth_type=AuthCredentialTypes.HTTP,
|
||||
http=HttpAuth(
|
||||
scheme="bearer",
|
||||
credentials=HttpCredentials(token="eyAkaknabna...."),
|
||||
),
|
||||
)
|
||||
|
||||
Example: OAuth2 Auth with Authorization Code Flow
|
||||
AuthCredential(
|
||||
auth_type=AuthCredentialTypes.OAUTH2,
|
||||
oauth2=OAuth2Auth(
|
||||
client_id="1234",
|
||||
client_secret="secret",
|
||||
),
|
||||
)
|
||||
|
||||
Example: OpenID Connect Auth
|
||||
AuthCredential(
|
||||
auth_type=AuthCredentialTypes.OPEN_ID_CONNECT,
|
||||
oauth2=OAuth2Auth(
|
||||
client_id="1234",
|
||||
client_secret="secret",
|
||||
redirect_uri="https://example.com",
|
||||
scopes=["scope1", "scope2"],
|
||||
),
|
||||
)
|
||||
|
||||
Example: Auth with resource reference
|
||||
AuthCredential(
|
||||
auth_type=AuthCredentialTypes.API_KEY,
|
||||
resource_ref="projects/1234/locations/us-central1/resources/resource1",
|
||||
)
|
||||
"""
|
||||
|
||||
auth_type: AuthCredentialTypes
|
||||
# Resource reference for the credential.
|
||||
# This will be supported in the future.
|
||||
resource_ref: Optional[str] = None
|
||||
|
||||
api_key: Optional[str] = None
|
||||
http: Optional[HttpAuth] = None
|
||||
service_account: Optional[ServiceAccount] = None
|
||||
oauth2: Optional[OAuth2Auth] = None
|
||||
@@ -0,0 +1,272 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from fastapi.openapi.models import OAuth2
|
||||
from fastapi.openapi.models import SecurityBase
|
||||
|
||||
from .auth_credential import AuthCredential
|
||||
from .auth_credential import AuthCredentialTypes
|
||||
from .auth_credential import OAuth2Auth
|
||||
from .auth_schemes import AuthSchemeType
|
||||
from .auth_schemes import OAuthGrantType
|
||||
from .auth_schemes import OpenIdConnectWithConfig
|
||||
from .auth_tool import AuthConfig
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..sessions.state import State
|
||||
|
||||
try:
|
||||
from authlib.integrations.requests_client import OAuth2Session
|
||||
|
||||
SUPPORT_TOKEN_EXCHANGE = True
|
||||
except ImportError:
|
||||
SUPPORT_TOKEN_EXCHANGE = False
|
||||
|
||||
|
||||
class AuthHandler:
|
||||
|
||||
def __init__(self, auth_config: AuthConfig):
|
||||
self.auth_config = auth_config
|
||||
|
||||
def exchange_auth_token(
|
||||
self,
|
||||
) -> AuthCredential:
|
||||
"""Generates an auth token from the authorization response.
|
||||
|
||||
Returns:
|
||||
An AuthCredential object containing the access token.
|
||||
|
||||
Raises:
|
||||
ValueError: If the token endpoint is not configured in the auth
|
||||
scheme.
|
||||
AuthCredentialMissingError: If the access token cannot be retrieved
|
||||
from the token endpoint.
|
||||
"""
|
||||
auth_scheme = self.auth_config.auth_scheme
|
||||
auth_credential = self.auth_config.exchanged_auth_credential
|
||||
if not SUPPORT_TOKEN_EXCHANGE:
|
||||
return auth_credential
|
||||
if isinstance(auth_scheme, OpenIdConnectWithConfig):
|
||||
if not hasattr(auth_scheme, "token_endpoint"):
|
||||
return self.auth_config.exchanged_auth_credential
|
||||
token_endpoint = auth_scheme.token_endpoint
|
||||
scopes = auth_scheme.scopes
|
||||
elif isinstance(auth_scheme, OAuth2):
|
||||
if (
|
||||
not auth_scheme.flows.authorizationCode
|
||||
or not auth_scheme.flows.authorizationCode.tokenUrl
|
||||
):
|
||||
return self.auth_config.exchanged_auth_credential
|
||||
token_endpoint = auth_scheme.flows.authorizationCode.tokenUrl
|
||||
scopes = list(auth_scheme.flows.authorizationCode.scopes.keys())
|
||||
else:
|
||||
return self.auth_config.exchanged_auth_credential
|
||||
|
||||
if (
|
||||
not auth_credential
|
||||
or not auth_credential.oauth2
|
||||
or not auth_credential.oauth2.client_id
|
||||
or not auth_credential.oauth2.client_secret
|
||||
or auth_credential.oauth2.access_token
|
||||
or auth_credential.oauth2.refresh_token
|
||||
):
|
||||
return self.auth_config.exchanged_auth_credential
|
||||
|
||||
client = OAuth2Session(
|
||||
auth_credential.oauth2.client_id,
|
||||
auth_credential.oauth2.client_secret,
|
||||
scope=" ".join(scopes),
|
||||
redirect_uri=auth_credential.oauth2.redirect_uri,
|
||||
state=auth_credential.oauth2.state,
|
||||
)
|
||||
tokens = client.fetch_token(
|
||||
token_endpoint,
|
||||
authorization_response=auth_credential.oauth2.auth_response_uri,
|
||||
code=auth_credential.oauth2.auth_code,
|
||||
grant_type=OAuthGrantType.AUTHORIZATION_CODE,
|
||||
)
|
||||
|
||||
updated_credential = AuthCredential(
|
||||
auth_type=AuthCredentialTypes.OAUTH2,
|
||||
oauth2=OAuth2Auth(
|
||||
access_token=tokens.get("access_token"),
|
||||
refresh_token=tokens.get("refresh_token"),
|
||||
),
|
||||
)
|
||||
return updated_credential
|
||||
|
||||
def parse_and_store_auth_response(self, state: State) -> None:
|
||||
|
||||
credential_key = self.get_credential_key()
|
||||
|
||||
state[credential_key] = self.auth_config.exchanged_auth_credential
|
||||
if not isinstance(
|
||||
self.auth_config.auth_scheme, SecurityBase
|
||||
) or self.auth_config.auth_scheme.type_ not in (
|
||||
AuthSchemeType.oauth2,
|
||||
AuthSchemeType.openIdConnect,
|
||||
):
|
||||
return
|
||||
|
||||
state[credential_key] = self.exchange_auth_token()
|
||||
|
||||
def _validate(self) -> None:
|
||||
if not self.auth_scheme:
|
||||
raise ValueError("auth_scheme is empty.")
|
||||
|
||||
def get_auth_response(self, state: State) -> AuthCredential:
|
||||
credential_key = self.get_credential_key()
|
||||
return state.get(credential_key, None)
|
||||
|
||||
def generate_auth_request(self) -> AuthConfig:
|
||||
if not isinstance(
|
||||
self.auth_config.auth_scheme, SecurityBase
|
||||
) or self.auth_config.auth_scheme.type_ not in (
|
||||
AuthSchemeType.oauth2,
|
||||
AuthSchemeType.openIdConnect,
|
||||
):
|
||||
return self.auth_config.model_copy(deep=True)
|
||||
|
||||
# auth_uri already in exchanged credential
|
||||
if (
|
||||
self.auth_config.exchanged_auth_credential
|
||||
and self.auth_config.exchanged_auth_credential.oauth2
|
||||
and self.auth_config.exchanged_auth_credential.oauth2.auth_uri
|
||||
):
|
||||
return self.auth_config.model_copy(deep=True)
|
||||
|
||||
# Check if raw_auth_credential exists
|
||||
if not self.auth_config.raw_auth_credential:
|
||||
raise ValueError(
|
||||
f"Auth Scheme {self.auth_config.auth_scheme.type_} requires"
|
||||
" auth_credential."
|
||||
)
|
||||
|
||||
# Check if oauth2 exists in raw_auth_credential
|
||||
if not self.auth_config.raw_auth_credential.oauth2:
|
||||
raise ValueError(
|
||||
f"Auth Scheme {self.auth_config.auth_scheme.type_} requires oauth2 in"
|
||||
" auth_credential."
|
||||
)
|
||||
|
||||
# auth_uri in raw credential
|
||||
if self.auth_config.raw_auth_credential.oauth2.auth_uri:
|
||||
return AuthConfig(
|
||||
auth_scheme=self.auth_config.auth_scheme,
|
||||
raw_auth_credential=self.auth_config.raw_auth_credential,
|
||||
exchanged_auth_credential=self.auth_config.raw_auth_credential.model_copy(
|
||||
deep=True
|
||||
),
|
||||
)
|
||||
|
||||
# Check for client_id and client_secret
|
||||
if (
|
||||
not self.auth_config.raw_auth_credential.oauth2.client_id
|
||||
or not self.auth_config.raw_auth_credential.oauth2.client_secret
|
||||
):
|
||||
raise ValueError(
|
||||
f"Auth Scheme {self.auth_config.auth_scheme.type_} requires both"
|
||||
" client_id and client_secret in auth_credential.oauth2."
|
||||
)
|
||||
|
||||
# Generate new auth URI
|
||||
exchanged_credential = self.generate_auth_uri()
|
||||
return AuthConfig(
|
||||
auth_scheme=self.auth_config.auth_scheme,
|
||||
raw_auth_credential=self.auth_config.raw_auth_credential,
|
||||
exchanged_auth_credential=exchanged_credential,
|
||||
)
|
||||
|
||||
def get_credential_key(self) -> str:
|
||||
"""Generates a unique key for the given auth scheme and credential."""
|
||||
auth_scheme = self.auth_config.auth_scheme
|
||||
auth_credential = self.auth_config.raw_auth_credential
|
||||
if auth_scheme.model_extra:
|
||||
auth_scheme = auth_scheme.model_copy(deep=True)
|
||||
auth_scheme.model_extra.clear()
|
||||
scheme_name = (
|
||||
f"{auth_scheme.type_.name}_{hash(auth_scheme.model_dump_json())}"
|
||||
if auth_scheme
|
||||
else ""
|
||||
)
|
||||
if auth_credential.model_extra:
|
||||
auth_credential = auth_credential.model_copy(deep=True)
|
||||
auth_credential.model_extra.clear()
|
||||
credential_name = (
|
||||
f"{auth_credential.auth_type.value}_{hash(auth_credential.model_dump_json())}"
|
||||
if auth_credential
|
||||
else ""
|
||||
)
|
||||
|
||||
return f"temp:adk_{scheme_name}_{credential_name}"
|
||||
|
||||
def generate_auth_uri(
|
||||
self,
|
||||
) -> AuthCredential:
|
||||
"""Generates an response containing the auth uri for user to sign in.
|
||||
|
||||
Returns:
|
||||
An AuthCredential object containing the auth URI and state.
|
||||
|
||||
Raises:
|
||||
ValueError: If the authorization endpoint is not configured in the auth
|
||||
scheme.
|
||||
"""
|
||||
auth_scheme = self.auth_config.auth_scheme
|
||||
auth_credential = self.auth_config.raw_auth_credential
|
||||
|
||||
if isinstance(auth_scheme, OpenIdConnectWithConfig):
|
||||
authorization_endpoint = auth_scheme.authorization_endpoint
|
||||
scopes = auth_scheme.scopes
|
||||
else:
|
||||
authorization_endpoint = (
|
||||
auth_scheme.flows.implicit
|
||||
and auth_scheme.flows.implicit.authorizationUrl
|
||||
or auth_scheme.flows.authorizationCode
|
||||
and auth_scheme.flows.authorizationCode.authorizationUrl
|
||||
or auth_scheme.flows.clientCredentials
|
||||
and auth_scheme.flows.clientCredentials.tokenUrl
|
||||
or auth_scheme.flows.password
|
||||
and auth_scheme.flows.password.tokenUrl
|
||||
)
|
||||
scopes = (
|
||||
auth_scheme.flows.implicit
|
||||
and auth_scheme.flows.implicit.scopes
|
||||
or auth_scheme.flows.authorizationCode
|
||||
and auth_scheme.flows.authorizationCode.scopes
|
||||
or auth_scheme.flows.clientCredentials
|
||||
and auth_scheme.flows.clientCredentials.scopes
|
||||
or auth_scheme.flows.password
|
||||
and auth_scheme.flows.password.scopes
|
||||
)
|
||||
scopes = list(scopes.keys())
|
||||
|
||||
client = OAuth2Session(
|
||||
auth_credential.oauth2.client_id,
|
||||
auth_credential.oauth2.client_secret,
|
||||
scope=" ".join(scopes),
|
||||
redirect_uri=auth_credential.oauth2.redirect_uri,
|
||||
)
|
||||
uri, state = client.create_authorization_url(
|
||||
url=authorization_endpoint, access_type="offline", prompt="consent"
|
||||
)
|
||||
exchanged_auth_credential = auth_credential.model_copy(deep=True)
|
||||
exchanged_auth_credential.oauth2.auth_uri = uri
|
||||
exchanged_auth_credential.oauth2.state = state
|
||||
|
||||
return exchanged_auth_credential
|
||||
@@ -0,0 +1,119 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import AsyncGenerator
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from typing_extensions import override
|
||||
|
||||
from ..agents.invocation_context import InvocationContext
|
||||
from ..events.event import Event
|
||||
from ..flows.llm_flows import functions
|
||||
from ..flows.llm_flows._base_llm_processor import BaseLlmRequestProcessor
|
||||
from ..flows.llm_flows.functions import REQUEST_EUC_FUNCTION_CALL_NAME
|
||||
from ..models.llm_request import LlmRequest
|
||||
from .auth_handler import AuthHandler
|
||||
from .auth_tool import AuthConfig
|
||||
from .auth_tool import AuthToolArguments
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..agents.llm_agent import LlmAgent
|
||||
|
||||
|
||||
class _AuthLlmRequestProcessor(BaseLlmRequestProcessor):
|
||||
"""Handles auth information to build the LLM request."""
|
||||
|
||||
@override
|
||||
async def run_async(
|
||||
self, invocation_context: InvocationContext, llm_request: LlmRequest
|
||||
) -> AsyncGenerator[Event, None]:
|
||||
from ..agents.llm_agent import LlmAgent
|
||||
|
||||
agent = invocation_context.agent
|
||||
if not isinstance(agent, LlmAgent):
|
||||
return
|
||||
events = invocation_context.session.events
|
||||
if not events:
|
||||
return
|
||||
|
||||
request_euc_function_call_ids = set()
|
||||
for k in range(len(events) - 1, -1, -1):
|
||||
event = events[k]
|
||||
# look for first event authored by user
|
||||
if not event.author or event.author != 'user':
|
||||
continue
|
||||
responses = event.get_function_responses()
|
||||
if not responses:
|
||||
return
|
||||
|
||||
for function_call_response in responses:
|
||||
if function_call_response.name != REQUEST_EUC_FUNCTION_CALL_NAME:
|
||||
continue
|
||||
# found the function call response for the system long running request euc
|
||||
# function call
|
||||
request_euc_function_call_ids.add(function_call_response.id)
|
||||
auth_config = AuthConfig.model_validate(function_call_response.response)
|
||||
AuthHandler(auth_config=auth_config).parse_and_store_auth_response(
|
||||
state=invocation_context.session.state
|
||||
)
|
||||
break
|
||||
|
||||
if not request_euc_function_call_ids:
|
||||
return
|
||||
|
||||
for i in range(len(events) - 2, -1, -1):
|
||||
event = events[i]
|
||||
# looking for the system long running request euc function call
|
||||
function_calls = event.get_function_calls()
|
||||
if not function_calls:
|
||||
continue
|
||||
|
||||
tools_to_resume = set()
|
||||
|
||||
for function_call in function_calls:
|
||||
if function_call.id not in request_euc_function_call_ids:
|
||||
continue
|
||||
args = AuthToolArguments.model_validate(function_call.args)
|
||||
|
||||
tools_to_resume.add(args.function_call_id)
|
||||
if not tools_to_resume:
|
||||
continue
|
||||
|
||||
# found the the system long running request euc function call
|
||||
# looking for original function call that requests euc
|
||||
for j in range(i - 1, -1, -1):
|
||||
event = events[j]
|
||||
function_calls = event.get_function_calls()
|
||||
if not function_calls:
|
||||
continue
|
||||
for function_call in function_calls:
|
||||
function_response_event = None
|
||||
if function_call.id in tools_to_resume:
|
||||
function_response_event = await functions.handle_function_calls_async(
|
||||
invocation_context,
|
||||
event,
|
||||
{tool.name: tool for tool in agent.canonical_tools},
|
||||
# there could be parallel function calls that require auth
|
||||
# auth response would be a dict keyed by function call id
|
||||
tools_to_resume,
|
||||
)
|
||||
if function_response_event:
|
||||
yield function_response_event
|
||||
return
|
||||
return
|
||||
|
||||
|
||||
request_processor = _AuthLlmRequestProcessor()
|
||||
@@ -0,0 +1,67 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from enum import Enum
|
||||
from typing import List
|
||||
from typing import Optional
|
||||
from typing import Union
|
||||
|
||||
from fastapi.openapi.models import OAuthFlows
|
||||
from fastapi.openapi.models import SecurityBase
|
||||
from fastapi.openapi.models import SecurityScheme
|
||||
from fastapi.openapi.models import SecuritySchemeType
|
||||
from pydantic import Field
|
||||
|
||||
|
||||
class OpenIdConnectWithConfig(SecurityBase):
|
||||
type_: SecuritySchemeType = Field(
|
||||
default=SecuritySchemeType.openIdConnect, alias="type"
|
||||
)
|
||||
authorization_endpoint: str
|
||||
token_endpoint: str
|
||||
userinfo_endpoint: Optional[str] = None
|
||||
revocation_endpoint: Optional[str] = None
|
||||
token_endpoint_auth_methods_supported: Optional[List[str]] = None
|
||||
grant_types_supported: Optional[List[str]] = None
|
||||
scopes: Optional[List[str]] = None
|
||||
|
||||
|
||||
# AuthSchemes contains SecuritySchemes from OpenAPI 3.0 and an extra flattened OpenIdConnectWithConfig.
|
||||
AuthScheme = Union[SecurityScheme, OpenIdConnectWithConfig]
|
||||
|
||||
|
||||
class OAuthGrantType(str, Enum):
|
||||
"""Represents the OAuth2 flow (or grant type)."""
|
||||
|
||||
CLIENT_CREDENTIALS = "client_credentials"
|
||||
AUTHORIZATION_CODE = "authorization_code"
|
||||
IMPLICIT = "implicit"
|
||||
PASSWORD = "password"
|
||||
|
||||
@staticmethod
|
||||
def from_flow(flow: OAuthFlows) -> "OAuthGrantType":
|
||||
"""Converts an OAuthFlows object to a OAuthGrantType."""
|
||||
if flow.clientCredentials:
|
||||
return OAuthGrantType.CLIENT_CREDENTIALS
|
||||
if flow.authorizationCode:
|
||||
return OAuthGrantType.AUTHORIZATION_CODE
|
||||
if flow.implicit:
|
||||
return OAuthGrantType.IMPLICIT
|
||||
if flow.password:
|
||||
return OAuthGrantType.PASSWORD
|
||||
return None
|
||||
|
||||
|
||||
# AuthSchemeType re-exports SecuritySchemeType from OpenAPI 3.0.
|
||||
AuthSchemeType = SecuritySchemeType
|
||||
@@ -0,0 +1,55 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from .auth_credential import AuthCredential
|
||||
from .auth_schemes import AuthScheme
|
||||
|
||||
|
||||
class AuthConfig(BaseModel):
|
||||
"""The auth config sent by tool asking client to collect auth credentials and
|
||||
|
||||
adk and client will help to fill in the response
|
||||
"""
|
||||
|
||||
auth_scheme: AuthScheme
|
||||
"""The auth scheme used to collect credentials"""
|
||||
raw_auth_credential: AuthCredential = None
|
||||
"""The raw auth credential used to collect credentials. The raw auth
|
||||
credentials are used in some auth scheme that needs to exchange auth
|
||||
credentials. e.g. OAuth2 and OIDC. For other auth scheme, it could be None.
|
||||
"""
|
||||
exchanged_auth_credential: AuthCredential = None
|
||||
"""The exchanged auth credential used to collect credentials. adk and client
|
||||
will work together to fill it. For those auth scheme that doesn't need to
|
||||
exchange auth credentials, e.g. API key, service account etc. It's filled by
|
||||
client directly. For those auth scheme that need to exchange auth credentials,
|
||||
e.g. OAuth2 and OIDC, it's first filled by adk. If the raw credentials
|
||||
passed by tool only has client id and client credential, adk will help to
|
||||
generate the corresponding authorization uri and state and store the processed
|
||||
credential in this field. If the raw credentials passed by tool already has
|
||||
authorization uri, state, etc. then it's copied to this field. Client will use
|
||||
this field to guide the user through the OAuth2 flow and fill auth response in
|
||||
this field"""
|
||||
|
||||
|
||||
class AuthToolArguments(BaseModel):
|
||||
"""the arguments for the special long running function tool that is used to
|
||||
|
||||
request end user credentials.
|
||||
"""
|
||||
|
||||
function_call_id: str
|
||||
auth_config: AuthConfig
|
||||
@@ -0,0 +1,15 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .cli_tools_click import main
|
||||
@@ -0,0 +1,18 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .cli_tools_click import main
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
148
.venv/lib/python3.10/site-packages/google/adk/cli/agent_graph.py
Normal file
148
.venv/lib/python3.10/site-packages/google/adk/cli/agent_graph.py
Normal file
@@ -0,0 +1,148 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Union
|
||||
|
||||
import graphviz
|
||||
|
||||
from ..agents import BaseAgent
|
||||
from ..agents.llm_agent import LlmAgent
|
||||
from ..tools.agent_tool import AgentTool
|
||||
from ..tools.base_tool import BaseTool
|
||||
from ..tools.function_tool import FunctionTool
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
try:
|
||||
from ..tools.retrieval.base_retrieval_tool import BaseRetrievalTool
|
||||
except ModuleNotFoundError:
|
||||
retrieval_tool_module_loaded = False
|
||||
else:
|
||||
retrieval_tool_module_loaded = True
|
||||
|
||||
|
||||
def build_graph(graph, agent: BaseAgent, highlight_pairs):
|
||||
dark_green = '#0F5223'
|
||||
light_green = '#69CB87'
|
||||
light_gray = '#cccccc'
|
||||
|
||||
def get_node_name(tool_or_agent: Union[BaseAgent, BaseTool]):
|
||||
if isinstance(tool_or_agent, BaseAgent):
|
||||
return tool_or_agent.name
|
||||
elif isinstance(tool_or_agent, BaseTool):
|
||||
return tool_or_agent.name
|
||||
else:
|
||||
raise ValueError(f'Unsupported tool type: {tool_or_agent}')
|
||||
|
||||
def get_node_caption(tool_or_agent: Union[BaseAgent, BaseTool]):
|
||||
|
||||
if isinstance(tool_or_agent, BaseAgent):
|
||||
return '🤖 ' + tool_or_agent.name
|
||||
elif retrieval_tool_module_loaded and isinstance(
|
||||
tool_or_agent, BaseRetrievalTool
|
||||
):
|
||||
return '🔎 ' + tool_or_agent.name
|
||||
elif isinstance(tool_or_agent, FunctionTool):
|
||||
return '🔧 ' + tool_or_agent.name
|
||||
elif isinstance(tool_or_agent, AgentTool):
|
||||
return '🤖 ' + tool_or_agent.name
|
||||
elif isinstance(tool_or_agent, BaseTool):
|
||||
return '🔧 ' + tool_or_agent.name
|
||||
else:
|
||||
logger.warning(
|
||||
'Unsupported tool, type: %s, obj: %s',
|
||||
type(tool_or_agent),
|
||||
tool_or_agent,
|
||||
)
|
||||
return f'❓ Unsupported tool type: {type(tool_or_agent)}'
|
||||
|
||||
def get_node_shape(tool_or_agent: Union[BaseAgent, BaseTool]):
|
||||
if isinstance(tool_or_agent, BaseAgent):
|
||||
return 'ellipse'
|
||||
elif retrieval_tool_module_loaded and isinstance(
|
||||
tool_or_agent, BaseRetrievalTool
|
||||
):
|
||||
return 'cylinder'
|
||||
elif isinstance(tool_or_agent, FunctionTool):
|
||||
return 'box'
|
||||
elif isinstance(tool_or_agent, BaseTool):
|
||||
return 'box'
|
||||
else:
|
||||
logger.warning(
|
||||
'Unsupported tool, type: %s, obj: %s',
|
||||
type(tool_or_agent),
|
||||
tool_or_agent,
|
||||
)
|
||||
return 'cylinder'
|
||||
|
||||
def draw_node(tool_or_agent: Union[BaseAgent, BaseTool]):
|
||||
name = get_node_name(tool_or_agent)
|
||||
shape = get_node_shape(tool_or_agent)
|
||||
caption = get_node_caption(tool_or_agent)
|
||||
if highlight_pairs:
|
||||
for highlight_tuple in highlight_pairs:
|
||||
if name in highlight_tuple:
|
||||
graph.node(
|
||||
name,
|
||||
caption,
|
||||
style='filled,rounded',
|
||||
fillcolor=dark_green,
|
||||
color=dark_green,
|
||||
shape=shape,
|
||||
fontcolor=light_gray,
|
||||
)
|
||||
return
|
||||
# if not in highlight, draw non-highliht node
|
||||
graph.node(
|
||||
name,
|
||||
caption,
|
||||
shape=shape,
|
||||
style='rounded',
|
||||
color=light_gray,
|
||||
fontcolor=light_gray,
|
||||
)
|
||||
|
||||
def draw_edge(from_name, to_name):
|
||||
if highlight_pairs:
|
||||
for highlight_from, highlight_to in highlight_pairs:
|
||||
if from_name == highlight_from and to_name == highlight_to:
|
||||
graph.edge(from_name, to_name, color=light_green)
|
||||
return
|
||||
elif from_name == highlight_to and to_name == highlight_from:
|
||||
graph.edge(from_name, to_name, color=light_green, dir='back')
|
||||
return
|
||||
# if no need to highlight, color gray
|
||||
graph.edge(from_name, to_name, arrowhead='none', color=light_gray)
|
||||
|
||||
draw_node(agent)
|
||||
for sub_agent in agent.sub_agents:
|
||||
build_graph(graph, sub_agent, highlight_pairs)
|
||||
draw_edge(agent.name, sub_agent.name)
|
||||
if isinstance(agent, LlmAgent):
|
||||
for tool in agent.canonical_tools:
|
||||
draw_node(tool)
|
||||
draw_edge(agent.name, get_node_name(tool))
|
||||
|
||||
|
||||
def get_agent_graph(root_agent, highlights_pairs, image=False):
|
||||
print('build graph')
|
||||
graph = graphviz.Digraph(graph_attr={'rankdir': 'LR', 'bgcolor': '#333537'})
|
||||
build_graph(graph, root_agent, highlights_pairs)
|
||||
if image:
|
||||
return graph.pipe(format='png')
|
||||
else:
|
||||
return graph
|
||||
@@ -0,0 +1,17 @@
|
||||
<svg width="16" height="16" viewBox="0 0 16 16" fill="none" xmlns="http://www.w3.org/2000/svg">
|
||||
<g clip-path="url(#clip0_756_3354)">
|
||||
<path fill-rule="evenodd" clip-rule="evenodd" d="M8.69139 10.1458C8.89799 10.3937 8.8645 10.7622 8.61657 10.9688L7.07351 12.2547L8.61657 13.5406C8.8645 13.7472 8.89799 14.1157 8.69139 14.3636C8.48478 14.6115 8.11631 14.645 7.86838 14.4384L5.82029 12.7317C5.52243 12.4834 5.52242 12.026 5.82029 11.7777L7.86838 10.071C8.11631 9.86438 8.48478 9.89788 8.69139 10.1458Z" fill="#EA4335"/>
|
||||
<path fill-rule="evenodd" clip-rule="evenodd" d="M11.4459 10.1458C11.2393 10.3937 11.2728 10.7622 11.5207 10.9688L13.0638 12.2547L11.5207 13.5406C11.2728 13.7472 11.2393 14.1157 11.4459 14.3636C11.6525 14.6115 12.021 14.645 12.2689 14.4384L14.317 12.7317C14.6149 12.4834 14.6149 12.026 14.317 11.7777L12.2689 10.071C12.021 9.86438 11.6525 9.89788 11.4459 10.1458Z" fill="#EA4335"/>
|
||||
<path fill-rule="evenodd" clip-rule="evenodd" d="M5.94165 2.19288C4.44903 2.19288 3.23902 3.40289 3.23902 4.89551C3.23902 6.38813 4.44903 7.59814 5.94165 7.59814H8.60776V8.76685H5.94165C3.80357 8.76685 2.07031 7.03359 2.07031 4.89551C2.07031 2.75743 3.80357 1.02417 5.94165 1.02417H9.73995C10.0627 1.02417 10.3243 1.28579 10.3243 1.60852C10.3243 1.93125 10.0627 2.19288 9.73995 2.19288H5.94165Z" fill="#4285F4"/>
|
||||
<path fill-rule="evenodd" clip-rule="evenodd" d="M10.6895 2.19288C12.1821 2.19288 13.3922 3.40289 13.3922 4.89551C13.3922 6.38813 12.1821 7.59814 10.6895 7.59814H6.89123C6.5685 7.59814 6.30687 7.85977 6.30687 8.1825C6.30687 8.50523 6.5685 8.76685 6.89123 8.76685H10.6895C12.8276 8.76685 14.5609 7.03359 14.5609 4.89551C14.5609 2.75743 12.8276 1.02417 10.6895 1.02417H6.89123C6.5685 1.02417 6.30687 1.28579 6.30687 1.60852C6.30687 1.93125 6.5685 2.19288 6.89123 2.19288H10.6895Z" fill="#34A853"/>
|
||||
<path fill-rule="evenodd" clip-rule="evenodd" d="M3.23902 10.739H4.18859C4.51132 10.739 4.77295 10.4774 4.77295 10.1547C4.77295 9.83196 4.51132 9.57033 4.18859 9.57033H3.01989C2.49545 9.57033 2.07031 9.99547 2.07031 10.5199V14.026C2.07031 14.5505 2.49545 14.9756 3.01989 14.9756H4.18859C4.51132 14.9756 4.77295 14.714 4.77295 14.3912C4.77295 14.0685 4.51132 13.8069 4.18859 13.8069H3.23902V10.739Z" fill="#FBBC04"/>
|
||||
<path fill-rule="evenodd" clip-rule="evenodd" d="M10.9452 8.1825C10.9452 7.85977 10.6836 7.59814 10.3608 7.59814H6.89123C6.5685 7.59814 6.30687 7.85977 6.30687 8.1825C6.30687 8.50523 6.5685 8.76685 6.89123 8.76685H10.3608C10.6836 8.76685 10.9452 8.50523 10.9452 8.1825Z" fill="#4285F4"/>
|
||||
<path d="M6.74514 4.89551C6.74514 5.25858 6.45081 5.55291 6.08774 5.55291C5.72467 5.55291 5.43034 5.25858 5.43034 4.89551C5.43034 4.53244 5.72467 4.23811 6.08774 4.23811C6.45081 4.23811 6.74514 4.53244 6.74514 4.89551Z" fill="#4285F4"/>
|
||||
<path d="M11.2739 4.89551C11.2739 5.25858 10.9795 5.55291 10.6165 5.55291C10.2534 5.55291 9.95908 5.25858 9.95908 4.89551C9.95908 4.53244 10.2534 4.23811 10.6165 4.23811C10.9795 4.23811 11.2739 4.53244 11.2739 4.89551Z" fill="#4285F4"/>
|
||||
</g>
|
||||
<defs>
|
||||
<clipPath id="clip0_756_3354">
|
||||
<rect width="12.6294" height="14" fill="white" transform="translate(2 1)"/>
|
||||
</clipPath>
|
||||
</defs>
|
||||
</svg>
|
||||
|
After Width: | Height: | Size: 3.1 KiB |
@@ -0,0 +1,51 @@
|
||||
/**
|
||||
* Copyright 2025 Google LLC
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
class AudioProcessor extends AudioWorkletProcessor {
|
||||
constructor() {
|
||||
super();
|
||||
this.targetSampleRate = 22000; // Change to your desired rate
|
||||
this.originalSampleRate = sampleRate; // Browser's sample rate
|
||||
this.resampleRatio = this.originalSampleRate / this.targetSampleRate;
|
||||
}
|
||||
|
||||
process(inputs, outputs, parameters) {
|
||||
const input = inputs[0];
|
||||
if (input.length > 0) {
|
||||
let audioData = input[0]; // Get first channel's data
|
||||
|
||||
if (this.resampleRatio !== 1) {
|
||||
audioData = this.resample(audioData);
|
||||
}
|
||||
|
||||
this.port.postMessage(audioData);
|
||||
}
|
||||
return true; // Keep processor alive
|
||||
}
|
||||
|
||||
resample(audioData) {
|
||||
const newLength = Math.round(audioData.length / this.resampleRatio);
|
||||
const resampled = new Float32Array(newLength);
|
||||
|
||||
for (let i = 0; i < newLength; i++) {
|
||||
const srcIndex = Math.floor(i * this.resampleRatio);
|
||||
resampled[i] = audioData[srcIndex]; // Nearest neighbor resampling
|
||||
}
|
||||
return resampled;
|
||||
}
|
||||
}
|
||||
|
||||
registerProcessor('audio-processor', AudioProcessor);
|
||||
@@ -0,0 +1,3 @@
|
||||
{
|
||||
"backendUrl": ""
|
||||
}
|
||||
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
@@ -0,0 +1,17 @@
|
||||
/**
|
||||
* Copyright 2025 Google LLC
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
html{color-scheme:dark}html{--mat-sys-background: light-dark(#fcf9f8, #131314);--mat-sys-error: light-dark(#ba1a1a, #ffb4ab);--mat-sys-error-container: light-dark(#ffdad6, #93000a);--mat-sys-inverse-on-surface: light-dark(#f3f0f0, #313030);--mat-sys-inverse-primary: light-dark(#c1c7cd, #595f65);--mat-sys-inverse-surface: light-dark(#313030, #e5e2e2);--mat-sys-on-background: light-dark(#1c1b1c, #e5e2e2);--mat-sys-on-error: light-dark(#ffffff, #690005);--mat-sys-on-error-container: light-dark(#410002, #ffdad6);--mat-sys-on-primary: light-dark(#ffffff, #2b3136);--mat-sys-on-primary-container: light-dark(#161c21, #dde3e9);--mat-sys-on-primary-fixed: light-dark(#161c21, #161c21);--mat-sys-on-primary-fixed-variant: light-dark(#41474d, #41474d);--mat-sys-on-secondary: light-dark(#ffffff, #003061);--mat-sys-on-secondary-container: light-dark(#001b3c, #d5e3ff);--mat-sys-on-secondary-fixed: light-dark(#001b3c, #001b3c);--mat-sys-on-secondary-fixed-variant: light-dark(#0f4784, #0f4784);--mat-sys-on-surface: light-dark(#1c1b1c, #e5e2e2);--mat-sys-on-surface-variant: light-dark(#44474a, #e1e2e6);--mat-sys-on-tertiary: light-dark(#ffffff, #2b3136);--mat-sys-on-tertiary-container: light-dark(#161c21, #dde3e9);--mat-sys-on-tertiary-fixed: light-dark(#161c21, #161c21);--mat-sys-on-tertiary-fixed-variant: light-dark(#41474d, #41474d);--mat-sys-outline: light-dark(#74777b, #8e9194);--mat-sys-outline-variant: light-dark(#c4c7ca, #44474a);--mat-sys-primary: light-dark(#595f65, #c1c7cd);--mat-sys-primary-container: light-dark(#dde3e9, #41474d);--mat-sys-primary-fixed: light-dark(#dde3e9, #dde3e9);--mat-sys-primary-fixed-dim: light-dark(#c1c7cd, #c1c7cd);--mat-sys-scrim: light-dark(#000000, #000000);--mat-sys-secondary: light-dark(#305f9d, #a7c8ff);--mat-sys-secondary-container: light-dark(#d5e3ff, #0f4784);--mat-sys-secondary-fixed: light-dark(#d5e3ff, #d5e3ff);--mat-sys-secondary-fixed-dim: light-dark(#a7c8ff, #a7c8ff);--mat-sys-shadow: light-dark(#000000, #000000);--mat-sys-surface: light-dark(#fcf9f8, #131314);--mat-sys-surface-bright: light-dark(#fcf9f8, #393939);--mat-sys-surface-container: light-dark(#f0eded, #201f20);--mat-sys-surface-container-high: light-dark(#eae7e7, #2a2a2a);--mat-sys-surface-container-highest: light-dark(#e5e2e2, #393939);--mat-sys-surface-container-low: light-dark(#f6f3f3, #1c1b1c);--mat-sys-surface-container-lowest: light-dark(#ffffff, #0e0e0e);--mat-sys-surface-dim: light-dark(#dcd9d9, #131314);--mat-sys-surface-tint: light-dark(#595f65, #c1c7cd);--mat-sys-surface-variant: light-dark(#e1e2e6, #44474a);--mat-sys-tertiary: light-dark(#595f65, #c1c7cd);--mat-sys-tertiary-container: light-dark(#dde3e9, #41474d);--mat-sys-tertiary-fixed: light-dark(#dde3e9, #dde3e9);--mat-sys-tertiary-fixed-dim: light-dark(#c1c7cd, #c1c7cd);--mat-sys-neutral-variant20: #2d3134;--mat-sys-neutral10: #1c1b1c}html{--mat-sys-level0: 0px 0px 0px 0px rgba(0, 0, 0, .2), 0px 0px 0px 0px rgba(0, 0, 0, .14), 0px 0px 0px 0px rgba(0, 0, 0, .12)}html{--mat-sys-level1: 0px 2px 1px -1px rgba(0, 0, 0, .2), 0px 1px 1px 0px rgba(0, 0, 0, .14), 0px 1px 3px 0px rgba(0, 0, 0, .12)}html{--mat-sys-level2: 0px 3px 3px -2px rgba(0, 0, 0, .2), 0px 3px 4px 0px rgba(0, 0, 0, .14), 0px 1px 8px 0px rgba(0, 0, 0, .12)}html{--mat-sys-level3: 0px 3px 5px -1px rgba(0, 0, 0, .2), 0px 6px 10px 0px rgba(0, 0, 0, .14), 0px 1px 18px 0px rgba(0, 0, 0, .12)}html{--mat-sys-level4: 0px 5px 5px -3px rgba(0, 0, 0, .2), 0px 8px 10px 1px rgba(0, 0, 0, .14), 0px 3px 14px 2px rgba(0, 0, 0, .12)}html{--mat-sys-level5: 0px 7px 8px -4px rgba(0, 0, 0, .2), 0px 12px 17px 2px rgba(0, 0, 0, .14), 0px 5px 22px 4px rgba(0, 0, 0, .12)}html{--mat-sys-corner-extra-large: 28px;--mat-sys-corner-extra-large-top: 28px 28px 0 0;--mat-sys-corner-extra-small: 4px;--mat-sys-corner-extra-small-top: 4px 4px 0 0;--mat-sys-corner-full: 9999px;--mat-sys-corner-large: 16px;--mat-sys-corner-large-end: 0 16px 16px 0;--mat-sys-corner-large-start: 16px 0 0 16px;--mat-sys-corner-large-top: 16px 16px 0 0;--mat-sys-corner-medium: 12px;--mat-sys-corner-none: 0;--mat-sys-corner-small: 8px}html{--mat-sys-dragged-state-layer-opacity: .16;--mat-sys-focus-state-layer-opacity: .12;--mat-sys-hover-state-layer-opacity: .08;--mat-sys-pressed-state-layer-opacity: .12}html{font-family:Google Sans,Helvetica Neue,sans-serif!important}body{height:100vh;margin:0}markdown p{margin-block-start:.5em;margin-block-end:.5em}:root{--mat-sys-primary: black;--mdc-checkbox-selected-icon-color: white;--mat-sys-background: #131314;--mat-tab-header-active-label-text-color: #8AB4F8;--mat-tab-header-active-hover-label-text-color: #8AB4F8;--mat-tab-header-active-focus-label-text-color: #8AB4F8;--mat-tab-header-label-text-weight: 500;--mdc-text-button-label-text-color: #89b4f8}:root{--mdc-dialog-container-color: #2b2b2f}:root{--mdc-dialog-subhead-color: white}:root{--mdc-circular-progress-active-indicator-color: #a8c7fa}:root{--mdc-circular-progress-size: 80}
|
||||
183
.venv/lib/python3.10/site-packages/google/adk/cli/cli.py
Normal file
183
.venv/lib/python3.10/site-packages/google/adk/cli/cli.py
Normal file
@@ -0,0 +1,183 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from datetime import datetime
|
||||
import importlib
|
||||
import os
|
||||
import sys
|
||||
from typing import Optional
|
||||
|
||||
import click
|
||||
from google.genai import types
|
||||
from pydantic import BaseModel
|
||||
|
||||
from ..agents.llm_agent import LlmAgent
|
||||
from ..artifacts import BaseArtifactService
|
||||
from ..artifacts import InMemoryArtifactService
|
||||
from ..runners import Runner
|
||||
from ..sessions.base_session_service import BaseSessionService
|
||||
from ..sessions.in_memory_session_service import InMemorySessionService
|
||||
from ..sessions.session import Session
|
||||
from .utils import envs
|
||||
|
||||
|
||||
class InputFile(BaseModel):
|
||||
state: dict[str, object]
|
||||
queries: list[str]
|
||||
|
||||
|
||||
async def run_input_file(
|
||||
app_name: str,
|
||||
root_agent: LlmAgent,
|
||||
artifact_service: BaseArtifactService,
|
||||
session: Session,
|
||||
session_service: BaseSessionService,
|
||||
input_path: str,
|
||||
) -> None:
|
||||
runner = Runner(
|
||||
app_name=app_name,
|
||||
agent=root_agent,
|
||||
artifact_service=artifact_service,
|
||||
session_service=session_service,
|
||||
)
|
||||
with open(input_path, 'r', encoding='utf-8') as f:
|
||||
input_file = InputFile.model_validate_json(f.read())
|
||||
input_file.state['_time'] = datetime.now()
|
||||
|
||||
session.state = input_file.state
|
||||
for query in input_file.queries:
|
||||
click.echo(f'user: {query}')
|
||||
content = types.Content(role='user', parts=[types.Part(text=query)])
|
||||
async for event in runner.run_async(
|
||||
user_id=session.user_id, session_id=session.id, new_message=content
|
||||
):
|
||||
if event.content and event.content.parts:
|
||||
if text := ''.join(part.text or '' for part in event.content.parts):
|
||||
click.echo(f'[{event.author}]: {text}')
|
||||
|
||||
|
||||
async def run_interactively(
|
||||
app_name: str,
|
||||
root_agent: LlmAgent,
|
||||
artifact_service: BaseArtifactService,
|
||||
session: Session,
|
||||
session_service: BaseSessionService,
|
||||
) -> None:
|
||||
runner = Runner(
|
||||
app_name=app_name,
|
||||
agent=root_agent,
|
||||
artifact_service=artifact_service,
|
||||
session_service=session_service,
|
||||
)
|
||||
while True:
|
||||
query = input('user: ')
|
||||
if not query or not query.strip():
|
||||
continue
|
||||
if query == 'exit':
|
||||
break
|
||||
async for event in runner.run_async(
|
||||
user_id=session.user_id,
|
||||
session_id=session.id,
|
||||
new_message=types.Content(role='user', parts=[types.Part(text=query)]),
|
||||
):
|
||||
if event.content and event.content.parts:
|
||||
if text := ''.join(part.text or '' for part in event.content.parts):
|
||||
click.echo(f'[{event.author}]: {text}')
|
||||
|
||||
|
||||
async def run_cli(
|
||||
*,
|
||||
agent_parent_dir: str,
|
||||
agent_folder_name: str,
|
||||
json_file_path: Optional[str] = None,
|
||||
save_session: bool,
|
||||
) -> None:
|
||||
"""Runs an interactive CLI for a certain agent.
|
||||
|
||||
Args:
|
||||
agent_parent_dir: str, the absolute path of the parent folder of the agent
|
||||
folder.
|
||||
agent_folder_name: str, the name of the agent folder.
|
||||
json_file_path: Optional[str], the absolute path to the json file, either
|
||||
*.input.json or *.session.json.
|
||||
save_session: bool, whether to save the session on exit.
|
||||
"""
|
||||
if agent_parent_dir not in sys.path:
|
||||
sys.path.append(agent_parent_dir)
|
||||
|
||||
artifact_service = InMemoryArtifactService()
|
||||
session_service = InMemorySessionService()
|
||||
session = session_service.create_session(
|
||||
app_name=agent_folder_name, user_id='test_user'
|
||||
)
|
||||
|
||||
agent_module_path = os.path.join(agent_parent_dir, agent_folder_name)
|
||||
agent_module = importlib.import_module(agent_folder_name)
|
||||
root_agent = agent_module.agent.root_agent
|
||||
envs.load_dotenv_for_agent(agent_folder_name, agent_parent_dir)
|
||||
if json_file_path:
|
||||
if json_file_path.endswith('.input.json'):
|
||||
await run_input_file(
|
||||
app_name=agent_folder_name,
|
||||
root_agent=root_agent,
|
||||
artifact_service=artifact_service,
|
||||
session=session,
|
||||
session_service=session_service,
|
||||
input_path=json_file_path,
|
||||
)
|
||||
elif json_file_path.endswith('.session.json'):
|
||||
with open(json_file_path, 'r') as f:
|
||||
session = Session.model_validate_json(f.read())
|
||||
for content in session.get_contents():
|
||||
if content.role == 'user':
|
||||
print('user: ', content.parts[0].text)
|
||||
else:
|
||||
print(content.parts[0].text)
|
||||
await run_interactively(
|
||||
agent_folder_name,
|
||||
root_agent,
|
||||
artifact_service,
|
||||
session,
|
||||
session_service,
|
||||
)
|
||||
else:
|
||||
print(f'Unsupported file type: {json_file_path}')
|
||||
exit(1)
|
||||
else:
|
||||
print(f'Running agent {root_agent.name}, type exit to exit.')
|
||||
await run_interactively(
|
||||
agent_folder_name,
|
||||
root_agent,
|
||||
artifact_service,
|
||||
session,
|
||||
session_service,
|
||||
)
|
||||
|
||||
if save_session:
|
||||
if json_file_path:
|
||||
session_path = json_file_path.replace('.input.json', '.session.json')
|
||||
else:
|
||||
session_id = input('Session ID to save: ')
|
||||
session_path = f'{agent_module_path}/{session_id}.session.json'
|
||||
|
||||
# Fetch the session again to get all the details.
|
||||
session = session_service.get_session(
|
||||
app_name=session.app_name,
|
||||
user_id=session.user_id,
|
||||
session_id=session.id,
|
||||
)
|
||||
with open(session_path, 'w') as f:
|
||||
f.write(session.model_dump_json(indent=2, exclude_none=True))
|
||||
|
||||
print('Session saved to', session_path)
|
||||
279
.venv/lib/python3.10/site-packages/google/adk/cli/cli_create.py
Normal file
279
.venv/lib/python3.10/site-packages/google/adk/cli/cli_create.py
Normal file
@@ -0,0 +1,279 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
import subprocess
|
||||
from typing import Optional
|
||||
from typing import Tuple
|
||||
|
||||
import click
|
||||
|
||||
_INIT_PY_TEMPLATE = """\
|
||||
from . import agent
|
||||
"""
|
||||
|
||||
_AGENT_PY_TEMPLATE = """\
|
||||
from google.adk.agents import Agent
|
||||
|
||||
root_agent = Agent(
|
||||
model='{model_name}',
|
||||
name='root_agent',
|
||||
description='A helpful assistant for user questions.',
|
||||
instruction='Answer user questions to the best of your knowledge',
|
||||
)
|
||||
"""
|
||||
|
||||
|
||||
_GOOGLE_API_MSG = """
|
||||
Don't have API Key? Create one in AI Studio: https://aistudio.google.com/apikey
|
||||
"""
|
||||
|
||||
_GOOGLE_CLOUD_SETUP_MSG = """
|
||||
You need an existing Google Cloud account and project, check out this link for details:
|
||||
https://google.github.io/adk-docs/get-started/quickstart/#gemini---google-cloud-vertex-ai
|
||||
"""
|
||||
|
||||
_OTHER_MODEL_MSG = """
|
||||
Please see below guide to configure other models:
|
||||
https://google.github.io/adk-docs/agents/models
|
||||
"""
|
||||
|
||||
_SUCCESS_MSG = """
|
||||
Agent created in {agent_folder}:
|
||||
- .env
|
||||
- __init__.py
|
||||
- agent.py
|
||||
"""
|
||||
|
||||
|
||||
def _get_gcp_project_from_gcloud() -> str:
|
||||
"""Uses gcloud to get default project."""
|
||||
try:
|
||||
result = subprocess.run(
|
||||
["gcloud", "config", "get-value", "project"],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
check=True,
|
||||
)
|
||||
return result.stdout.strip()
|
||||
except (subprocess.CalledProcessError, FileNotFoundError):
|
||||
return ""
|
||||
|
||||
|
||||
def _get_gcp_region_from_gcloud() -> str:
|
||||
"""Uses gcloud to get default region."""
|
||||
try:
|
||||
result = subprocess.run(
|
||||
["gcloud", "config", "get-value", "compute/region"],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
check=True,
|
||||
)
|
||||
return result.stdout.strip()
|
||||
except (subprocess.CalledProcessError, FileNotFoundError):
|
||||
return ""
|
||||
|
||||
|
||||
def _prompt_str(
|
||||
prompt_prefix: str,
|
||||
*,
|
||||
prior_msg: Optional[str] = None,
|
||||
default_value: Optional[str] = None,
|
||||
) -> str:
|
||||
if prior_msg:
|
||||
click.secho(prior_msg, fg="green")
|
||||
while True:
|
||||
value: str = click.prompt(
|
||||
prompt_prefix, default=default_value or None, type=str
|
||||
)
|
||||
if value and value.strip():
|
||||
return value.strip()
|
||||
|
||||
|
||||
def _prompt_for_google_cloud(
|
||||
google_cloud_project: Optional[str],
|
||||
) -> str:
|
||||
"""Prompts user for Google Cloud project ID."""
|
||||
google_cloud_project = (
|
||||
google_cloud_project
|
||||
or os.environ.get("GOOGLE_CLOUD_PROJECT", None)
|
||||
or _get_gcp_project_from_gcloud()
|
||||
)
|
||||
|
||||
google_cloud_project = _prompt_str(
|
||||
"Enter Google Cloud project ID", default_value=google_cloud_project
|
||||
)
|
||||
|
||||
return google_cloud_project
|
||||
|
||||
|
||||
def _prompt_for_google_cloud_region(
|
||||
google_cloud_region: Optional[str],
|
||||
) -> str:
|
||||
"""Prompts user for Google Cloud region."""
|
||||
google_cloud_region = (
|
||||
google_cloud_region
|
||||
or os.environ.get("GOOGLE_CLOUD_LOCATION", None)
|
||||
or _get_gcp_region_from_gcloud()
|
||||
)
|
||||
|
||||
google_cloud_region = _prompt_str(
|
||||
"Enter Google Cloud region",
|
||||
default_value=google_cloud_region or "us-central1",
|
||||
)
|
||||
return google_cloud_region
|
||||
|
||||
|
||||
def _prompt_for_google_api_key(
|
||||
google_api_key: Optional[str],
|
||||
) -> str:
|
||||
"""Prompts user for Google API key."""
|
||||
google_api_key = google_api_key or os.environ.get("GOOGLE_API_KEY", None)
|
||||
|
||||
google_api_key = _prompt_str(
|
||||
"Enter Google API key",
|
||||
prior_msg=_GOOGLE_API_MSG,
|
||||
default_value=google_api_key,
|
||||
)
|
||||
return google_api_key
|
||||
|
||||
|
||||
def _generate_files(
|
||||
agent_folder: str,
|
||||
*,
|
||||
google_api_key: Optional[str] = None,
|
||||
google_cloud_project: Optional[str] = None,
|
||||
google_cloud_region: Optional[str] = None,
|
||||
model: Optional[str] = None,
|
||||
):
|
||||
"""Generates a folder name for the agent."""
|
||||
os.makedirs(agent_folder, exist_ok=True)
|
||||
|
||||
dotenv_file_path = os.path.join(agent_folder, ".env")
|
||||
init_file_path = os.path.join(agent_folder, "__init__.py")
|
||||
agent_file_path = os.path.join(agent_folder, "agent.py")
|
||||
|
||||
with open(dotenv_file_path, "w", encoding="utf-8") as f:
|
||||
lines = []
|
||||
if google_api_key:
|
||||
lines.append("GOOGLE_GENAI_USE_VERTEXAI=0")
|
||||
elif google_cloud_project and google_cloud_region:
|
||||
lines.append("GOOGLE_GENAI_USE_VERTEXAI=1")
|
||||
if google_api_key:
|
||||
lines.append(f"GOOGLE_API_KEY={google_api_key}")
|
||||
if google_cloud_project:
|
||||
lines.append(f"GOOGLE_CLOUD_PROJECT={google_cloud_project}")
|
||||
if google_cloud_region:
|
||||
lines.append(f"GOOGLE_CLOUD_LOCATION={google_cloud_region}")
|
||||
f.write("\n".join(lines))
|
||||
|
||||
with open(init_file_path, "w", encoding="utf-8") as f:
|
||||
f.write(_INIT_PY_TEMPLATE)
|
||||
|
||||
with open(agent_file_path, "w", encoding="utf-8") as f:
|
||||
f.write(_AGENT_PY_TEMPLATE.format(model_name=model))
|
||||
|
||||
click.secho(
|
||||
_SUCCESS_MSG.format(agent_folder=agent_folder),
|
||||
fg="green",
|
||||
)
|
||||
|
||||
|
||||
def _prompt_for_model() -> str:
|
||||
model_choice = click.prompt(
|
||||
"""\
|
||||
Choose a model for the root agent:
|
||||
1. gemini-2.0-flash-001
|
||||
2. Other models (fill later)
|
||||
Choose model""",
|
||||
type=click.Choice(["1", "2"]),
|
||||
)
|
||||
if model_choice == "1":
|
||||
return "gemini-2.0-flash-001"
|
||||
else:
|
||||
click.secho(_OTHER_MODEL_MSG, fg="green")
|
||||
return "<FILL_IN_MODEL>"
|
||||
|
||||
|
||||
def _prompt_to_choose_backend(
|
||||
google_api_key: Optional[str],
|
||||
google_cloud_project: Optional[str],
|
||||
google_cloud_region: Optional[str],
|
||||
) -> Tuple[Optional[str], Optional[str], Optional[str]]:
|
||||
"""Prompts user to choose backend.
|
||||
|
||||
Returns:
|
||||
A tuple of (google_api_key, google_cloud_project, google_cloud_region).
|
||||
"""
|
||||
backend_choice = click.prompt(
|
||||
"1. Google AI\n2. Vertex AI\nChoose a backend",
|
||||
type=click.Choice(["1", "2"]),
|
||||
)
|
||||
if backend_choice == "1":
|
||||
google_api_key = _prompt_for_google_api_key(google_api_key)
|
||||
elif backend_choice == "2":
|
||||
click.secho(_GOOGLE_CLOUD_SETUP_MSG, fg="green")
|
||||
google_cloud_project = _prompt_for_google_cloud(google_cloud_project)
|
||||
google_cloud_region = _prompt_for_google_cloud_region(google_cloud_region)
|
||||
return google_api_key, google_cloud_project, google_cloud_region
|
||||
|
||||
|
||||
def run_cmd(
|
||||
agent_name: str,
|
||||
*,
|
||||
model: Optional[str],
|
||||
google_api_key: Optional[str],
|
||||
google_cloud_project: Optional[str],
|
||||
google_cloud_region: Optional[str],
|
||||
):
|
||||
"""Runs `adk create` command to create agent template.
|
||||
|
||||
Args:
|
||||
agent_name: str, The name of the agent.
|
||||
google_api_key: Optional[str], The Google API key for using Google AI as
|
||||
backend.
|
||||
google_cloud_project: Optional[str], The Google Cloud project for using
|
||||
VertexAI as backend.
|
||||
google_cloud_region: Optional[str], The Google Cloud region for using
|
||||
VertexAI as backend.
|
||||
"""
|
||||
agent_folder = os.path.join(os.getcwd(), agent_name)
|
||||
# check folder doesn't exist or it's empty. Otherwise, throw
|
||||
if os.path.exists(agent_folder) and os.listdir(agent_folder):
|
||||
# Prompt user whether to override existing files using click
|
||||
if not click.confirm(
|
||||
f"Non-empty folder already exist: '{agent_folder}'\n"
|
||||
"Override existing content?",
|
||||
default=False,
|
||||
):
|
||||
raise click.Abort()
|
||||
|
||||
if not model:
|
||||
model = _prompt_for_model()
|
||||
|
||||
if not google_api_key and not (google_cloud_project and google_cloud_region):
|
||||
if model.startswith("gemini"):
|
||||
google_api_key, google_cloud_project, google_cloud_region = (
|
||||
_prompt_to_choose_backend(
|
||||
google_api_key, google_cloud_project, google_cloud_region
|
||||
)
|
||||
)
|
||||
|
||||
_generate_files(
|
||||
agent_folder,
|
||||
google_api_key=google_api_key,
|
||||
google_cloud_project=google_cloud_project,
|
||||
google_cloud_region=google_cloud_region,
|
||||
model=model,
|
||||
)
|
||||
188
.venv/lib/python3.10/site-packages/google/adk/cli/cli_deploy.py
Normal file
188
.venv/lib/python3.10/site-packages/google/adk/cli/cli_deploy.py
Normal file
@@ -0,0 +1,188 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
import shutil
|
||||
import subprocess
|
||||
from typing import Optional
|
||||
|
||||
import click
|
||||
|
||||
_DOCKERFILE_TEMPLATE = """
|
||||
FROM python:3.11-slim
|
||||
WORKDIR /app
|
||||
|
||||
# Create a non-root user
|
||||
RUN adduser --disabled-password --gecos "" myuser
|
||||
|
||||
# Change ownership of /app to myuser
|
||||
RUN chown -R myuser:myuser /app
|
||||
|
||||
# Switch to the non-root user
|
||||
USER myuser
|
||||
|
||||
# Set up environment variables - Start
|
||||
ENV PATH="/home/myuser/.local/bin:$PATH"
|
||||
|
||||
ENV GOOGLE_GENAI_USE_VERTEXAI=1
|
||||
ENV GOOGLE_CLOUD_PROJECT={gcp_project_id}
|
||||
ENV GOOGLE_CLOUD_LOCATION={gcp_region}
|
||||
|
||||
# Set up environment variables - End
|
||||
|
||||
# Install ADK - Start
|
||||
RUN pip install google-adk
|
||||
# Install ADK - End
|
||||
|
||||
# Copy agent - Start
|
||||
|
||||
COPY "agents/{app_name}/" "/app/agents/{app_name}/"
|
||||
{install_agent_deps}
|
||||
|
||||
# Copy agent - End
|
||||
|
||||
EXPOSE {port}
|
||||
|
||||
CMD adk {command} --port={port} {session_db_option} {trace_to_cloud_option} "/app/agents"
|
||||
"""
|
||||
|
||||
|
||||
def _resolve_project(project_in_option: Optional[str]) -> str:
|
||||
if project_in_option:
|
||||
return project_in_option
|
||||
|
||||
result = subprocess.run(
|
||||
['gcloud', 'config', 'get-value', 'project'],
|
||||
check=True,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
)
|
||||
project = result.stdout.strip()
|
||||
click.echo(f'Use default project: {project}')
|
||||
return project
|
||||
|
||||
|
||||
def to_cloud_run(
|
||||
*,
|
||||
agent_folder: str,
|
||||
project: Optional[str],
|
||||
region: Optional[str],
|
||||
service_name: str,
|
||||
app_name: str,
|
||||
temp_folder: str,
|
||||
port: int,
|
||||
trace_to_cloud: bool,
|
||||
with_ui: bool,
|
||||
verbosity: str,
|
||||
session_db_url: str,
|
||||
):
|
||||
"""Deploys an agent to Google Cloud Run.
|
||||
|
||||
`agent_folder` should contain the following files:
|
||||
|
||||
- __init__.py
|
||||
- agent.py
|
||||
- requirements.txt (optional, for additional dependencies)
|
||||
- ... (other required source files)
|
||||
|
||||
The folder structure of temp_folder will be
|
||||
|
||||
* dist/[google_adk wheel file]
|
||||
* agents/[app_name]/
|
||||
* agent source code from `agent_folder`
|
||||
|
||||
Args:
|
||||
agent_folder: The folder (absolute path) containing the agent source code.
|
||||
project: Google Cloud project id.
|
||||
region: Google Cloud region.
|
||||
service_name: The service name in Cloud Run.
|
||||
app_name: The name of the app, by default, it's basename of `agent_folder`.
|
||||
temp_folder: The temp folder for the generated Cloud Run source files.
|
||||
port: The port of the ADK api server.
|
||||
trace_to_cloud: Whether to enable Cloud Trace.
|
||||
with_ui: Whether to deploy with UI.
|
||||
verbosity: The verbosity level of the CLI.
|
||||
session_db_url: The database URL to connect the session.
|
||||
"""
|
||||
app_name = app_name or os.path.basename(agent_folder)
|
||||
|
||||
click.echo(f'Start generating Cloud Run source files in {temp_folder}')
|
||||
|
||||
# remove temp_folder if exists
|
||||
if os.path.exists(temp_folder):
|
||||
click.echo('Removing existing files')
|
||||
shutil.rmtree(temp_folder)
|
||||
|
||||
try:
|
||||
# copy agent source code
|
||||
click.echo('Copying agent source code...')
|
||||
agent_src_path = os.path.join(temp_folder, 'agents', app_name)
|
||||
shutil.copytree(agent_folder, agent_src_path)
|
||||
requirements_txt_path = os.path.join(agent_src_path, 'requirements.txt')
|
||||
install_agent_deps = (
|
||||
f'RUN pip install -r "/app/agents/{app_name}/requirements.txt"'
|
||||
if os.path.exists(requirements_txt_path)
|
||||
else ''
|
||||
)
|
||||
click.echo('Copying agent source code complete.')
|
||||
|
||||
# create Dockerfile
|
||||
click.echo('Creating Dockerfile...')
|
||||
dockerfile_content = _DOCKERFILE_TEMPLATE.format(
|
||||
gcp_project_id=project,
|
||||
gcp_region=region,
|
||||
app_name=app_name,
|
||||
port=port,
|
||||
command='web' if with_ui else 'api_server',
|
||||
install_agent_deps=install_agent_deps,
|
||||
session_db_option=f'--session_db_url={session_db_url}'
|
||||
if session_db_url
|
||||
else '',
|
||||
trace_to_cloud_option='--trace_to_cloud' if trace_to_cloud else '',
|
||||
)
|
||||
dockerfile_path = os.path.join(temp_folder, 'Dockerfile')
|
||||
os.makedirs(temp_folder, exist_ok=True)
|
||||
with open(dockerfile_path, 'w', encoding='utf-8') as f:
|
||||
f.write(
|
||||
dockerfile_content,
|
||||
)
|
||||
click.echo(f'Creating Dockerfile complete: {dockerfile_path}')
|
||||
|
||||
# Deploy to Cloud Run
|
||||
click.echo('Deploying to Cloud Run...')
|
||||
region_options = ['--region', region] if region else []
|
||||
project = _resolve_project(project)
|
||||
subprocess.run(
|
||||
[
|
||||
'gcloud',
|
||||
'run',
|
||||
'deploy',
|
||||
service_name,
|
||||
'--source',
|
||||
temp_folder,
|
||||
'--project',
|
||||
project,
|
||||
*region_options,
|
||||
'--port',
|
||||
str(port),
|
||||
'--verbosity',
|
||||
verbosity,
|
||||
'--labels',
|
||||
'created-by=adk',
|
||||
],
|
||||
check=True,
|
||||
)
|
||||
finally:
|
||||
click.echo(f'Cleaning up the temp folder: {temp_folder}')
|
||||
shutil.rmtree(temp_folder)
|
||||
282
.venv/lib/python3.10/site-packages/google/adk/cli/cli_eval.py
Normal file
282
.venv/lib/python3.10/site-packages/google/adk/cli/cli_eval.py
Normal file
@@ -0,0 +1,282 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from enum import Enum
|
||||
import importlib.util
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import traceback
|
||||
from typing import Any
|
||||
from typing import Generator
|
||||
from typing import Optional
|
||||
import uuid
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from ..agents import Agent
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class EvalStatus(Enum):
|
||||
PASSED = 1
|
||||
FAILED = 2
|
||||
NOT_EVALUATED = 3
|
||||
|
||||
|
||||
class EvalMetric(BaseModel):
|
||||
metric_name: str
|
||||
threshold: float
|
||||
|
||||
|
||||
class EvalMetricResult(BaseModel):
|
||||
score: Optional[float]
|
||||
eval_status: EvalStatus
|
||||
|
||||
|
||||
class EvalResult(BaseModel):
|
||||
eval_set_file: str
|
||||
eval_id: str
|
||||
final_eval_status: EvalStatus
|
||||
eval_metric_results: list[tuple[EvalMetric, EvalMetricResult]]
|
||||
session_id: str
|
||||
|
||||
|
||||
MISSING_EVAL_DEPENDENCIES_MESSAGE = (
|
||||
"Eval module is not installed, please install via `pip install"
|
||||
" google-adk[eval]`."
|
||||
)
|
||||
TOOL_TRAJECTORY_SCORE_KEY = "tool_trajectory_avg_score"
|
||||
RESPONSE_MATCH_SCORE_KEY = "response_match_score"
|
||||
# This evaluation is not very stable.
|
||||
# This is always optional unless explicitly specified.
|
||||
RESPONSE_EVALUATION_SCORE_KEY = "response_evaluation_score"
|
||||
|
||||
EVAL_SESSION_ID_PREFIX = "___eval___session___"
|
||||
DEFAULT_CRITERIA = {
|
||||
TOOL_TRAJECTORY_SCORE_KEY: 1.0, # 1-point scale; 1.0 is perfect.
|
||||
RESPONSE_MATCH_SCORE_KEY: 0.8,
|
||||
}
|
||||
|
||||
|
||||
def _import_from_path(module_name, file_path):
|
||||
spec = importlib.util.spec_from_file_location(module_name, file_path)
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
sys.modules[module_name] = module
|
||||
spec.loader.exec_module(module)
|
||||
return module
|
||||
|
||||
|
||||
def _get_agent_module(agent_module_file_path: str):
|
||||
file_path = os.path.join(agent_module_file_path, "__init__.py")
|
||||
module_name = "agent"
|
||||
return _import_from_path(module_name, file_path)
|
||||
|
||||
|
||||
def get_evaluation_criteria_or_default(
|
||||
eval_config_file_path: str,
|
||||
) -> dict[str, float]:
|
||||
"""Returns evaluation criteria from the config file, if present.
|
||||
|
||||
Otherwise a default one is returned.
|
||||
"""
|
||||
if eval_config_file_path:
|
||||
with open(eval_config_file_path, "r", encoding="utf-8") as f:
|
||||
config_data = json.load(f)
|
||||
|
||||
if "criteria" in config_data and isinstance(config_data["criteria"], dict):
|
||||
evaluation_criteria = config_data["criteria"]
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Invalid format for test_config.json at {eval_config_file_path}."
|
||||
" Expected a 'criteria' dictionary."
|
||||
)
|
||||
else:
|
||||
logger.info("No config file supplied. Using default criteria.")
|
||||
evaluation_criteria = DEFAULT_CRITERIA
|
||||
|
||||
return evaluation_criteria
|
||||
|
||||
|
||||
def get_root_agent(agent_module_file_path: str) -> Agent:
|
||||
"""Returns root agent given the agent module."""
|
||||
agent_module = _get_agent_module(agent_module_file_path)
|
||||
root_agent = agent_module.agent.root_agent
|
||||
return root_agent
|
||||
|
||||
|
||||
def try_get_reset_func(agent_module_file_path: str) -> Any:
|
||||
"""Returns reset function for the agent, if present, given the agent module."""
|
||||
agent_module = _get_agent_module(agent_module_file_path)
|
||||
reset_func = getattr(agent_module.agent, "reset_data", None)
|
||||
return reset_func
|
||||
|
||||
|
||||
def parse_and_get_evals_to_run(
|
||||
eval_set_file_path: tuple[str],
|
||||
) -> dict[str, list[str]]:
|
||||
"""Returns a dictionary of eval sets to evals that should be run."""
|
||||
eval_set_to_evals = {}
|
||||
for input_eval_set in eval_set_file_path:
|
||||
evals = []
|
||||
if ":" not in input_eval_set:
|
||||
eval_set_file = input_eval_set
|
||||
else:
|
||||
eval_set_file = input_eval_set.split(":")[0]
|
||||
evals = input_eval_set.split(":")[1].split(",")
|
||||
|
||||
if eval_set_file not in eval_set_to_evals:
|
||||
eval_set_to_evals[eval_set_file] = []
|
||||
|
||||
eval_set_to_evals[eval_set_file].extend(evals)
|
||||
|
||||
return eval_set_to_evals
|
||||
|
||||
|
||||
def run_evals(
|
||||
eval_set_to_evals: dict[str, list[str]],
|
||||
root_agent: Agent,
|
||||
reset_func: Optional[Any],
|
||||
eval_metrics: list[EvalMetric],
|
||||
session_service=None,
|
||||
artifact_service=None,
|
||||
print_detailed_results=False,
|
||||
) -> Generator[EvalResult, None, None]:
|
||||
try:
|
||||
from ..evaluation.agent_evaluator import EvaluationGenerator
|
||||
from ..evaluation.response_evaluator import ResponseEvaluator
|
||||
from ..evaluation.trajectory_evaluator import TrajectoryEvaluator
|
||||
except ModuleNotFoundError as e:
|
||||
raise ModuleNotFoundError(MISSING_EVAL_DEPENDENCIES_MESSAGE) from e
|
||||
|
||||
"""Returns a summary of eval runs."""
|
||||
for eval_set_file, evals_to_run in eval_set_to_evals.items():
|
||||
with open(eval_set_file, "r", encoding="utf-8") as file:
|
||||
eval_items = json.load(file) # Load JSON into a list
|
||||
|
||||
assert eval_items, f"No eval data found in eval set file: {eval_set_file}"
|
||||
|
||||
for eval_item in eval_items:
|
||||
eval_name = eval_item["name"]
|
||||
eval_data = eval_item["data"]
|
||||
initial_session = eval_item.get("initial_session", {})
|
||||
|
||||
if evals_to_run and eval_name not in evals_to_run:
|
||||
continue
|
||||
|
||||
try:
|
||||
print(f"Running Eval: {eval_set_file}:{eval_name}")
|
||||
session_id = f"{EVAL_SESSION_ID_PREFIX}{str(uuid.uuid4())}"
|
||||
|
||||
scrape_result = EvaluationGenerator._process_query_with_root_agent(
|
||||
data=eval_data,
|
||||
root_agent=root_agent,
|
||||
reset_func=reset_func,
|
||||
initial_session=initial_session,
|
||||
session_id=session_id,
|
||||
session_service=session_service,
|
||||
artifact_service=artifact_service,
|
||||
)
|
||||
|
||||
eval_metric_results = []
|
||||
for eval_metric in eval_metrics:
|
||||
eval_metric_result = None
|
||||
if eval_metric.metric_name == TOOL_TRAJECTORY_SCORE_KEY:
|
||||
score = TrajectoryEvaluator.evaluate(
|
||||
[scrape_result], print_detailed_results=print_detailed_results
|
||||
)
|
||||
eval_metric_result = _get_eval_metric_result(eval_metric, score)
|
||||
elif eval_metric.metric_name == RESPONSE_MATCH_SCORE_KEY:
|
||||
score = ResponseEvaluator.evaluate(
|
||||
[scrape_result],
|
||||
[RESPONSE_MATCH_SCORE_KEY],
|
||||
print_detailed_results=print_detailed_results,
|
||||
)
|
||||
eval_metric_result = _get_eval_metric_result(
|
||||
eval_metric, score["rouge_1/mean"].item()
|
||||
)
|
||||
elif eval_metric.metric_name == RESPONSE_EVALUATION_SCORE_KEY:
|
||||
score = ResponseEvaluator.evaluate(
|
||||
[scrape_result],
|
||||
[RESPONSE_EVALUATION_SCORE_KEY],
|
||||
print_detailed_results=print_detailed_results,
|
||||
)
|
||||
eval_metric_result = _get_eval_metric_result(
|
||||
eval_metric, score["coherence/mean"].item()
|
||||
)
|
||||
else:
|
||||
logger.warning("`%s` is not supported.", eval_metric.metric_name)
|
||||
eval_metric_results.append((
|
||||
eval_metric,
|
||||
EvalMetricResult(eval_status=EvalStatus.NOT_EVALUATED),
|
||||
))
|
||||
|
||||
eval_metric_results.append((
|
||||
eval_metric,
|
||||
eval_metric_result,
|
||||
))
|
||||
_print_eval_metric_result(eval_metric, eval_metric_result)
|
||||
|
||||
final_eval_status = EvalStatus.NOT_EVALUATED
|
||||
|
||||
# Go over the all the eval statuses and mark the final eval status as
|
||||
# passed if all of them pass, otherwise mark the final eval status to
|
||||
# failed.
|
||||
for eval_metric_result in eval_metric_results:
|
||||
eval_status = eval_metric_result[1].eval_status
|
||||
if eval_status == EvalStatus.PASSED:
|
||||
final_eval_status = EvalStatus.PASSED
|
||||
elif eval_status == EvalStatus.NOT_EVALUATED:
|
||||
continue
|
||||
elif eval_status == EvalStatus.FAILED:
|
||||
final_eval_status = EvalStatus.FAILED
|
||||
break
|
||||
else:
|
||||
raise ValueError("Unknown eval status.")
|
||||
|
||||
yield EvalResult(
|
||||
eval_set_file=eval_set_file,
|
||||
eval_id=eval_name,
|
||||
final_eval_status=final_eval_status,
|
||||
eval_metric_results=eval_metric_results,
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
if final_eval_status == EvalStatus.PASSED:
|
||||
result = "✅ Passed"
|
||||
else:
|
||||
result = "❌ Failed"
|
||||
|
||||
print(f"Result: {result}\n")
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error: {e}")
|
||||
logger.info("Error: %s", str(traceback.format_exc()))
|
||||
|
||||
|
||||
def _get_eval_metric_result(eval_metric, score):
|
||||
eval_status = (
|
||||
EvalStatus.PASSED if score >= eval_metric.threshold else EvalStatus.FAILED
|
||||
)
|
||||
return EvalMetricResult(score=score, eval_status=eval_status)
|
||||
|
||||
|
||||
def _print_eval_metric_result(eval_metric, eval_metric_result):
|
||||
print(
|
||||
f"Metric: {eval_metric.metric_name}\tStatus:"
|
||||
f" {eval_metric_result.eval_status}\tScore:"
|
||||
f" {eval_metric_result.score}\tThreshold: {eval_metric.threshold}"
|
||||
)
|
||||
@@ -0,0 +1,600 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import asyncio
|
||||
from contextlib import asynccontextmanager
|
||||
from datetime import datetime
|
||||
import logging
|
||||
import os
|
||||
import tempfile
|
||||
from typing import Optional
|
||||
|
||||
import click
|
||||
from fastapi import FastAPI
|
||||
import uvicorn
|
||||
|
||||
from . import cli_create
|
||||
from . import cli_deploy
|
||||
from .cli import run_cli
|
||||
from .cli_eval import MISSING_EVAL_DEPENDENCIES_MESSAGE
|
||||
from .fast_api import get_fast_api_app
|
||||
from .utils import envs
|
||||
from .utils import logs
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@click.group(context_settings={"max_content_width": 240})
|
||||
def main():
|
||||
"""Agent Development Kit CLI tools."""
|
||||
pass
|
||||
|
||||
|
||||
@main.group()
|
||||
def deploy():
|
||||
"""Deploys agent to hosted environments."""
|
||||
pass
|
||||
|
||||
|
||||
@main.command("create")
|
||||
@click.option(
|
||||
"--model",
|
||||
type=str,
|
||||
help="Optional. The model used for the root agent.",
|
||||
)
|
||||
@click.option(
|
||||
"--api_key",
|
||||
type=str,
|
||||
help=(
|
||||
"Optional. The API Key needed to access the model, e.g. Google AI API"
|
||||
" Key."
|
||||
),
|
||||
)
|
||||
@click.option(
|
||||
"--project",
|
||||
type=str,
|
||||
help="Optional. The Google Cloud Project for using VertexAI as backend.",
|
||||
)
|
||||
@click.option(
|
||||
"--region",
|
||||
type=str,
|
||||
help="Optional. The Google Cloud Region for using VertexAI as backend.",
|
||||
)
|
||||
@click.argument("app_name", type=str, required=True)
|
||||
def cli_create_cmd(
|
||||
app_name: str,
|
||||
model: Optional[str],
|
||||
api_key: Optional[str],
|
||||
project: Optional[str],
|
||||
region: Optional[str],
|
||||
):
|
||||
"""Creates a new app in the current folder with prepopulated agent template.
|
||||
|
||||
APP_NAME: required, the folder of the agent source code.
|
||||
|
||||
Example:
|
||||
|
||||
adk create path/to/my_app
|
||||
"""
|
||||
cli_create.run_cmd(
|
||||
app_name,
|
||||
model=model,
|
||||
google_api_key=api_key,
|
||||
google_cloud_project=project,
|
||||
google_cloud_region=region,
|
||||
)
|
||||
|
||||
|
||||
@main.command("run")
|
||||
@click.option(
|
||||
"--save_session",
|
||||
type=bool,
|
||||
is_flag=True,
|
||||
show_default=True,
|
||||
default=False,
|
||||
help="Optional. Whether to save the session to a json file on exit.",
|
||||
)
|
||||
@click.argument(
|
||||
"agent",
|
||||
type=click.Path(
|
||||
exists=True, dir_okay=True, file_okay=False, resolve_path=True
|
||||
),
|
||||
)
|
||||
def cli_run(agent: str, save_session: bool):
|
||||
"""Runs an interactive CLI for a certain agent.
|
||||
|
||||
AGENT: The path to the agent source code folder.
|
||||
|
||||
Example:
|
||||
|
||||
adk run path/to/my_agent
|
||||
"""
|
||||
logs.log_to_tmp_folder()
|
||||
|
||||
agent_parent_folder = os.path.dirname(agent)
|
||||
agent_folder_name = os.path.basename(agent)
|
||||
|
||||
asyncio.run(
|
||||
run_cli(
|
||||
agent_parent_dir=agent_parent_folder,
|
||||
agent_folder_name=agent_folder_name,
|
||||
save_session=save_session,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@main.command("eval")
|
||||
@click.argument(
|
||||
"agent_module_file_path",
|
||||
type=click.Path(
|
||||
exists=True, dir_okay=True, file_okay=False, resolve_path=True
|
||||
),
|
||||
)
|
||||
@click.argument("eval_set_file_path", nargs=-1)
|
||||
@click.option("--config_file_path", help="Optional. The path to config file.")
|
||||
@click.option(
|
||||
"--print_detailed_results",
|
||||
is_flag=True,
|
||||
show_default=True,
|
||||
default=False,
|
||||
help="Optional. Whether to print detailed results on console or not.",
|
||||
)
|
||||
def cli_eval(
|
||||
agent_module_file_path: str,
|
||||
eval_set_file_path: tuple[str],
|
||||
config_file_path: str,
|
||||
print_detailed_results: bool,
|
||||
):
|
||||
"""Evaluates an agent given the eval sets.
|
||||
|
||||
AGENT_MODULE_FILE_PATH: The path to the __init__.py file that contains a
|
||||
module by the name "agent". "agent" module contains a root_agent.
|
||||
|
||||
EVAL_SET_FILE_PATH: You can specify one or more eval set file paths.
|
||||
|
||||
For each file, all evals will be run by default.
|
||||
|
||||
If you want to run only specific evals from a eval set, first create a comma
|
||||
separated list of eval names and then add that as a suffix to the eval set
|
||||
file name, demarcated by a `:`.
|
||||
|
||||
For example,
|
||||
|
||||
sample_eval_set_file.json:eval_1,eval_2,eval_3
|
||||
|
||||
This will only run eval_1, eval_2 and eval_3 from sample_eval_set_file.json.
|
||||
|
||||
CONFIG_FILE_PATH: The path to config file.
|
||||
|
||||
PRINT_DETAILED_RESULTS: Prints detailed results on the console.
|
||||
"""
|
||||
envs.load_dotenv_for_agent(agent_module_file_path, ".")
|
||||
|
||||
try:
|
||||
from .cli_eval import EvalMetric
|
||||
from .cli_eval import EvalResult
|
||||
from .cli_eval import EvalStatus
|
||||
from .cli_eval import get_evaluation_criteria_or_default
|
||||
from .cli_eval import get_root_agent
|
||||
from .cli_eval import parse_and_get_evals_to_run
|
||||
from .cli_eval import run_evals
|
||||
from .cli_eval import try_get_reset_func
|
||||
except ModuleNotFoundError:
|
||||
raise click.ClickException(MISSING_EVAL_DEPENDENCIES_MESSAGE)
|
||||
|
||||
evaluation_criteria = get_evaluation_criteria_or_default(config_file_path)
|
||||
eval_metrics = []
|
||||
for metric_name, threshold in evaluation_criteria.items():
|
||||
eval_metrics.append(
|
||||
EvalMetric(metric_name=metric_name, threshold=threshold)
|
||||
)
|
||||
|
||||
print(f"Using evaluation creiteria: {evaluation_criteria}")
|
||||
|
||||
root_agent = get_root_agent(agent_module_file_path)
|
||||
reset_func = try_get_reset_func(agent_module_file_path)
|
||||
|
||||
eval_set_to_evals = parse_and_get_evals_to_run(eval_set_file_path)
|
||||
|
||||
try:
|
||||
eval_results = list(
|
||||
run_evals(
|
||||
eval_set_to_evals,
|
||||
root_agent,
|
||||
reset_func,
|
||||
eval_metrics,
|
||||
print_detailed_results=print_detailed_results,
|
||||
)
|
||||
)
|
||||
except ModuleNotFoundError:
|
||||
raise click.ClickException(MISSING_EVAL_DEPENDENCIES_MESSAGE)
|
||||
|
||||
print("*********************************************************************")
|
||||
eval_run_summary = {}
|
||||
|
||||
for eval_result in eval_results:
|
||||
eval_result: EvalResult
|
||||
|
||||
if eval_result.eval_set_file not in eval_run_summary:
|
||||
eval_run_summary[eval_result.eval_set_file] = [0, 0]
|
||||
|
||||
if eval_result.final_eval_status == EvalStatus.PASSED:
|
||||
eval_run_summary[eval_result.eval_set_file][0] += 1
|
||||
else:
|
||||
eval_run_summary[eval_result.eval_set_file][1] += 1
|
||||
print("Eval Run Summary")
|
||||
for eval_set_file, pass_fail_count in eval_run_summary.items():
|
||||
print(
|
||||
f"{eval_set_file}:\n Tests passed: {pass_fail_count[0]}\n Tests"
|
||||
f" failed: {pass_fail_count[1]}"
|
||||
)
|
||||
|
||||
|
||||
@main.command("web")
|
||||
@click.option(
|
||||
"--session_db_url",
|
||||
help=(
|
||||
"""Optional. The database URL to store the session.
|
||||
|
||||
- Use 'agentengine://<agent_engine_resource_id>' to connect to Agent Engine sessions.
|
||||
|
||||
- Use 'sqlite://<path_to_sqlite_file>' to connect to a SQLite DB.
|
||||
|
||||
- See https://docs.sqlalchemy.org/en/20/core/engines.html#backend-specific-urls for more details on supported DB URLs."""
|
||||
),
|
||||
)
|
||||
@click.option(
|
||||
"--port",
|
||||
type=int,
|
||||
help="Optional. The port of the server",
|
||||
default=8000,
|
||||
)
|
||||
@click.option(
|
||||
"--allow_origins",
|
||||
help="Optional. Any additional origins to allow for CORS.",
|
||||
multiple=True,
|
||||
)
|
||||
@click.option(
|
||||
"--log_level",
|
||||
type=click.Choice(
|
||||
["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"], case_sensitive=False
|
||||
),
|
||||
default="INFO",
|
||||
help="Optional. Set the logging level",
|
||||
)
|
||||
@click.option(
|
||||
"--log_to_tmp",
|
||||
is_flag=True,
|
||||
show_default=True,
|
||||
default=False,
|
||||
help=(
|
||||
"Optional. Whether to log to system temp folder instead of console."
|
||||
" This is useful for local debugging."
|
||||
),
|
||||
)
|
||||
@click.option(
|
||||
"--trace_to_cloud",
|
||||
is_flag=True,
|
||||
show_default=True,
|
||||
default=False,
|
||||
help="Optional. Whether to enable cloud trace for telemetry.",
|
||||
)
|
||||
@click.argument(
|
||||
"agents_dir",
|
||||
type=click.Path(
|
||||
exists=True, dir_okay=True, file_okay=False, resolve_path=True
|
||||
),
|
||||
default=os.getcwd,
|
||||
)
|
||||
def cli_web(
|
||||
agents_dir: str,
|
||||
log_to_tmp: bool,
|
||||
session_db_url: str = "",
|
||||
log_level: str = "INFO",
|
||||
allow_origins: Optional[list[str]] = None,
|
||||
port: int = 8000,
|
||||
trace_to_cloud: bool = False,
|
||||
):
|
||||
"""Starts a FastAPI server with Web UI for agents.
|
||||
|
||||
AGENTS_DIR: The directory of agents, where each sub-directory is a single
|
||||
agent, containing at least `__init__.py` and `agent.py` files.
|
||||
|
||||
Example:
|
||||
|
||||
adk web --session_db_url=[db_url] --port=[port] path/to/agents_dir
|
||||
"""
|
||||
if log_to_tmp:
|
||||
logs.log_to_tmp_folder()
|
||||
else:
|
||||
logs.log_to_stderr()
|
||||
|
||||
logging.getLogger().setLevel(log_level)
|
||||
|
||||
@asynccontextmanager
|
||||
async def _lifespan(app: FastAPI):
|
||||
click.secho(
|
||||
f"""
|
||||
+-----------------------------------------------------------------------------+
|
||||
| ADK Web Server started |
|
||||
| |
|
||||
| For local testing, access at http://localhost:{port}.{" "*(29 - len(str(port)))}|
|
||||
+-----------------------------------------------------------------------------+
|
||||
""",
|
||||
fg="green",
|
||||
)
|
||||
yield # Startup is done, now app is running
|
||||
click.secho(
|
||||
"""
|
||||
+-----------------------------------------------------------------------------+
|
||||
| ADK Web Server shutting down... |
|
||||
+-----------------------------------------------------------------------------+
|
||||
""",
|
||||
fg="green",
|
||||
)
|
||||
|
||||
app = get_fast_api_app(
|
||||
agent_dir=agents_dir,
|
||||
session_db_url=session_db_url,
|
||||
allow_origins=allow_origins,
|
||||
web=True,
|
||||
trace_to_cloud=trace_to_cloud,
|
||||
lifespan=_lifespan,
|
||||
)
|
||||
config = uvicorn.Config(
|
||||
app,
|
||||
host="0.0.0.0",
|
||||
port=port,
|
||||
reload=True,
|
||||
)
|
||||
|
||||
server = uvicorn.Server(config)
|
||||
server.run()
|
||||
|
||||
|
||||
@main.command("api_server")
|
||||
@click.option(
|
||||
"--session_db_url",
|
||||
help=(
|
||||
"""Optional. The database URL to store the session.
|
||||
|
||||
- Use 'agentengine://<agent_engine_resource_id>' to connect to Agent Engine sessions.
|
||||
|
||||
- Use 'sqlite://<path_to_sqlite_file>' to connect to a SQLite DB.
|
||||
|
||||
- See https://docs.sqlalchemy.org/en/20/core/engines.html#backend-specific-urls for more details on supported DB URLs."""
|
||||
),
|
||||
)
|
||||
@click.option(
|
||||
"--port",
|
||||
type=int,
|
||||
help="Optional. The port of the server",
|
||||
default=8000,
|
||||
)
|
||||
@click.option(
|
||||
"--allow_origins",
|
||||
help="Optional. Any additional origins to allow for CORS.",
|
||||
multiple=True,
|
||||
)
|
||||
@click.option(
|
||||
"--log_level",
|
||||
type=click.Choice(
|
||||
["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"], case_sensitive=False
|
||||
),
|
||||
default="INFO",
|
||||
help="Optional. Set the logging level",
|
||||
)
|
||||
@click.option(
|
||||
"--log_to_tmp",
|
||||
is_flag=True,
|
||||
show_default=True,
|
||||
default=False,
|
||||
help=(
|
||||
"Optional. Whether to log to system temp folder instead of console."
|
||||
" This is useful for local debugging."
|
||||
),
|
||||
)
|
||||
@click.option(
|
||||
"--trace_to_cloud",
|
||||
is_flag=True,
|
||||
show_default=True,
|
||||
default=False,
|
||||
help="Optional. Whether to enable cloud trace for telemetry.",
|
||||
)
|
||||
# The directory of agents, where each sub-directory is a single agent.
|
||||
# By default, it is the current working directory
|
||||
@click.argument(
|
||||
"agents_dir",
|
||||
type=click.Path(
|
||||
exists=True, dir_okay=True, file_okay=False, resolve_path=True
|
||||
),
|
||||
default=os.getcwd(),
|
||||
)
|
||||
def cli_api_server(
|
||||
agents_dir: str,
|
||||
log_to_tmp: bool,
|
||||
session_db_url: str = "",
|
||||
log_level: str = "INFO",
|
||||
allow_origins: Optional[list[str]] = None,
|
||||
port: int = 8000,
|
||||
trace_to_cloud: bool = False,
|
||||
):
|
||||
"""Starts a FastAPI server for agents.
|
||||
|
||||
AGENTS_DIR: The directory of agents, where each sub-directory is a single
|
||||
agent, containing at least `__init__.py` and `agent.py` files.
|
||||
|
||||
Example:
|
||||
|
||||
adk api_server --session_db_url=[db_url] --port=[port] path/to/agents_dir
|
||||
"""
|
||||
if log_to_tmp:
|
||||
logs.log_to_tmp_folder()
|
||||
else:
|
||||
logs.log_to_stderr()
|
||||
|
||||
logging.getLogger().setLevel(log_level)
|
||||
|
||||
config = uvicorn.Config(
|
||||
get_fast_api_app(
|
||||
agent_dir=agents_dir,
|
||||
session_db_url=session_db_url,
|
||||
allow_origins=allow_origins,
|
||||
web=False,
|
||||
trace_to_cloud=trace_to_cloud,
|
||||
),
|
||||
host="0.0.0.0",
|
||||
port=port,
|
||||
reload=True,
|
||||
)
|
||||
server = uvicorn.Server(config)
|
||||
server.run()
|
||||
|
||||
|
||||
@deploy.command("cloud_run")
|
||||
@click.option(
|
||||
"--project",
|
||||
type=str,
|
||||
help=(
|
||||
"Required. Google Cloud project to deploy the agent. When absent,"
|
||||
" default project from gcloud config is used."
|
||||
),
|
||||
)
|
||||
@click.option(
|
||||
"--region",
|
||||
type=str,
|
||||
help=(
|
||||
"Required. Google Cloud region to deploy the agent. When absent,"
|
||||
" gcloud run deploy will prompt later."
|
||||
),
|
||||
)
|
||||
@click.option(
|
||||
"--service_name",
|
||||
type=str,
|
||||
default="adk-default-service-name",
|
||||
help=(
|
||||
"Optional. The service name to use in Cloud Run (default:"
|
||||
" 'adk-default-service-name')."
|
||||
),
|
||||
)
|
||||
@click.option(
|
||||
"--app_name",
|
||||
type=str,
|
||||
default="",
|
||||
help=(
|
||||
"Optional. App name of the ADK API server (default: the folder name"
|
||||
" of the AGENT source code)."
|
||||
),
|
||||
)
|
||||
@click.option(
|
||||
"--port",
|
||||
type=int,
|
||||
default=8000,
|
||||
help="Optional. The port of the ADK API server (default: 8000).",
|
||||
)
|
||||
@click.option(
|
||||
"--trace_to_cloud",
|
||||
type=bool,
|
||||
is_flag=True,
|
||||
show_default=True,
|
||||
default=False,
|
||||
help="Optional. Whether to enable Cloud Trace for cloud run.",
|
||||
)
|
||||
@click.option(
|
||||
"--with_ui",
|
||||
type=bool,
|
||||
is_flag=True,
|
||||
show_default=True,
|
||||
default=False,
|
||||
help=(
|
||||
"Optional. Deploy ADK Web UI if set. (default: deploy ADK API server"
|
||||
" only)"
|
||||
),
|
||||
)
|
||||
@click.option(
|
||||
"--temp_folder",
|
||||
type=str,
|
||||
default=os.path.join(
|
||||
tempfile.gettempdir(),
|
||||
"cloud_run_deploy_src",
|
||||
datetime.now().strftime("%Y%m%d_%H%M%S"),
|
||||
),
|
||||
help=(
|
||||
"Optional. Temp folder for the generated Cloud Run source files"
|
||||
" (default: a timestamped folder in the system temp directory)."
|
||||
),
|
||||
)
|
||||
@click.option(
|
||||
"--verbosity",
|
||||
type=click.Choice(
|
||||
["debug", "info", "warning", "error", "critical"], case_sensitive=False
|
||||
),
|
||||
default="WARNING",
|
||||
help="Optional. Override the default verbosity level.",
|
||||
)
|
||||
@click.option(
|
||||
"--session_db_url",
|
||||
help=(
|
||||
"""Optional. The database URL to store the session.
|
||||
|
||||
- Use 'agentengine://<agent_engine_resource_id>' to connect to Agent Engine sessions.
|
||||
|
||||
- Use 'sqlite://<path_to_sqlite_file>' to connect to a SQLite DB.
|
||||
|
||||
- See https://docs.sqlalchemy.org/en/20/core/engines.html#backend-specific-urls for more details on supported DB URLs."""
|
||||
),
|
||||
)
|
||||
@click.argument(
|
||||
"agent",
|
||||
type=click.Path(
|
||||
exists=True, dir_okay=True, file_okay=False, resolve_path=True
|
||||
),
|
||||
)
|
||||
def cli_deploy_cloud_run(
|
||||
agent: str,
|
||||
project: Optional[str],
|
||||
region: Optional[str],
|
||||
service_name: str,
|
||||
app_name: str,
|
||||
temp_folder: str,
|
||||
port: int,
|
||||
trace_to_cloud: bool,
|
||||
with_ui: bool,
|
||||
verbosity: str,
|
||||
session_db_url: str,
|
||||
):
|
||||
"""Deploys an agent to Cloud Run.
|
||||
|
||||
AGENT: The path to the agent source code folder.
|
||||
|
||||
Example:
|
||||
|
||||
adk deploy cloud_run --project=[project] --region=[region] path/to/my_agent
|
||||
"""
|
||||
try:
|
||||
cli_deploy.to_cloud_run(
|
||||
agent_folder=agent,
|
||||
project=project,
|
||||
region=region,
|
||||
service_name=service_name,
|
||||
app_name=app_name,
|
||||
temp_folder=temp_folder,
|
||||
port=port,
|
||||
trace_to_cloud=trace_to_cloud,
|
||||
with_ui=with_ui,
|
||||
verbosity=verbosity,
|
||||
session_db_url=session_db_url,
|
||||
)
|
||||
except Exception as e:
|
||||
click.secho(f"Deploy failed: {e}", fg="red", err=True)
|
||||
822
.venv/lib/python3.10/site-packages/google/adk/cli/fast_api.py
Normal file
822
.venv/lib/python3.10/site-packages/google/adk/cli/fast_api.py
Normal file
@@ -0,0 +1,822 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import asyncio
|
||||
from contextlib import asynccontextmanager
|
||||
import importlib
|
||||
import inspect
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
import re
|
||||
import sys
|
||||
import traceback
|
||||
import typing
|
||||
from typing import Any
|
||||
from typing import List
|
||||
from typing import Literal
|
||||
from typing import Optional
|
||||
|
||||
import click
|
||||
from click import Tuple
|
||||
from fastapi import FastAPI
|
||||
from fastapi import HTTPException
|
||||
from fastapi import Query
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import FileResponse
|
||||
from fastapi.responses import RedirectResponse
|
||||
from fastapi.responses import StreamingResponse
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
from fastapi.websockets import WebSocket
|
||||
from fastapi.websockets import WebSocketDisconnect
|
||||
from google.genai import types
|
||||
import graphviz
|
||||
from opentelemetry import trace
|
||||
from opentelemetry.exporter.cloud_trace import CloudTraceSpanExporter
|
||||
from opentelemetry.sdk.trace import export
|
||||
from opentelemetry.sdk.trace import ReadableSpan
|
||||
from opentelemetry.sdk.trace import TracerProvider
|
||||
from pydantic import BaseModel
|
||||
from pydantic import ValidationError
|
||||
from starlette.types import Lifespan
|
||||
|
||||
from ..agents import RunConfig
|
||||
from ..agents.live_request_queue import LiveRequest
|
||||
from ..agents.live_request_queue import LiveRequestQueue
|
||||
from ..agents.llm_agent import Agent
|
||||
from ..agents.run_config import StreamingMode
|
||||
from ..artifacts import InMemoryArtifactService
|
||||
from ..events.event import Event
|
||||
from ..memory.in_memory_memory_service import InMemoryMemoryService
|
||||
from ..runners import Runner
|
||||
from ..sessions.database_session_service import DatabaseSessionService
|
||||
from ..sessions.in_memory_session_service import InMemorySessionService
|
||||
from ..sessions.session import Session
|
||||
from ..sessions.vertex_ai_session_service import VertexAiSessionService
|
||||
from .cli_eval import EVAL_SESSION_ID_PREFIX
|
||||
from .cli_eval import EvalMetric
|
||||
from .cli_eval import EvalMetricResult
|
||||
from .cli_eval import EvalStatus
|
||||
from .utils import create_empty_state
|
||||
from .utils import envs
|
||||
from .utils import evals
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_EVAL_SET_FILE_EXTENSION = ".evalset.json"
|
||||
|
||||
|
||||
class ApiServerSpanExporter(export.SpanExporter):
|
||||
|
||||
def __init__(self, trace_dict):
|
||||
self.trace_dict = trace_dict
|
||||
|
||||
def export(
|
||||
self, spans: typing.Sequence[ReadableSpan]
|
||||
) -> export.SpanExportResult:
|
||||
for span in spans:
|
||||
if (
|
||||
span.name == "call_llm"
|
||||
or span.name == "send_data"
|
||||
or span.name.startswith("tool_response")
|
||||
):
|
||||
attributes = dict(span.attributes)
|
||||
attributes["trace_id"] = span.get_span_context().trace_id
|
||||
attributes["span_id"] = span.get_span_context().span_id
|
||||
if attributes.get("gcp.vertex.agent.event_id", None):
|
||||
self.trace_dict[attributes["gcp.vertex.agent.event_id"]] = attributes
|
||||
return export.SpanExportResult.SUCCESS
|
||||
|
||||
def force_flush(self, timeout_millis: int = 30000) -> bool:
|
||||
return True
|
||||
|
||||
|
||||
class AgentRunRequest(BaseModel):
|
||||
app_name: str
|
||||
user_id: str
|
||||
session_id: str
|
||||
new_message: types.Content
|
||||
streaming: bool = False
|
||||
|
||||
|
||||
class AddSessionToEvalSetRequest(BaseModel):
|
||||
eval_id: str
|
||||
session_id: str
|
||||
user_id: str
|
||||
|
||||
|
||||
class RunEvalRequest(BaseModel):
|
||||
eval_ids: list[str] # if empty, then all evals in the eval set are run.
|
||||
eval_metrics: list[EvalMetric]
|
||||
|
||||
|
||||
class RunEvalResult(BaseModel):
|
||||
eval_set_id: str
|
||||
eval_id: str
|
||||
final_eval_status: EvalStatus
|
||||
eval_metric_results: list[tuple[EvalMetric, EvalMetricResult]]
|
||||
session_id: str
|
||||
|
||||
|
||||
def get_fast_api_app(
|
||||
*,
|
||||
agent_dir: str,
|
||||
session_db_url: str = "",
|
||||
allow_origins: Optional[list[str]] = None,
|
||||
web: bool,
|
||||
trace_to_cloud: bool = False,
|
||||
lifespan: Optional[Lifespan[FastAPI]] = None,
|
||||
) -> FastAPI:
|
||||
# InMemory tracing dict.
|
||||
trace_dict: dict[str, Any] = {}
|
||||
|
||||
# Set up tracing in the FastAPI server.
|
||||
provider = TracerProvider()
|
||||
provider.add_span_processor(
|
||||
export.SimpleSpanProcessor(ApiServerSpanExporter(trace_dict))
|
||||
)
|
||||
if trace_to_cloud:
|
||||
envs.load_dotenv_for_agent("", agent_dir)
|
||||
if project_id := os.environ.get("GOOGLE_CLOUD_PROJECT", None):
|
||||
processor = export.BatchSpanProcessor(
|
||||
CloudTraceSpanExporter(project_id=project_id)
|
||||
)
|
||||
provider.add_span_processor(processor)
|
||||
else:
|
||||
logging.warning(
|
||||
"GOOGLE_CLOUD_PROJECT environment variable is not set. Tracing will"
|
||||
" not be enabled."
|
||||
)
|
||||
|
||||
trace.set_tracer_provider(provider)
|
||||
|
||||
exit_stacks = []
|
||||
|
||||
@asynccontextmanager
|
||||
async def internal_lifespan(app: FastAPI):
|
||||
if lifespan:
|
||||
async with lifespan(app) as lifespan_context:
|
||||
yield
|
||||
|
||||
if exit_stacks:
|
||||
for stack in exit_stacks:
|
||||
await stack.aclose()
|
||||
else:
|
||||
yield
|
||||
|
||||
# Run the FastAPI server.
|
||||
app = FastAPI(lifespan=internal_lifespan)
|
||||
|
||||
if allow_origins:
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=allow_origins,
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
if agent_dir not in sys.path:
|
||||
sys.path.append(agent_dir)
|
||||
|
||||
runner_dict = {}
|
||||
root_agent_dict = {}
|
||||
|
||||
# Build the Artifact service
|
||||
artifact_service = InMemoryArtifactService()
|
||||
memory_service = InMemoryMemoryService()
|
||||
|
||||
# Build the Session service
|
||||
agent_engine_id = ""
|
||||
if session_db_url:
|
||||
if session_db_url.startswith("agentengine://"):
|
||||
# Create vertex session service
|
||||
agent_engine_id = session_db_url.split("://")[1]
|
||||
if not agent_engine_id:
|
||||
raise click.ClickException("Agent engine id can not be empty.")
|
||||
envs.load_dotenv_for_agent("", agent_dir)
|
||||
session_service = VertexAiSessionService(
|
||||
os.environ["GOOGLE_CLOUD_PROJECT"],
|
||||
os.environ["GOOGLE_CLOUD_LOCATION"],
|
||||
)
|
||||
else:
|
||||
session_service = DatabaseSessionService(db_url=session_db_url)
|
||||
else:
|
||||
session_service = InMemorySessionService()
|
||||
|
||||
@app.get("/list-apps")
|
||||
def list_apps() -> list[str]:
|
||||
base_path = Path.cwd() / agent_dir
|
||||
if not base_path.exists():
|
||||
raise HTTPException(status_code=404, detail="Path not found")
|
||||
if not base_path.is_dir():
|
||||
raise HTTPException(status_code=400, detail="Not a directory")
|
||||
agent_names = [
|
||||
x
|
||||
for x in os.listdir(base_path)
|
||||
if os.path.isdir(os.path.join(base_path, x))
|
||||
and not x.startswith(".")
|
||||
and x != "__pycache__"
|
||||
]
|
||||
agent_names.sort()
|
||||
return agent_names
|
||||
|
||||
@app.get("/debug/trace/{event_id}")
|
||||
def get_trace_dict(event_id: str) -> Any:
|
||||
event_dict = trace_dict.get(event_id, None)
|
||||
if event_dict is None:
|
||||
raise HTTPException(status_code=404, detail="Trace not found")
|
||||
return event_dict
|
||||
|
||||
@app.get(
|
||||
"/apps/{app_name}/users/{user_id}/sessions/{session_id}",
|
||||
response_model_exclude_none=True,
|
||||
)
|
||||
def get_session(app_name: str, user_id: str, session_id: str) -> Session:
|
||||
# Connect to managed session if agent_engine_id is set.
|
||||
app_name = agent_engine_id if agent_engine_id else app_name
|
||||
session = session_service.get_session(
|
||||
app_name=app_name, user_id=user_id, session_id=session_id
|
||||
)
|
||||
if not session:
|
||||
raise HTTPException(status_code=404, detail="Session not found")
|
||||
return session
|
||||
|
||||
@app.get(
|
||||
"/apps/{app_name}/users/{user_id}/sessions",
|
||||
response_model_exclude_none=True,
|
||||
)
|
||||
def list_sessions(app_name: str, user_id: str) -> list[Session]:
|
||||
# Connect to managed session if agent_engine_id is set.
|
||||
app_name = agent_engine_id if agent_engine_id else app_name
|
||||
return [
|
||||
session
|
||||
for session in session_service.list_sessions(
|
||||
app_name=app_name, user_id=user_id
|
||||
).sessions
|
||||
# Remove sessions that were generated as a part of Eval.
|
||||
if not session.id.startswith(EVAL_SESSION_ID_PREFIX)
|
||||
]
|
||||
|
||||
@app.post(
|
||||
"/apps/{app_name}/users/{user_id}/sessions/{session_id}",
|
||||
response_model_exclude_none=True,
|
||||
)
|
||||
def create_session_with_id(
|
||||
app_name: str,
|
||||
user_id: str,
|
||||
session_id: str,
|
||||
state: Optional[dict[str, Any]] = None,
|
||||
) -> Session:
|
||||
# Connect to managed session if agent_engine_id is set.
|
||||
app_name = agent_engine_id if agent_engine_id else app_name
|
||||
if (
|
||||
session_service.get_session(
|
||||
app_name=app_name, user_id=user_id, session_id=session_id
|
||||
)
|
||||
is not None
|
||||
):
|
||||
logger.warning("Session already exists: %s", session_id)
|
||||
raise HTTPException(
|
||||
status_code=400, detail=f"Session already exists: {session_id}"
|
||||
)
|
||||
|
||||
logger.info("New session created: %s", session_id)
|
||||
return session_service.create_session(
|
||||
app_name=app_name, user_id=user_id, state=state, session_id=session_id
|
||||
)
|
||||
|
||||
@app.post(
|
||||
"/apps/{app_name}/users/{user_id}/sessions",
|
||||
response_model_exclude_none=True,
|
||||
)
|
||||
def create_session(
|
||||
app_name: str,
|
||||
user_id: str,
|
||||
state: Optional[dict[str, Any]] = None,
|
||||
) -> Session:
|
||||
# Connect to managed session if agent_engine_id is set.
|
||||
app_name = agent_engine_id if agent_engine_id else app_name
|
||||
|
||||
logger.info("New session created")
|
||||
return session_service.create_session(
|
||||
app_name=app_name, user_id=user_id, state=state
|
||||
)
|
||||
|
||||
def _get_eval_set_file_path(app_name, agent_dir, eval_set_id) -> str:
|
||||
return os.path.join(
|
||||
agent_dir,
|
||||
app_name,
|
||||
eval_set_id + _EVAL_SET_FILE_EXTENSION,
|
||||
)
|
||||
|
||||
@app.post(
|
||||
"/apps/{app_name}/eval_sets/{eval_set_id}",
|
||||
response_model_exclude_none=True,
|
||||
)
|
||||
def create_eval_set(
|
||||
app_name: str,
|
||||
eval_set_id: str,
|
||||
):
|
||||
"""Creates an eval set, given the id."""
|
||||
pattern = r"^[a-zA-Z0-9_]+$"
|
||||
if not bool(re.fullmatch(pattern, eval_set_id)):
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=(
|
||||
f"Invalid eval set id. Eval set id should have the `{pattern}`"
|
||||
" format"
|
||||
),
|
||||
)
|
||||
# Define the file path
|
||||
new_eval_set_path = _get_eval_set_file_path(
|
||||
app_name, agent_dir, eval_set_id
|
||||
)
|
||||
|
||||
logger.info("Creating eval set file `%s`", new_eval_set_path)
|
||||
|
||||
if not os.path.exists(new_eval_set_path):
|
||||
# Write the JSON string to the file
|
||||
logger.info("Eval set file doesn't exist, we will create a new one.")
|
||||
with open(new_eval_set_path, "w") as f:
|
||||
empty_content = json.dumps([], indent=2)
|
||||
f.write(empty_content)
|
||||
|
||||
@app.get(
|
||||
"/apps/{app_name}/eval_sets",
|
||||
response_model_exclude_none=True,
|
||||
)
|
||||
def list_eval_sets(app_name: str) -> list[str]:
|
||||
"""Lists all eval sets for the given app."""
|
||||
eval_set_file_path = os.path.join(agent_dir, app_name)
|
||||
eval_sets = []
|
||||
for file in os.listdir(eval_set_file_path):
|
||||
if file.endswith(_EVAL_SET_FILE_EXTENSION):
|
||||
eval_sets.append(
|
||||
os.path.basename(file).removesuffix(_EVAL_SET_FILE_EXTENSION)
|
||||
)
|
||||
|
||||
return sorted(eval_sets)
|
||||
|
||||
@app.post(
|
||||
"/apps/{app_name}/eval_sets/{eval_set_id}/add_session",
|
||||
response_model_exclude_none=True,
|
||||
)
|
||||
async def add_session_to_eval_set(
|
||||
app_name: str, eval_set_id: str, req: AddSessionToEvalSetRequest
|
||||
):
|
||||
pattern = r"^[a-zA-Z0-9_]+$"
|
||||
if not bool(re.fullmatch(pattern, req.eval_id)):
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Invalid eval id. Eval id should have the `{pattern}` format",
|
||||
)
|
||||
|
||||
# Get the session
|
||||
session = session_service.get_session(
|
||||
app_name=app_name, user_id=req.user_id, session_id=req.session_id
|
||||
)
|
||||
assert session, "Session not found."
|
||||
# Load the eval set file data
|
||||
eval_set_file_path = _get_eval_set_file_path(
|
||||
app_name, agent_dir, eval_set_id
|
||||
)
|
||||
with open(eval_set_file_path, "r") as file:
|
||||
eval_set_data = json.load(file) # Load JSON into a list
|
||||
|
||||
if [x for x in eval_set_data if x["name"] == req.eval_id]:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=(
|
||||
f"Eval id `{req.eval_id}` already exists in `{eval_set_id}`"
|
||||
" eval set."
|
||||
),
|
||||
)
|
||||
|
||||
# Convert the session data to evaluation format
|
||||
test_data = evals.convert_session_to_eval_format(session)
|
||||
|
||||
# Populate the session with initial session state.
|
||||
initial_session_state = create_empty_state(
|
||||
await _get_root_agent_async(app_name)
|
||||
)
|
||||
|
||||
eval_set_data.append({
|
||||
"name": req.eval_id,
|
||||
"data": test_data,
|
||||
"initial_session": {
|
||||
"state": initial_session_state,
|
||||
"app_name": app_name,
|
||||
"user_id": req.user_id,
|
||||
},
|
||||
})
|
||||
# Serialize the test data to JSON and write to the eval set file.
|
||||
with open(eval_set_file_path, "w") as f:
|
||||
f.write(json.dumps(eval_set_data, indent=2))
|
||||
|
||||
@app.get(
|
||||
"/apps/{app_name}/eval_sets/{eval_set_id}/evals",
|
||||
response_model_exclude_none=True,
|
||||
)
|
||||
def list_evals_in_eval_set(
|
||||
app_name: str,
|
||||
eval_set_id: str,
|
||||
) -> list[str]:
|
||||
"""Lists all evals in an eval set."""
|
||||
# Load the eval set file data
|
||||
eval_set_file_path = _get_eval_set_file_path(
|
||||
app_name, agent_dir, eval_set_id
|
||||
)
|
||||
with open(eval_set_file_path, "r") as file:
|
||||
eval_set_data = json.load(file) # Load JSON into a list
|
||||
|
||||
return sorted([x["name"] for x in eval_set_data])
|
||||
|
||||
@app.post(
|
||||
"/apps/{app_name}/eval_sets/{eval_set_id}/run_eval",
|
||||
response_model_exclude_none=True,
|
||||
)
|
||||
async def run_eval(
|
||||
app_name: str, eval_set_id: str, req: RunEvalRequest
|
||||
) -> list[RunEvalResult]:
|
||||
from .cli_eval import run_evals
|
||||
|
||||
"""Runs an eval given the details in the eval request."""
|
||||
# Create a mapping from eval set file to all the evals that needed to be
|
||||
# run.
|
||||
eval_set_file_path = _get_eval_set_file_path(
|
||||
app_name, agent_dir, eval_set_id
|
||||
)
|
||||
eval_set_to_evals = {eval_set_file_path: req.eval_ids}
|
||||
|
||||
if not req.eval_ids:
|
||||
logger.info(
|
||||
"Eval ids to run list is empty. We will all evals in the eval set."
|
||||
)
|
||||
root_agent = await _get_root_agent_async(app_name)
|
||||
eval_results = list(
|
||||
run_evals(
|
||||
eval_set_to_evals,
|
||||
root_agent,
|
||||
getattr(root_agent, "reset_data", None),
|
||||
req.eval_metrics,
|
||||
session_service=session_service,
|
||||
artifact_service=artifact_service,
|
||||
)
|
||||
)
|
||||
|
||||
run_eval_results = []
|
||||
for eval_result in eval_results:
|
||||
run_eval_results.append(
|
||||
RunEvalResult(
|
||||
app_name=app_name,
|
||||
eval_set_id=eval_set_id,
|
||||
eval_id=eval_result.eval_id,
|
||||
final_eval_status=eval_result.final_eval_status,
|
||||
eval_metric_results=eval_result.eval_metric_results,
|
||||
session_id=eval_result.session_id,
|
||||
)
|
||||
)
|
||||
return run_eval_results
|
||||
|
||||
@app.delete("/apps/{app_name}/users/{user_id}/sessions/{session_id}")
|
||||
def delete_session(app_name: str, user_id: str, session_id: str):
|
||||
# Connect to managed session if agent_engine_id is set.
|
||||
app_name = agent_engine_id if agent_engine_id else app_name
|
||||
session_service.delete_session(
|
||||
app_name=app_name, user_id=user_id, session_id=session_id
|
||||
)
|
||||
|
||||
@app.get(
|
||||
"/apps/{app_name}/users/{user_id}/sessions/{session_id}/artifacts/{artifact_name}",
|
||||
response_model_exclude_none=True,
|
||||
)
|
||||
def load_artifact(
|
||||
app_name: str,
|
||||
user_id: str,
|
||||
session_id: str,
|
||||
artifact_name: str,
|
||||
version: Optional[int] = Query(None),
|
||||
) -> Optional[types.Part]:
|
||||
app_name = agent_engine_id if agent_engine_id else app_name
|
||||
artifact = artifact_service.load_artifact(
|
||||
app_name=app_name,
|
||||
user_id=user_id,
|
||||
session_id=session_id,
|
||||
filename=artifact_name,
|
||||
version=version,
|
||||
)
|
||||
if not artifact:
|
||||
raise HTTPException(status_code=404, detail="Artifact not found")
|
||||
return artifact
|
||||
|
||||
@app.get(
|
||||
"/apps/{app_name}/users/{user_id}/sessions/{session_id}/artifacts/{artifact_name}/versions/{version_id}",
|
||||
response_model_exclude_none=True,
|
||||
)
|
||||
def load_artifact_version(
|
||||
app_name: str,
|
||||
user_id: str,
|
||||
session_id: str,
|
||||
artifact_name: str,
|
||||
version_id: int,
|
||||
) -> Optional[types.Part]:
|
||||
app_name = agent_engine_id if agent_engine_id else app_name
|
||||
artifact = artifact_service.load_artifact(
|
||||
app_name=app_name,
|
||||
user_id=user_id,
|
||||
session_id=session_id,
|
||||
filename=artifact_name,
|
||||
version=version_id,
|
||||
)
|
||||
if not artifact:
|
||||
raise HTTPException(status_code=404, detail="Artifact not found")
|
||||
return artifact
|
||||
|
||||
@app.get(
|
||||
"/apps/{app_name}/users/{user_id}/sessions/{session_id}/artifacts",
|
||||
response_model_exclude_none=True,
|
||||
)
|
||||
def list_artifact_names(
|
||||
app_name: str, user_id: str, session_id: str
|
||||
) -> list[str]:
|
||||
app_name = agent_engine_id if agent_engine_id else app_name
|
||||
return artifact_service.list_artifact_keys(
|
||||
app_name=app_name, user_id=user_id, session_id=session_id
|
||||
)
|
||||
|
||||
@app.get(
|
||||
"/apps/{app_name}/users/{user_id}/sessions/{session_id}/artifacts/{artifact_name}/versions",
|
||||
response_model_exclude_none=True,
|
||||
)
|
||||
def list_artifact_versions(
|
||||
app_name: str, user_id: str, session_id: str, artifact_name: str
|
||||
) -> list[int]:
|
||||
app_name = agent_engine_id if agent_engine_id else app_name
|
||||
return artifact_service.list_versions(
|
||||
app_name=app_name,
|
||||
user_id=user_id,
|
||||
session_id=session_id,
|
||||
filename=artifact_name,
|
||||
)
|
||||
|
||||
@app.delete(
|
||||
"/apps/{app_name}/users/{user_id}/sessions/{session_id}/artifacts/{artifact_name}",
|
||||
)
|
||||
def delete_artifact(
|
||||
app_name: str, user_id: str, session_id: str, artifact_name: str
|
||||
):
|
||||
app_name = agent_engine_id if agent_engine_id else app_name
|
||||
artifact_service.delete_artifact(
|
||||
app_name=app_name,
|
||||
user_id=user_id,
|
||||
session_id=session_id,
|
||||
filename=artifact_name,
|
||||
)
|
||||
|
||||
@app.post("/run", response_model_exclude_none=True)
|
||||
async def agent_run(req: AgentRunRequest) -> list[Event]:
|
||||
# Connect to managed session if agent_engine_id is set.
|
||||
app_id = agent_engine_id if agent_engine_id else req.app_name
|
||||
session = session_service.get_session(
|
||||
app_name=app_id, user_id=req.user_id, session_id=req.session_id
|
||||
)
|
||||
if not session:
|
||||
raise HTTPException(status_code=404, detail="Session not found")
|
||||
runner = await _get_runner_async(req.app_name)
|
||||
events = [
|
||||
event
|
||||
async for event in runner.run_async(
|
||||
user_id=req.user_id,
|
||||
session_id=req.session_id,
|
||||
new_message=req.new_message,
|
||||
)
|
||||
]
|
||||
logger.info("Generated %s events in agent run: %s", len(events), events)
|
||||
return events
|
||||
|
||||
@app.post("/run_sse")
|
||||
async def agent_run_sse(req: AgentRunRequest) -> StreamingResponse:
|
||||
# Connect to managed session if agent_engine_id is set.
|
||||
app_id = agent_engine_id if agent_engine_id else req.app_name
|
||||
# SSE endpoint
|
||||
session = session_service.get_session(
|
||||
app_name=app_id, user_id=req.user_id, session_id=req.session_id
|
||||
)
|
||||
if not session:
|
||||
raise HTTPException(status_code=404, detail="Session not found")
|
||||
|
||||
# Convert the events to properly formatted SSE
|
||||
async def event_generator():
|
||||
try:
|
||||
stream_mode = StreamingMode.SSE if req.streaming else StreamingMode.NONE
|
||||
runner = await _get_runner_async(req.app_name)
|
||||
async for event in runner.run_async(
|
||||
user_id=req.user_id,
|
||||
session_id=req.session_id,
|
||||
new_message=req.new_message,
|
||||
run_config=RunConfig(streaming_mode=stream_mode),
|
||||
):
|
||||
# Format as SSE data
|
||||
sse_event = event.model_dump_json(exclude_none=True, by_alias=True)
|
||||
logger.info("Generated event in agent run streaming: %s", sse_event)
|
||||
yield f"data: {sse_event}\n\n"
|
||||
except Exception as e:
|
||||
logger.exception("Error in event_generator: %s", e)
|
||||
# You might want to yield an error event here
|
||||
yield f'data: {{"error": "{str(e)}"}}\n\n'
|
||||
|
||||
# Returns a streaming response with the proper media type for SSE
|
||||
return StreamingResponse(
|
||||
event_generator(),
|
||||
media_type="text/event-stream",
|
||||
)
|
||||
|
||||
@app.get(
|
||||
"/apps/{app_name}/users/{user_id}/sessions/{session_id}/events/{event_id}/graph",
|
||||
response_model_exclude_none=True,
|
||||
)
|
||||
async def get_event_graph(
|
||||
app_name: str, user_id: str, session_id: str, event_id: str
|
||||
):
|
||||
# Connect to managed session if agent_engine_id is set.
|
||||
app_id = agent_engine_id if agent_engine_id else app_name
|
||||
session = session_service.get_session(
|
||||
app_name=app_id, user_id=user_id, session_id=session_id
|
||||
)
|
||||
session_events = session.events if session else []
|
||||
event = next((x for x in session_events if x.id == event_id), None)
|
||||
if not event:
|
||||
return {}
|
||||
|
||||
from . import agent_graph
|
||||
|
||||
function_calls = event.get_function_calls()
|
||||
function_responses = event.get_function_responses()
|
||||
root_agent = await _get_root_agent_async(app_name)
|
||||
dot_graph = None
|
||||
if function_calls:
|
||||
function_call_highlights = []
|
||||
for function_call in function_calls:
|
||||
from_name = event.author
|
||||
to_name = function_call.name
|
||||
function_call_highlights.append((from_name, to_name))
|
||||
dot_graph = agent_graph.get_agent_graph(
|
||||
root_agent, function_call_highlights
|
||||
)
|
||||
elif function_responses:
|
||||
function_responses_highlights = []
|
||||
for function_response in function_responses:
|
||||
from_name = function_response.name
|
||||
to_name = event.author
|
||||
function_responses_highlights.append((from_name, to_name))
|
||||
dot_graph = agent_graph.get_agent_graph(
|
||||
root_agent, function_responses_highlights
|
||||
)
|
||||
else:
|
||||
from_name = event.author
|
||||
to_name = ""
|
||||
dot_graph = agent_graph.get_agent_graph(
|
||||
root_agent, [(from_name, to_name)]
|
||||
)
|
||||
if dot_graph and isinstance(dot_graph, graphviz.Digraph):
|
||||
return {"dot_src": dot_graph.source}
|
||||
else:
|
||||
return {}
|
||||
|
||||
@app.websocket("/run_live")
|
||||
async def agent_live_run(
|
||||
websocket: WebSocket,
|
||||
app_name: str,
|
||||
user_id: str,
|
||||
session_id: str,
|
||||
modalities: List[Literal["TEXT", "AUDIO"]] = Query(
|
||||
default=["TEXT", "AUDIO"]
|
||||
), # Only allows "TEXT" or "AUDIO"
|
||||
) -> None:
|
||||
await websocket.accept()
|
||||
|
||||
# Connect to managed session if agent_engine_id is set.
|
||||
app_id = agent_engine_id if agent_engine_id else app_name
|
||||
session = session_service.get_session(
|
||||
app_name=app_id, user_id=user_id, session_id=session_id
|
||||
)
|
||||
if not session:
|
||||
# Accept first so that the client is aware of connection establishment,
|
||||
# then close with a specific code.
|
||||
await websocket.close(code=1002, reason="Session not found")
|
||||
return
|
||||
|
||||
live_request_queue = LiveRequestQueue()
|
||||
|
||||
async def forward_events():
|
||||
runner = await _get_runner_async(app_name)
|
||||
async for event in runner.run_live(
|
||||
session=session, live_request_queue=live_request_queue
|
||||
):
|
||||
await websocket.send_text(
|
||||
event.model_dump_json(exclude_none=True, by_alias=True)
|
||||
)
|
||||
|
||||
async def process_messages():
|
||||
try:
|
||||
while True:
|
||||
data = await websocket.receive_text()
|
||||
# Validate and send the received message to the live queue.
|
||||
live_request_queue.send(LiveRequest.model_validate_json(data))
|
||||
except ValidationError as ve:
|
||||
logger.error("Validation error in process_messages: %s", ve)
|
||||
|
||||
# Run both tasks concurrently and cancel all if one fails.
|
||||
tasks = [
|
||||
asyncio.create_task(forward_events()),
|
||||
asyncio.create_task(process_messages()),
|
||||
]
|
||||
done, pending = await asyncio.wait(
|
||||
tasks, return_when=asyncio.FIRST_EXCEPTION
|
||||
)
|
||||
try:
|
||||
# This will re-raise any exception from the completed tasks.
|
||||
for task in done:
|
||||
task.result()
|
||||
except WebSocketDisconnect:
|
||||
logger.info("Client disconnected during process_messages.")
|
||||
except Exception as e:
|
||||
logger.exception("Error during live websocket communication: %s", e)
|
||||
traceback.print_exc()
|
||||
WEBSOCKET_INTERNAL_ERROR_CODE = 1011
|
||||
WEBSOCKET_MAX_BYTES_FOR_REASON = 123
|
||||
await websocket.close(
|
||||
code=WEBSOCKET_INTERNAL_ERROR_CODE,
|
||||
reason=str(e)[:WEBSOCKET_MAX_BYTES_FOR_REASON],
|
||||
)
|
||||
finally:
|
||||
for task in pending:
|
||||
task.cancel()
|
||||
|
||||
async def _get_root_agent_async(app_name: str) -> Agent:
|
||||
"""Returns the root agent for the given app."""
|
||||
if app_name in root_agent_dict:
|
||||
return root_agent_dict[app_name]
|
||||
agent_module = importlib.import_module(app_name)
|
||||
if getattr(agent_module.agent, "root_agent"):
|
||||
root_agent = agent_module.agent.root_agent
|
||||
else:
|
||||
raise ValueError(f'Unable to find "root_agent" from {app_name}.')
|
||||
|
||||
# Handle an awaitable root agent and await for the actual agent.
|
||||
if inspect.isawaitable(root_agent):
|
||||
try:
|
||||
agent, exit_stack = await root_agent
|
||||
exit_stacks.append(exit_stack)
|
||||
root_agent = agent
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"error getting root agent, {e}") from e
|
||||
|
||||
root_agent_dict[app_name] = root_agent
|
||||
return root_agent
|
||||
|
||||
async def _get_runner_async(app_name: str) -> Runner:
|
||||
"""Returns the runner for the given app."""
|
||||
envs.load_dotenv_for_agent(os.path.basename(app_name), agent_dir)
|
||||
if app_name in runner_dict:
|
||||
return runner_dict[app_name]
|
||||
root_agent = await _get_root_agent_async(app_name)
|
||||
runner = Runner(
|
||||
app_name=agent_engine_id if agent_engine_id else app_name,
|
||||
agent=root_agent,
|
||||
artifact_service=artifact_service,
|
||||
session_service=session_service,
|
||||
memory_service=memory_service,
|
||||
)
|
||||
runner_dict[app_name] = runner
|
||||
return runner
|
||||
|
||||
if web:
|
||||
BASE_DIR = Path(__file__).parent.resolve()
|
||||
ANGULAR_DIST_PATH = BASE_DIR / "browser"
|
||||
|
||||
@app.get("/")
|
||||
async def redirect_to_dev_ui():
|
||||
return RedirectResponse("/dev-ui")
|
||||
|
||||
@app.get("/dev-ui")
|
||||
async def dev_ui():
|
||||
return FileResponse(BASE_DIR / "browser/index.html")
|
||||
|
||||
app.mount(
|
||||
"/", StaticFiles(directory=ANGULAR_DIST_PATH, html=True), name="static"
|
||||
)
|
||||
return app
|
||||
@@ -0,0 +1,49 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import re
|
||||
from typing import Any
|
||||
from typing import Optional
|
||||
|
||||
from ...agents.base_agent import BaseAgent
|
||||
from ...agents.llm_agent import LlmAgent
|
||||
|
||||
__all__ = [
|
||||
'create_empty_state',
|
||||
]
|
||||
|
||||
|
||||
def _create_empty_state(agent: BaseAgent, all_state: dict[str, Any]):
|
||||
for sub_agent in agent.sub_agents:
|
||||
_create_empty_state(sub_agent, all_state)
|
||||
|
||||
if (
|
||||
isinstance(agent, LlmAgent)
|
||||
and agent.instruction
|
||||
and isinstance(agent.instruction, str)
|
||||
):
|
||||
for key in re.findall(r'{([\w]+)}', agent.instruction):
|
||||
all_state[key] = ''
|
||||
|
||||
|
||||
def create_empty_state(
|
||||
agent: BaseAgent, initialized_states: Optional[dict[str, Any]] = None
|
||||
) -> dict[str, Any]:
|
||||
"""Creates empty str for non-initialized states."""
|
||||
non_initialized_states = {}
|
||||
_create_empty_state(agent, non_initialized_states)
|
||||
for key in initialized_states or {}:
|
||||
if key in non_initialized_states:
|
||||
del non_initialized_states[key]
|
||||
return non_initialized_states
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,54 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
import os
|
||||
|
||||
from dotenv import load_dotenv
|
||||
|
||||
logger = logging.getLogger(__file__)
|
||||
|
||||
|
||||
def _walk_to_root_until_found(folder, filename) -> str:
|
||||
checkpath = os.path.join(folder, filename)
|
||||
if os.path.exists(checkpath) and os.path.isfile(checkpath):
|
||||
return checkpath
|
||||
|
||||
parent_folder = os.path.dirname(folder)
|
||||
if parent_folder == folder: # reached the root
|
||||
return ''
|
||||
|
||||
return _walk_to_root_until_found(parent_folder, filename)
|
||||
|
||||
|
||||
def load_dotenv_for_agent(
|
||||
agent_name: str, agent_parent_folder: str, filename: str = '.env'
|
||||
):
|
||||
"""Lods the .env file for the agent module."""
|
||||
|
||||
# Gets the folder of agent_module as starting_folder
|
||||
starting_folder = os.path.abspath(
|
||||
os.path.join(agent_parent_folder, agent_name)
|
||||
)
|
||||
dotenv_file_path = _walk_to_root_until_found(starting_folder, filename)
|
||||
if dotenv_file_path:
|
||||
load_dotenv(dotenv_file_path, override=True, verbose=True)
|
||||
logger.info(
|
||||
'Loaded %s file for %s at %s',
|
||||
filename,
|
||||
agent_name,
|
||||
dotenv_file_path,
|
||||
)
|
||||
else:
|
||||
logger.info('No %s file found for %s', filename, agent_name)
|
||||
@@ -0,0 +1,93 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Any
|
||||
|
||||
from ...sessions.session import Session
|
||||
|
||||
|
||||
def convert_session_to_eval_format(session: Session) -> list[dict[str, Any]]:
|
||||
"""Converts a session data into eval format.
|
||||
|
||||
Args:
|
||||
session: The session that should be converted.
|
||||
|
||||
Returns:
|
||||
list: A single evaluation dataset in the required format.
|
||||
"""
|
||||
eval_case = []
|
||||
events = session.events if session and session.events else []
|
||||
|
||||
for event in events:
|
||||
if event.author == 'user':
|
||||
if not event.content or not event.content.parts:
|
||||
continue
|
||||
|
||||
# Extract user query
|
||||
content = event.content
|
||||
parts = content.parts
|
||||
|
||||
query = parts[0].text or ''
|
||||
|
||||
# Find the corresponding tool usage or response for the query
|
||||
expected_tool_use = []
|
||||
intermediate_agent_responses = []
|
||||
|
||||
# Check subsequent events to extract tool uses or responses for this turn.
|
||||
for subsequent_event in events[events.index(event) + 1 :]:
|
||||
event_author = subsequent_event.author or 'agent'
|
||||
if event_author == 'user':
|
||||
# We found an event where the author was the user. This means that a
|
||||
# new turn has started. So close this turn here.
|
||||
break
|
||||
|
||||
if not subsequent_event.content or not subsequent_event.content.parts:
|
||||
continue
|
||||
|
||||
for subsequent_part in subsequent_event.content.parts:
|
||||
# Some events have both function call and reference
|
||||
|
||||
if subsequent_part.function_call:
|
||||
tool_name = subsequent_part.function_call.name or ''
|
||||
tool_input = subsequent_part.function_call.args or {}
|
||||
expected_tool_use.append({
|
||||
'tool_name': tool_name,
|
||||
'tool_input': tool_input,
|
||||
})
|
||||
elif subsequent_part.text:
|
||||
# Also keep track of all the natural language responses that
|
||||
# agent (or sub agents) generated.
|
||||
intermediate_agent_responses.append(
|
||||
{'author': event_author, 'text': subsequent_part.text}
|
||||
)
|
||||
|
||||
# If we are here then either we are done reading all the events or we
|
||||
# encountered an event that had content authored by the end-user.
|
||||
# This, basically means an end of turn.
|
||||
# We assume that the last natural language intermediate response is the
|
||||
# final response from the agent/model. We treat that as a reference.
|
||||
eval_case.append({
|
||||
'query': query,
|
||||
'expected_tool_use': expected_tool_use,
|
||||
'expected_intermediate_agent_responses': intermediate_agent_responses[
|
||||
:-1
|
||||
],
|
||||
'reference': (
|
||||
intermediate_agent_responses[-1]['text']
|
||||
if intermediate_agent_responses
|
||||
else ''
|
||||
),
|
||||
})
|
||||
|
||||
return eval_case
|
||||
@@ -0,0 +1,72 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
import os
|
||||
import tempfile
|
||||
import time
|
||||
|
||||
LOGGING_FORMAT = (
|
||||
'%(asctime)s - %(levelname)s - %(filename)s:%(lineno)d - %(message)s'
|
||||
)
|
||||
|
||||
|
||||
def log_to_stderr(level=logging.INFO):
|
||||
logging.basicConfig(
|
||||
level=level,
|
||||
format=LOGGING_FORMAT,
|
||||
)
|
||||
|
||||
|
||||
def log_to_tmp_folder(
|
||||
level=logging.INFO,
|
||||
*,
|
||||
sub_folder: str = 'agents_log',
|
||||
log_file_prefix: str = 'agent',
|
||||
log_file_timestamp: str = time.strftime('%Y%m%d_%H%M%S'),
|
||||
):
|
||||
"""Logs to system temp folder, instead of logging to stderr.
|
||||
|
||||
Args
|
||||
sub_folder: str = 'agents_log',
|
||||
log_file_prefix: str = 'agent',
|
||||
log_file_timestamp: str = time.strftime('%Y%m%d_%H%M%S'),
|
||||
|
||||
Returns
|
||||
the log file path.
|
||||
"""
|
||||
log_dir = os.path.join(tempfile.gettempdir(), sub_folder)
|
||||
log_filename = f'{log_file_prefix}.{log_file_timestamp}.log'
|
||||
log_filepath = os.path.join(log_dir, log_filename)
|
||||
|
||||
os.makedirs(log_dir, exist_ok=True)
|
||||
|
||||
file_handler = logging.FileHandler(log_filepath, mode='w')
|
||||
file_handler.setLevel(level)
|
||||
file_handler.setFormatter(logging.Formatter(LOGGING_FORMAT))
|
||||
|
||||
root_logger = logging.getLogger()
|
||||
root_logger.setLevel(level)
|
||||
root_logger.handlers = [] # Clear handles to disable logging to stderr
|
||||
root_logger.addHandler(file_handler)
|
||||
|
||||
print(f'Log setup complete: {log_filepath}')
|
||||
|
||||
latest_log_link = os.path.join(log_dir, f'{log_file_prefix}.latest.log')
|
||||
if os.path.islink(latest_log_link):
|
||||
os.unlink(latest_log_link)
|
||||
os.symlink(log_filepath, latest_log_link)
|
||||
|
||||
print(f'To access latest log: tail -F {latest_log_link}')
|
||||
return log_filepath
|
||||
@@ -0,0 +1,49 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
|
||||
from .base_code_executor import BaseCodeExecutor
|
||||
from .code_executor_context import CodeExecutorContext
|
||||
from .unsafe_local_code_executor import UnsafeLocalCodeExecutor
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
__all__ = [
|
||||
'BaseCodeExecutor',
|
||||
'CodeExecutorContext',
|
||||
'UnsafeLocalCodeExecutor',
|
||||
]
|
||||
|
||||
try:
|
||||
from .vertex_ai_code_executor import VertexAiCodeExecutor
|
||||
|
||||
__all__.append('VertexAiCodeExecutor')
|
||||
except ImportError:
|
||||
logger.debug(
|
||||
'The Vertex sdk is not installed. If you want to use the Vertex Code'
|
||||
' Interpreter with agents, please install it. If not, you can ignore this'
|
||||
' warning.'
|
||||
)
|
||||
|
||||
try:
|
||||
from .container_code_executor import ContainerCodeExecutor
|
||||
|
||||
__all__.append('ContainerCodeExecutor')
|
||||
except ImportError:
|
||||
logger.debug(
|
||||
'The docker sdk is not installed. If you want to use the Container Code'
|
||||
' Executor with agents, please install it. If not, you can ignore this'
|
||||
' warning.'
|
||||
)
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,97 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import abc
|
||||
from typing import List
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from ..agents.invocation_context import InvocationContext
|
||||
from .code_execution_utils import CodeExecutionInput
|
||||
from .code_execution_utils import CodeExecutionResult
|
||||
|
||||
|
||||
class BaseCodeExecutor(BaseModel):
|
||||
"""Abstract base class for all code executors.
|
||||
|
||||
The code executor allows the agent to execute code blocks from model responses
|
||||
and incorporate the execution results into the final response.
|
||||
|
||||
Attributes:
|
||||
optimize_data_file: If true, extract and process data files from the model
|
||||
request and attach them to the code executor. Supported data file
|
||||
MimeTypes are [text/csv]. Default to False.
|
||||
stateful: Whether the code executor is stateful. Default to False.
|
||||
error_retry_attempts: The number of attempts to retry on consecutive code
|
||||
execution errors. Default to 2.
|
||||
code_block_delimiters: The list of the enclosing delimiters to identify the
|
||||
code blocks.
|
||||
execution_result_delimiters: The delimiters to format the code execution
|
||||
result.
|
||||
"""
|
||||
|
||||
optimize_data_file: bool = False
|
||||
"""
|
||||
If true, extract and process data files from the model request
|
||||
and attach them to the code executor.
|
||||
Supported data file MimeTypes are [text/csv].
|
||||
|
||||
Default to False.
|
||||
"""
|
||||
|
||||
stateful: bool = False
|
||||
"""
|
||||
Whether the code executor is stateful. Default to False.
|
||||
"""
|
||||
|
||||
error_retry_attempts: int = 2
|
||||
"""
|
||||
The number of attempts to retry on consecutive code execution errors. Default to 2.
|
||||
"""
|
||||
|
||||
code_block_delimiters: List[tuple[str, str]] = [
|
||||
('```tool_code\n', '\n```'),
|
||||
('```python\n', '\n```'),
|
||||
]
|
||||
"""
|
||||
The list of the enclosing delimiters to identify the code blocks.
|
||||
For example, the delimiter ('```python\n', '\n```') can be
|
||||
used to identify code blocks with the following format:
|
||||
|
||||
```python
|
||||
print("hello")
|
||||
```
|
||||
"""
|
||||
|
||||
execution_result_delimiters: tuple[str, str] = ('```tool_output\n', '\n```')
|
||||
"""
|
||||
The delimiters to format the code execution result.
|
||||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
def execute_code(
|
||||
self,
|
||||
invocation_context: InvocationContext,
|
||||
code_execution_input: CodeExecutionInput,
|
||||
) -> CodeExecutionResult:
|
||||
"""Executes code and return the code execution result.
|
||||
|
||||
Args:
|
||||
invocation_context: The invocation context of the code execution.
|
||||
code_execution_input: The code execution input.
|
||||
|
||||
Returns:
|
||||
The code execution result.
|
||||
"""
|
||||
pass
|
||||
@@ -0,0 +1,256 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Utility functions for code execution."""
|
||||
|
||||
import base64
|
||||
import binascii
|
||||
import copy
|
||||
import dataclasses
|
||||
import re
|
||||
from typing import List, Optional
|
||||
|
||||
from google.genai import types
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class File:
|
||||
"""A structure that contains a file name and its content."""
|
||||
|
||||
name: str
|
||||
"""
|
||||
The name of the file with file extension (e.g., "file.csv").
|
||||
"""
|
||||
|
||||
content: str
|
||||
"""
|
||||
The base64-encoded bytes of the file content.
|
||||
"""
|
||||
|
||||
mime_type: str = 'text/plain'
|
||||
"""
|
||||
The mime type of the file (e.g., "image/png").
|
||||
"""
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class CodeExecutionInput:
|
||||
"""A structure that contains the input of code execution."""
|
||||
|
||||
code: str
|
||||
"""
|
||||
The code to execute.
|
||||
"""
|
||||
|
||||
input_files: list[File] = dataclasses.field(default_factory=list)
|
||||
"""
|
||||
The input files available to the code.
|
||||
"""
|
||||
|
||||
execution_id: Optional[str] = None
|
||||
"""
|
||||
The execution ID for the stateful code execution.
|
||||
"""
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class CodeExecutionResult:
|
||||
"""A structure that contains the result of code execution."""
|
||||
|
||||
stdout: str = ''
|
||||
"""
|
||||
The standard output of the code execution.
|
||||
"""
|
||||
|
||||
stderr: str = ''
|
||||
"""
|
||||
The standard error of the code execution.
|
||||
"""
|
||||
|
||||
output_files: list[File] = dataclasses.field(default_factory=list)
|
||||
"""
|
||||
The output files from the code execution.
|
||||
"""
|
||||
|
||||
|
||||
class CodeExecutionUtils:
|
||||
"""Utility functions for code execution."""
|
||||
|
||||
@staticmethod
|
||||
def get_encoded_file_content(data: bytes) -> bytes:
|
||||
"""Gets the file content as a base64-encoded bytes.
|
||||
|
||||
Args:
|
||||
data: The file content bytes.
|
||||
|
||||
Returns:
|
||||
The file content as a base64-encoded bytes.
|
||||
"""
|
||||
|
||||
def _is_base64_encoded(data: bytes) -> bool:
|
||||
try:
|
||||
return base64.b64encode(base64.b64decode(data)) == data
|
||||
except binascii.Error:
|
||||
return False
|
||||
|
||||
return data if _is_base64_encoded(data) else base64.b64encode(data)
|
||||
|
||||
@staticmethod
|
||||
def extract_code_and_truncate_content(
|
||||
content: types.Content,
|
||||
code_block_delimiters: List[tuple[str, str]],
|
||||
) -> Optional[str]:
|
||||
"""Extracts the first code block from the content and truncate everything after it.
|
||||
|
||||
Args:
|
||||
content: The mutable content to extract the code from.
|
||||
code_block_delimiters: The list of the enclosing delimiters to identify
|
||||
the code blocks.
|
||||
|
||||
Returns:
|
||||
The first code block if found, otherwise None.
|
||||
"""
|
||||
if not content or not content.parts:
|
||||
return
|
||||
|
||||
# Extract the code from the executable code parts if there're no associated
|
||||
# code execution result parts.
|
||||
for idx, part in enumerate(content.parts):
|
||||
if part.executable_code and (
|
||||
idx == len(content.parts) - 1
|
||||
or not content.parts[idx + 1].code_execution_result
|
||||
):
|
||||
content.parts = content.parts[: idx + 1]
|
||||
return part.executable_code.code
|
||||
|
||||
# Extract the code from the text parts.
|
||||
text_parts = [p for p in content.parts if p.text]
|
||||
if not text_parts:
|
||||
return
|
||||
|
||||
first_text_part = copy.deepcopy(text_parts[0])
|
||||
response_text = '\n'.join([p.text for p in text_parts])
|
||||
|
||||
# Find the first code block.
|
||||
leading_delimiter_pattern = '|'.join(d[0] for d in code_block_delimiters)
|
||||
trailing_delimiter_pattern = '|'.join(d[1] for d in code_block_delimiters)
|
||||
pattern = re.compile(
|
||||
(
|
||||
rf'(?P<prefix>.*?)({leading_delimiter_pattern})(?P<code>.*?)({trailing_delimiter_pattern})(?P<suffix>.*?)$'
|
||||
).encode(),
|
||||
re.DOTALL,
|
||||
)
|
||||
pattern_match = pattern.search(response_text.encode())
|
||||
if pattern_match is None:
|
||||
return
|
||||
|
||||
code_str = pattern_match.group('code').decode()
|
||||
if not code_str:
|
||||
return
|
||||
|
||||
content.parts = []
|
||||
if pattern_match.group('prefix'):
|
||||
first_text_part.text = pattern_match.group('prefix').decode()
|
||||
content.parts.append(first_text_part)
|
||||
content.parts.append(
|
||||
CodeExecutionUtils.build_executable_code_part(code_str)
|
||||
)
|
||||
return pattern_match.group('code').decode()
|
||||
|
||||
@staticmethod
|
||||
def build_executable_code_part(code: str) -> types.Part:
|
||||
"""Builds an executable code part with code string.
|
||||
|
||||
Args:
|
||||
code: The code string.
|
||||
|
||||
Returns:
|
||||
The constructed executable code part.
|
||||
"""
|
||||
return types.Part.from_executable_code(
|
||||
code=code,
|
||||
language='PYTHON',
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def build_code_execution_result_part(
|
||||
code_execution_result: CodeExecutionResult,
|
||||
) -> types.Part:
|
||||
"""Builds the code execution result part from the code execution result.
|
||||
|
||||
Args:
|
||||
code_execution_result: The code execution result.
|
||||
|
||||
Returns:
|
||||
The constructed code execution result part.
|
||||
"""
|
||||
if code_execution_result.stderr:
|
||||
return types.Part.from_code_execution_result(
|
||||
outcome='OUTCOME_FAILED',
|
||||
output=code_execution_result.stderr,
|
||||
)
|
||||
final_result = []
|
||||
if code_execution_result.stdout or not code_execution_result.output_files:
|
||||
final_result.append(
|
||||
'Code execution result:\n' + '%s\n' % code_execution_result.stdout
|
||||
)
|
||||
if code_execution_result.output_files:
|
||||
final_result.append(
|
||||
'Saved artifacts:\n'
|
||||
+ ','.join(
|
||||
['`%s`' % f.name for f in code_execution_result.output_files]
|
||||
)
|
||||
)
|
||||
return types.Part.from_code_execution_result(
|
||||
outcome='OUTCOME_OK',
|
||||
output='\n\n'.join(final_result),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def convert_code_execution_parts(
|
||||
content: types.Content,
|
||||
code_block_delimiter: tuple[str, str],
|
||||
execution_result_delimiters: tuple[str, str],
|
||||
):
|
||||
"""Converts the code execution parts to text parts in a Content.
|
||||
|
||||
Args:
|
||||
content: The mutable content to convert the code execution parts to text
|
||||
parts.
|
||||
code_block_delimiter: The delimiter to format the code block.
|
||||
execution_result_delimiters: The delimiter to format the code execution
|
||||
result.
|
||||
"""
|
||||
if not content.parts:
|
||||
return
|
||||
|
||||
# Handle the conversion of trailing executable code parts.
|
||||
if content.parts[-1].executable_code:
|
||||
content.parts[-1] = types.Part(
|
||||
text=(
|
||||
code_block_delimiter[0]
|
||||
+ content.parts[-1].executable_code.code
|
||||
+ code_block_delimiter[1]
|
||||
)
|
||||
)
|
||||
# Handle the conversion of trailing code execution result parts.
|
||||
# Skip if the Content has multiple parts, which means the Content is
|
||||
# likely generated by the model.
|
||||
elif len(content.parts) == 1 and content.parts[-1].code_execution_result:
|
||||
content.parts[-1] = types.Part(
|
||||
text=execution_result_delimiters[0]
|
||||
+ content.parts[-1].code_execution_result.output
|
||||
+ execution_result_delimiters[1]
|
||||
)
|
||||
content.role = 'user'
|
||||
@@ -0,0 +1,202 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""The persistent context used to configure the code executor."""
|
||||
|
||||
import copy
|
||||
import dataclasses
|
||||
import datetime
|
||||
from typing import Any
|
||||
from typing import Optional
|
||||
|
||||
from ..sessions.state import State
|
||||
from .code_execution_utils import File
|
||||
|
||||
_CONTEXT_KEY = '_code_execution_context'
|
||||
_SESSION_ID_KEY = 'execution_session_id'
|
||||
_PROCESSED_FILE_NAMES_KEY = 'processed_input_files'
|
||||
_INPUT_FILE_KEY = '_code_executor_input_files'
|
||||
_ERROR_COUNT_KEY = '_code_executor_error_counts'
|
||||
|
||||
_CODE_EXECUTION_RESULTS_KEY = '_code_execution_results'
|
||||
|
||||
|
||||
class CodeExecutorContext:
|
||||
"""The persistent context used to configure the code executor."""
|
||||
|
||||
_context: dict[str, Any]
|
||||
|
||||
def __init__(self, session_state: State):
|
||||
"""Initializes the code executor context.
|
||||
|
||||
Args:
|
||||
session_state: The session state to get the code executor context from.
|
||||
"""
|
||||
self._context = self._get_code_executor_context(session_state)
|
||||
self._session_state = session_state
|
||||
|
||||
def get_state_delta(self) -> dict[str, Any]:
|
||||
"""Gets the state delta to update in the persistent session state.
|
||||
|
||||
Returns:
|
||||
The state delta to update in the persistent session state.
|
||||
"""
|
||||
context_to_update = copy.deepcopy(self._context)
|
||||
return {_CONTEXT_KEY: context_to_update}
|
||||
|
||||
def get_execution_id(self) -> Optional[str]:
|
||||
"""Gets the session ID for the code executor.
|
||||
|
||||
Returns:
|
||||
The session ID for the code executor context.
|
||||
"""
|
||||
if _SESSION_ID_KEY not in self._context:
|
||||
return None
|
||||
return self._context[_SESSION_ID_KEY]
|
||||
|
||||
def set_execution_id(self, session_id: str):
|
||||
"""Sets the session ID for the code executor.
|
||||
|
||||
Args:
|
||||
session_id: The session ID for the code executor.
|
||||
"""
|
||||
self._context[_SESSION_ID_KEY] = session_id
|
||||
|
||||
def get_processed_file_names(self) -> list[str]:
|
||||
"""Gets the processed file names from the session state.
|
||||
|
||||
Returns:
|
||||
A list of processed file names in the code executor context.
|
||||
"""
|
||||
if _PROCESSED_FILE_NAMES_KEY not in self._context:
|
||||
return []
|
||||
return self._context[_PROCESSED_FILE_NAMES_KEY]
|
||||
|
||||
def add_processed_file_names(self, file_names: [str]):
|
||||
"""Adds the processed file name to the session state.
|
||||
|
||||
Args:
|
||||
file_names: The processed file names to add to the session state.
|
||||
"""
|
||||
if _PROCESSED_FILE_NAMES_KEY not in self._context:
|
||||
self._context[_PROCESSED_FILE_NAMES_KEY] = []
|
||||
self._context[_PROCESSED_FILE_NAMES_KEY].extend(file_names)
|
||||
|
||||
def get_input_files(self) -> list[File]:
|
||||
"""Gets the code executor input file names from the session state.
|
||||
|
||||
Returns:
|
||||
A list of input files in the code executor context.
|
||||
"""
|
||||
if _INPUT_FILE_KEY not in self._session_state:
|
||||
return []
|
||||
return [File(**file) for file in self._session_state[_INPUT_FILE_KEY]]
|
||||
|
||||
def add_input_files(
|
||||
self,
|
||||
input_files: list[File],
|
||||
):
|
||||
"""Adds the input files to the code executor context.
|
||||
|
||||
Args:
|
||||
input_files: The input files to add to the code executor context.
|
||||
"""
|
||||
if _INPUT_FILE_KEY not in self._session_state:
|
||||
self._session_state[_INPUT_FILE_KEY] = []
|
||||
for input_file in input_files:
|
||||
self._session_state[_INPUT_FILE_KEY].append(
|
||||
dataclasses.asdict(input_file)
|
||||
)
|
||||
|
||||
def clear_input_files(self):
|
||||
"""Removes the input files and processed file names to the code executor context."""
|
||||
if _INPUT_FILE_KEY in self._session_state:
|
||||
self._session_state[_INPUT_FILE_KEY] = []
|
||||
if _PROCESSED_FILE_NAMES_KEY in self._context:
|
||||
self._context[_PROCESSED_FILE_NAMES_KEY] = []
|
||||
|
||||
def get_error_count(self, invocation_id: str) -> int:
|
||||
"""Gets the error count from the session state.
|
||||
|
||||
Args:
|
||||
invocation_id: The invocation ID to get the error count for.
|
||||
|
||||
Returns:
|
||||
The error count for the given invocation ID.
|
||||
"""
|
||||
if _ERROR_COUNT_KEY not in self._session_state:
|
||||
return 0
|
||||
return self._session_state[_ERROR_COUNT_KEY].get(invocation_id, 0)
|
||||
|
||||
def increment_error_count(self, invocation_id: str):
|
||||
"""Increments the error count from the session state.
|
||||
|
||||
Args:
|
||||
invocation_id: The invocation ID to increment the error count for.
|
||||
"""
|
||||
if _ERROR_COUNT_KEY not in self._session_state:
|
||||
self._session_state[_ERROR_COUNT_KEY] = {}
|
||||
self._session_state[_ERROR_COUNT_KEY][invocation_id] = (
|
||||
self.get_error_count(invocation_id) + 1
|
||||
)
|
||||
|
||||
def reset_error_count(self, invocation_id: str):
|
||||
"""Resets the error count from the session state.
|
||||
|
||||
Args:
|
||||
invocation_id: The invocation ID to reset the error count for.
|
||||
"""
|
||||
if _ERROR_COUNT_KEY not in self._session_state:
|
||||
return
|
||||
if invocation_id in self._session_state[_ERROR_COUNT_KEY]:
|
||||
del self._session_state[_ERROR_COUNT_KEY][invocation_id]
|
||||
|
||||
def update_code_execution_result(
|
||||
self,
|
||||
invocation_id: str,
|
||||
code: str,
|
||||
result_stdout: str,
|
||||
result_stderr: str,
|
||||
):
|
||||
"""Updates the code execution result.
|
||||
|
||||
Args:
|
||||
invocation_id: The invocation ID to update the code execution result for.
|
||||
code: The code to execute.
|
||||
result_stdout: The standard output of the code execution.
|
||||
result_stderr: The standard error of the code execution.
|
||||
"""
|
||||
if _CODE_EXECUTION_RESULTS_KEY not in self._session_state:
|
||||
self._session_state[_CODE_EXECUTION_RESULTS_KEY] = {}
|
||||
if invocation_id not in self._session_state[_CODE_EXECUTION_RESULTS_KEY]:
|
||||
self._session_state[_CODE_EXECUTION_RESULTS_KEY][invocation_id] = []
|
||||
self._session_state[_CODE_EXECUTION_RESULTS_KEY][invocation_id].append({
|
||||
'code': code,
|
||||
'result_stdout': result_stdout,
|
||||
'result_stderr': result_stderr,
|
||||
'timestamp': int(datetime.datetime.now().timestamp()),
|
||||
})
|
||||
|
||||
def _get_code_executor_context(self, session_state: State) -> dict[str, Any]:
|
||||
"""Gets the code executor context from the session state.
|
||||
|
||||
Args:
|
||||
session_state: The session state to get the code executor context from.
|
||||
|
||||
Returns:
|
||||
A dict of code executor context.
|
||||
"""
|
||||
if _CONTEXT_KEY not in session_state:
|
||||
session_state[_CONTEXT_KEY] = {}
|
||||
return session_state[_CONTEXT_KEY]
|
||||
@@ -0,0 +1,196 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import atexit
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
import docker
|
||||
from docker.client import DockerClient
|
||||
from docker.models.containers import Container
|
||||
from pydantic import Field
|
||||
from typing_extensions import override
|
||||
|
||||
from ..agents.invocation_context import InvocationContext
|
||||
from .base_code_executor import BaseCodeExecutor
|
||||
from .code_execution_utils import CodeExecutionInput
|
||||
from .code_execution_utils import CodeExecutionResult
|
||||
|
||||
|
||||
DEFAULT_IMAGE_TAG = 'adk-code-executor:latest'
|
||||
|
||||
|
||||
class ContainerCodeExecutor(BaseCodeExecutor):
|
||||
"""A code executor that uses a custom container to execute code.
|
||||
|
||||
Attributes:
|
||||
base_url: Optional. The base url of the user hosted Docker client.
|
||||
image: The tag of the predefined image or custom image to run on the
|
||||
container. Either docker_path or image must be set.
|
||||
docker_path: The path to the directory containing the Dockerfile. If set,
|
||||
build the image from the dockerfile path instead of using the predefined
|
||||
image. Either docker_path or image must be set.
|
||||
"""
|
||||
|
||||
base_url: Optional[str] = None
|
||||
"""
|
||||
Optional. The base url of the user hosted Docker client.
|
||||
"""
|
||||
|
||||
image: str = None
|
||||
"""
|
||||
The tag of the predefined image or custom image to run on the container.
|
||||
Either docker_path or image must be set.
|
||||
"""
|
||||
|
||||
docker_path: str = None
|
||||
"""
|
||||
The path to the directory containing the Dockerfile.
|
||||
If set, build the image from the dockerfile path instead of using the
|
||||
predefined image. Either docker_path or image must be set.
|
||||
"""
|
||||
|
||||
# Overrides the BaseCodeExecutor attribute: this executor cannot be stateful.
|
||||
stateful: bool = Field(default=False, frozen=True, exclude=True)
|
||||
|
||||
# Overrides the BaseCodeExecutor attribute: this executor cannot
|
||||
# optimize_data_file.
|
||||
optimize_data_file: bool = Field(default=False, frozen=True, exclude=True)
|
||||
|
||||
_client: DockerClient = None
|
||||
_container: Container = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
base_url: Optional[str] = None,
|
||||
image: Optional[str] = None,
|
||||
docker_path: Optional[str] = None,
|
||||
**data,
|
||||
):
|
||||
"""Initializes the ContainerCodeExecutor.
|
||||
|
||||
Args:
|
||||
base_url: Optional. The base url of the user hosted Docker client.
|
||||
image: The tag of the predefined image or custom image to run on the
|
||||
container. Either docker_path or image must be set.
|
||||
docker_path: The path to the directory containing the Dockerfile. If set,
|
||||
build the image from the dockerfile path instead of using the predefined
|
||||
image. Either docker_path or image must be set.
|
||||
**data: The data to initialize the ContainerCodeExecutor.
|
||||
"""
|
||||
if not image and not docker_path:
|
||||
raise ValueError(
|
||||
'Either image or docker_path must be set for ContainerCodeExecutor.'
|
||||
)
|
||||
if 'stateful' in data and data['stateful']:
|
||||
raise ValueError('Cannot set `stateful=True` in ContainerCodeExecutor.')
|
||||
if 'optimize_data_file' in data and data['optimize_data_file']:
|
||||
raise ValueError(
|
||||
'Cannot set `optimize_data_file=True` in ContainerCodeExecutor.'
|
||||
)
|
||||
|
||||
super().__init__(**data)
|
||||
self.base_url = base_url
|
||||
self.image = image if image else DEFAULT_IMAGE_TAG
|
||||
self.docker_path = os.path.abspath(docker_path) if docker_path else None
|
||||
|
||||
self._client = (
|
||||
docker.from_env()
|
||||
if not self.base_url
|
||||
else docker.DockerClient(base_url=self.base_url)
|
||||
)
|
||||
# Initialize the container.
|
||||
self.__init_container()
|
||||
|
||||
# Close the container when the on exit.
|
||||
atexit.register(self.__cleanup_container)
|
||||
|
||||
@override
|
||||
def execute_code(
|
||||
self,
|
||||
invocation_context: InvocationContext,
|
||||
code_execution_input: CodeExecutionInput,
|
||||
) -> CodeExecutionResult:
|
||||
output = ''
|
||||
error = ''
|
||||
exec_result = self._container.exec_run(
|
||||
['python3', '-c', code_execution_input.code],
|
||||
demux=True,
|
||||
)
|
||||
|
||||
if exec_result.output and exec_result.output[0]:
|
||||
output = exec_result.output[0].decode('utf-8')
|
||||
if (
|
||||
exec_result.output
|
||||
and len(exec_result.output) > 1
|
||||
and exec_result.output[1]
|
||||
):
|
||||
error = exec_result.output[1].decode('utf-8')
|
||||
|
||||
# Collect the final result.
|
||||
return CodeExecutionResult(
|
||||
stdout=output,
|
||||
stderr=error,
|
||||
output_files=[],
|
||||
)
|
||||
|
||||
def _build_docker_image(self):
|
||||
"""Builds the Docker image."""
|
||||
if not self.docker_path:
|
||||
raise ValueError('Docker path is not set.')
|
||||
if not os.path.exists(self.docker_path):
|
||||
raise FileNotFoundError(f'Invalid Docker path: {self.docker_path}')
|
||||
|
||||
print('Building Docker image...')
|
||||
self._client.images.build(
|
||||
path=self.docker_path,
|
||||
tag=self.image,
|
||||
rm=True,
|
||||
)
|
||||
print(f'Docker image: {self.image} built.')
|
||||
|
||||
def _verify_python_installation(self):
|
||||
"""Verifies the container has python3 installed."""
|
||||
exec_result = self._container.exec_run(['which', 'python3'])
|
||||
if exec_result.exit_code != 0:
|
||||
raise ValueError('python3 is not installed in the container.')
|
||||
|
||||
def __init_container(self):
|
||||
"""Initializes the container."""
|
||||
if not self._client:
|
||||
raise RuntimeError('Docker client is not initialized.')
|
||||
|
||||
if self.docker_path:
|
||||
self._build_docker_image()
|
||||
|
||||
print('Starting container for ContainerCodeExecutor...')
|
||||
self._container = self._client.containers.run(
|
||||
image=self.image,
|
||||
detach=True,
|
||||
tty=True,
|
||||
)
|
||||
print(f'Container {self._container.id} started.')
|
||||
|
||||
# Verify the container is able to run python3.
|
||||
self._verify_python_installation()
|
||||
|
||||
def __cleanup_container(self):
|
||||
"""Closes the container on exit."""
|
||||
if not self._container:
|
||||
return
|
||||
|
||||
print('[Cleanup] Stopping the container...')
|
||||
self._container.stop()
|
||||
self._container.remove()
|
||||
print(f'Container {self._container.id} stopped and removed.')
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user