structure saas with tools
This commit is contained in:
@@ -0,0 +1,14 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
Binary file not shown.
@@ -0,0 +1,20 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from . import _code_execution
|
||||
from . import _nl_planning
|
||||
from . import contents
|
||||
from . import functions
|
||||
from . import identity
|
||||
from . import instructions
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,52 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Defines the processor interface used for BaseLlmFlow."""
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC
|
||||
from abc import abstractmethod
|
||||
from typing import AsyncGenerator
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ...agents.invocation_context import InvocationContext
|
||||
from ...events.event import Event
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ...models.llm_request import LlmRequest
|
||||
from ...models.llm_response import LlmResponse
|
||||
|
||||
|
||||
class BaseLlmRequestProcessor(ABC):
|
||||
"""Base class for LLM request processor."""
|
||||
|
||||
@abstractmethod
|
||||
async def run_async(
|
||||
self, invocation_context: InvocationContext, llm_request: LlmRequest
|
||||
) -> AsyncGenerator[Event, None]:
|
||||
"""Runs the processor."""
|
||||
raise NotImplementedError("Not implemented.")
|
||||
yield # AsyncGenerator requires a yield in function body.
|
||||
|
||||
|
||||
class BaseLlmResponseProcessor(ABC):
|
||||
"""Base class for LLM response processor."""
|
||||
|
||||
@abstractmethod
|
||||
async def run_async(
|
||||
self, invocation_context: InvocationContext, llm_response: LlmResponse
|
||||
) -> AsyncGenerator[Event, None]:
|
||||
"""Processes the LLM response."""
|
||||
raise NotImplementedError("Not implemented.")
|
||||
yield # AsyncGenerator requires a yield in function body.
|
||||
@@ -0,0 +1,458 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Handles Code Execution related logic."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import copy
|
||||
import dataclasses
|
||||
import os
|
||||
import re
|
||||
from typing import AsyncGenerator
|
||||
from typing import Generator
|
||||
from typing import Optional
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from google.genai import types
|
||||
from typing_extensions import override
|
||||
|
||||
from ...agents.invocation_context import InvocationContext
|
||||
from ...code_executors.base_code_executor import BaseCodeExecutor
|
||||
from ...code_executors.code_execution_utils import CodeExecutionInput
|
||||
from ...code_executors.code_execution_utils import CodeExecutionResult
|
||||
from ...code_executors.code_execution_utils import CodeExecutionUtils
|
||||
from ...code_executors.code_execution_utils import File
|
||||
from ...code_executors.code_executor_context import CodeExecutorContext
|
||||
from ...events.event import Event
|
||||
from ...events.event_actions import EventActions
|
||||
from ...models.llm_response import LlmResponse
|
||||
from ._base_llm_processor import BaseLlmRequestProcessor
|
||||
from ._base_llm_processor import BaseLlmResponseProcessor
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ...models.llm_request import LlmRequest
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class DataFileUtil:
|
||||
"""A structure that contains a data file name and its content."""
|
||||
|
||||
extension: str
|
||||
"""
|
||||
The file extension (e.g., ".csv").
|
||||
"""
|
||||
|
||||
loader_code_template: str
|
||||
"""
|
||||
The code template to load the data file.
|
||||
"""
|
||||
|
||||
|
||||
_DATA_FILE_UTIL_MAP = {
|
||||
'text/csv': DataFileUtil(
|
||||
extension='.csv',
|
||||
loader_code_template="pd.read_csv('{filename}')",
|
||||
),
|
||||
}
|
||||
|
||||
_DATA_FILE_HELPER_LIB = '''
|
||||
import pandas as pd
|
||||
|
||||
def explore_df(df: pd.DataFrame) -> None:
|
||||
"""Prints some information about a pandas DataFrame."""
|
||||
|
||||
with pd.option_context(
|
||||
'display.max_columns', None, 'display.expand_frame_repr', False
|
||||
):
|
||||
# Print the column names to never encounter KeyError when selecting one.
|
||||
df_dtypes = df.dtypes
|
||||
|
||||
# Obtain information about data types and missing values.
|
||||
df_nulls = (len(df) - df.isnull().sum()).apply(
|
||||
lambda x: f'{x} / {df.shape[0]} non-null'
|
||||
)
|
||||
|
||||
# Explore unique total values in columns using `.unique()`.
|
||||
df_unique_count = df.apply(lambda x: len(x.unique()))
|
||||
|
||||
# Explore unique values in columns using `.unique()`.
|
||||
df_unique = df.apply(lambda x: crop(str(list(x.unique()))))
|
||||
|
||||
df_info = pd.concat(
|
||||
(
|
||||
df_dtypes.rename('Dtype'),
|
||||
df_nulls.rename('Non-Null Count'),
|
||||
df_unique_count.rename('Unique Values Count'),
|
||||
df_unique.rename('Unique Values'),
|
||||
),
|
||||
axis=1,
|
||||
)
|
||||
df_info.index.name = 'Columns'
|
||||
print(f"""Total rows: {df.shape[0]}
|
||||
Total columns: {df.shape[1]}
|
||||
|
||||
{df_info}""")
|
||||
'''
|
||||
|
||||
|
||||
class _CodeExecutionRequestProcessor(BaseLlmRequestProcessor):
|
||||
"""Processes code execution requests."""
|
||||
|
||||
@override
|
||||
async def run_async(
|
||||
self, invocation_context: InvocationContext, llm_request: LlmRequest
|
||||
) -> AsyncGenerator[Event, None]:
|
||||
from ...agents.llm_agent import LlmAgent
|
||||
|
||||
if not isinstance(invocation_context.agent, LlmAgent):
|
||||
return
|
||||
if not invocation_context.agent.code_executor:
|
||||
return
|
||||
|
||||
for event in _run_pre_processor(invocation_context, llm_request):
|
||||
yield event
|
||||
|
||||
# Convert the code execution parts to text parts.
|
||||
if not isinstance(invocation_context.agent.code_executor, BaseCodeExecutor):
|
||||
return
|
||||
for content in llm_request.contents:
|
||||
CodeExecutionUtils.convert_code_execution_parts(
|
||||
content,
|
||||
invocation_context.agent.code_executor.code_block_delimiters[0]
|
||||
if invocation_context.agent.code_executor.code_block_delimiters
|
||||
else ('', ''),
|
||||
invocation_context.agent.code_executor.execution_result_delimiters,
|
||||
)
|
||||
|
||||
|
||||
request_processor = _CodeExecutionRequestProcessor()
|
||||
|
||||
|
||||
class _CodeExecutionResponseProcessor(BaseLlmResponseProcessor):
|
||||
"""Processes code execution responses."""
|
||||
|
||||
@override
|
||||
async def run_async(
|
||||
self, invocation_context: InvocationContext, llm_response: LlmResponse
|
||||
) -> AsyncGenerator[Event, None]:
|
||||
# Skip if the response is partial (streaming).
|
||||
if llm_response.partial:
|
||||
return
|
||||
|
||||
for event in _run_post_processor(invocation_context, llm_response):
|
||||
yield event
|
||||
|
||||
|
||||
response_processor = _CodeExecutionResponseProcessor()
|
||||
|
||||
|
||||
def _run_pre_processor(
|
||||
invocation_context: InvocationContext,
|
||||
llm_request: LlmRequest,
|
||||
) -> Generator[Event, None, None]:
|
||||
"""Pre-process the user message by adding the user message to the Colab notebook."""
|
||||
from ...agents.llm_agent import LlmAgent
|
||||
|
||||
if not isinstance(invocation_context.agent, LlmAgent):
|
||||
return
|
||||
|
||||
agent = invocation_context.agent
|
||||
code_executor = agent.code_executor
|
||||
|
||||
if not code_executor or not isinstance(code_executor, BaseCodeExecutor):
|
||||
return
|
||||
if not code_executor.optimize_data_file:
|
||||
return
|
||||
|
||||
code_executor_context = CodeExecutorContext(invocation_context.session.state)
|
||||
|
||||
# Skip if the error count exceeds the max retry attempts.
|
||||
if (
|
||||
code_executor_context.get_error_count(invocation_context.invocation_id)
|
||||
>= code_executor.error_retry_attempts
|
||||
):
|
||||
return
|
||||
|
||||
# [Step 1] Extract data files from the session_history and store them in
|
||||
# memory. Meanwhile, mutate the inline data file to text part in session
|
||||
# history from all turns.
|
||||
all_input_files = _extrac_and_replace_inline_files(
|
||||
code_executor_context, llm_request
|
||||
)
|
||||
|
||||
# [Step 2] Run Explore_Df code on the data files from the current turn. We
|
||||
# only need to explore the new data files because the previous data files
|
||||
# should already be explored and cached in the code execution runtime.
|
||||
processed_file_names = set(code_executor_context.get_processed_file_names())
|
||||
files_to_process = [
|
||||
f for f in all_input_files if f.name not in processed_file_names
|
||||
]
|
||||
for file in files_to_process:
|
||||
code_str = _get_data_file_preprocessing_code(file)
|
||||
# Skip for unsupported file or executor types.
|
||||
if not code_str:
|
||||
return
|
||||
|
||||
# Emit the code to execute, and add it to the LLM request.
|
||||
code_content = types.Content(
|
||||
role='model',
|
||||
parts=[
|
||||
types.Part(text=f'Processing input file: `{file.name}`'),
|
||||
CodeExecutionUtils.build_executable_code_part(code_str),
|
||||
],
|
||||
)
|
||||
llm_request.contents.append(copy.deepcopy(code_content))
|
||||
yield Event(
|
||||
invocation_id=invocation_context.invocation_id,
|
||||
author=agent.name,
|
||||
branch=invocation_context.branch,
|
||||
content=code_content,
|
||||
)
|
||||
|
||||
code_execution_result = code_executor.execute_code(
|
||||
invocation_context,
|
||||
CodeExecutionInput(
|
||||
code=code_str,
|
||||
input_files=[file],
|
||||
execution_id=_get_or_set_execution_id(
|
||||
invocation_context, code_executor_context
|
||||
),
|
||||
),
|
||||
)
|
||||
# Update the processing results to code executor context.
|
||||
code_executor_context.update_code_execution_result(
|
||||
invocation_context.invocation_id,
|
||||
code_str,
|
||||
code_execution_result.stdout,
|
||||
code_execution_result.stderr,
|
||||
)
|
||||
code_executor_context.add_processed_file_names([file.name])
|
||||
|
||||
# Emit the execution result, and add it to the LLM request.
|
||||
execution_result_event = _post_process_code_execution_result(
|
||||
invocation_context, code_executor_context, code_execution_result
|
||||
)
|
||||
yield execution_result_event
|
||||
llm_request.contents.append(copy.deepcopy(execution_result_event.content))
|
||||
|
||||
|
||||
def _run_post_processor(
|
||||
invocation_context: InvocationContext,
|
||||
llm_response,
|
||||
) -> Generator[Event, None, None]:
|
||||
"""Post-process the model response by extracting and executing the first code block."""
|
||||
agent = invocation_context.agent
|
||||
code_executor = agent.code_executor
|
||||
|
||||
if not code_executor or not isinstance(code_executor, BaseCodeExecutor):
|
||||
return
|
||||
if not llm_response or not llm_response.content:
|
||||
return
|
||||
|
||||
code_executor_context = CodeExecutorContext(invocation_context.session.state)
|
||||
# Skip if the error count exceeds the max retry attempts.
|
||||
if (
|
||||
code_executor_context.get_error_count(invocation_context.invocation_id)
|
||||
>= code_executor.error_retry_attempts
|
||||
):
|
||||
return
|
||||
|
||||
# [Step 1] Extract code from the model predict response and truncate the
|
||||
# content to the part with the first code block.
|
||||
response_content = llm_response.content
|
||||
code_str = CodeExecutionUtils.extract_code_and_truncate_content(
|
||||
response_content, code_executor.code_block_delimiters
|
||||
)
|
||||
# Terminal state: no code to execute.
|
||||
if not code_str:
|
||||
return
|
||||
|
||||
# [Step 2] Executes the code and emit 2 Events for code and execution result.
|
||||
yield Event(
|
||||
invocation_id=invocation_context.invocation_id,
|
||||
author=agent.name,
|
||||
branch=invocation_context.branch,
|
||||
content=response_content,
|
||||
actions=EventActions(),
|
||||
)
|
||||
|
||||
code_execution_result = code_executor.execute_code(
|
||||
invocation_context,
|
||||
CodeExecutionInput(
|
||||
code=code_str,
|
||||
input_files=code_executor_context.get_input_files(),
|
||||
execution_id=_get_or_set_execution_id(
|
||||
invocation_context, code_executor_context
|
||||
),
|
||||
),
|
||||
)
|
||||
code_executor_context.update_code_execution_result(
|
||||
invocation_context.invocation_id,
|
||||
code_str,
|
||||
code_execution_result.stdout,
|
||||
code_execution_result.stderr,
|
||||
)
|
||||
yield _post_process_code_execution_result(
|
||||
invocation_context, code_executor_context, code_execution_result
|
||||
)
|
||||
|
||||
# [Step 3] Skip processing the original model response
|
||||
# to continue code generation loop.
|
||||
llm_response.content = None
|
||||
|
||||
|
||||
def _extrac_and_replace_inline_files(
|
||||
code_executor_context: CodeExecutorContext,
|
||||
llm_request: LlmRequest,
|
||||
) -> list[File]:
|
||||
"""Extracts and replaces inline files with file names in the LLM request."""
|
||||
all_input_files = code_executor_context.get_input_files()
|
||||
saved_file_names = set(f.name for f in all_input_files)
|
||||
|
||||
# [Step 1] Process input files from LlmRequest and cache them in CodeExecutor.
|
||||
for i in range(len(llm_request.contents)):
|
||||
content = llm_request.contents[i]
|
||||
# Only process the user message.
|
||||
if content.role != 'user' and not content.parts:
|
||||
continue
|
||||
|
||||
for j in range(len(content.parts)):
|
||||
part = content.parts[j]
|
||||
# Skip if the inline data is not supported.
|
||||
if (
|
||||
not part.inline_data
|
||||
or part.inline_data.mime_type not in _DATA_FILE_UTIL_MAP
|
||||
):
|
||||
continue
|
||||
|
||||
# Replace the inline data file with a file name placeholder.
|
||||
mime_type = part.inline_data.mime_type
|
||||
file_name = f'data_{i+1}_{j+1}' + _DATA_FILE_UTIL_MAP[mime_type].extension
|
||||
llm_request.contents[i].parts[j] = types.Part(
|
||||
text='\nAvailable file: `%s`\n' % file_name
|
||||
)
|
||||
|
||||
# Add the inlne data as input file to the code executor context.
|
||||
file = File(
|
||||
name=file_name,
|
||||
content=CodeExecutionUtils.get_encoded_file_content(
|
||||
part.inline_data.data
|
||||
).decode(),
|
||||
mime_type=mime_type,
|
||||
)
|
||||
if file_name not in saved_file_names:
|
||||
code_executor_context.add_input_files([file])
|
||||
all_input_files.append(file)
|
||||
|
||||
return all_input_files
|
||||
|
||||
|
||||
def _get_or_set_execution_id(
|
||||
invocation_context: InvocationContext,
|
||||
code_executor_context: CodeExecutorContext,
|
||||
) -> Optional[str]:
|
||||
"""Returns the ID for stateful code execution or None if not stateful."""
|
||||
if not invocation_context.agent.code_executor.stateful:
|
||||
return None
|
||||
|
||||
execution_id = code_executor_context.get_execution_id()
|
||||
if not execution_id:
|
||||
execution_id = invocation_context.session.id
|
||||
code_executor_context.set_execution_id(execution_id)
|
||||
return execution_id
|
||||
|
||||
|
||||
def _post_process_code_execution_result(
|
||||
invocation_context: InvocationContext,
|
||||
code_executor_context: CodeExecutorContext,
|
||||
code_execution_result: CodeExecutionResult,
|
||||
) -> Event:
|
||||
"""Post-process the code execution result and emit an Event."""
|
||||
if invocation_context.artifact_service is None:
|
||||
raise ValueError('Artifact service is not initialized.')
|
||||
|
||||
result_content = types.Content(
|
||||
role='model',
|
||||
parts=[
|
||||
CodeExecutionUtils.build_code_execution_result_part(
|
||||
code_execution_result
|
||||
),
|
||||
],
|
||||
)
|
||||
event_actions = EventActions(
|
||||
state_delta=code_executor_context.get_state_delta()
|
||||
)
|
||||
|
||||
# Handle code execution error retry.
|
||||
if code_execution_result.stderr:
|
||||
code_executor_context.increment_error_count(
|
||||
invocation_context.invocation_id
|
||||
)
|
||||
else:
|
||||
code_executor_context.reset_error_count(invocation_context.invocation_id)
|
||||
|
||||
# Handle output files.
|
||||
for output_file in code_execution_result.output_files:
|
||||
version = invocation_context.artifact_service.save_artifact(
|
||||
app_name=invocation_context.app_name,
|
||||
user_id=invocation_context.user_id,
|
||||
session_id=invocation_context.session.id,
|
||||
filename=output_file.name,
|
||||
artifact=types.Part.from_bytes(
|
||||
data=base64.b64decode(output_file.content),
|
||||
mime_type=output_file.mime_type,
|
||||
),
|
||||
)
|
||||
event_actions.artifact_delta[output_file.name] = version
|
||||
|
||||
return Event(
|
||||
invocation_id=invocation_context.invocation_id,
|
||||
author=invocation_context.agent.name,
|
||||
branch=invocation_context.branch,
|
||||
content=result_content,
|
||||
actions=event_actions,
|
||||
)
|
||||
|
||||
|
||||
def _get_data_file_preprocessing_code(file: File) -> Optional[str]:
|
||||
"""Returns the code to explore the data file."""
|
||||
|
||||
def _get_normalized_file_name(file_name: str) -> str:
|
||||
var_name, _ = os.path.splitext(file_name)
|
||||
# Replace non-alphanumeric characters with underscores
|
||||
var_name = re.sub(r'[^a-zA-Z0-9_]', '_', var_name)
|
||||
|
||||
# If the filename starts with a digit, prepend an underscore
|
||||
if var_name[0].isdigit():
|
||||
var_name = '_' + var_name
|
||||
return var_name
|
||||
|
||||
if file.mime_type not in _DATA_FILE_UTIL_MAP:
|
||||
return
|
||||
|
||||
var_name = _get_normalized_file_name(file.name)
|
||||
loader_code = _DATA_FILE_UTIL_MAP[file.mime_type].loader_code_template.format(
|
||||
filename=file.name
|
||||
)
|
||||
return f"""
|
||||
{_DATA_FILE_HELPER_LIB}
|
||||
|
||||
# Load the dataframe.
|
||||
{var_name} = {loader_code}
|
||||
|
||||
# Use `explore_df` to guide my analysis.
|
||||
explore_df({var_name})
|
||||
"""
|
||||
@@ -0,0 +1,135 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Handles NL planning related logic."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import AsyncGenerator
|
||||
from typing import Generator
|
||||
from typing import Optional
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from typing_extensions import override
|
||||
|
||||
from ...agents.callback_context import CallbackContext
|
||||
from ...agents.invocation_context import InvocationContext
|
||||
from ...agents.readonly_context import ReadonlyContext
|
||||
from ...events.event import Event
|
||||
from ...planners.plan_re_act_planner import PlanReActPlanner
|
||||
from ._base_llm_processor import BaseLlmRequestProcessor
|
||||
from ._base_llm_processor import BaseLlmResponseProcessor
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ...models.llm_request import LlmRequest
|
||||
from ...models.llm_response import LlmResponse
|
||||
from ...planners.base_planner import BasePlanner
|
||||
from ...planners.built_in_planner import BuiltInPlanner
|
||||
|
||||
|
||||
class _NlPlanningRequestProcessor(BaseLlmRequestProcessor):
|
||||
"""Processor for NL planning."""
|
||||
|
||||
async def run_async(
|
||||
self, invocation_context: InvocationContext, llm_request: LlmRequest
|
||||
) -> AsyncGenerator[Event, None]:
|
||||
from ...planners.built_in_planner import BuiltInPlanner
|
||||
|
||||
planner = _get_planner(invocation_context)
|
||||
if not planner:
|
||||
return
|
||||
|
||||
if isinstance(planner, BuiltInPlanner):
|
||||
planner.apply_thinking_config(llm_request)
|
||||
|
||||
planning_instruction = planner.build_planning_instruction(
|
||||
ReadonlyContext(invocation_context), llm_request
|
||||
)
|
||||
if planning_instruction:
|
||||
llm_request.append_instructions([planning_instruction])
|
||||
|
||||
_remove_thought_from_request(llm_request)
|
||||
|
||||
# Maintain async generator behavior
|
||||
if False: # Ensures it behaves as a generator
|
||||
yield # This is a no-op but maintains generator structure
|
||||
|
||||
|
||||
request_processor = _NlPlanningRequestProcessor()
|
||||
|
||||
|
||||
class _NlPlanningResponse(BaseLlmResponseProcessor):
|
||||
|
||||
@override
|
||||
async def run_async(
|
||||
self, invocation_context: InvocationContext, llm_response: LlmResponse
|
||||
) -> AsyncGenerator[Event, None]:
|
||||
if (
|
||||
not llm_response
|
||||
or not llm_response.content
|
||||
or not llm_response.content.parts
|
||||
):
|
||||
return
|
||||
|
||||
planner = _get_planner(invocation_context)
|
||||
if not planner:
|
||||
return
|
||||
|
||||
# Postprocess the LLM response.
|
||||
callback_context = CallbackContext(invocation_context)
|
||||
processed_parts = planner.process_planning_response(
|
||||
callback_context, llm_response.content.parts
|
||||
)
|
||||
if processed_parts:
|
||||
llm_response.content.parts = processed_parts
|
||||
|
||||
if callback_context.state.has_delta():
|
||||
state_update_event = Event(
|
||||
invocation_id=invocation_context.invocation_id,
|
||||
author=invocation_context.agent.name,
|
||||
branch=invocation_context.branch,
|
||||
actions=callback_context._event_actions,
|
||||
)
|
||||
yield state_update_event
|
||||
|
||||
|
||||
response_processor = _NlPlanningResponse()
|
||||
|
||||
|
||||
def _get_planner(
|
||||
invocation_context: InvocationContext,
|
||||
) -> Optional[BasePlanner]:
|
||||
from ...agents.llm_agent import Agent
|
||||
from ...planners.base_planner import BasePlanner
|
||||
|
||||
agent = invocation_context.agent
|
||||
if not isinstance(agent, Agent):
|
||||
return None
|
||||
if not agent.planner:
|
||||
return None
|
||||
|
||||
if isinstance(agent.planner, BasePlanner):
|
||||
return agent.planner
|
||||
return PlanReActPlanner()
|
||||
|
||||
|
||||
def _remove_thought_from_request(llm_request: LlmRequest):
|
||||
if not llm_request.contents:
|
||||
return
|
||||
|
||||
for content in llm_request.contents:
|
||||
if not content.parts:
|
||||
continue
|
||||
for part in content.parts:
|
||||
part.thought = None
|
||||
@@ -0,0 +1,132 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Handles agent transfer for LLM flow."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import typing
|
||||
from typing import AsyncGenerator
|
||||
|
||||
from typing_extensions import override
|
||||
|
||||
from ...agents.invocation_context import InvocationContext
|
||||
from ...events.event import Event
|
||||
from ...models.llm_request import LlmRequest
|
||||
from ...tools.function_tool import FunctionTool
|
||||
from ...tools.tool_context import ToolContext
|
||||
from ...tools.transfer_to_agent_tool import transfer_to_agent
|
||||
from ._base_llm_processor import BaseLlmRequestProcessor
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
from ...agents import BaseAgent
|
||||
from ...agents import LlmAgent
|
||||
|
||||
|
||||
class _AgentTransferLlmRequestProcessor(BaseLlmRequestProcessor):
|
||||
"""Agent transfer request processor."""
|
||||
|
||||
@override
|
||||
async def run_async(
|
||||
self, invocation_context: InvocationContext, llm_request: LlmRequest
|
||||
) -> AsyncGenerator[Event, None]:
|
||||
from ...agents.llm_agent import LlmAgent
|
||||
|
||||
if not isinstance(invocation_context.agent, LlmAgent):
|
||||
return
|
||||
|
||||
transfer_targets = _get_transfer_targets(invocation_context.agent)
|
||||
if not transfer_targets:
|
||||
return
|
||||
|
||||
llm_request.append_instructions([
|
||||
_build_target_agents_instructions(
|
||||
invocation_context.agent, transfer_targets
|
||||
)
|
||||
])
|
||||
|
||||
transfer_to_agent_tool = FunctionTool(func=transfer_to_agent)
|
||||
tool_context = ToolContext(invocation_context)
|
||||
await transfer_to_agent_tool.process_llm_request(
|
||||
tool_context=tool_context, llm_request=llm_request
|
||||
)
|
||||
|
||||
return
|
||||
yield # AsyncGenerator requires yield statement in function body.
|
||||
|
||||
|
||||
request_processor = _AgentTransferLlmRequestProcessor()
|
||||
|
||||
|
||||
def _build_target_agents_info(target_agent: BaseAgent) -> str:
|
||||
return f"""
|
||||
Agent name: {target_agent.name}
|
||||
Agent description: {target_agent.description}
|
||||
"""
|
||||
|
||||
|
||||
line_break = '\n'
|
||||
|
||||
|
||||
def _build_target_agents_instructions(
|
||||
agent: LlmAgent, target_agents: list[BaseAgent]
|
||||
) -> str:
|
||||
si = f"""
|
||||
You have a list of other agents to transfer to:
|
||||
|
||||
{line_break.join([
|
||||
_build_target_agents_info(target_agent) for target_agent in target_agents
|
||||
])}
|
||||
|
||||
If you are the best to answer the question according to your description, you
|
||||
can answer it.
|
||||
|
||||
If another agent is better for answering the question according to its
|
||||
description, call `{_TRANSFER_TO_AGENT_FUNCTION_NAME}` function to transfer the
|
||||
question to that agent. When transferring, do not generate any text other than
|
||||
the function call.
|
||||
"""
|
||||
|
||||
if agent.parent_agent:
|
||||
si += f"""
|
||||
Your parent agent is {agent.parent_agent.name}. If neither the other agents nor
|
||||
you are best for answering the question according to the descriptions, transfer
|
||||
to your parent agent. If you don't have parent agent, try answer by yourself.
|
||||
"""
|
||||
return si
|
||||
|
||||
|
||||
_TRANSFER_TO_AGENT_FUNCTION_NAME = transfer_to_agent.__name__
|
||||
|
||||
|
||||
def _get_transfer_targets(agent: LlmAgent) -> list[BaseAgent]:
|
||||
from ...agents.llm_agent import LlmAgent
|
||||
|
||||
result = []
|
||||
result.extend(agent.sub_agents)
|
||||
|
||||
if not agent.parent_agent or not isinstance(agent.parent_agent, LlmAgent):
|
||||
return result
|
||||
|
||||
if not agent.disallow_transfer_to_parent:
|
||||
result.append(agent.parent_agent)
|
||||
|
||||
if not agent.disallow_transfer_to_peers:
|
||||
result.extend([
|
||||
peer_agent
|
||||
for peer_agent in agent.parent_agent.sub_agents
|
||||
if peer_agent.name != agent.name
|
||||
])
|
||||
|
||||
return result
|
||||
@@ -0,0 +1,109 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from google.cloud import speech
|
||||
from google.genai import types as genai_types
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ...agents.invocation_context import InvocationContext
|
||||
|
||||
|
||||
class AudioTranscriber:
|
||||
"""Transcribes audio using Google Cloud Speech-to-Text."""
|
||||
|
||||
def __init__(self):
|
||||
self.client = speech.SpeechClient()
|
||||
|
||||
def transcribe_file(
|
||||
self, invocation_context: InvocationContext
|
||||
) -> list[genai_types.Content]:
|
||||
"""Transcribe audio, bundling consecutive segments from the same speaker.
|
||||
|
||||
The ordering of speakers will be preserved. Audio blobs will be merged for
|
||||
the same speaker as much as we can do reduce the transcription latency.
|
||||
|
||||
Args:
|
||||
invocation_context: The invocation context to access the transcription
|
||||
cache.
|
||||
|
||||
Returns:
|
||||
A list of Content objects containing the transcribed text.
|
||||
"""
|
||||
|
||||
bundled_audio = []
|
||||
current_speaker = None
|
||||
current_audio_data = b''
|
||||
contents = []
|
||||
|
||||
# Step1: merge audio blobs
|
||||
for transcription_entry in invocation_context.transcription_cache or []:
|
||||
speaker, audio_data = (
|
||||
transcription_entry.role,
|
||||
transcription_entry.data,
|
||||
)
|
||||
|
||||
if isinstance(audio_data, genai_types.Content):
|
||||
if current_speaker is not None:
|
||||
bundled_audio.append((current_speaker, current_audio_data))
|
||||
current_speaker = None
|
||||
current_audio_data = b''
|
||||
bundled_audio.append((speaker, audio_data))
|
||||
continue
|
||||
|
||||
if not audio_data.data:
|
||||
continue
|
||||
|
||||
if speaker == current_speaker:
|
||||
current_audio_data += audio_data.data
|
||||
else:
|
||||
if current_speaker is not None:
|
||||
bundled_audio.append((current_speaker, current_audio_data))
|
||||
current_speaker = speaker
|
||||
current_audio_data = audio_data.data
|
||||
|
||||
# Append the last audio segment if any
|
||||
if current_speaker is not None:
|
||||
bundled_audio.append((current_speaker, current_audio_data))
|
||||
|
||||
# reset cache
|
||||
invocation_context.transcription_cache = []
|
||||
|
||||
# Step2: transcription
|
||||
for speaker, data in bundled_audio:
|
||||
if speaker == 'user':
|
||||
audio = speech.RecognitionAudio(content=data)
|
||||
|
||||
config = speech.RecognitionConfig(
|
||||
encoding=speech.RecognitionConfig.AudioEncoding.LINEAR16,
|
||||
sample_rate_hertz=16000,
|
||||
language_code='en-US',
|
||||
)
|
||||
|
||||
response = self.client.recognize(config=config, audio=audio)
|
||||
|
||||
for result in response.results:
|
||||
transcript = result.alternatives[0].transcript
|
||||
|
||||
parts = [genai_types.Part(text=transcript)]
|
||||
role = speaker.lower()
|
||||
content = genai_types.Content(role=role, parts=parts)
|
||||
contents.append(content)
|
||||
else:
|
||||
# don't need to transcribe model which are already text
|
||||
contents.append(data)
|
||||
|
||||
return contents
|
||||
@@ -0,0 +1,49 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Implementation of AutoFlow."""
|
||||
|
||||
from . import agent_transfer
|
||||
from .single_flow import SingleFlow
|
||||
|
||||
|
||||
class AutoFlow(SingleFlow):
|
||||
"""AutoFlow is SingleFlow with agent transfer capability.
|
||||
|
||||
Agent transfer is allowed in the following direction:
|
||||
|
||||
1. from parent to sub-agent;
|
||||
2. from sub-agent to parent;
|
||||
3. from sub-agent to its peer agents;
|
||||
|
||||
For peer-agent transfers, it's only enabled when all below conditions are met:
|
||||
|
||||
- The parent agent is also of AutoFlow;
|
||||
- `disallow_transfer_to_peer` option of this agent is False (default).
|
||||
|
||||
Depending on the target agent flow type, the transfer may be automatically
|
||||
reversed. The condition is as below:
|
||||
|
||||
- If the flow type of the tranferee agent is also auto, transfee agent will
|
||||
remain as the active agent. The transfee agent will respond to the user's
|
||||
next message directly.
|
||||
- If the flow type of the transfere agent is not auto, the active agent will
|
||||
be reversed back to previous agent.
|
||||
|
||||
TODO: allow user to config auto-reverse function.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.request_processors += [agent_transfer.request_processor]
|
||||
@@ -0,0 +1,559 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import AsyncGenerator
|
||||
from typing import cast
|
||||
from typing import Optional
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from websockets.exceptions import ConnectionClosedOK
|
||||
|
||||
from ...agents.base_agent import BaseAgent
|
||||
from ...agents.callback_context import CallbackContext
|
||||
from ...agents.invocation_context import InvocationContext
|
||||
from ...agents.live_request_queue import LiveRequestQueue
|
||||
from ...agents.run_config import StreamingMode
|
||||
from ...agents.transcription_entry import TranscriptionEntry
|
||||
from ...events.event import Event
|
||||
from ...models.base_llm_connection import BaseLlmConnection
|
||||
from ...models.llm_request import LlmRequest
|
||||
from ...models.llm_response import LlmResponse
|
||||
from ...telemetry import trace_call_llm
|
||||
from ...telemetry import trace_send_data
|
||||
from ...telemetry import tracer
|
||||
from ...tools.tool_context import ToolContext
|
||||
from . import functions
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ...agents.llm_agent import LlmAgent
|
||||
from ...models.base_llm import BaseLlm
|
||||
from ._base_llm_processor import BaseLlmRequestProcessor
|
||||
from ._base_llm_processor import BaseLlmResponseProcessor
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BaseLlmFlow(ABC):
|
||||
"""A basic flow that calls the LLM in a loop until a final response is generated.
|
||||
|
||||
This flow ends when it transfer to another agent.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.request_processors: list[BaseLlmRequestProcessor] = []
|
||||
self.response_processors: list[BaseLlmResponseProcessor] = []
|
||||
|
||||
async def run_live(
|
||||
self,
|
||||
invocation_context: InvocationContext,
|
||||
) -> AsyncGenerator[Event, None]:
|
||||
"""Runs the flow using live api."""
|
||||
llm_request = LlmRequest()
|
||||
event_id = Event.new_id()
|
||||
|
||||
# Preprocess before calling the LLM.
|
||||
async for event in self._preprocess_async(invocation_context, llm_request):
|
||||
yield event
|
||||
if invocation_context.end_invocation:
|
||||
return
|
||||
|
||||
llm = self.__get_llm(invocation_context)
|
||||
logger.debug(
|
||||
'Establishing live connection for agent: %s with llm request: %s',
|
||||
invocation_context.agent.name,
|
||||
llm_request,
|
||||
)
|
||||
async with llm.connect(llm_request) as llm_connection:
|
||||
if llm_request.contents:
|
||||
# Sends the conversation history to the model.
|
||||
with tracer.start_as_current_span('send_data'):
|
||||
|
||||
if invocation_context.transcription_cache:
|
||||
from . import audio_transcriber
|
||||
|
||||
audio_transcriber = audio_transcriber.AudioTranscriber()
|
||||
contents = audio_transcriber.transcribe_file(invocation_context)
|
||||
logger.debug('Sending history to model: %s', contents)
|
||||
await llm_connection.send_history(contents)
|
||||
invocation_context.transcription_cache = None
|
||||
trace_send_data(invocation_context, event_id, contents)
|
||||
else:
|
||||
await llm_connection.send_history(llm_request.contents)
|
||||
trace_send_data(invocation_context, event_id, llm_request.contents)
|
||||
|
||||
send_task = asyncio.create_task(
|
||||
self._send_to_model(llm_connection, invocation_context)
|
||||
)
|
||||
|
||||
try:
|
||||
async for event in self._receive_from_model(
|
||||
llm_connection,
|
||||
event_id,
|
||||
invocation_context,
|
||||
llm_request,
|
||||
):
|
||||
# Empty event means the queue is closed.
|
||||
if not event:
|
||||
break
|
||||
logger.debug('Receive new event: %s', event)
|
||||
yield event
|
||||
# send back the function response
|
||||
if event.get_function_responses():
|
||||
logger.debug('Sending back last function response event: %s', event)
|
||||
invocation_context.live_request_queue.send_content(event.content)
|
||||
if (
|
||||
event.content
|
||||
and event.content.parts
|
||||
and event.content.parts[0].function_response
|
||||
and event.content.parts[0].function_response.name
|
||||
== 'transfer_to_agent'
|
||||
):
|
||||
await asyncio.sleep(1)
|
||||
# cancel the tasks that belongs to the closed connection.
|
||||
send_task.cancel()
|
||||
await llm_connection.close()
|
||||
finally:
|
||||
# Clean up
|
||||
if not send_task.done():
|
||||
send_task.cancel()
|
||||
try:
|
||||
await send_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
async def _send_to_model(
|
||||
self,
|
||||
llm_connection: BaseLlmConnection,
|
||||
invocation_context: InvocationContext,
|
||||
):
|
||||
"""Sends data to model."""
|
||||
while True:
|
||||
live_request_queue = invocation_context.live_request_queue
|
||||
try:
|
||||
# Streamlit's execution model doesn't preemptively yield to the event
|
||||
# loop. Therefore, we must explicitly introduce timeouts to allow the
|
||||
# event loop to process events.
|
||||
# TODO: revert back(remove timeout) once we move off streamlit.
|
||||
live_request = await asyncio.wait_for(
|
||||
live_request_queue.get(), timeout=0.25
|
||||
)
|
||||
# duplicate the live_request to all the active streams
|
||||
logger.debug(
|
||||
'Sending live request %s to active streams: %s',
|
||||
live_request,
|
||||
invocation_context.active_streaming_tools,
|
||||
)
|
||||
if invocation_context.active_streaming_tools:
|
||||
for active_streaming_tool in (
|
||||
invocation_context.active_streaming_tools
|
||||
).values():
|
||||
if active_streaming_tool.stream:
|
||||
active_streaming_tool.stream.send(live_request)
|
||||
await asyncio.sleep(0)
|
||||
except asyncio.TimeoutError:
|
||||
continue
|
||||
if live_request.close:
|
||||
await llm_connection.close()
|
||||
return
|
||||
if live_request.blob:
|
||||
# Cache audio data here for transcription
|
||||
if not invocation_context.transcription_cache:
|
||||
invocation_context.transcription_cache = []
|
||||
invocation_context.transcription_cache.append(
|
||||
TranscriptionEntry(role='user', data=live_request.blob)
|
||||
)
|
||||
await llm_connection.send_realtime(live_request.blob)
|
||||
if live_request.content:
|
||||
await llm_connection.send_content(live_request.content)
|
||||
|
||||
async def _receive_from_model(
|
||||
self,
|
||||
llm_connection: BaseLlmConnection,
|
||||
event_id: str,
|
||||
invocation_context: InvocationContext,
|
||||
llm_request: LlmRequest,
|
||||
) -> AsyncGenerator[Event, None]:
|
||||
"""Receive data from model and process events using BaseLlmConnection."""
|
||||
assert invocation_context.live_request_queue
|
||||
try:
|
||||
while True:
|
||||
async for llm_response in llm_connection.receive():
|
||||
model_response_event = Event(
|
||||
id=Event.new_id(),
|
||||
invocation_id=invocation_context.invocation_id,
|
||||
author=invocation_context.agent.name,
|
||||
)
|
||||
async for event in self._postprocess_live(
|
||||
invocation_context,
|
||||
llm_request,
|
||||
llm_response,
|
||||
model_response_event,
|
||||
):
|
||||
if (
|
||||
event.content
|
||||
and event.content.parts
|
||||
and event.content.parts[0].text
|
||||
and not event.partial
|
||||
):
|
||||
if not invocation_context.transcription_cache:
|
||||
invocation_context.transcription_cache = []
|
||||
invocation_context.transcription_cache.append(
|
||||
TranscriptionEntry(role='model', data=event.content)
|
||||
)
|
||||
yield event
|
||||
# Give opportunity for other tasks to run.
|
||||
await asyncio.sleep(0)
|
||||
except ConnectionClosedOK:
|
||||
pass
|
||||
|
||||
async def run_async(
|
||||
self, invocation_context: InvocationContext
|
||||
) -> AsyncGenerator[Event, None]:
|
||||
"""Runs the flow."""
|
||||
while True:
|
||||
last_event = None
|
||||
async for event in self._run_one_step_async(invocation_context):
|
||||
last_event = event
|
||||
yield event
|
||||
if not last_event or last_event.is_final_response():
|
||||
break
|
||||
|
||||
async def _run_one_step_async(
|
||||
self,
|
||||
invocation_context: InvocationContext,
|
||||
) -> AsyncGenerator[Event, None]:
|
||||
"""One step means one LLM call."""
|
||||
llm_request = LlmRequest()
|
||||
|
||||
# Preprocess before calling the LLM.
|
||||
async for event in self._preprocess_async(invocation_context, llm_request):
|
||||
yield event
|
||||
if invocation_context.end_invocation:
|
||||
return
|
||||
|
||||
# Calls the LLM.
|
||||
model_response_event = Event(
|
||||
id=Event.new_id(),
|
||||
invocation_id=invocation_context.invocation_id,
|
||||
author=invocation_context.agent.name,
|
||||
branch=invocation_context.branch,
|
||||
)
|
||||
async for llm_response in self._call_llm_async(
|
||||
invocation_context, llm_request, model_response_event
|
||||
):
|
||||
# Postprocess after calling the LLM.
|
||||
async for event in self._postprocess_async(
|
||||
invocation_context, llm_request, llm_response, model_response_event
|
||||
):
|
||||
yield event
|
||||
|
||||
async def _preprocess_async(
|
||||
self, invocation_context: InvocationContext, llm_request: LlmRequest
|
||||
) -> AsyncGenerator[Event, None]:
|
||||
from ...agents.llm_agent import LlmAgent
|
||||
|
||||
agent = invocation_context.agent
|
||||
if not isinstance(agent, LlmAgent):
|
||||
return
|
||||
|
||||
# Runs processors.
|
||||
for processor in self.request_processors:
|
||||
async for event in processor.run_async(invocation_context, llm_request):
|
||||
yield event
|
||||
|
||||
# Run processors for tools.
|
||||
for tool in agent.canonical_tools:
|
||||
tool_context = ToolContext(invocation_context)
|
||||
await tool.process_llm_request(
|
||||
tool_context=tool_context, llm_request=llm_request
|
||||
)
|
||||
|
||||
async def _postprocess_async(
|
||||
self,
|
||||
invocation_context: InvocationContext,
|
||||
llm_request: LlmRequest,
|
||||
llm_response: LlmResponse,
|
||||
model_response_event: Event,
|
||||
) -> AsyncGenerator[Event, None]:
|
||||
"""Postprocess after calling the LLM.
|
||||
|
||||
Args:
|
||||
invocation_context: The invocation context.
|
||||
llm_request: The original LLM request.
|
||||
llm_response: The LLM response from the LLM call.
|
||||
model_response_event: A mutable event for the LLM response.
|
||||
|
||||
Yields:
|
||||
A generator of events.
|
||||
"""
|
||||
|
||||
# Runs processors.
|
||||
async for event in self._postprocess_run_processors_async(
|
||||
invocation_context, llm_response
|
||||
):
|
||||
yield event
|
||||
|
||||
# Skip the model response event if there is no content and no error code.
|
||||
# This is needed for the code executor to trigger another loop.
|
||||
if (
|
||||
not llm_response.content
|
||||
and not llm_response.error_code
|
||||
and not llm_response.interrupted
|
||||
):
|
||||
return
|
||||
|
||||
# Builds the event.
|
||||
model_response_event = self._finalize_model_response_event(
|
||||
llm_request, llm_response, model_response_event
|
||||
)
|
||||
yield model_response_event
|
||||
|
||||
# Handles function calls.
|
||||
if model_response_event.get_function_calls():
|
||||
async for event in self._postprocess_handle_function_calls_async(
|
||||
invocation_context, model_response_event, llm_request
|
||||
):
|
||||
yield event
|
||||
|
||||
async def _postprocess_live(
|
||||
self,
|
||||
invocation_context: InvocationContext,
|
||||
llm_request: LlmRequest,
|
||||
llm_response: LlmResponse,
|
||||
model_response_event: Event,
|
||||
) -> AsyncGenerator[Event, None]:
|
||||
"""Postprocess after calling the LLM asynchronously.
|
||||
|
||||
Args:
|
||||
invocation_context: The invocation context.
|
||||
llm_request: The original LLM request.
|
||||
llm_response: The LLM response from the LLM call.
|
||||
model_response_event: A mutable event for the LLM response.
|
||||
|
||||
Yields:
|
||||
A generator of events.
|
||||
"""
|
||||
|
||||
# Runs processors.
|
||||
async for event in self._postprocess_run_processors_async(
|
||||
invocation_context, llm_response
|
||||
):
|
||||
yield event
|
||||
|
||||
# Skip the model response event if there is no content and no error code.
|
||||
# This is needed for the code executor to trigger another loop.
|
||||
# But don't skip control events like turn_complete.
|
||||
if (
|
||||
not llm_response.content
|
||||
and not llm_response.error_code
|
||||
and not llm_response.interrupted
|
||||
and not llm_response.turn_complete
|
||||
):
|
||||
return
|
||||
|
||||
# Builds the event.
|
||||
model_response_event = self._finalize_model_response_event(
|
||||
llm_request, llm_response, model_response_event
|
||||
)
|
||||
yield model_response_event
|
||||
|
||||
# Handles function calls.
|
||||
if model_response_event.get_function_calls():
|
||||
function_response_event = await functions.handle_function_calls_live(
|
||||
invocation_context, model_response_event, llm_request.tools_dict
|
||||
)
|
||||
yield function_response_event
|
||||
|
||||
transfer_to_agent = function_response_event.actions.transfer_to_agent
|
||||
if transfer_to_agent:
|
||||
agent_to_run = self._get_agent_to_run(
|
||||
invocation_context, transfer_to_agent
|
||||
)
|
||||
async for item in agent_to_run.run_live(invocation_context):
|
||||
yield item
|
||||
|
||||
async def _postprocess_run_processors_async(
|
||||
self, invocation_context: InvocationContext, llm_response: LlmResponse
|
||||
) -> AsyncGenerator[Event, None]:
|
||||
for processor in self.response_processors:
|
||||
async for event in processor.run_async(invocation_context, llm_response):
|
||||
yield event
|
||||
|
||||
async def _postprocess_handle_function_calls_async(
|
||||
self,
|
||||
invocation_context: InvocationContext,
|
||||
function_call_event: Event,
|
||||
llm_request: LlmRequest,
|
||||
) -> AsyncGenerator[Event, None]:
|
||||
if function_response_event := await functions.handle_function_calls_async(
|
||||
invocation_context, function_call_event, llm_request.tools_dict
|
||||
):
|
||||
auth_event = functions.generate_auth_event(
|
||||
invocation_context, function_response_event
|
||||
)
|
||||
if auth_event:
|
||||
yield auth_event
|
||||
|
||||
yield function_response_event
|
||||
transfer_to_agent = function_response_event.actions.transfer_to_agent
|
||||
if transfer_to_agent:
|
||||
agent_to_run = self._get_agent_to_run(
|
||||
invocation_context, transfer_to_agent
|
||||
)
|
||||
async for event in agent_to_run.run_async(invocation_context):
|
||||
yield event
|
||||
|
||||
def _get_agent_to_run(
|
||||
self, invocation_context: InvocationContext, transfer_to_agent
|
||||
) -> BaseAgent:
|
||||
root_agent = invocation_context.agent.root_agent
|
||||
agent_to_run = root_agent.find_agent(transfer_to_agent)
|
||||
if not agent_to_run:
|
||||
raise ValueError(
|
||||
f'Agent {transfer_to_agent} not found in the agent tree.'
|
||||
)
|
||||
return agent_to_run
|
||||
|
||||
async def _call_llm_async(
|
||||
self,
|
||||
invocation_context: InvocationContext,
|
||||
llm_request: LlmRequest,
|
||||
model_response_event: Event,
|
||||
) -> AsyncGenerator[LlmResponse, None]:
|
||||
# Runs before_model_callback if it exists.
|
||||
if response := self._handle_before_model_callback(
|
||||
invocation_context, llm_request, model_response_event
|
||||
):
|
||||
yield response
|
||||
return
|
||||
|
||||
# Calls the LLM.
|
||||
llm = self.__get_llm(invocation_context)
|
||||
with tracer.start_as_current_span('call_llm'):
|
||||
if invocation_context.run_config.support_cfc:
|
||||
invocation_context.live_request_queue = LiveRequestQueue()
|
||||
async for llm_response in self.run_live(invocation_context):
|
||||
# Runs after_model_callback if it exists.
|
||||
if altered_llm_response := self._handle_after_model_callback(
|
||||
invocation_context, llm_response, model_response_event
|
||||
):
|
||||
llm_response = altered_llm_response
|
||||
# only yield partial response in SSE streaming mode
|
||||
if (
|
||||
invocation_context.run_config.streaming_mode == StreamingMode.SSE
|
||||
or not llm_response.partial
|
||||
):
|
||||
yield llm_response
|
||||
if llm_response.turn_complete:
|
||||
invocation_context.live_request_queue.close()
|
||||
else:
|
||||
# Check if we can make this llm call or not. If the current call pushes
|
||||
# the counter beyond the max set value, then the execution is stopped
|
||||
# right here, and exception is thrown.
|
||||
invocation_context.increment_llm_call_count()
|
||||
async for llm_response in llm.generate_content_async(
|
||||
llm_request,
|
||||
stream=invocation_context.run_config.streaming_mode
|
||||
== StreamingMode.SSE,
|
||||
):
|
||||
trace_call_llm(
|
||||
invocation_context,
|
||||
model_response_event.id,
|
||||
llm_request,
|
||||
llm_response,
|
||||
)
|
||||
# Runs after_model_callback if it exists.
|
||||
if altered_llm_response := self._handle_after_model_callback(
|
||||
invocation_context, llm_response, model_response_event
|
||||
):
|
||||
llm_response = altered_llm_response
|
||||
|
||||
yield llm_response
|
||||
|
||||
def _handle_before_model_callback(
|
||||
self,
|
||||
invocation_context: InvocationContext,
|
||||
llm_request: LlmRequest,
|
||||
model_response_event: Event,
|
||||
) -> Optional[LlmResponse]:
|
||||
from ...agents.llm_agent import LlmAgent
|
||||
|
||||
agent = invocation_context.agent
|
||||
if not isinstance(agent, LlmAgent):
|
||||
return
|
||||
|
||||
if not agent.before_model_callback:
|
||||
return
|
||||
|
||||
callback_context = CallbackContext(
|
||||
invocation_context, event_actions=model_response_event.actions
|
||||
)
|
||||
return agent.before_model_callback(
|
||||
callback_context=callback_context, llm_request=llm_request
|
||||
)
|
||||
|
||||
def _handle_after_model_callback(
|
||||
self,
|
||||
invocation_context: InvocationContext,
|
||||
llm_response: LlmResponse,
|
||||
model_response_event: Event,
|
||||
) -> Optional[LlmResponse]:
|
||||
from ...agents.llm_agent import LlmAgent
|
||||
|
||||
agent = invocation_context.agent
|
||||
if not isinstance(agent, LlmAgent):
|
||||
return
|
||||
|
||||
if not agent.after_model_callback:
|
||||
return
|
||||
|
||||
callback_context = CallbackContext(
|
||||
invocation_context, event_actions=model_response_event.actions
|
||||
)
|
||||
return agent.after_model_callback(
|
||||
callback_context=callback_context, llm_response=llm_response
|
||||
)
|
||||
|
||||
def _finalize_model_response_event(
|
||||
self,
|
||||
llm_request: LlmRequest,
|
||||
llm_response: LlmResponse,
|
||||
model_response_event: Event,
|
||||
) -> Event:
|
||||
model_response_event = Event.model_validate({
|
||||
**model_response_event.model_dump(exclude_none=True),
|
||||
**llm_response.model_dump(exclude_none=True),
|
||||
})
|
||||
|
||||
if model_response_event.content:
|
||||
function_calls = model_response_event.get_function_calls()
|
||||
if function_calls:
|
||||
functions.populate_client_function_call_id(model_response_event)
|
||||
model_response_event.long_running_tool_ids = (
|
||||
functions.get_long_running_function_calls(
|
||||
function_calls, llm_request.tools_dict
|
||||
)
|
||||
)
|
||||
|
||||
return model_response_event
|
||||
|
||||
def __get_llm(self, invocation_context: InvocationContext) -> BaseLlm:
|
||||
from ...agents.llm_agent import LlmAgent
|
||||
|
||||
return cast(LlmAgent, invocation_context.agent).canonical_model
|
||||
@@ -0,0 +1,72 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Handles basic information to build the LLM request."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import AsyncGenerator
|
||||
from typing import Generator
|
||||
|
||||
from google.genai import types
|
||||
from typing_extensions import override
|
||||
|
||||
from ...agents.invocation_context import InvocationContext
|
||||
from ...events.event import Event
|
||||
from ...models.llm_request import LlmRequest
|
||||
from ._base_llm_processor import BaseLlmRequestProcessor
|
||||
|
||||
|
||||
class _BasicLlmRequestProcessor(BaseLlmRequestProcessor):
|
||||
|
||||
@override
|
||||
async def run_async(
|
||||
self, invocation_context: InvocationContext, llm_request: LlmRequest
|
||||
) -> AsyncGenerator[Event, None]:
|
||||
from ...agents.llm_agent import LlmAgent
|
||||
|
||||
agent = invocation_context.agent
|
||||
if not isinstance(agent, LlmAgent):
|
||||
return
|
||||
|
||||
llm_request.model = (
|
||||
agent.canonical_model
|
||||
if isinstance(agent.canonical_model, str)
|
||||
else agent.canonical_model.model
|
||||
)
|
||||
llm_request.config = (
|
||||
agent.generate_content_config.model_copy(deep=True)
|
||||
if agent.generate_content_config
|
||||
else types.GenerateContentConfig()
|
||||
)
|
||||
if agent.output_schema:
|
||||
llm_request.set_output_schema(agent.output_schema)
|
||||
|
||||
llm_request.live_connect_config.response_modalities = (
|
||||
invocation_context.run_config.response_modalities
|
||||
)
|
||||
llm_request.live_connect_config.speech_config = (
|
||||
invocation_context.run_config.speech_config
|
||||
)
|
||||
llm_request.live_connect_config.output_audio_transcription = (
|
||||
invocation_context.run_config.output_audio_transcription
|
||||
)
|
||||
|
||||
# TODO: handle tool append here, instead of in BaseTool.process_llm_request.
|
||||
|
||||
return
|
||||
yield # Generator requires yield statement in function body.
|
||||
|
||||
|
||||
request_processor = _BasicLlmRequestProcessor()
|
||||
@@ -0,0 +1,390 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import copy
|
||||
from typing import AsyncGenerator
|
||||
from typing import Generator
|
||||
from typing import Optional
|
||||
|
||||
from google.genai import types
|
||||
from typing_extensions import override
|
||||
|
||||
from ...agents.invocation_context import InvocationContext
|
||||
from ...events.event import Event
|
||||
from ...models.llm_request import LlmRequest
|
||||
from ._base_llm_processor import BaseLlmRequestProcessor
|
||||
from .functions import remove_client_function_call_id
|
||||
from .functions import REQUEST_EUC_FUNCTION_CALL_NAME
|
||||
|
||||
|
||||
class _ContentLlmRequestProcessor(BaseLlmRequestProcessor):
|
||||
"""Builds the contents for the LLM request."""
|
||||
|
||||
@override
|
||||
async def run_async(
|
||||
self, invocation_context: InvocationContext, llm_request: LlmRequest
|
||||
) -> AsyncGenerator[Event, None]:
|
||||
from ...agents.llm_agent import LlmAgent
|
||||
|
||||
agent = invocation_context.agent
|
||||
if not isinstance(agent, LlmAgent):
|
||||
return
|
||||
|
||||
if agent.include_contents != 'none':
|
||||
llm_request.contents = _get_contents(
|
||||
invocation_context.branch,
|
||||
invocation_context.session.events,
|
||||
agent.name,
|
||||
)
|
||||
|
||||
# Maintain async generator behavior
|
||||
if False: # Ensures it behaves as a generator
|
||||
yield # This is a no-op but maintains generator structure
|
||||
|
||||
|
||||
request_processor = _ContentLlmRequestProcessor()
|
||||
|
||||
|
||||
def _rearrange_events_for_async_function_responses_in_history(
|
||||
events: list[Event],
|
||||
) -> list[Event]:
|
||||
"""Rearrange the async function_response events in the history."""
|
||||
|
||||
function_call_id_to_response_events_index: dict[str, list[Event]] = {}
|
||||
for i, event in enumerate(events):
|
||||
function_responses = event.get_function_responses()
|
||||
if function_responses:
|
||||
for function_response in function_responses:
|
||||
function_call_id = function_response.id
|
||||
function_call_id_to_response_events_index[function_call_id] = i
|
||||
|
||||
result_events: list[Event] = []
|
||||
for event in events:
|
||||
if event.get_function_responses():
|
||||
# function_response should be handled together with function_call below.
|
||||
continue
|
||||
elif event.get_function_calls():
|
||||
|
||||
function_response_events_indices = set()
|
||||
for function_call in event.get_function_calls():
|
||||
function_call_id = function_call.id
|
||||
if function_call_id in function_call_id_to_response_events_index:
|
||||
function_response_events_indices.add(
|
||||
function_call_id_to_response_events_index[function_call_id]
|
||||
)
|
||||
result_events.append(event)
|
||||
if not function_response_events_indices:
|
||||
continue
|
||||
if len(function_response_events_indices) == 1:
|
||||
result_events.append(
|
||||
events[next(iter(function_response_events_indices))]
|
||||
)
|
||||
else: # Merge all async function_response as one response event
|
||||
result_events.append(
|
||||
_merge_function_response_events(
|
||||
[events[i] for i in sorted(function_response_events_indices)]
|
||||
)
|
||||
)
|
||||
continue
|
||||
else:
|
||||
result_events.append(event)
|
||||
|
||||
return result_events
|
||||
|
||||
|
||||
def _rearrange_events_for_latest_function_response(
|
||||
events: list[Event],
|
||||
) -> list[Event]:
|
||||
"""Rearrange the events for the latest function_response.
|
||||
|
||||
If the latest function_response is for an async function_call, all events
|
||||
between the initial function_call and the latest function_response will be
|
||||
removed.
|
||||
|
||||
Args:
|
||||
events: A list of events.
|
||||
|
||||
Returns:
|
||||
A list of events with the latest function_response rearranged.
|
||||
"""
|
||||
if not events:
|
||||
return events
|
||||
|
||||
function_responses = events[-1].get_function_responses()
|
||||
if not function_responses:
|
||||
# No need to process, since the latest event is not fuction_response.
|
||||
return events
|
||||
|
||||
function_responses_ids = set()
|
||||
for function_response in function_responses:
|
||||
function_responses_ids.add(function_response.id)
|
||||
|
||||
function_calls = events[-2].get_function_calls()
|
||||
|
||||
if function_calls:
|
||||
for function_call in function_calls:
|
||||
# The latest function_response is already matched
|
||||
if function_call.id in function_responses_ids:
|
||||
return events
|
||||
|
||||
function_call_event_idx = -1
|
||||
# look for corresponding function call event reversely
|
||||
for idx in range(len(events) - 2, -1, -1):
|
||||
event = events[idx]
|
||||
function_calls = event.get_function_calls()
|
||||
if function_calls:
|
||||
for function_call in function_calls:
|
||||
if function_call.id in function_responses_ids:
|
||||
function_call_event_idx = idx
|
||||
break
|
||||
if function_call_event_idx != -1:
|
||||
# in case the last response event only have part of the responses
|
||||
# for the function calls in the function call event
|
||||
for function_call in function_calls:
|
||||
function_responses_ids.add(function_call.id)
|
||||
break
|
||||
|
||||
if function_call_event_idx == -1:
|
||||
raise ValueError(
|
||||
'No function call event found for function responses ids:'
|
||||
f' {function_responses_ids}'
|
||||
)
|
||||
|
||||
# collect all function response between last function response event
|
||||
# and function call event
|
||||
|
||||
function_response_events: list[Event] = []
|
||||
for idx in range(function_call_event_idx + 1, len(events) - 1):
|
||||
event = events[idx]
|
||||
function_responses = event.get_function_responses()
|
||||
if (
|
||||
function_responses
|
||||
and function_responses[0].id in function_responses_ids
|
||||
):
|
||||
function_response_events.append(event)
|
||||
function_response_events.append(events[-1])
|
||||
|
||||
result_events = events[: function_call_event_idx + 1]
|
||||
result_events.append(
|
||||
_merge_function_response_events(function_response_events)
|
||||
)
|
||||
|
||||
return result_events
|
||||
|
||||
|
||||
def _get_contents(
|
||||
current_branch: Optional[str], events: list[Event], agent_name: str = ''
|
||||
) -> list[types.Content]:
|
||||
"""Get the contents for the LLM request.
|
||||
|
||||
Args:
|
||||
current_branch: The current branch of the agent.
|
||||
events: A list of events.
|
||||
agent_name: The name of the agent.
|
||||
|
||||
Returns:
|
||||
A list of contents.
|
||||
"""
|
||||
filtered_events = []
|
||||
# Parse the events, leaving the contents and the function calls and
|
||||
# responses from the current agent.
|
||||
for event in events:
|
||||
if not event.content or not event.content.role:
|
||||
# Skip events without content, or generated neither by user nor by model.
|
||||
# E.g. events purely for mutating session states.
|
||||
continue
|
||||
if not _is_event_belongs_to_branch(current_branch, event):
|
||||
# Skip events not belong to current branch.
|
||||
continue
|
||||
if _is_auth_event(event):
|
||||
# skip auth event
|
||||
continue
|
||||
filtered_events.append(
|
||||
_convert_foreign_event(event)
|
||||
if _is_other_agent_reply(agent_name, event)
|
||||
else event
|
||||
)
|
||||
|
||||
result_events = _rearrange_events_for_latest_function_response(
|
||||
filtered_events
|
||||
)
|
||||
result_events = _rearrange_events_for_async_function_responses_in_history(
|
||||
result_events
|
||||
)
|
||||
contents = []
|
||||
for event in result_events:
|
||||
content = copy.deepcopy(event.content)
|
||||
remove_client_function_call_id(content)
|
||||
contents.append(content)
|
||||
return contents
|
||||
|
||||
|
||||
def _is_other_agent_reply(current_agent_name: str, event: Event) -> bool:
|
||||
"""Whether the event is a reply from another agent."""
|
||||
return bool(
|
||||
current_agent_name
|
||||
and event.author != current_agent_name
|
||||
and event.author != 'user'
|
||||
)
|
||||
|
||||
|
||||
def _convert_foreign_event(event: Event) -> Event:
|
||||
"""Converts an event authored by another agent as a user-content event.
|
||||
|
||||
This is to provide another agent's output as context to the current agent, so
|
||||
that current agent can continue to respond, such as summarizing previous
|
||||
agent's reply, etc.
|
||||
|
||||
Args:
|
||||
event: The event to convert.
|
||||
|
||||
Returns:
|
||||
The converted event.
|
||||
|
||||
"""
|
||||
if not event.content or not event.content.parts:
|
||||
return event
|
||||
|
||||
content = types.Content()
|
||||
content.role = 'user'
|
||||
content.parts = [types.Part(text='For context:')]
|
||||
for part in event.content.parts:
|
||||
if part.text:
|
||||
content.parts.append(
|
||||
types.Part(text=f'[{event.author}] said: {part.text}')
|
||||
)
|
||||
elif part.function_call:
|
||||
content.parts.append(
|
||||
types.Part(
|
||||
text=(
|
||||
f'[{event.author}] called tool `{part.function_call.name}`'
|
||||
f' with parameters: {part.function_call.args}'
|
||||
)
|
||||
)
|
||||
)
|
||||
elif part.function_response:
|
||||
# Otherwise, create a new text part.
|
||||
content.parts.append(
|
||||
types.Part(
|
||||
text=(
|
||||
f'[{event.author}] `{part.function_response.name}` tool'
|
||||
f' returned result: {part.function_response.response}'
|
||||
)
|
||||
)
|
||||
)
|
||||
# Fallback to the original part for non-text and non-functionCall parts.
|
||||
else:
|
||||
content.parts.append(part)
|
||||
|
||||
return Event(
|
||||
timestamp=event.timestamp,
|
||||
author='user',
|
||||
content=content,
|
||||
branch=event.branch,
|
||||
)
|
||||
|
||||
|
||||
def _merge_function_response_events(
|
||||
function_response_events: list[Event],
|
||||
) -> Event:
|
||||
"""Merges a list of function_response events into one event.
|
||||
|
||||
The key goal is to ensure:
|
||||
1. function_call and function_response are always of the same number.
|
||||
2. The function_call and function_response are consecutively in the content.
|
||||
|
||||
Args:
|
||||
function_response_events: A list of function_response events.
|
||||
NOTE: function_response_events must fulfill these requirements: 1. The
|
||||
list is in increasing order of timestamp; 2. the first event is the
|
||||
initial function_response event; 3. all later events should contain at
|
||||
least one function_response part that related to the function_call
|
||||
event. (Note, 3. may not be true when aync function return some
|
||||
intermediate response, there could also be some intermediate model
|
||||
response event without any function_response and such event will be
|
||||
ignored.)
|
||||
Caveat: This implementation doesn't support when a parallel function_call
|
||||
event contains async function_call of the same name.
|
||||
|
||||
Returns:
|
||||
A merged event, that is
|
||||
1. All later function_response will replace function_response part in
|
||||
the initial function_response event.
|
||||
2. All non-function_response parts will be appended to the part list of
|
||||
the initial function_response event.
|
||||
"""
|
||||
if not function_response_events:
|
||||
raise ValueError('At least one function_response event is required.')
|
||||
|
||||
merged_event = function_response_events[0].model_copy(deep=True)
|
||||
parts_in_merged_event: list[types.Part] = merged_event.content.parts # type: ignore
|
||||
|
||||
if not parts_in_merged_event:
|
||||
raise ValueError('There should be at least one function_response part.')
|
||||
|
||||
part_indices_in_merged_event: dict[str, int] = {}
|
||||
for idx, part in enumerate(parts_in_merged_event):
|
||||
if part.function_response:
|
||||
function_call_id: str = part.function_response.id # type: ignore
|
||||
part_indices_in_merged_event[function_call_id] = idx
|
||||
|
||||
for event in function_response_events[1:]:
|
||||
if not event.content.parts:
|
||||
raise ValueError('There should be at least one function_response part.')
|
||||
|
||||
for part in event.content.parts:
|
||||
if part.function_response:
|
||||
function_call_id: str = part.function_response.id # type: ignore
|
||||
if function_call_id in part_indices_in_merged_event:
|
||||
parts_in_merged_event[
|
||||
part_indices_in_merged_event[function_call_id]
|
||||
] = part
|
||||
else:
|
||||
parts_in_merged_event.append(part)
|
||||
part_indices_in_merged_event[function_call_id] = (
|
||||
len(parts_in_merged_event) - 1
|
||||
)
|
||||
|
||||
else:
|
||||
parts_in_merged_event.append(part)
|
||||
|
||||
return merged_event
|
||||
|
||||
|
||||
def _is_event_belongs_to_branch(
|
||||
invocation_branch: Optional[str], event: Event
|
||||
) -> bool:
|
||||
"""Event belongs to a branch, when event.branch is prefix of the invocation branch."""
|
||||
if not invocation_branch or not event.branch:
|
||||
return True
|
||||
return invocation_branch.startswith(event.branch)
|
||||
|
||||
|
||||
def _is_auth_event(event: Event) -> bool:
|
||||
if not event.content.parts:
|
||||
return False
|
||||
for part in event.content.parts:
|
||||
if (
|
||||
part.function_call
|
||||
and part.function_call.name == REQUEST_EUC_FUNCTION_CALL_NAME
|
||||
):
|
||||
return True
|
||||
if (
|
||||
part.function_response
|
||||
and part.function_response.name == REQUEST_EUC_FUNCTION_CALL_NAME
|
||||
):
|
||||
return True
|
||||
return False
|
||||
@@ -0,0 +1,486 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Handles function callings for LLM flow."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import inspect
|
||||
import logging
|
||||
from typing import Any
|
||||
from typing import AsyncGenerator
|
||||
from typing import cast
|
||||
from typing import Optional
|
||||
import uuid
|
||||
|
||||
from google.genai import types
|
||||
|
||||
from ...agents.active_streaming_tool import ActiveStreamingTool
|
||||
from ...agents.invocation_context import InvocationContext
|
||||
from ...auth.auth_tool import AuthToolArguments
|
||||
from ...events.event import Event
|
||||
from ...events.event_actions import EventActions
|
||||
from ...telemetry import trace_tool_call
|
||||
from ...telemetry import trace_tool_response
|
||||
from ...telemetry import tracer
|
||||
from ...tools.base_tool import BaseTool
|
||||
from ...tools.tool_context import ToolContext
|
||||
|
||||
AF_FUNCTION_CALL_ID_PREFIX = 'adk-'
|
||||
REQUEST_EUC_FUNCTION_CALL_NAME = 'adk_request_credential'
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def generate_client_function_call_id() -> str:
|
||||
return f'{AF_FUNCTION_CALL_ID_PREFIX}{uuid.uuid4()}'
|
||||
|
||||
|
||||
def populate_client_function_call_id(model_response_event: Event) -> None:
|
||||
if not model_response_event.get_function_calls():
|
||||
return
|
||||
for function_call in model_response_event.get_function_calls():
|
||||
if not function_call.id:
|
||||
function_call.id = generate_client_function_call_id()
|
||||
|
||||
|
||||
def remove_client_function_call_id(content: types.Content) -> None:
|
||||
if content and content.parts:
|
||||
for part in content.parts:
|
||||
if (
|
||||
part.function_call
|
||||
and part.function_call.id
|
||||
and part.function_call.id.startswith(AF_FUNCTION_CALL_ID_PREFIX)
|
||||
):
|
||||
part.function_call.id = None
|
||||
if (
|
||||
part.function_response
|
||||
and part.function_response.id
|
||||
and part.function_response.id.startswith(AF_FUNCTION_CALL_ID_PREFIX)
|
||||
):
|
||||
part.function_response.id = None
|
||||
|
||||
|
||||
def get_long_running_function_calls(
|
||||
function_calls: list[types.FunctionCall],
|
||||
tools_dict: dict[str, BaseTool],
|
||||
) -> set[str]:
|
||||
long_running_tool_ids = set()
|
||||
for function_call in function_calls:
|
||||
if (
|
||||
function_call.name in tools_dict
|
||||
and tools_dict[function_call.name].is_long_running
|
||||
):
|
||||
long_running_tool_ids.add(function_call.id)
|
||||
|
||||
return long_running_tool_ids
|
||||
|
||||
|
||||
def generate_auth_event(
|
||||
invocation_context: InvocationContext,
|
||||
function_response_event: Event,
|
||||
) -> Optional[Event]:
|
||||
if not function_response_event.actions.requested_auth_configs:
|
||||
return None
|
||||
parts = []
|
||||
long_running_tool_ids = set()
|
||||
for (
|
||||
function_call_id,
|
||||
auth_config,
|
||||
) in function_response_event.actions.requested_auth_configs.items():
|
||||
|
||||
request_euc_function_call = types.FunctionCall(
|
||||
name=REQUEST_EUC_FUNCTION_CALL_NAME,
|
||||
args=AuthToolArguments(
|
||||
function_call_id=function_call_id,
|
||||
auth_config=auth_config,
|
||||
).model_dump(exclude_none=True),
|
||||
)
|
||||
request_euc_function_call.id = generate_client_function_call_id()
|
||||
long_running_tool_ids.add(request_euc_function_call.id)
|
||||
parts.append(types.Part(function_call=request_euc_function_call))
|
||||
|
||||
return Event(
|
||||
invocation_id=invocation_context.invocation_id,
|
||||
author=invocation_context.agent.name,
|
||||
branch=invocation_context.branch,
|
||||
content=types.Content(
|
||||
parts=parts, role=function_response_event.content.role
|
||||
),
|
||||
long_running_tool_ids=long_running_tool_ids,
|
||||
)
|
||||
|
||||
|
||||
async def handle_function_calls_async(
|
||||
invocation_context: InvocationContext,
|
||||
function_call_event: Event,
|
||||
tools_dict: dict[str, BaseTool],
|
||||
filters: Optional[set[str]] = None,
|
||||
) -> Optional[Event]:
|
||||
"""Calls the functions and returns the function response event."""
|
||||
from ...agents.llm_agent import LlmAgent
|
||||
|
||||
agent = invocation_context.agent
|
||||
if not isinstance(agent, LlmAgent):
|
||||
return
|
||||
|
||||
function_calls = function_call_event.get_function_calls()
|
||||
|
||||
function_response_events: list[Event] = []
|
||||
for function_call in function_calls:
|
||||
if filters and function_call.id not in filters:
|
||||
continue
|
||||
tool, tool_context = _get_tool_and_context(
|
||||
invocation_context,
|
||||
function_call_event,
|
||||
function_call,
|
||||
tools_dict,
|
||||
)
|
||||
# do not use "args" as the variable name, because it is a reserved keyword
|
||||
# in python debugger.
|
||||
function_args = function_call.args or {}
|
||||
function_response = None
|
||||
# Calls the tool if before_tool_callback does not exist or returns None.
|
||||
if agent.before_tool_callback:
|
||||
function_response = agent.before_tool_callback(
|
||||
tool=tool, args=function_args, tool_context=tool_context
|
||||
)
|
||||
|
||||
if not function_response:
|
||||
function_response = await __call_tool_async(
|
||||
tool, args=function_args, tool_context=tool_context
|
||||
)
|
||||
|
||||
# Calls after_tool_callback if it exists.
|
||||
if agent.after_tool_callback:
|
||||
new_response = agent.after_tool_callback(
|
||||
tool=tool,
|
||||
args=function_args,
|
||||
tool_context=tool_context,
|
||||
tool_response=function_response,
|
||||
)
|
||||
if new_response:
|
||||
function_response = new_response
|
||||
|
||||
if tool.is_long_running:
|
||||
# Allow long running function to return None to not provide function response.
|
||||
if not function_response:
|
||||
continue
|
||||
|
||||
# Builds the function response event.
|
||||
function_response_event = __build_response_event(
|
||||
tool, function_response, tool_context, invocation_context
|
||||
)
|
||||
function_response_events.append(function_response_event)
|
||||
|
||||
if not function_response_events:
|
||||
return None
|
||||
merged_event = merge_parallel_function_response_events(
|
||||
function_response_events
|
||||
)
|
||||
if len(function_response_events) > 1:
|
||||
# this is needed for debug traces of parallel calls
|
||||
# individual response with tool.name is traced in __build_response_event
|
||||
# (we drop tool.name from span name here as this is merged event)
|
||||
with tracer.start_as_current_span('tool_response'):
|
||||
trace_tool_response(
|
||||
invocation_context=invocation_context,
|
||||
event_id=merged_event.id,
|
||||
function_response_event=merged_event,
|
||||
)
|
||||
return merged_event
|
||||
|
||||
|
||||
async def handle_function_calls_live(
|
||||
invocation_context: InvocationContext,
|
||||
function_call_event: Event,
|
||||
tools_dict: dict[str, BaseTool],
|
||||
) -> Event:
|
||||
"""Calls the functions and returns the function response event."""
|
||||
from ...agents.llm_agent import LlmAgent
|
||||
|
||||
agent = cast(LlmAgent, invocation_context.agent)
|
||||
function_calls = function_call_event.get_function_calls()
|
||||
|
||||
function_response_events: list[Event] = []
|
||||
for function_call in function_calls:
|
||||
tool, tool_context = _get_tool_and_context(
|
||||
invocation_context, function_call_event, function_call, tools_dict
|
||||
)
|
||||
# do not use "args" as the variable name, because it is a reserved keyword
|
||||
# in python debugger.
|
||||
function_args = function_call.args or {}
|
||||
function_response = None
|
||||
# Calls the tool if before_tool_callback does not exist or returns None.
|
||||
if agent.before_tool_callback:
|
||||
function_response = agent.before_tool_callback(
|
||||
tool, function_args, tool_context
|
||||
)
|
||||
|
||||
if not function_response:
|
||||
function_response = await _process_function_live_helper(
|
||||
tool, tool_context, function_call, function_args, invocation_context
|
||||
)
|
||||
|
||||
# Calls after_tool_callback if it exists.
|
||||
if agent.after_tool_callback:
|
||||
new_response = agent.after_tool_callback(
|
||||
tool,
|
||||
function_args,
|
||||
tool_context,
|
||||
function_response,
|
||||
)
|
||||
if new_response:
|
||||
function_response = new_response
|
||||
|
||||
if tool.is_long_running:
|
||||
# Allow async function to return None to not provide function response.
|
||||
if not function_response:
|
||||
continue
|
||||
|
||||
# Builds the function response event.
|
||||
function_response_event = __build_response_event(
|
||||
tool, function_response, tool_context, invocation_context
|
||||
)
|
||||
function_response_events.append(function_response_event)
|
||||
|
||||
if not function_response_events:
|
||||
return None
|
||||
merged_event = merge_parallel_function_response_events(
|
||||
function_response_events
|
||||
)
|
||||
return merged_event
|
||||
|
||||
|
||||
async def _process_function_live_helper(
|
||||
tool, tool_context, function_call, function_args, invocation_context
|
||||
):
|
||||
function_response = None
|
||||
# Check if this is a stop_streaming function call
|
||||
if (
|
||||
function_call.name == 'stop_streaming'
|
||||
and 'function_name' in function_args
|
||||
):
|
||||
function_name = function_args['function_name']
|
||||
active_tasks = invocation_context.active_streaming_tools
|
||||
if (
|
||||
function_name in active_tasks
|
||||
and active_tasks[function_name].task
|
||||
and not active_tasks[function_name].task.done()
|
||||
):
|
||||
task = active_tasks[function_name].task
|
||||
task.cancel()
|
||||
try:
|
||||
# Wait for the task to be cancelled
|
||||
await asyncio.wait_for(task, timeout=1.0)
|
||||
except (asyncio.CancelledError, asyncio.TimeoutError):
|
||||
# Log the specific condition
|
||||
if task.cancelled():
|
||||
logging.info(f'Task {function_name} was cancelled successfully')
|
||||
elif task.done():
|
||||
logging.info(f'Task {function_name} completed during cancellation')
|
||||
else:
|
||||
logging.warning(
|
||||
f'Task {function_name} might still be running after'
|
||||
' cancellation timeout'
|
||||
)
|
||||
function_response = {
|
||||
'status': f'The task is not cancelled yet for {function_name}.'
|
||||
}
|
||||
if not function_response:
|
||||
# Clean up the reference
|
||||
active_tasks[function_name].task = None
|
||||
|
||||
function_response = {
|
||||
'status': f'Successfully stopped streaming function {function_name}'
|
||||
}
|
||||
else:
|
||||
function_response = {
|
||||
'status': f'No active streaming function named {function_name} found'
|
||||
}
|
||||
elif hasattr(tool, "func") and inspect.isasyncgenfunction(tool.func):
|
||||
# for streaming tool use case
|
||||
# we require the function to be a async generator function
|
||||
async def run_tool_and_update_queue(tool, function_args, tool_context):
|
||||
try:
|
||||
async for result in __call_tool_live(
|
||||
tool=tool,
|
||||
args=function_args,
|
||||
tool_context=tool_context,
|
||||
invocation_context=invocation_context,
|
||||
):
|
||||
updated_content = types.Content(
|
||||
role='user',
|
||||
parts=[
|
||||
types.Part.from_text(
|
||||
text=f'Function {tool.name} returned: {result}'
|
||||
)
|
||||
],
|
||||
)
|
||||
invocation_context.live_request_queue.send_content(updated_content)
|
||||
except asyncio.CancelledError:
|
||||
raise # Re-raise to properly propagate the cancellation
|
||||
|
||||
task = asyncio.create_task(
|
||||
run_tool_and_update_queue(tool, function_args, tool_context)
|
||||
)
|
||||
if invocation_context.active_streaming_tools is None:
|
||||
invocation_context.active_streaming_tools = {}
|
||||
if tool.name in invocation_context.active_streaming_tools:
|
||||
invocation_context.active_streaming_tools[tool.name].task = task
|
||||
else:
|
||||
invocation_context.active_streaming_tools[tool.name] = (
|
||||
ActiveStreamingTool(task=task)
|
||||
)
|
||||
# Immediately return a pending response.
|
||||
# This is required by current live model.
|
||||
function_response = {
|
||||
'status': (
|
||||
'The function is running asynchronously and the results are'
|
||||
' pending.'
|
||||
)
|
||||
}
|
||||
else:
|
||||
function_response = await __call_tool_async(
|
||||
tool, args=function_args, tool_context=tool_context
|
||||
)
|
||||
return function_response
|
||||
|
||||
|
||||
def _get_tool_and_context(
|
||||
invocation_context: InvocationContext,
|
||||
function_call_event: Event,
|
||||
function_call: types.FunctionCall,
|
||||
tools_dict: dict[str, BaseTool],
|
||||
):
|
||||
if function_call.name not in tools_dict:
|
||||
raise ValueError(
|
||||
f'Function {function_call.name} is not found in the tools_dict.'
|
||||
)
|
||||
|
||||
tool_context = ToolContext(
|
||||
invocation_context=invocation_context,
|
||||
function_call_id=function_call.id,
|
||||
)
|
||||
|
||||
tool = tools_dict[function_call.name]
|
||||
|
||||
return (tool, tool_context)
|
||||
|
||||
|
||||
async def __call_tool_live(
|
||||
tool: BaseTool,
|
||||
args: dict[str, object],
|
||||
tool_context: ToolContext,
|
||||
invocation_context: InvocationContext,
|
||||
) -> AsyncGenerator[Event, None]:
|
||||
"""Calls the tool asynchronously (awaiting the coroutine)."""
|
||||
with tracer.start_as_current_span(f'tool_call [{tool.name}]'):
|
||||
trace_tool_call(args=args)
|
||||
async for item in tool._call_live(
|
||||
args=args,
|
||||
tool_context=tool_context,
|
||||
invocation_context=invocation_context,
|
||||
):
|
||||
yield item
|
||||
|
||||
|
||||
async def __call_tool_async(
|
||||
tool: BaseTool,
|
||||
args: dict[str, Any],
|
||||
tool_context: ToolContext,
|
||||
) -> Any:
|
||||
"""Calls the tool."""
|
||||
with tracer.start_as_current_span(f'tool_call [{tool.name}]'):
|
||||
trace_tool_call(args=args)
|
||||
return await tool.run_async(args=args, tool_context=tool_context)
|
||||
|
||||
|
||||
def __build_response_event(
|
||||
tool: BaseTool,
|
||||
function_result: dict[str, object],
|
||||
tool_context: ToolContext,
|
||||
invocation_context: InvocationContext,
|
||||
) -> Event:
|
||||
with tracer.start_as_current_span(f'tool_response [{tool.name}]'):
|
||||
# Specs requires the result to be a dict.
|
||||
if not isinstance(function_result, dict):
|
||||
function_result = {'result': function_result}
|
||||
|
||||
part_function_response = types.Part.from_function_response(
|
||||
name=tool.name, response=function_result
|
||||
)
|
||||
part_function_response.function_response.id = tool_context.function_call_id
|
||||
|
||||
content = types.Content(
|
||||
role='user',
|
||||
parts=[part_function_response],
|
||||
)
|
||||
|
||||
function_response_event = Event(
|
||||
invocation_id=invocation_context.invocation_id,
|
||||
author=invocation_context.agent.name,
|
||||
content=content,
|
||||
actions=tool_context.actions,
|
||||
branch=invocation_context.branch,
|
||||
)
|
||||
|
||||
trace_tool_response(
|
||||
invocation_context=invocation_context,
|
||||
event_id=function_response_event.id,
|
||||
function_response_event=function_response_event,
|
||||
)
|
||||
return function_response_event
|
||||
|
||||
|
||||
def merge_parallel_function_response_events(
|
||||
function_response_events: list['Event'],
|
||||
) -> 'Event':
|
||||
if not function_response_events:
|
||||
raise ValueError('No function response events provided.')
|
||||
|
||||
if len(function_response_events) == 1:
|
||||
return function_response_events[0]
|
||||
merged_parts = []
|
||||
for event in function_response_events:
|
||||
if event.content:
|
||||
for part in event.content.parts or []:
|
||||
merged_parts.append(part)
|
||||
|
||||
# Use the first event as the "base" for common attributes
|
||||
base_event = function_response_events[0]
|
||||
|
||||
# Merge actions from all events
|
||||
|
||||
merged_actions = EventActions()
|
||||
merged_requested_auth_configs = {}
|
||||
for event in function_response_events:
|
||||
merged_requested_auth_configs.update(event.actions.requested_auth_configs)
|
||||
merged_actions = merged_actions.model_copy(
|
||||
update=event.actions.model_dump()
|
||||
)
|
||||
merged_actions.requested_auth_configs = merged_requested_auth_configs
|
||||
# Create the new merged event
|
||||
merged_event = Event(
|
||||
invocation_id=Event.new_id(),
|
||||
author=base_event.author,
|
||||
branch=base_event.branch,
|
||||
content=types.Content(role='user', parts=merged_parts),
|
||||
actions=merged_actions, # Optionally merge actions if required
|
||||
)
|
||||
|
||||
# Use the base_event as the timestamp
|
||||
merged_event.timestamp = base_event.timestamp
|
||||
return merged_event
|
||||
@@ -0,0 +1,47 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Gives the agent identity from the framework."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import AsyncGenerator
|
||||
|
||||
from typing_extensions import override
|
||||
|
||||
from ...agents.invocation_context import InvocationContext
|
||||
from ...events.event import Event
|
||||
from ...models.llm_request import LlmRequest
|
||||
from ._base_llm_processor import BaseLlmRequestProcessor
|
||||
|
||||
|
||||
class _IdentityLlmRequestProcessor(BaseLlmRequestProcessor):
|
||||
"""Gives the agent identity from the framework."""
|
||||
|
||||
@override
|
||||
async def run_async(
|
||||
self, invocation_context: InvocationContext, llm_request: LlmRequest
|
||||
) -> AsyncGenerator[Event, None]:
|
||||
agent = invocation_context.agent
|
||||
si = [f'You are an agent. Your internal name is "{agent.name}".']
|
||||
if agent.description:
|
||||
si.append(f' The description about you is "{agent.description}"')
|
||||
llm_request.append_instructions(si)
|
||||
|
||||
# Maintain async generator behavior
|
||||
if False: # Ensures it behaves as a generator
|
||||
yield # This is a no-op but maintains generator structure
|
||||
|
||||
|
||||
request_processor = _IdentityLlmRequestProcessor()
|
||||
@@ -0,0 +1,137 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Handles instructions and global instructions for LLM flow."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from typing import AsyncGenerator
|
||||
from typing import Generator
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from typing_extensions import override
|
||||
|
||||
from ...agents.readonly_context import ReadonlyContext
|
||||
from ...events.event import Event
|
||||
from ...sessions.state import State
|
||||
from ._base_llm_processor import BaseLlmRequestProcessor
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ...agents.invocation_context import InvocationContext
|
||||
from ...models.llm_request import LlmRequest
|
||||
|
||||
|
||||
class _InstructionsLlmRequestProcessor(BaseLlmRequestProcessor):
|
||||
"""Handles instructions and global instructions for LLM flow."""
|
||||
|
||||
@override
|
||||
async def run_async(
|
||||
self, invocation_context: InvocationContext, llm_request: LlmRequest
|
||||
) -> AsyncGenerator[Event, None]:
|
||||
from ...agents.base_agent import BaseAgent
|
||||
from ...agents.llm_agent import LlmAgent
|
||||
|
||||
agent = invocation_context.agent
|
||||
if not isinstance(agent, LlmAgent):
|
||||
return
|
||||
|
||||
root_agent: BaseAgent = agent.root_agent
|
||||
|
||||
# Appends global instructions if set.
|
||||
if (
|
||||
isinstance(root_agent, LlmAgent) and root_agent.global_instruction
|
||||
): # not empty str
|
||||
raw_si = root_agent.canonical_global_instruction(
|
||||
ReadonlyContext(invocation_context)
|
||||
)
|
||||
si = _populate_values(raw_si, invocation_context)
|
||||
llm_request.append_instructions([si])
|
||||
|
||||
# Appends agent instructions if set.
|
||||
if agent.instruction: # not empty str
|
||||
raw_si = agent.canonical_instruction(ReadonlyContext(invocation_context))
|
||||
si = _populate_values(raw_si, invocation_context)
|
||||
llm_request.append_instructions([si])
|
||||
|
||||
# Maintain async generator behavior
|
||||
if False: # Ensures it behaves as a generator
|
||||
yield # This is a no-op but maintains generator structure
|
||||
|
||||
|
||||
request_processor = _InstructionsLlmRequestProcessor()
|
||||
|
||||
|
||||
def _populate_values(
|
||||
instruction_template: str,
|
||||
context: InvocationContext,
|
||||
) -> str:
|
||||
"""Populates values in the instruction template, e.g. state, artifact, etc."""
|
||||
|
||||
def _replace_match(match) -> str:
|
||||
var_name = match.group().lstrip('{').rstrip('}').strip()
|
||||
optional = False
|
||||
if var_name.endswith('?'):
|
||||
optional = True
|
||||
var_name = var_name.removesuffix('?')
|
||||
if var_name.startswith('artifact.'):
|
||||
var_name = var_name.removeprefix('artifact.')
|
||||
if context.artifact_service is None:
|
||||
raise ValueError('Artifact service is not initialized.')
|
||||
artifact = context.artifact_service.load_artifact(
|
||||
app_name=context.session.app_name,
|
||||
user_id=context.session.user_id,
|
||||
session_id=context.session.id,
|
||||
filename=var_name,
|
||||
)
|
||||
if not var_name:
|
||||
raise KeyError(f'Artifact {var_name} not found.')
|
||||
return str(artifact)
|
||||
else:
|
||||
if not _is_valid_state_name(var_name):
|
||||
return match.group()
|
||||
if var_name in context.session.state:
|
||||
return str(context.session.state[var_name])
|
||||
else:
|
||||
if optional:
|
||||
return ''
|
||||
else:
|
||||
raise KeyError(f'Context variable not found: `{var_name}`.')
|
||||
|
||||
return re.sub(r'{+[^{}]*}+', _replace_match, instruction_template)
|
||||
|
||||
|
||||
def _is_valid_state_name(var_name):
|
||||
"""Checks if the variable name is a valid state name.
|
||||
|
||||
Valid state is either:
|
||||
- Valid identifier
|
||||
- <Valid prefix>:<Valid identifier>
|
||||
All the others will just return as it is.
|
||||
|
||||
Args:
|
||||
var_name: The variable name to check.
|
||||
|
||||
Returns:
|
||||
True if the variable name is a valid state name, False otherwise.
|
||||
"""
|
||||
parts = var_name.split(':')
|
||||
if len(parts) == 1:
|
||||
return var_name.isidentifier()
|
||||
|
||||
if len(parts) == 2:
|
||||
prefixes = [State.APP_PREFIX, State.USER_PREFIX, State.TEMP_PREFIX]
|
||||
if (parts[0] + ':') in prefixes:
|
||||
return parts[1].isidentifier()
|
||||
return False
|
||||
@@ -0,0 +1,57 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Implementation of single flow."""
|
||||
|
||||
import logging
|
||||
|
||||
from ...auth import auth_preprocessor
|
||||
from . import _code_execution
|
||||
from . import _nl_planning
|
||||
from . import basic
|
||||
from . import contents
|
||||
from . import identity
|
||||
from . import instructions
|
||||
from .base_llm_flow import BaseLlmFlow
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SingleFlow(BaseLlmFlow):
|
||||
"""SingleFlow is the LLM flows that handles tools calls.
|
||||
|
||||
A single flow only consider an agent itself and tools.
|
||||
No sub-agents are allowed for single flow.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.request_processors += [
|
||||
basic.request_processor,
|
||||
auth_preprocessor.request_processor,
|
||||
instructions.request_processor,
|
||||
identity.request_processor,
|
||||
contents.request_processor,
|
||||
# Some implementations of NL Planning mark planning contents as thoughts
|
||||
# in the post processor. Since these need to be unmarked, NL Planning
|
||||
# should be after contents.
|
||||
_nl_planning.request_processor,
|
||||
# Code execution should be after the contents as it mutates the contents
|
||||
# to optimize data files.
|
||||
_code_execution.request_processor,
|
||||
]
|
||||
self.response_processors += [
|
||||
_nl_planning.response_processor,
|
||||
_code_execution.response_processor,
|
||||
]
|
||||
Reference in New Issue
Block a user