structure saas with tools

This commit is contained in:
Davidson Gomes
2025-04-25 15:30:54 -03:00
commit 1aef473937
16434 changed files with 6584257 additions and 0 deletions

View File

@@ -0,0 +1,14 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

View File

@@ -0,0 +1,20 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from . import _code_execution
from . import _nl_planning
from . import contents
from . import functions
from . import identity
from . import instructions

View File

@@ -0,0 +1,52 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Defines the processor interface used for BaseLlmFlow."""
from __future__ import annotations
from abc import ABC
from abc import abstractmethod
from typing import AsyncGenerator
from typing import TYPE_CHECKING
from ...agents.invocation_context import InvocationContext
from ...events.event import Event
if TYPE_CHECKING:
from ...models.llm_request import LlmRequest
from ...models.llm_response import LlmResponse
class BaseLlmRequestProcessor(ABC):
"""Base class for LLM request processor."""
@abstractmethod
async def run_async(
self, invocation_context: InvocationContext, llm_request: LlmRequest
) -> AsyncGenerator[Event, None]:
"""Runs the processor."""
raise NotImplementedError("Not implemented.")
yield # AsyncGenerator requires a yield in function body.
class BaseLlmResponseProcessor(ABC):
"""Base class for LLM response processor."""
@abstractmethod
async def run_async(
self, invocation_context: InvocationContext, llm_response: LlmResponse
) -> AsyncGenerator[Event, None]:
"""Processes the LLM response."""
raise NotImplementedError("Not implemented.")
yield # AsyncGenerator requires a yield in function body.

View File

@@ -0,0 +1,458 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Handles Code Execution related logic."""
from __future__ import annotations
import base64
import copy
import dataclasses
import os
import re
from typing import AsyncGenerator
from typing import Generator
from typing import Optional
from typing import TYPE_CHECKING
from google.genai import types
from typing_extensions import override
from ...agents.invocation_context import InvocationContext
from ...code_executors.base_code_executor import BaseCodeExecutor
from ...code_executors.code_execution_utils import CodeExecutionInput
from ...code_executors.code_execution_utils import CodeExecutionResult
from ...code_executors.code_execution_utils import CodeExecutionUtils
from ...code_executors.code_execution_utils import File
from ...code_executors.code_executor_context import CodeExecutorContext
from ...events.event import Event
from ...events.event_actions import EventActions
from ...models.llm_response import LlmResponse
from ._base_llm_processor import BaseLlmRequestProcessor
from ._base_llm_processor import BaseLlmResponseProcessor
if TYPE_CHECKING:
from ...models.llm_request import LlmRequest
@dataclasses.dataclass
class DataFileUtil:
"""A structure that contains a data file name and its content."""
extension: str
"""
The file extension (e.g., ".csv").
"""
loader_code_template: str
"""
The code template to load the data file.
"""
_DATA_FILE_UTIL_MAP = {
'text/csv': DataFileUtil(
extension='.csv',
loader_code_template="pd.read_csv('{filename}')",
),
}
_DATA_FILE_HELPER_LIB = '''
import pandas as pd
def explore_df(df: pd.DataFrame) -> None:
"""Prints some information about a pandas DataFrame."""
with pd.option_context(
'display.max_columns', None, 'display.expand_frame_repr', False
):
# Print the column names to never encounter KeyError when selecting one.
df_dtypes = df.dtypes
# Obtain information about data types and missing values.
df_nulls = (len(df) - df.isnull().sum()).apply(
lambda x: f'{x} / {df.shape[0]} non-null'
)
# Explore unique total values in columns using `.unique()`.
df_unique_count = df.apply(lambda x: len(x.unique()))
# Explore unique values in columns using `.unique()`.
df_unique = df.apply(lambda x: crop(str(list(x.unique()))))
df_info = pd.concat(
(
df_dtypes.rename('Dtype'),
df_nulls.rename('Non-Null Count'),
df_unique_count.rename('Unique Values Count'),
df_unique.rename('Unique Values'),
),
axis=1,
)
df_info.index.name = 'Columns'
print(f"""Total rows: {df.shape[0]}
Total columns: {df.shape[1]}
{df_info}""")
'''
class _CodeExecutionRequestProcessor(BaseLlmRequestProcessor):
"""Processes code execution requests."""
@override
async def run_async(
self, invocation_context: InvocationContext, llm_request: LlmRequest
) -> AsyncGenerator[Event, None]:
from ...agents.llm_agent import LlmAgent
if not isinstance(invocation_context.agent, LlmAgent):
return
if not invocation_context.agent.code_executor:
return
for event in _run_pre_processor(invocation_context, llm_request):
yield event
# Convert the code execution parts to text parts.
if not isinstance(invocation_context.agent.code_executor, BaseCodeExecutor):
return
for content in llm_request.contents:
CodeExecutionUtils.convert_code_execution_parts(
content,
invocation_context.agent.code_executor.code_block_delimiters[0]
if invocation_context.agent.code_executor.code_block_delimiters
else ('', ''),
invocation_context.agent.code_executor.execution_result_delimiters,
)
request_processor = _CodeExecutionRequestProcessor()
class _CodeExecutionResponseProcessor(BaseLlmResponseProcessor):
"""Processes code execution responses."""
@override
async def run_async(
self, invocation_context: InvocationContext, llm_response: LlmResponse
) -> AsyncGenerator[Event, None]:
# Skip if the response is partial (streaming).
if llm_response.partial:
return
for event in _run_post_processor(invocation_context, llm_response):
yield event
response_processor = _CodeExecutionResponseProcessor()
def _run_pre_processor(
invocation_context: InvocationContext,
llm_request: LlmRequest,
) -> Generator[Event, None, None]:
"""Pre-process the user message by adding the user message to the Colab notebook."""
from ...agents.llm_agent import LlmAgent
if not isinstance(invocation_context.agent, LlmAgent):
return
agent = invocation_context.agent
code_executor = agent.code_executor
if not code_executor or not isinstance(code_executor, BaseCodeExecutor):
return
if not code_executor.optimize_data_file:
return
code_executor_context = CodeExecutorContext(invocation_context.session.state)
# Skip if the error count exceeds the max retry attempts.
if (
code_executor_context.get_error_count(invocation_context.invocation_id)
>= code_executor.error_retry_attempts
):
return
# [Step 1] Extract data files from the session_history and store them in
# memory. Meanwhile, mutate the inline data file to text part in session
# history from all turns.
all_input_files = _extrac_and_replace_inline_files(
code_executor_context, llm_request
)
# [Step 2] Run Explore_Df code on the data files from the current turn. We
# only need to explore the new data files because the previous data files
# should already be explored and cached in the code execution runtime.
processed_file_names = set(code_executor_context.get_processed_file_names())
files_to_process = [
f for f in all_input_files if f.name not in processed_file_names
]
for file in files_to_process:
code_str = _get_data_file_preprocessing_code(file)
# Skip for unsupported file or executor types.
if not code_str:
return
# Emit the code to execute, and add it to the LLM request.
code_content = types.Content(
role='model',
parts=[
types.Part(text=f'Processing input file: `{file.name}`'),
CodeExecutionUtils.build_executable_code_part(code_str),
],
)
llm_request.contents.append(copy.deepcopy(code_content))
yield Event(
invocation_id=invocation_context.invocation_id,
author=agent.name,
branch=invocation_context.branch,
content=code_content,
)
code_execution_result = code_executor.execute_code(
invocation_context,
CodeExecutionInput(
code=code_str,
input_files=[file],
execution_id=_get_or_set_execution_id(
invocation_context, code_executor_context
),
),
)
# Update the processing results to code executor context.
code_executor_context.update_code_execution_result(
invocation_context.invocation_id,
code_str,
code_execution_result.stdout,
code_execution_result.stderr,
)
code_executor_context.add_processed_file_names([file.name])
# Emit the execution result, and add it to the LLM request.
execution_result_event = _post_process_code_execution_result(
invocation_context, code_executor_context, code_execution_result
)
yield execution_result_event
llm_request.contents.append(copy.deepcopy(execution_result_event.content))
def _run_post_processor(
invocation_context: InvocationContext,
llm_response,
) -> Generator[Event, None, None]:
"""Post-process the model response by extracting and executing the first code block."""
agent = invocation_context.agent
code_executor = agent.code_executor
if not code_executor or not isinstance(code_executor, BaseCodeExecutor):
return
if not llm_response or not llm_response.content:
return
code_executor_context = CodeExecutorContext(invocation_context.session.state)
# Skip if the error count exceeds the max retry attempts.
if (
code_executor_context.get_error_count(invocation_context.invocation_id)
>= code_executor.error_retry_attempts
):
return
# [Step 1] Extract code from the model predict response and truncate the
# content to the part with the first code block.
response_content = llm_response.content
code_str = CodeExecutionUtils.extract_code_and_truncate_content(
response_content, code_executor.code_block_delimiters
)
# Terminal state: no code to execute.
if not code_str:
return
# [Step 2] Executes the code and emit 2 Events for code and execution result.
yield Event(
invocation_id=invocation_context.invocation_id,
author=agent.name,
branch=invocation_context.branch,
content=response_content,
actions=EventActions(),
)
code_execution_result = code_executor.execute_code(
invocation_context,
CodeExecutionInput(
code=code_str,
input_files=code_executor_context.get_input_files(),
execution_id=_get_or_set_execution_id(
invocation_context, code_executor_context
),
),
)
code_executor_context.update_code_execution_result(
invocation_context.invocation_id,
code_str,
code_execution_result.stdout,
code_execution_result.stderr,
)
yield _post_process_code_execution_result(
invocation_context, code_executor_context, code_execution_result
)
# [Step 3] Skip processing the original model response
# to continue code generation loop.
llm_response.content = None
def _extrac_and_replace_inline_files(
code_executor_context: CodeExecutorContext,
llm_request: LlmRequest,
) -> list[File]:
"""Extracts and replaces inline files with file names in the LLM request."""
all_input_files = code_executor_context.get_input_files()
saved_file_names = set(f.name for f in all_input_files)
# [Step 1] Process input files from LlmRequest and cache them in CodeExecutor.
for i in range(len(llm_request.contents)):
content = llm_request.contents[i]
# Only process the user message.
if content.role != 'user' and not content.parts:
continue
for j in range(len(content.parts)):
part = content.parts[j]
# Skip if the inline data is not supported.
if (
not part.inline_data
or part.inline_data.mime_type not in _DATA_FILE_UTIL_MAP
):
continue
# Replace the inline data file with a file name placeholder.
mime_type = part.inline_data.mime_type
file_name = f'data_{i+1}_{j+1}' + _DATA_FILE_UTIL_MAP[mime_type].extension
llm_request.contents[i].parts[j] = types.Part(
text='\nAvailable file: `%s`\n' % file_name
)
# Add the inlne data as input file to the code executor context.
file = File(
name=file_name,
content=CodeExecutionUtils.get_encoded_file_content(
part.inline_data.data
).decode(),
mime_type=mime_type,
)
if file_name not in saved_file_names:
code_executor_context.add_input_files([file])
all_input_files.append(file)
return all_input_files
def _get_or_set_execution_id(
invocation_context: InvocationContext,
code_executor_context: CodeExecutorContext,
) -> Optional[str]:
"""Returns the ID for stateful code execution or None if not stateful."""
if not invocation_context.agent.code_executor.stateful:
return None
execution_id = code_executor_context.get_execution_id()
if not execution_id:
execution_id = invocation_context.session.id
code_executor_context.set_execution_id(execution_id)
return execution_id
def _post_process_code_execution_result(
invocation_context: InvocationContext,
code_executor_context: CodeExecutorContext,
code_execution_result: CodeExecutionResult,
) -> Event:
"""Post-process the code execution result and emit an Event."""
if invocation_context.artifact_service is None:
raise ValueError('Artifact service is not initialized.')
result_content = types.Content(
role='model',
parts=[
CodeExecutionUtils.build_code_execution_result_part(
code_execution_result
),
],
)
event_actions = EventActions(
state_delta=code_executor_context.get_state_delta()
)
# Handle code execution error retry.
if code_execution_result.stderr:
code_executor_context.increment_error_count(
invocation_context.invocation_id
)
else:
code_executor_context.reset_error_count(invocation_context.invocation_id)
# Handle output files.
for output_file in code_execution_result.output_files:
version = invocation_context.artifact_service.save_artifact(
app_name=invocation_context.app_name,
user_id=invocation_context.user_id,
session_id=invocation_context.session.id,
filename=output_file.name,
artifact=types.Part.from_bytes(
data=base64.b64decode(output_file.content),
mime_type=output_file.mime_type,
),
)
event_actions.artifact_delta[output_file.name] = version
return Event(
invocation_id=invocation_context.invocation_id,
author=invocation_context.agent.name,
branch=invocation_context.branch,
content=result_content,
actions=event_actions,
)
def _get_data_file_preprocessing_code(file: File) -> Optional[str]:
"""Returns the code to explore the data file."""
def _get_normalized_file_name(file_name: str) -> str:
var_name, _ = os.path.splitext(file_name)
# Replace non-alphanumeric characters with underscores
var_name = re.sub(r'[^a-zA-Z0-9_]', '_', var_name)
# If the filename starts with a digit, prepend an underscore
if var_name[0].isdigit():
var_name = '_' + var_name
return var_name
if file.mime_type not in _DATA_FILE_UTIL_MAP:
return
var_name = _get_normalized_file_name(file.name)
loader_code = _DATA_FILE_UTIL_MAP[file.mime_type].loader_code_template.format(
filename=file.name
)
return f"""
{_DATA_FILE_HELPER_LIB}
# Load the dataframe.
{var_name} = {loader_code}
# Use `explore_df` to guide my analysis.
explore_df({var_name})
"""

View File

@@ -0,0 +1,135 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Handles NL planning related logic."""
from __future__ import annotations
from typing import AsyncGenerator
from typing import Generator
from typing import Optional
from typing import TYPE_CHECKING
from typing_extensions import override
from ...agents.callback_context import CallbackContext
from ...agents.invocation_context import InvocationContext
from ...agents.readonly_context import ReadonlyContext
from ...events.event import Event
from ...planners.plan_re_act_planner import PlanReActPlanner
from ._base_llm_processor import BaseLlmRequestProcessor
from ._base_llm_processor import BaseLlmResponseProcessor
if TYPE_CHECKING:
from ...models.llm_request import LlmRequest
from ...models.llm_response import LlmResponse
from ...planners.base_planner import BasePlanner
from ...planners.built_in_planner import BuiltInPlanner
class _NlPlanningRequestProcessor(BaseLlmRequestProcessor):
"""Processor for NL planning."""
async def run_async(
self, invocation_context: InvocationContext, llm_request: LlmRequest
) -> AsyncGenerator[Event, None]:
from ...planners.built_in_planner import BuiltInPlanner
planner = _get_planner(invocation_context)
if not planner:
return
if isinstance(planner, BuiltInPlanner):
planner.apply_thinking_config(llm_request)
planning_instruction = planner.build_planning_instruction(
ReadonlyContext(invocation_context), llm_request
)
if planning_instruction:
llm_request.append_instructions([planning_instruction])
_remove_thought_from_request(llm_request)
# Maintain async generator behavior
if False: # Ensures it behaves as a generator
yield # This is a no-op but maintains generator structure
request_processor = _NlPlanningRequestProcessor()
class _NlPlanningResponse(BaseLlmResponseProcessor):
@override
async def run_async(
self, invocation_context: InvocationContext, llm_response: LlmResponse
) -> AsyncGenerator[Event, None]:
if (
not llm_response
or not llm_response.content
or not llm_response.content.parts
):
return
planner = _get_planner(invocation_context)
if not planner:
return
# Postprocess the LLM response.
callback_context = CallbackContext(invocation_context)
processed_parts = planner.process_planning_response(
callback_context, llm_response.content.parts
)
if processed_parts:
llm_response.content.parts = processed_parts
if callback_context.state.has_delta():
state_update_event = Event(
invocation_id=invocation_context.invocation_id,
author=invocation_context.agent.name,
branch=invocation_context.branch,
actions=callback_context._event_actions,
)
yield state_update_event
response_processor = _NlPlanningResponse()
def _get_planner(
invocation_context: InvocationContext,
) -> Optional[BasePlanner]:
from ...agents.llm_agent import Agent
from ...planners.base_planner import BasePlanner
agent = invocation_context.agent
if not isinstance(agent, Agent):
return None
if not agent.planner:
return None
if isinstance(agent.planner, BasePlanner):
return agent.planner
return PlanReActPlanner()
def _remove_thought_from_request(llm_request: LlmRequest):
if not llm_request.contents:
return
for content in llm_request.contents:
if not content.parts:
continue
for part in content.parts:
part.thought = None

View File

@@ -0,0 +1,132 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Handles agent transfer for LLM flow."""
from __future__ import annotations
import typing
from typing import AsyncGenerator
from typing_extensions import override
from ...agents.invocation_context import InvocationContext
from ...events.event import Event
from ...models.llm_request import LlmRequest
from ...tools.function_tool import FunctionTool
from ...tools.tool_context import ToolContext
from ...tools.transfer_to_agent_tool import transfer_to_agent
from ._base_llm_processor import BaseLlmRequestProcessor
if typing.TYPE_CHECKING:
from ...agents import BaseAgent
from ...agents import LlmAgent
class _AgentTransferLlmRequestProcessor(BaseLlmRequestProcessor):
"""Agent transfer request processor."""
@override
async def run_async(
self, invocation_context: InvocationContext, llm_request: LlmRequest
) -> AsyncGenerator[Event, None]:
from ...agents.llm_agent import LlmAgent
if not isinstance(invocation_context.agent, LlmAgent):
return
transfer_targets = _get_transfer_targets(invocation_context.agent)
if not transfer_targets:
return
llm_request.append_instructions([
_build_target_agents_instructions(
invocation_context.agent, transfer_targets
)
])
transfer_to_agent_tool = FunctionTool(func=transfer_to_agent)
tool_context = ToolContext(invocation_context)
await transfer_to_agent_tool.process_llm_request(
tool_context=tool_context, llm_request=llm_request
)
return
yield # AsyncGenerator requires yield statement in function body.
request_processor = _AgentTransferLlmRequestProcessor()
def _build_target_agents_info(target_agent: BaseAgent) -> str:
return f"""
Agent name: {target_agent.name}
Agent description: {target_agent.description}
"""
line_break = '\n'
def _build_target_agents_instructions(
agent: LlmAgent, target_agents: list[BaseAgent]
) -> str:
si = f"""
You have a list of other agents to transfer to:
{line_break.join([
_build_target_agents_info(target_agent) for target_agent in target_agents
])}
If you are the best to answer the question according to your description, you
can answer it.
If another agent is better for answering the question according to its
description, call `{_TRANSFER_TO_AGENT_FUNCTION_NAME}` function to transfer the
question to that agent. When transferring, do not generate any text other than
the function call.
"""
if agent.parent_agent:
si += f"""
Your parent agent is {agent.parent_agent.name}. If neither the other agents nor
you are best for answering the question according to the descriptions, transfer
to your parent agent. If you don't have parent agent, try answer by yourself.
"""
return si
_TRANSFER_TO_AGENT_FUNCTION_NAME = transfer_to_agent.__name__
def _get_transfer_targets(agent: LlmAgent) -> list[BaseAgent]:
from ...agents.llm_agent import LlmAgent
result = []
result.extend(agent.sub_agents)
if not agent.parent_agent or not isinstance(agent.parent_agent, LlmAgent):
return result
if not agent.disallow_transfer_to_parent:
result.append(agent.parent_agent)
if not agent.disallow_transfer_to_peers:
result.extend([
peer_agent
for peer_agent in agent.parent_agent.sub_agents
if peer_agent.name != agent.name
])
return result

View File

@@ -0,0 +1,109 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from typing import TYPE_CHECKING
from google.cloud import speech
from google.genai import types as genai_types
if TYPE_CHECKING:
from ...agents.invocation_context import InvocationContext
class AudioTranscriber:
"""Transcribes audio using Google Cloud Speech-to-Text."""
def __init__(self):
self.client = speech.SpeechClient()
def transcribe_file(
self, invocation_context: InvocationContext
) -> list[genai_types.Content]:
"""Transcribe audio, bundling consecutive segments from the same speaker.
The ordering of speakers will be preserved. Audio blobs will be merged for
the same speaker as much as we can do reduce the transcription latency.
Args:
invocation_context: The invocation context to access the transcription
cache.
Returns:
A list of Content objects containing the transcribed text.
"""
bundled_audio = []
current_speaker = None
current_audio_data = b''
contents = []
# Step1: merge audio blobs
for transcription_entry in invocation_context.transcription_cache or []:
speaker, audio_data = (
transcription_entry.role,
transcription_entry.data,
)
if isinstance(audio_data, genai_types.Content):
if current_speaker is not None:
bundled_audio.append((current_speaker, current_audio_data))
current_speaker = None
current_audio_data = b''
bundled_audio.append((speaker, audio_data))
continue
if not audio_data.data:
continue
if speaker == current_speaker:
current_audio_data += audio_data.data
else:
if current_speaker is not None:
bundled_audio.append((current_speaker, current_audio_data))
current_speaker = speaker
current_audio_data = audio_data.data
# Append the last audio segment if any
if current_speaker is not None:
bundled_audio.append((current_speaker, current_audio_data))
# reset cache
invocation_context.transcription_cache = []
# Step2: transcription
for speaker, data in bundled_audio:
if speaker == 'user':
audio = speech.RecognitionAudio(content=data)
config = speech.RecognitionConfig(
encoding=speech.RecognitionConfig.AudioEncoding.LINEAR16,
sample_rate_hertz=16000,
language_code='en-US',
)
response = self.client.recognize(config=config, audio=audio)
for result in response.results:
transcript = result.alternatives[0].transcript
parts = [genai_types.Part(text=transcript)]
role = speaker.lower()
content = genai_types.Content(role=role, parts=parts)
contents.append(content)
else:
# don't need to transcribe model which are already text
contents.append(data)
return contents

View File

@@ -0,0 +1,49 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Implementation of AutoFlow."""
from . import agent_transfer
from .single_flow import SingleFlow
class AutoFlow(SingleFlow):
"""AutoFlow is SingleFlow with agent transfer capability.
Agent transfer is allowed in the following direction:
1. from parent to sub-agent;
2. from sub-agent to parent;
3. from sub-agent to its peer agents;
For peer-agent transfers, it's only enabled when all below conditions are met:
- The parent agent is also of AutoFlow;
- `disallow_transfer_to_peer` option of this agent is False (default).
Depending on the target agent flow type, the transfer may be automatically
reversed. The condition is as below:
- If the flow type of the tranferee agent is also auto, transfee agent will
remain as the active agent. The transfee agent will respond to the user's
next message directly.
- If the flow type of the transfere agent is not auto, the active agent will
be reversed back to previous agent.
TODO: allow user to config auto-reverse function.
"""
def __init__(self):
super().__init__()
self.request_processors += [agent_transfer.request_processor]

View File

@@ -0,0 +1,559 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from abc import ABC
import asyncio
import logging
from typing import AsyncGenerator
from typing import cast
from typing import Optional
from typing import TYPE_CHECKING
from websockets.exceptions import ConnectionClosedOK
from ...agents.base_agent import BaseAgent
from ...agents.callback_context import CallbackContext
from ...agents.invocation_context import InvocationContext
from ...agents.live_request_queue import LiveRequestQueue
from ...agents.run_config import StreamingMode
from ...agents.transcription_entry import TranscriptionEntry
from ...events.event import Event
from ...models.base_llm_connection import BaseLlmConnection
from ...models.llm_request import LlmRequest
from ...models.llm_response import LlmResponse
from ...telemetry import trace_call_llm
from ...telemetry import trace_send_data
from ...telemetry import tracer
from ...tools.tool_context import ToolContext
from . import functions
if TYPE_CHECKING:
from ...agents.llm_agent import LlmAgent
from ...models.base_llm import BaseLlm
from ._base_llm_processor import BaseLlmRequestProcessor
from ._base_llm_processor import BaseLlmResponseProcessor
logger = logging.getLogger(__name__)
class BaseLlmFlow(ABC):
"""A basic flow that calls the LLM in a loop until a final response is generated.
This flow ends when it transfer to another agent.
"""
def __init__(self):
self.request_processors: list[BaseLlmRequestProcessor] = []
self.response_processors: list[BaseLlmResponseProcessor] = []
async def run_live(
self,
invocation_context: InvocationContext,
) -> AsyncGenerator[Event, None]:
"""Runs the flow using live api."""
llm_request = LlmRequest()
event_id = Event.new_id()
# Preprocess before calling the LLM.
async for event in self._preprocess_async(invocation_context, llm_request):
yield event
if invocation_context.end_invocation:
return
llm = self.__get_llm(invocation_context)
logger.debug(
'Establishing live connection for agent: %s with llm request: %s',
invocation_context.agent.name,
llm_request,
)
async with llm.connect(llm_request) as llm_connection:
if llm_request.contents:
# Sends the conversation history to the model.
with tracer.start_as_current_span('send_data'):
if invocation_context.transcription_cache:
from . import audio_transcriber
audio_transcriber = audio_transcriber.AudioTranscriber()
contents = audio_transcriber.transcribe_file(invocation_context)
logger.debug('Sending history to model: %s', contents)
await llm_connection.send_history(contents)
invocation_context.transcription_cache = None
trace_send_data(invocation_context, event_id, contents)
else:
await llm_connection.send_history(llm_request.contents)
trace_send_data(invocation_context, event_id, llm_request.contents)
send_task = asyncio.create_task(
self._send_to_model(llm_connection, invocation_context)
)
try:
async for event in self._receive_from_model(
llm_connection,
event_id,
invocation_context,
llm_request,
):
# Empty event means the queue is closed.
if not event:
break
logger.debug('Receive new event: %s', event)
yield event
# send back the function response
if event.get_function_responses():
logger.debug('Sending back last function response event: %s', event)
invocation_context.live_request_queue.send_content(event.content)
if (
event.content
and event.content.parts
and event.content.parts[0].function_response
and event.content.parts[0].function_response.name
== 'transfer_to_agent'
):
await asyncio.sleep(1)
# cancel the tasks that belongs to the closed connection.
send_task.cancel()
await llm_connection.close()
finally:
# Clean up
if not send_task.done():
send_task.cancel()
try:
await send_task
except asyncio.CancelledError:
pass
async def _send_to_model(
self,
llm_connection: BaseLlmConnection,
invocation_context: InvocationContext,
):
"""Sends data to model."""
while True:
live_request_queue = invocation_context.live_request_queue
try:
# Streamlit's execution model doesn't preemptively yield to the event
# loop. Therefore, we must explicitly introduce timeouts to allow the
# event loop to process events.
# TODO: revert back(remove timeout) once we move off streamlit.
live_request = await asyncio.wait_for(
live_request_queue.get(), timeout=0.25
)
# duplicate the live_request to all the active streams
logger.debug(
'Sending live request %s to active streams: %s',
live_request,
invocation_context.active_streaming_tools,
)
if invocation_context.active_streaming_tools:
for active_streaming_tool in (
invocation_context.active_streaming_tools
).values():
if active_streaming_tool.stream:
active_streaming_tool.stream.send(live_request)
await asyncio.sleep(0)
except asyncio.TimeoutError:
continue
if live_request.close:
await llm_connection.close()
return
if live_request.blob:
# Cache audio data here for transcription
if not invocation_context.transcription_cache:
invocation_context.transcription_cache = []
invocation_context.transcription_cache.append(
TranscriptionEntry(role='user', data=live_request.blob)
)
await llm_connection.send_realtime(live_request.blob)
if live_request.content:
await llm_connection.send_content(live_request.content)
async def _receive_from_model(
self,
llm_connection: BaseLlmConnection,
event_id: str,
invocation_context: InvocationContext,
llm_request: LlmRequest,
) -> AsyncGenerator[Event, None]:
"""Receive data from model and process events using BaseLlmConnection."""
assert invocation_context.live_request_queue
try:
while True:
async for llm_response in llm_connection.receive():
model_response_event = Event(
id=Event.new_id(),
invocation_id=invocation_context.invocation_id,
author=invocation_context.agent.name,
)
async for event in self._postprocess_live(
invocation_context,
llm_request,
llm_response,
model_response_event,
):
if (
event.content
and event.content.parts
and event.content.parts[0].text
and not event.partial
):
if not invocation_context.transcription_cache:
invocation_context.transcription_cache = []
invocation_context.transcription_cache.append(
TranscriptionEntry(role='model', data=event.content)
)
yield event
# Give opportunity for other tasks to run.
await asyncio.sleep(0)
except ConnectionClosedOK:
pass
async def run_async(
self, invocation_context: InvocationContext
) -> AsyncGenerator[Event, None]:
"""Runs the flow."""
while True:
last_event = None
async for event in self._run_one_step_async(invocation_context):
last_event = event
yield event
if not last_event or last_event.is_final_response():
break
async def _run_one_step_async(
self,
invocation_context: InvocationContext,
) -> AsyncGenerator[Event, None]:
"""One step means one LLM call."""
llm_request = LlmRequest()
# Preprocess before calling the LLM.
async for event in self._preprocess_async(invocation_context, llm_request):
yield event
if invocation_context.end_invocation:
return
# Calls the LLM.
model_response_event = Event(
id=Event.new_id(),
invocation_id=invocation_context.invocation_id,
author=invocation_context.agent.name,
branch=invocation_context.branch,
)
async for llm_response in self._call_llm_async(
invocation_context, llm_request, model_response_event
):
# Postprocess after calling the LLM.
async for event in self._postprocess_async(
invocation_context, llm_request, llm_response, model_response_event
):
yield event
async def _preprocess_async(
self, invocation_context: InvocationContext, llm_request: LlmRequest
) -> AsyncGenerator[Event, None]:
from ...agents.llm_agent import LlmAgent
agent = invocation_context.agent
if not isinstance(agent, LlmAgent):
return
# Runs processors.
for processor in self.request_processors:
async for event in processor.run_async(invocation_context, llm_request):
yield event
# Run processors for tools.
for tool in agent.canonical_tools:
tool_context = ToolContext(invocation_context)
await tool.process_llm_request(
tool_context=tool_context, llm_request=llm_request
)
async def _postprocess_async(
self,
invocation_context: InvocationContext,
llm_request: LlmRequest,
llm_response: LlmResponse,
model_response_event: Event,
) -> AsyncGenerator[Event, None]:
"""Postprocess after calling the LLM.
Args:
invocation_context: The invocation context.
llm_request: The original LLM request.
llm_response: The LLM response from the LLM call.
model_response_event: A mutable event for the LLM response.
Yields:
A generator of events.
"""
# Runs processors.
async for event in self._postprocess_run_processors_async(
invocation_context, llm_response
):
yield event
# Skip the model response event if there is no content and no error code.
# This is needed for the code executor to trigger another loop.
if (
not llm_response.content
and not llm_response.error_code
and not llm_response.interrupted
):
return
# Builds the event.
model_response_event = self._finalize_model_response_event(
llm_request, llm_response, model_response_event
)
yield model_response_event
# Handles function calls.
if model_response_event.get_function_calls():
async for event in self._postprocess_handle_function_calls_async(
invocation_context, model_response_event, llm_request
):
yield event
async def _postprocess_live(
self,
invocation_context: InvocationContext,
llm_request: LlmRequest,
llm_response: LlmResponse,
model_response_event: Event,
) -> AsyncGenerator[Event, None]:
"""Postprocess after calling the LLM asynchronously.
Args:
invocation_context: The invocation context.
llm_request: The original LLM request.
llm_response: The LLM response from the LLM call.
model_response_event: A mutable event for the LLM response.
Yields:
A generator of events.
"""
# Runs processors.
async for event in self._postprocess_run_processors_async(
invocation_context, llm_response
):
yield event
# Skip the model response event if there is no content and no error code.
# This is needed for the code executor to trigger another loop.
# But don't skip control events like turn_complete.
if (
not llm_response.content
and not llm_response.error_code
and not llm_response.interrupted
and not llm_response.turn_complete
):
return
# Builds the event.
model_response_event = self._finalize_model_response_event(
llm_request, llm_response, model_response_event
)
yield model_response_event
# Handles function calls.
if model_response_event.get_function_calls():
function_response_event = await functions.handle_function_calls_live(
invocation_context, model_response_event, llm_request.tools_dict
)
yield function_response_event
transfer_to_agent = function_response_event.actions.transfer_to_agent
if transfer_to_agent:
agent_to_run = self._get_agent_to_run(
invocation_context, transfer_to_agent
)
async for item in agent_to_run.run_live(invocation_context):
yield item
async def _postprocess_run_processors_async(
self, invocation_context: InvocationContext, llm_response: LlmResponse
) -> AsyncGenerator[Event, None]:
for processor in self.response_processors:
async for event in processor.run_async(invocation_context, llm_response):
yield event
async def _postprocess_handle_function_calls_async(
self,
invocation_context: InvocationContext,
function_call_event: Event,
llm_request: LlmRequest,
) -> AsyncGenerator[Event, None]:
if function_response_event := await functions.handle_function_calls_async(
invocation_context, function_call_event, llm_request.tools_dict
):
auth_event = functions.generate_auth_event(
invocation_context, function_response_event
)
if auth_event:
yield auth_event
yield function_response_event
transfer_to_agent = function_response_event.actions.transfer_to_agent
if transfer_to_agent:
agent_to_run = self._get_agent_to_run(
invocation_context, transfer_to_agent
)
async for event in agent_to_run.run_async(invocation_context):
yield event
def _get_agent_to_run(
self, invocation_context: InvocationContext, transfer_to_agent
) -> BaseAgent:
root_agent = invocation_context.agent.root_agent
agent_to_run = root_agent.find_agent(transfer_to_agent)
if not agent_to_run:
raise ValueError(
f'Agent {transfer_to_agent} not found in the agent tree.'
)
return agent_to_run
async def _call_llm_async(
self,
invocation_context: InvocationContext,
llm_request: LlmRequest,
model_response_event: Event,
) -> AsyncGenerator[LlmResponse, None]:
# Runs before_model_callback if it exists.
if response := self._handle_before_model_callback(
invocation_context, llm_request, model_response_event
):
yield response
return
# Calls the LLM.
llm = self.__get_llm(invocation_context)
with tracer.start_as_current_span('call_llm'):
if invocation_context.run_config.support_cfc:
invocation_context.live_request_queue = LiveRequestQueue()
async for llm_response in self.run_live(invocation_context):
# Runs after_model_callback if it exists.
if altered_llm_response := self._handle_after_model_callback(
invocation_context, llm_response, model_response_event
):
llm_response = altered_llm_response
# only yield partial response in SSE streaming mode
if (
invocation_context.run_config.streaming_mode == StreamingMode.SSE
or not llm_response.partial
):
yield llm_response
if llm_response.turn_complete:
invocation_context.live_request_queue.close()
else:
# Check if we can make this llm call or not. If the current call pushes
# the counter beyond the max set value, then the execution is stopped
# right here, and exception is thrown.
invocation_context.increment_llm_call_count()
async for llm_response in llm.generate_content_async(
llm_request,
stream=invocation_context.run_config.streaming_mode
== StreamingMode.SSE,
):
trace_call_llm(
invocation_context,
model_response_event.id,
llm_request,
llm_response,
)
# Runs after_model_callback if it exists.
if altered_llm_response := self._handle_after_model_callback(
invocation_context, llm_response, model_response_event
):
llm_response = altered_llm_response
yield llm_response
def _handle_before_model_callback(
self,
invocation_context: InvocationContext,
llm_request: LlmRequest,
model_response_event: Event,
) -> Optional[LlmResponse]:
from ...agents.llm_agent import LlmAgent
agent = invocation_context.agent
if not isinstance(agent, LlmAgent):
return
if not agent.before_model_callback:
return
callback_context = CallbackContext(
invocation_context, event_actions=model_response_event.actions
)
return agent.before_model_callback(
callback_context=callback_context, llm_request=llm_request
)
def _handle_after_model_callback(
self,
invocation_context: InvocationContext,
llm_response: LlmResponse,
model_response_event: Event,
) -> Optional[LlmResponse]:
from ...agents.llm_agent import LlmAgent
agent = invocation_context.agent
if not isinstance(agent, LlmAgent):
return
if not agent.after_model_callback:
return
callback_context = CallbackContext(
invocation_context, event_actions=model_response_event.actions
)
return agent.after_model_callback(
callback_context=callback_context, llm_response=llm_response
)
def _finalize_model_response_event(
self,
llm_request: LlmRequest,
llm_response: LlmResponse,
model_response_event: Event,
) -> Event:
model_response_event = Event.model_validate({
**model_response_event.model_dump(exclude_none=True),
**llm_response.model_dump(exclude_none=True),
})
if model_response_event.content:
function_calls = model_response_event.get_function_calls()
if function_calls:
functions.populate_client_function_call_id(model_response_event)
model_response_event.long_running_tool_ids = (
functions.get_long_running_function_calls(
function_calls, llm_request.tools_dict
)
)
return model_response_event
def __get_llm(self, invocation_context: InvocationContext) -> BaseLlm:
from ...agents.llm_agent import LlmAgent
return cast(LlmAgent, invocation_context.agent).canonical_model

View File

@@ -0,0 +1,72 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Handles basic information to build the LLM request."""
from __future__ import annotations
from typing import AsyncGenerator
from typing import Generator
from google.genai import types
from typing_extensions import override
from ...agents.invocation_context import InvocationContext
from ...events.event import Event
from ...models.llm_request import LlmRequest
from ._base_llm_processor import BaseLlmRequestProcessor
class _BasicLlmRequestProcessor(BaseLlmRequestProcessor):
@override
async def run_async(
self, invocation_context: InvocationContext, llm_request: LlmRequest
) -> AsyncGenerator[Event, None]:
from ...agents.llm_agent import LlmAgent
agent = invocation_context.agent
if not isinstance(agent, LlmAgent):
return
llm_request.model = (
agent.canonical_model
if isinstance(agent.canonical_model, str)
else agent.canonical_model.model
)
llm_request.config = (
agent.generate_content_config.model_copy(deep=True)
if agent.generate_content_config
else types.GenerateContentConfig()
)
if agent.output_schema:
llm_request.set_output_schema(agent.output_schema)
llm_request.live_connect_config.response_modalities = (
invocation_context.run_config.response_modalities
)
llm_request.live_connect_config.speech_config = (
invocation_context.run_config.speech_config
)
llm_request.live_connect_config.output_audio_transcription = (
invocation_context.run_config.output_audio_transcription
)
# TODO: handle tool append here, instead of in BaseTool.process_llm_request.
return
yield # Generator requires yield statement in function body.
request_processor = _BasicLlmRequestProcessor()

View File

@@ -0,0 +1,390 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import copy
from typing import AsyncGenerator
from typing import Generator
from typing import Optional
from google.genai import types
from typing_extensions import override
from ...agents.invocation_context import InvocationContext
from ...events.event import Event
from ...models.llm_request import LlmRequest
from ._base_llm_processor import BaseLlmRequestProcessor
from .functions import remove_client_function_call_id
from .functions import REQUEST_EUC_FUNCTION_CALL_NAME
class _ContentLlmRequestProcessor(BaseLlmRequestProcessor):
"""Builds the contents for the LLM request."""
@override
async def run_async(
self, invocation_context: InvocationContext, llm_request: LlmRequest
) -> AsyncGenerator[Event, None]:
from ...agents.llm_agent import LlmAgent
agent = invocation_context.agent
if not isinstance(agent, LlmAgent):
return
if agent.include_contents != 'none':
llm_request.contents = _get_contents(
invocation_context.branch,
invocation_context.session.events,
agent.name,
)
# Maintain async generator behavior
if False: # Ensures it behaves as a generator
yield # This is a no-op but maintains generator structure
request_processor = _ContentLlmRequestProcessor()
def _rearrange_events_for_async_function_responses_in_history(
events: list[Event],
) -> list[Event]:
"""Rearrange the async function_response events in the history."""
function_call_id_to_response_events_index: dict[str, list[Event]] = {}
for i, event in enumerate(events):
function_responses = event.get_function_responses()
if function_responses:
for function_response in function_responses:
function_call_id = function_response.id
function_call_id_to_response_events_index[function_call_id] = i
result_events: list[Event] = []
for event in events:
if event.get_function_responses():
# function_response should be handled together with function_call below.
continue
elif event.get_function_calls():
function_response_events_indices = set()
for function_call in event.get_function_calls():
function_call_id = function_call.id
if function_call_id in function_call_id_to_response_events_index:
function_response_events_indices.add(
function_call_id_to_response_events_index[function_call_id]
)
result_events.append(event)
if not function_response_events_indices:
continue
if len(function_response_events_indices) == 1:
result_events.append(
events[next(iter(function_response_events_indices))]
)
else: # Merge all async function_response as one response event
result_events.append(
_merge_function_response_events(
[events[i] for i in sorted(function_response_events_indices)]
)
)
continue
else:
result_events.append(event)
return result_events
def _rearrange_events_for_latest_function_response(
events: list[Event],
) -> list[Event]:
"""Rearrange the events for the latest function_response.
If the latest function_response is for an async function_call, all events
between the initial function_call and the latest function_response will be
removed.
Args:
events: A list of events.
Returns:
A list of events with the latest function_response rearranged.
"""
if not events:
return events
function_responses = events[-1].get_function_responses()
if not function_responses:
# No need to process, since the latest event is not fuction_response.
return events
function_responses_ids = set()
for function_response in function_responses:
function_responses_ids.add(function_response.id)
function_calls = events[-2].get_function_calls()
if function_calls:
for function_call in function_calls:
# The latest function_response is already matched
if function_call.id in function_responses_ids:
return events
function_call_event_idx = -1
# look for corresponding function call event reversely
for idx in range(len(events) - 2, -1, -1):
event = events[idx]
function_calls = event.get_function_calls()
if function_calls:
for function_call in function_calls:
if function_call.id in function_responses_ids:
function_call_event_idx = idx
break
if function_call_event_idx != -1:
# in case the last response event only have part of the responses
# for the function calls in the function call event
for function_call in function_calls:
function_responses_ids.add(function_call.id)
break
if function_call_event_idx == -1:
raise ValueError(
'No function call event found for function responses ids:'
f' {function_responses_ids}'
)
# collect all function response between last function response event
# and function call event
function_response_events: list[Event] = []
for idx in range(function_call_event_idx + 1, len(events) - 1):
event = events[idx]
function_responses = event.get_function_responses()
if (
function_responses
and function_responses[0].id in function_responses_ids
):
function_response_events.append(event)
function_response_events.append(events[-1])
result_events = events[: function_call_event_idx + 1]
result_events.append(
_merge_function_response_events(function_response_events)
)
return result_events
def _get_contents(
current_branch: Optional[str], events: list[Event], agent_name: str = ''
) -> list[types.Content]:
"""Get the contents for the LLM request.
Args:
current_branch: The current branch of the agent.
events: A list of events.
agent_name: The name of the agent.
Returns:
A list of contents.
"""
filtered_events = []
# Parse the events, leaving the contents and the function calls and
# responses from the current agent.
for event in events:
if not event.content or not event.content.role:
# Skip events without content, or generated neither by user nor by model.
# E.g. events purely for mutating session states.
continue
if not _is_event_belongs_to_branch(current_branch, event):
# Skip events not belong to current branch.
continue
if _is_auth_event(event):
# skip auth event
continue
filtered_events.append(
_convert_foreign_event(event)
if _is_other_agent_reply(agent_name, event)
else event
)
result_events = _rearrange_events_for_latest_function_response(
filtered_events
)
result_events = _rearrange_events_for_async_function_responses_in_history(
result_events
)
contents = []
for event in result_events:
content = copy.deepcopy(event.content)
remove_client_function_call_id(content)
contents.append(content)
return contents
def _is_other_agent_reply(current_agent_name: str, event: Event) -> bool:
"""Whether the event is a reply from another agent."""
return bool(
current_agent_name
and event.author != current_agent_name
and event.author != 'user'
)
def _convert_foreign_event(event: Event) -> Event:
"""Converts an event authored by another agent as a user-content event.
This is to provide another agent's output as context to the current agent, so
that current agent can continue to respond, such as summarizing previous
agent's reply, etc.
Args:
event: The event to convert.
Returns:
The converted event.
"""
if not event.content or not event.content.parts:
return event
content = types.Content()
content.role = 'user'
content.parts = [types.Part(text='For context:')]
for part in event.content.parts:
if part.text:
content.parts.append(
types.Part(text=f'[{event.author}] said: {part.text}')
)
elif part.function_call:
content.parts.append(
types.Part(
text=(
f'[{event.author}] called tool `{part.function_call.name}`'
f' with parameters: {part.function_call.args}'
)
)
)
elif part.function_response:
# Otherwise, create a new text part.
content.parts.append(
types.Part(
text=(
f'[{event.author}] `{part.function_response.name}` tool'
f' returned result: {part.function_response.response}'
)
)
)
# Fallback to the original part for non-text and non-functionCall parts.
else:
content.parts.append(part)
return Event(
timestamp=event.timestamp,
author='user',
content=content,
branch=event.branch,
)
def _merge_function_response_events(
function_response_events: list[Event],
) -> Event:
"""Merges a list of function_response events into one event.
The key goal is to ensure:
1. function_call and function_response are always of the same number.
2. The function_call and function_response are consecutively in the content.
Args:
function_response_events: A list of function_response events.
NOTE: function_response_events must fulfill these requirements: 1. The
list is in increasing order of timestamp; 2. the first event is the
initial function_response event; 3. all later events should contain at
least one function_response part that related to the function_call
event. (Note, 3. may not be true when aync function return some
intermediate response, there could also be some intermediate model
response event without any function_response and such event will be
ignored.)
Caveat: This implementation doesn't support when a parallel function_call
event contains async function_call of the same name.
Returns:
A merged event, that is
1. All later function_response will replace function_response part in
the initial function_response event.
2. All non-function_response parts will be appended to the part list of
the initial function_response event.
"""
if not function_response_events:
raise ValueError('At least one function_response event is required.')
merged_event = function_response_events[0].model_copy(deep=True)
parts_in_merged_event: list[types.Part] = merged_event.content.parts # type: ignore
if not parts_in_merged_event:
raise ValueError('There should be at least one function_response part.')
part_indices_in_merged_event: dict[str, int] = {}
for idx, part in enumerate(parts_in_merged_event):
if part.function_response:
function_call_id: str = part.function_response.id # type: ignore
part_indices_in_merged_event[function_call_id] = idx
for event in function_response_events[1:]:
if not event.content.parts:
raise ValueError('There should be at least one function_response part.')
for part in event.content.parts:
if part.function_response:
function_call_id: str = part.function_response.id # type: ignore
if function_call_id in part_indices_in_merged_event:
parts_in_merged_event[
part_indices_in_merged_event[function_call_id]
] = part
else:
parts_in_merged_event.append(part)
part_indices_in_merged_event[function_call_id] = (
len(parts_in_merged_event) - 1
)
else:
parts_in_merged_event.append(part)
return merged_event
def _is_event_belongs_to_branch(
invocation_branch: Optional[str], event: Event
) -> bool:
"""Event belongs to a branch, when event.branch is prefix of the invocation branch."""
if not invocation_branch or not event.branch:
return True
return invocation_branch.startswith(event.branch)
def _is_auth_event(event: Event) -> bool:
if not event.content.parts:
return False
for part in event.content.parts:
if (
part.function_call
and part.function_call.name == REQUEST_EUC_FUNCTION_CALL_NAME
):
return True
if (
part.function_response
and part.function_response.name == REQUEST_EUC_FUNCTION_CALL_NAME
):
return True
return False

View File

@@ -0,0 +1,486 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Handles function callings for LLM flow."""
from __future__ import annotations
import asyncio
import inspect
import logging
from typing import Any
from typing import AsyncGenerator
from typing import cast
from typing import Optional
import uuid
from google.genai import types
from ...agents.active_streaming_tool import ActiveStreamingTool
from ...agents.invocation_context import InvocationContext
from ...auth.auth_tool import AuthToolArguments
from ...events.event import Event
from ...events.event_actions import EventActions
from ...telemetry import trace_tool_call
from ...telemetry import trace_tool_response
from ...telemetry import tracer
from ...tools.base_tool import BaseTool
from ...tools.tool_context import ToolContext
AF_FUNCTION_CALL_ID_PREFIX = 'adk-'
REQUEST_EUC_FUNCTION_CALL_NAME = 'adk_request_credential'
logger = logging.getLogger(__name__)
def generate_client_function_call_id() -> str:
return f'{AF_FUNCTION_CALL_ID_PREFIX}{uuid.uuid4()}'
def populate_client_function_call_id(model_response_event: Event) -> None:
if not model_response_event.get_function_calls():
return
for function_call in model_response_event.get_function_calls():
if not function_call.id:
function_call.id = generate_client_function_call_id()
def remove_client_function_call_id(content: types.Content) -> None:
if content and content.parts:
for part in content.parts:
if (
part.function_call
and part.function_call.id
and part.function_call.id.startswith(AF_FUNCTION_CALL_ID_PREFIX)
):
part.function_call.id = None
if (
part.function_response
and part.function_response.id
and part.function_response.id.startswith(AF_FUNCTION_CALL_ID_PREFIX)
):
part.function_response.id = None
def get_long_running_function_calls(
function_calls: list[types.FunctionCall],
tools_dict: dict[str, BaseTool],
) -> set[str]:
long_running_tool_ids = set()
for function_call in function_calls:
if (
function_call.name in tools_dict
and tools_dict[function_call.name].is_long_running
):
long_running_tool_ids.add(function_call.id)
return long_running_tool_ids
def generate_auth_event(
invocation_context: InvocationContext,
function_response_event: Event,
) -> Optional[Event]:
if not function_response_event.actions.requested_auth_configs:
return None
parts = []
long_running_tool_ids = set()
for (
function_call_id,
auth_config,
) in function_response_event.actions.requested_auth_configs.items():
request_euc_function_call = types.FunctionCall(
name=REQUEST_EUC_FUNCTION_CALL_NAME,
args=AuthToolArguments(
function_call_id=function_call_id,
auth_config=auth_config,
).model_dump(exclude_none=True),
)
request_euc_function_call.id = generate_client_function_call_id()
long_running_tool_ids.add(request_euc_function_call.id)
parts.append(types.Part(function_call=request_euc_function_call))
return Event(
invocation_id=invocation_context.invocation_id,
author=invocation_context.agent.name,
branch=invocation_context.branch,
content=types.Content(
parts=parts, role=function_response_event.content.role
),
long_running_tool_ids=long_running_tool_ids,
)
async def handle_function_calls_async(
invocation_context: InvocationContext,
function_call_event: Event,
tools_dict: dict[str, BaseTool],
filters: Optional[set[str]] = None,
) -> Optional[Event]:
"""Calls the functions and returns the function response event."""
from ...agents.llm_agent import LlmAgent
agent = invocation_context.agent
if not isinstance(agent, LlmAgent):
return
function_calls = function_call_event.get_function_calls()
function_response_events: list[Event] = []
for function_call in function_calls:
if filters and function_call.id not in filters:
continue
tool, tool_context = _get_tool_and_context(
invocation_context,
function_call_event,
function_call,
tools_dict,
)
# do not use "args" as the variable name, because it is a reserved keyword
# in python debugger.
function_args = function_call.args or {}
function_response = None
# Calls the tool if before_tool_callback does not exist or returns None.
if agent.before_tool_callback:
function_response = agent.before_tool_callback(
tool=tool, args=function_args, tool_context=tool_context
)
if not function_response:
function_response = await __call_tool_async(
tool, args=function_args, tool_context=tool_context
)
# Calls after_tool_callback if it exists.
if agent.after_tool_callback:
new_response = agent.after_tool_callback(
tool=tool,
args=function_args,
tool_context=tool_context,
tool_response=function_response,
)
if new_response:
function_response = new_response
if tool.is_long_running:
# Allow long running function to return None to not provide function response.
if not function_response:
continue
# Builds the function response event.
function_response_event = __build_response_event(
tool, function_response, tool_context, invocation_context
)
function_response_events.append(function_response_event)
if not function_response_events:
return None
merged_event = merge_parallel_function_response_events(
function_response_events
)
if len(function_response_events) > 1:
# this is needed for debug traces of parallel calls
# individual response with tool.name is traced in __build_response_event
# (we drop tool.name from span name here as this is merged event)
with tracer.start_as_current_span('tool_response'):
trace_tool_response(
invocation_context=invocation_context,
event_id=merged_event.id,
function_response_event=merged_event,
)
return merged_event
async def handle_function_calls_live(
invocation_context: InvocationContext,
function_call_event: Event,
tools_dict: dict[str, BaseTool],
) -> Event:
"""Calls the functions and returns the function response event."""
from ...agents.llm_agent import LlmAgent
agent = cast(LlmAgent, invocation_context.agent)
function_calls = function_call_event.get_function_calls()
function_response_events: list[Event] = []
for function_call in function_calls:
tool, tool_context = _get_tool_and_context(
invocation_context, function_call_event, function_call, tools_dict
)
# do not use "args" as the variable name, because it is a reserved keyword
# in python debugger.
function_args = function_call.args or {}
function_response = None
# Calls the tool if before_tool_callback does not exist or returns None.
if agent.before_tool_callback:
function_response = agent.before_tool_callback(
tool, function_args, tool_context
)
if not function_response:
function_response = await _process_function_live_helper(
tool, tool_context, function_call, function_args, invocation_context
)
# Calls after_tool_callback if it exists.
if agent.after_tool_callback:
new_response = agent.after_tool_callback(
tool,
function_args,
tool_context,
function_response,
)
if new_response:
function_response = new_response
if tool.is_long_running:
# Allow async function to return None to not provide function response.
if not function_response:
continue
# Builds the function response event.
function_response_event = __build_response_event(
tool, function_response, tool_context, invocation_context
)
function_response_events.append(function_response_event)
if not function_response_events:
return None
merged_event = merge_parallel_function_response_events(
function_response_events
)
return merged_event
async def _process_function_live_helper(
tool, tool_context, function_call, function_args, invocation_context
):
function_response = None
# Check if this is a stop_streaming function call
if (
function_call.name == 'stop_streaming'
and 'function_name' in function_args
):
function_name = function_args['function_name']
active_tasks = invocation_context.active_streaming_tools
if (
function_name in active_tasks
and active_tasks[function_name].task
and not active_tasks[function_name].task.done()
):
task = active_tasks[function_name].task
task.cancel()
try:
# Wait for the task to be cancelled
await asyncio.wait_for(task, timeout=1.0)
except (asyncio.CancelledError, asyncio.TimeoutError):
# Log the specific condition
if task.cancelled():
logging.info(f'Task {function_name} was cancelled successfully')
elif task.done():
logging.info(f'Task {function_name} completed during cancellation')
else:
logging.warning(
f'Task {function_name} might still be running after'
' cancellation timeout'
)
function_response = {
'status': f'The task is not cancelled yet for {function_name}.'
}
if not function_response:
# Clean up the reference
active_tasks[function_name].task = None
function_response = {
'status': f'Successfully stopped streaming function {function_name}'
}
else:
function_response = {
'status': f'No active streaming function named {function_name} found'
}
elif hasattr(tool, "func") and inspect.isasyncgenfunction(tool.func):
# for streaming tool use case
# we require the function to be a async generator function
async def run_tool_and_update_queue(tool, function_args, tool_context):
try:
async for result in __call_tool_live(
tool=tool,
args=function_args,
tool_context=tool_context,
invocation_context=invocation_context,
):
updated_content = types.Content(
role='user',
parts=[
types.Part.from_text(
text=f'Function {tool.name} returned: {result}'
)
],
)
invocation_context.live_request_queue.send_content(updated_content)
except asyncio.CancelledError:
raise # Re-raise to properly propagate the cancellation
task = asyncio.create_task(
run_tool_and_update_queue(tool, function_args, tool_context)
)
if invocation_context.active_streaming_tools is None:
invocation_context.active_streaming_tools = {}
if tool.name in invocation_context.active_streaming_tools:
invocation_context.active_streaming_tools[tool.name].task = task
else:
invocation_context.active_streaming_tools[tool.name] = (
ActiveStreamingTool(task=task)
)
# Immediately return a pending response.
# This is required by current live model.
function_response = {
'status': (
'The function is running asynchronously and the results are'
' pending.'
)
}
else:
function_response = await __call_tool_async(
tool, args=function_args, tool_context=tool_context
)
return function_response
def _get_tool_and_context(
invocation_context: InvocationContext,
function_call_event: Event,
function_call: types.FunctionCall,
tools_dict: dict[str, BaseTool],
):
if function_call.name not in tools_dict:
raise ValueError(
f'Function {function_call.name} is not found in the tools_dict.'
)
tool_context = ToolContext(
invocation_context=invocation_context,
function_call_id=function_call.id,
)
tool = tools_dict[function_call.name]
return (tool, tool_context)
async def __call_tool_live(
tool: BaseTool,
args: dict[str, object],
tool_context: ToolContext,
invocation_context: InvocationContext,
) -> AsyncGenerator[Event, None]:
"""Calls the tool asynchronously (awaiting the coroutine)."""
with tracer.start_as_current_span(f'tool_call [{tool.name}]'):
trace_tool_call(args=args)
async for item in tool._call_live(
args=args,
tool_context=tool_context,
invocation_context=invocation_context,
):
yield item
async def __call_tool_async(
tool: BaseTool,
args: dict[str, Any],
tool_context: ToolContext,
) -> Any:
"""Calls the tool."""
with tracer.start_as_current_span(f'tool_call [{tool.name}]'):
trace_tool_call(args=args)
return await tool.run_async(args=args, tool_context=tool_context)
def __build_response_event(
tool: BaseTool,
function_result: dict[str, object],
tool_context: ToolContext,
invocation_context: InvocationContext,
) -> Event:
with tracer.start_as_current_span(f'tool_response [{tool.name}]'):
# Specs requires the result to be a dict.
if not isinstance(function_result, dict):
function_result = {'result': function_result}
part_function_response = types.Part.from_function_response(
name=tool.name, response=function_result
)
part_function_response.function_response.id = tool_context.function_call_id
content = types.Content(
role='user',
parts=[part_function_response],
)
function_response_event = Event(
invocation_id=invocation_context.invocation_id,
author=invocation_context.agent.name,
content=content,
actions=tool_context.actions,
branch=invocation_context.branch,
)
trace_tool_response(
invocation_context=invocation_context,
event_id=function_response_event.id,
function_response_event=function_response_event,
)
return function_response_event
def merge_parallel_function_response_events(
function_response_events: list['Event'],
) -> 'Event':
if not function_response_events:
raise ValueError('No function response events provided.')
if len(function_response_events) == 1:
return function_response_events[0]
merged_parts = []
for event in function_response_events:
if event.content:
for part in event.content.parts or []:
merged_parts.append(part)
# Use the first event as the "base" for common attributes
base_event = function_response_events[0]
# Merge actions from all events
merged_actions = EventActions()
merged_requested_auth_configs = {}
for event in function_response_events:
merged_requested_auth_configs.update(event.actions.requested_auth_configs)
merged_actions = merged_actions.model_copy(
update=event.actions.model_dump()
)
merged_actions.requested_auth_configs = merged_requested_auth_configs
# Create the new merged event
merged_event = Event(
invocation_id=Event.new_id(),
author=base_event.author,
branch=base_event.branch,
content=types.Content(role='user', parts=merged_parts),
actions=merged_actions, # Optionally merge actions if required
)
# Use the base_event as the timestamp
merged_event.timestamp = base_event.timestamp
return merged_event

View File

@@ -0,0 +1,47 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Gives the agent identity from the framework."""
from __future__ import annotations
from typing import AsyncGenerator
from typing_extensions import override
from ...agents.invocation_context import InvocationContext
from ...events.event import Event
from ...models.llm_request import LlmRequest
from ._base_llm_processor import BaseLlmRequestProcessor
class _IdentityLlmRequestProcessor(BaseLlmRequestProcessor):
"""Gives the agent identity from the framework."""
@override
async def run_async(
self, invocation_context: InvocationContext, llm_request: LlmRequest
) -> AsyncGenerator[Event, None]:
agent = invocation_context.agent
si = [f'You are an agent. Your internal name is "{agent.name}".']
if agent.description:
si.append(f' The description about you is "{agent.description}"')
llm_request.append_instructions(si)
# Maintain async generator behavior
if False: # Ensures it behaves as a generator
yield # This is a no-op but maintains generator structure
request_processor = _IdentityLlmRequestProcessor()

View File

@@ -0,0 +1,137 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Handles instructions and global instructions for LLM flow."""
from __future__ import annotations
import re
from typing import AsyncGenerator
from typing import Generator
from typing import TYPE_CHECKING
from typing_extensions import override
from ...agents.readonly_context import ReadonlyContext
from ...events.event import Event
from ...sessions.state import State
from ._base_llm_processor import BaseLlmRequestProcessor
if TYPE_CHECKING:
from ...agents.invocation_context import InvocationContext
from ...models.llm_request import LlmRequest
class _InstructionsLlmRequestProcessor(BaseLlmRequestProcessor):
"""Handles instructions and global instructions for LLM flow."""
@override
async def run_async(
self, invocation_context: InvocationContext, llm_request: LlmRequest
) -> AsyncGenerator[Event, None]:
from ...agents.base_agent import BaseAgent
from ...agents.llm_agent import LlmAgent
agent = invocation_context.agent
if not isinstance(agent, LlmAgent):
return
root_agent: BaseAgent = agent.root_agent
# Appends global instructions if set.
if (
isinstance(root_agent, LlmAgent) and root_agent.global_instruction
): # not empty str
raw_si = root_agent.canonical_global_instruction(
ReadonlyContext(invocation_context)
)
si = _populate_values(raw_si, invocation_context)
llm_request.append_instructions([si])
# Appends agent instructions if set.
if agent.instruction: # not empty str
raw_si = agent.canonical_instruction(ReadonlyContext(invocation_context))
si = _populate_values(raw_si, invocation_context)
llm_request.append_instructions([si])
# Maintain async generator behavior
if False: # Ensures it behaves as a generator
yield # This is a no-op but maintains generator structure
request_processor = _InstructionsLlmRequestProcessor()
def _populate_values(
instruction_template: str,
context: InvocationContext,
) -> str:
"""Populates values in the instruction template, e.g. state, artifact, etc."""
def _replace_match(match) -> str:
var_name = match.group().lstrip('{').rstrip('}').strip()
optional = False
if var_name.endswith('?'):
optional = True
var_name = var_name.removesuffix('?')
if var_name.startswith('artifact.'):
var_name = var_name.removeprefix('artifact.')
if context.artifact_service is None:
raise ValueError('Artifact service is not initialized.')
artifact = context.artifact_service.load_artifact(
app_name=context.session.app_name,
user_id=context.session.user_id,
session_id=context.session.id,
filename=var_name,
)
if not var_name:
raise KeyError(f'Artifact {var_name} not found.')
return str(artifact)
else:
if not _is_valid_state_name(var_name):
return match.group()
if var_name in context.session.state:
return str(context.session.state[var_name])
else:
if optional:
return ''
else:
raise KeyError(f'Context variable not found: `{var_name}`.')
return re.sub(r'{+[^{}]*}+', _replace_match, instruction_template)
def _is_valid_state_name(var_name):
"""Checks if the variable name is a valid state name.
Valid state is either:
- Valid identifier
- <Valid prefix>:<Valid identifier>
All the others will just return as it is.
Args:
var_name: The variable name to check.
Returns:
True if the variable name is a valid state name, False otherwise.
"""
parts = var_name.split(':')
if len(parts) == 1:
return var_name.isidentifier()
if len(parts) == 2:
prefixes = [State.APP_PREFIX, State.USER_PREFIX, State.TEMP_PREFIX]
if (parts[0] + ':') in prefixes:
return parts[1].isidentifier()
return False

View File

@@ -0,0 +1,57 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Implementation of single flow."""
import logging
from ...auth import auth_preprocessor
from . import _code_execution
from . import _nl_planning
from . import basic
from . import contents
from . import identity
from . import instructions
from .base_llm_flow import BaseLlmFlow
logger = logging.getLogger(__name__)
class SingleFlow(BaseLlmFlow):
"""SingleFlow is the LLM flows that handles tools calls.
A single flow only consider an agent itself and tools.
No sub-agents are allowed for single flow.
"""
def __init__(self):
super().__init__()
self.request_processors += [
basic.request_processor,
auth_preprocessor.request_processor,
instructions.request_processor,
identity.request_processor,
contents.request_processor,
# Some implementations of NL Planning mark planning contents as thoughts
# in the post processor. Since these need to be unmarked, NL Planning
# should be after contents.
_nl_planning.request_processor,
# Code execution should be after the contents as it mutates the contents
# to optimize data files.
_code_execution.request_processor,
]
self.response_processors += [
_nl_planning.response_processor,
_code_execution.response_processor,
]