ADK changes

PiperOrigin-RevId: 755201925
This commit is contained in:
Shangjie Chen
2025-05-05 21:57:51 -07:00
committed by Copybara-Service
parent 6dec235c13
commit 905c20dad6
12 changed files with 86 additions and 69 deletions

View File

@@ -146,18 +146,20 @@ class AgentTool(BaseTool):
if runner.artifact_service:
# Forward all artifacts to parent session.
for artifact_name in runner.artifact_service.list_artifact_keys(
async for artifact_name in runner.artifact_service.list_artifact_keys(
app_name=session.app_name,
user_id=session.user_id,
session_id=session.id,
):
if artifact := runner.artifact_service.load_artifact(
if artifact := await runner.artifact_service.load_artifact(
app_name=session.app_name,
user_id=session.user_id,
session_id=session.id,
filename=artifact_name,
):
tool_context.save_artifact(filename=artifact_name, artifact=artifact)
await tool_context.save_artifact(
filename=artifact_name, artifact=artifact
)
if (
not last_event

View File

@@ -69,14 +69,14 @@ class LoadArtifactsTool(BaseTool):
tool_context=tool_context,
llm_request=llm_request,
)
self._append_artifacts_to_llm_request(
await self._append_artifacts_to_llm_request(
tool_context=tool_context, llm_request=llm_request
)
def _append_artifacts_to_llm_request(
async def _append_artifacts_to_llm_request(
self, *, tool_context: ToolContext, llm_request: LlmRequest
):
artifact_names = tool_context.list_artifacts()
artifact_names = await tool_context.list_artifacts()
if not artifact_names:
return
@@ -96,7 +96,7 @@ class LoadArtifactsTool(BaseTool):
if function_response and function_response.name == 'load_artifacts':
artifact_names = function_response.response['artifact_names']
for artifact_name in artifact_names:
artifact = tool_context.load_artifact(artifact_name)
artifact = await tool_context.load_artifact(artifact_name)
llm_request.contents.append(
types.Content(
role='user',

View File

@@ -69,11 +69,11 @@ class ToolContext(CallbackContext):
def get_auth_response(self, auth_config: AuthConfig) -> AuthCredential:
return AuthHandler(auth_config).get_auth_response(self.state)
def list_artifacts(self) -> list[str]:
async def list_artifacts(self) -> list[str]:
"""Lists the filenames of the artifacts attached to the current session."""
if self._invocation_context.artifact_service is None:
raise ValueError('Artifact service is not initialized.')
return self._invocation_context.artifact_service.list_artifact_keys(
return await self._invocation_context.artifact_service.list_artifact_keys(
app_name=self._invocation_context.app_name,
user_id=self._invocation_context.user_id,
session_id=self._invocation_context.session.id,