Files
adk-python/src/google/adk/tools/load_artifacts_tool.py
Shangjie Chen 905c20dad6 ADK changes
PiperOrigin-RevId: 755201925
2025-05-05 21:58:47 -07:00

114 lines
3.6 KiB
Python

# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import json
from typing import Any
from typing import TYPE_CHECKING
from google.genai import types
from typing_extensions import override
from .base_tool import BaseTool
if TYPE_CHECKING:
from ..models.llm_request import LlmRequest
from .tool_context import ToolContext
class LoadArtifactsTool(BaseTool):
"""A tool that loads the artifacts and adds them to the session."""
def __init__(self):
super().__init__(
name='load_artifacts',
description='Loads the artifacts and adds them to the session.',
)
def _get_declaration(self) -> types.FunctionDeclaration | None:
return types.FunctionDeclaration(
name=self.name,
description=self.description,
parameters=types.Schema(
type=types.Type.OBJECT,
properties={
'artifact_names': types.Schema(
type=types.Type.ARRAY,
items=types.Schema(
type=types.Type.STRING,
),
)
},
),
)
@override
async def run_async(
self, *, args: dict[str, Any], tool_context: ToolContext
) -> Any:
artifact_names: list[str] = args.get('artifact_names', [])
return {'artifact_names': artifact_names}
@override
async def process_llm_request(
self, *, tool_context: ToolContext, llm_request: LlmRequest
) -> None:
await super().process_llm_request(
tool_context=tool_context,
llm_request=llm_request,
)
await self._append_artifacts_to_llm_request(
tool_context=tool_context, llm_request=llm_request
)
async def _append_artifacts_to_llm_request(
self, *, tool_context: ToolContext, llm_request: LlmRequest
):
artifact_names = await tool_context.list_artifacts()
if not artifact_names:
return
# Tell the model about the available artifacts.
llm_request.append_instructions([f"""You have a list of artifacts:
{json.dumps(artifact_names)}
When the user asks questions about any of the artifacts, you should call the
`load_artifacts` function to load the artifact. Do not generate any text other
than the function call.
"""])
# Attach the content of the artifacts if the model requests them.
# This only adds the content to the model request, instead of the session.
if llm_request.contents and llm_request.contents[-1].parts:
function_response = llm_request.contents[-1].parts[0].function_response
if function_response and function_response.name == 'load_artifacts':
artifact_names = function_response.response['artifact_names']
for artifact_name in artifact_names:
artifact = await tool_context.load_artifact(artifact_name)
llm_request.contents.append(
types.Content(
role='user',
parts=[
types.Part.from_text(
text=f'Artifact {artifact_name} is:'
),
artifact,
],
)
)
load_artifacts_tool = LoadArtifactsTool()