diff --git a/src/google/adk/agents/callback_context.py b/src/google/adk/agents/callback_context.py index 93b785c..9d6e311 100644 --- a/src/google/adk/agents/callback_context.py +++ b/src/google/adk/agents/callback_context.py @@ -65,7 +65,7 @@ class CallbackContext(ReadonlyContext): """The user content that started this invocation. READONLY field.""" return self._invocation_context.user_content - def load_artifact( + async def load_artifact( self, filename: str, version: Optional[int] = None ) -> Optional[types.Part]: """Loads an artifact attached to the current session. @@ -80,7 +80,7 @@ class CallbackContext(ReadonlyContext): """ if self._invocation_context.artifact_service is None: raise ValueError("Artifact service is not initialized.") - return self._invocation_context.artifact_service.load_artifact( + return await self._invocation_context.artifact_service.load_artifact( app_name=self._invocation_context.app_name, user_id=self._invocation_context.user_id, session_id=self._invocation_context.session.id, @@ -88,7 +88,7 @@ class CallbackContext(ReadonlyContext): version=version, ) - def save_artifact(self, filename: str, artifact: types.Part) -> int: + async def save_artifact(self, filename: str, artifact: types.Part) -> int: """Saves an artifact and records it as delta for the current session. Args: @@ -100,7 +100,7 @@ class CallbackContext(ReadonlyContext): """ if self._invocation_context.artifact_service is None: raise ValueError("Artifact service is not initialized.") - version = self._invocation_context.artifact_service.save_artifact( + version = await self._invocation_context.artifact_service.save_artifact( app_name=self._invocation_context.app_name, user_id=self._invocation_context.user_id, session_id=self._invocation_context.session.id, diff --git a/src/google/adk/artifacts/base_artifact_service.py b/src/google/adk/artifacts/base_artifact_service.py index 0af9146..8d14fcb 100644 --- a/src/google/adk/artifacts/base_artifact_service.py +++ b/src/google/adk/artifacts/base_artifact_service.py @@ -25,7 +25,7 @@ class BaseArtifactService(ABC): """Abstract base class for artifact services.""" @abstractmethod - def save_artifact( + async def save_artifact( self, *, app_name: str, @@ -53,7 +53,7 @@ class BaseArtifactService(ABC): """ @abstractmethod - def load_artifact( + async def load_artifact( self, *, app_name: str, @@ -81,7 +81,7 @@ class BaseArtifactService(ABC): pass @abstractmethod - def list_artifact_keys( + async def list_artifact_keys( self, *, app_name: str, user_id: str, session_id: str ) -> list[str]: """Lists all the artifact filenames within a session. @@ -97,7 +97,7 @@ class BaseArtifactService(ABC): pass @abstractmethod - def delete_artifact( + async def delete_artifact( self, *, app_name: str, user_id: str, session_id: str, filename: str ) -> None: """Deletes an artifact. @@ -111,7 +111,7 @@ class BaseArtifactService(ABC): pass @abstractmethod - def list_versions( + async def list_versions( self, *, app_name: str, user_id: str, session_id: str, filename: str ) -> list[int]: """Lists all versions of an artifact. diff --git a/src/google/adk/artifacts/gcs_artifact_service.py b/src/google/adk/artifacts/gcs_artifact_service.py index 279d5e0..8adbfe5 100644 --- a/src/google/adk/artifacts/gcs_artifact_service.py +++ b/src/google/adk/artifacts/gcs_artifact_service.py @@ -77,7 +77,7 @@ class GcsArtifactService(BaseArtifactService): return f"{app_name}/{user_id}/{session_id}/{filename}/{version}" @override - def save_artifact( + async def save_artifact( self, *, app_name: str, @@ -86,7 +86,7 @@ class GcsArtifactService(BaseArtifactService): filename: str, artifact: types.Part, ) -> int: - versions = self.list_versions( + versions = await self.list_versions( app_name=app_name, user_id=user_id, session_id=session_id, @@ -107,7 +107,7 @@ class GcsArtifactService(BaseArtifactService): return version @override - def load_artifact( + async def load_artifact( self, *, app_name: str, @@ -117,7 +117,7 @@ class GcsArtifactService(BaseArtifactService): version: Optional[int] = None, ) -> Optional[types.Part]: if version is None: - versions = self.list_versions( + versions = await self.list_versions( app_name=app_name, user_id=user_id, session_id=session_id, @@ -141,7 +141,7 @@ class GcsArtifactService(BaseArtifactService): return artifact @override - def list_artifact_keys( + async def list_artifact_keys( self, *, app_name: str, user_id: str, session_id: str ) -> list[str]: filenames = set() @@ -165,10 +165,10 @@ class GcsArtifactService(BaseArtifactService): return sorted(list(filenames)) @override - def delete_artifact( + async def delete_artifact( self, *, app_name: str, user_id: str, session_id: str, filename: str ) -> None: - versions = self.list_versions( + versions = await self.list_versions( app_name=app_name, user_id=user_id, session_id=session_id, @@ -183,7 +183,7 @@ class GcsArtifactService(BaseArtifactService): return @override - def list_versions( + async def list_versions( self, *, app_name: str, user_id: str, session_id: str, filename: str ) -> list[int]: prefix = self._get_blob_name(app_name, user_id, session_id, filename, "") diff --git a/src/google/adk/artifacts/in_memory_artifact_service.py b/src/google/adk/artifacts/in_memory_artifact_service.py index 8c886f6..fcfb881 100644 --- a/src/google/adk/artifacts/in_memory_artifact_service.py +++ b/src/google/adk/artifacts/in_memory_artifact_service.py @@ -63,7 +63,7 @@ class InMemoryArtifactService(BaseArtifactService, BaseModel): return f"{app_name}/{user_id}/{session_id}/{filename}" @override - def save_artifact( + async def save_artifact( self, *, app_name: str, @@ -80,7 +80,7 @@ class InMemoryArtifactService(BaseArtifactService, BaseModel): return version @override - def load_artifact( + async def load_artifact( self, *, app_name: str, @@ -98,7 +98,7 @@ class InMemoryArtifactService(BaseArtifactService, BaseModel): return versions[version] @override - def list_artifact_keys( + async def list_artifact_keys( self, *, app_name: str, user_id: str, session_id: str ) -> list[str]: session_prefix = f"{app_name}/{user_id}/{session_id}/" @@ -114,7 +114,7 @@ class InMemoryArtifactService(BaseArtifactService, BaseModel): return sorted(filenames) @override - def delete_artifact( + async def delete_artifact( self, *, app_name: str, user_id: str, session_id: str, filename: str ) -> None: path = self._artifact_path(app_name, user_id, session_id, filename) @@ -123,7 +123,7 @@ class InMemoryArtifactService(BaseArtifactService, BaseModel): self.artifacts.pop(path, None) @override - def list_versions( + async def list_versions( self, *, app_name: str, user_id: str, session_id: str, filename: str ) -> list[int]: path = self._artifact_path(app_name, user_id, session_id, filename) diff --git a/src/google/adk/cli/fast_api.py b/src/google/adk/cli/fast_api.py index a7747a1..ed663a5 100644 --- a/src/google/adk/cli/fast_api.py +++ b/src/google/adk/cli/fast_api.py @@ -503,7 +503,7 @@ def get_fast_api_app( "/apps/{app_name}/users/{user_id}/sessions/{session_id}/artifacts/{artifact_name}", response_model_exclude_none=True, ) - def load_artifact( + async def load_artifact( app_name: str, user_id: str, session_id: str, @@ -511,7 +511,7 @@ def get_fast_api_app( version: Optional[int] = Query(None), ) -> Optional[types.Part]: app_name = agent_engine_id if agent_engine_id else app_name - artifact = artifact_service.load_artifact( + artifact = await artifact_service.load_artifact( app_name=app_name, user_id=user_id, session_id=session_id, @@ -526,7 +526,7 @@ def get_fast_api_app( "/apps/{app_name}/users/{user_id}/sessions/{session_id}/artifacts/{artifact_name}/versions/{version_id}", response_model_exclude_none=True, ) - def load_artifact_version( + async def load_artifact_version( app_name: str, user_id: str, session_id: str, @@ -534,7 +534,7 @@ def get_fast_api_app( version_id: int, ) -> Optional[types.Part]: app_name = agent_engine_id if agent_engine_id else app_name - artifact = artifact_service.load_artifact( + artifact = await artifact_service.load_artifact( app_name=app_name, user_id=user_id, session_id=session_id, @@ -549,11 +549,11 @@ def get_fast_api_app( "/apps/{app_name}/users/{user_id}/sessions/{session_id}/artifacts", response_model_exclude_none=True, ) - def list_artifact_names( + async def list_artifact_names( app_name: str, user_id: str, session_id: str ) -> list[str]: app_name = agent_engine_id if agent_engine_id else app_name - return artifact_service.list_artifact_keys( + return await artifact_service.list_artifact_keys( app_name=app_name, user_id=user_id, session_id=session_id ) @@ -561,11 +561,11 @@ def get_fast_api_app( "/apps/{app_name}/users/{user_id}/sessions/{session_id}/artifacts/{artifact_name}/versions", response_model_exclude_none=True, ) - def list_artifact_versions( + async def list_artifact_versions( app_name: str, user_id: str, session_id: str, artifact_name: str ) -> list[int]: app_name = agent_engine_id if agent_engine_id else app_name - return artifact_service.list_versions( + return await artifact_service.list_versions( app_name=app_name, user_id=user_id, session_id=session_id, @@ -575,11 +575,11 @@ def get_fast_api_app( @app.delete( "/apps/{app_name}/users/{user_id}/sessions/{session_id}/artifacts/{artifact_name}", ) - def delete_artifact( + async def delete_artifact( app_name: str, user_id: str, session_id: str, artifact_name: str ): app_name = agent_engine_id if agent_engine_id else app_name - artifact_service.delete_artifact( + await artifact_service.delete_artifact( app_name=app_name, user_id=user_id, session_id=session_id, diff --git a/src/google/adk/flows/llm_flows/_code_execution.py b/src/google/adk/flows/llm_flows/_code_execution.py index 7ca57af..8a6a56e 100644 --- a/src/google/adk/flows/llm_flows/_code_execution.py +++ b/src/google/adk/flows/llm_flows/_code_execution.py @@ -122,7 +122,7 @@ class _CodeExecutionRequestProcessor(BaseLlmRequestProcessor): if not invocation_context.agent.code_executor: return - for event in _run_pre_processor(invocation_context, llm_request): + async for event in _run_pre_processor(invocation_context, llm_request): yield event # Convert the code execution parts to text parts. @@ -159,10 +159,10 @@ class _CodeExecutionResponseProcessor(BaseLlmResponseProcessor): response_processor = _CodeExecutionResponseProcessor() -def _run_pre_processor( +async def _run_pre_processor( invocation_context: InvocationContext, llm_request: LlmRequest, -) -> Generator[Event, None, None]: +) -> AsyncGenerator[Event, None]: """Pre-process the user message by adding the user message to the Colab notebook.""" from ...agents.llm_agent import LlmAgent @@ -242,7 +242,7 @@ def _run_pre_processor( code_executor_context.add_processed_file_names([file.name]) # Emit the execution result, and add it to the LLM request. - execution_result_event = _post_process_code_execution_result( + execution_result_event = await _post_process_code_execution_result( invocation_context, code_executor_context, code_execution_result ) yield execution_result_event @@ -375,7 +375,7 @@ def _get_or_set_execution_id( return execution_id -def _post_process_code_execution_result( +async def _post_process_code_execution_result( invocation_context: InvocationContext, code_executor_context: CodeExecutorContext, code_execution_result: CodeExecutionResult, @@ -406,7 +406,7 @@ def _post_process_code_execution_result( # Handle output files. for output_file in code_execution_result.output_files: - version = invocation_context.artifact_service.save_artifact( + version = await invocation_context.artifact_service.save_artifact( app_name=invocation_context.app_name, user_id=invocation_context.user_id, session_id=invocation_context.session.id, diff --git a/src/google/adk/flows/llm_flows/instructions.py b/src/google/adk/flows/llm_flows/instructions.py index 860240d..041c867 100644 --- a/src/google/adk/flows/llm_flows/instructions.py +++ b/src/google/adk/flows/llm_flows/instructions.py @@ -56,13 +56,13 @@ class _InstructionsLlmRequestProcessor(BaseLlmRequestProcessor): raw_si = root_agent.canonical_global_instruction( ReadonlyContext(invocation_context) ) - si = _populate_values(raw_si, invocation_context) + si = await _populate_values(raw_si, invocation_context) llm_request.append_instructions([si]) # Appends agent instructions if set. if agent.instruction: # not empty str raw_si = agent.canonical_instruction(ReadonlyContext(invocation_context)) - si = _populate_values(raw_si, invocation_context) + si = await _populate_values(raw_si, invocation_context) llm_request.append_instructions([si]) # Maintain async generator behavior @@ -73,13 +73,24 @@ class _InstructionsLlmRequestProcessor(BaseLlmRequestProcessor): request_processor = _InstructionsLlmRequestProcessor() -def _populate_values( +async def _populate_values( instruction_template: str, context: InvocationContext, ) -> str: """Populates values in the instruction template, e.g. state, artifact, etc.""" - def _replace_match(match) -> str: + async def _async_sub(pattern, repl_async_fn, string) -> str: + result = [] + last_end = 0 + for match in re.finditer(pattern, string): + result.append(string[last_end : match.start()]) + replacement = await repl_async_fn(match) + result.append(replacement) + last_end = match.end() + result.append(string[last_end:]) + return ''.join(result) + + async def _replace_match(match) -> str: var_name = match.group().lstrip('{').rstrip('}').strip() optional = False if var_name.endswith('?'): @@ -89,7 +100,7 @@ def _populate_values( var_name = var_name.removeprefix('artifact.') if context.artifact_service is None: raise ValueError('Artifact service is not initialized.') - artifact = context.artifact_service.load_artifact( + artifact = await context.artifact_service.load_artifact( app_name=context.session.app_name, user_id=context.session.user_id, session_id=context.session.id, @@ -109,7 +120,7 @@ def _populate_values( else: raise KeyError(f'Context variable not found: `{var_name}`.') - return re.sub(r'{+[^{}]*}+', _replace_match, instruction_template) + return await _async_sub(r'{+[^{}]*}+', _replace_match, instruction_template) def _is_valid_state_name(var_name): diff --git a/src/google/adk/runners.py b/src/google/adk/runners.py index 15b1ee2..1ec8631 100644 --- a/src/google/adk/runners.py +++ b/src/google/adk/runners.py @@ -186,7 +186,7 @@ class Runner: root_agent = self.agent if new_message: - self._append_new_message_to_session( + await self._append_new_message_to_session( session, new_message, invocation_context, @@ -199,7 +199,7 @@ class Runner: self.session_service.append_event(session=session, event=event) yield event - def _append_new_message_to_session( + async def _append_new_message_to_session( self, session: Session, new_message: types.Content, @@ -225,7 +225,7 @@ class Runner: if part.inline_data is None: continue file_name = f'artifact_{invocation_context.invocation_id}_{i}' - self.artifact_service.save_artifact( + await self.artifact_service.save_artifact( app_name=self.app_name, user_id=session.user_id, session_id=session.id, diff --git a/src/google/adk/tools/agent_tool.py b/src/google/adk/tools/agent_tool.py index 3efafdb..aab9ae2 100644 --- a/src/google/adk/tools/agent_tool.py +++ b/src/google/adk/tools/agent_tool.py @@ -146,18 +146,20 @@ class AgentTool(BaseTool): if runner.artifact_service: # Forward all artifacts to parent session. - for artifact_name in runner.artifact_service.list_artifact_keys( + async for artifact_name in runner.artifact_service.list_artifact_keys( app_name=session.app_name, user_id=session.user_id, session_id=session.id, ): - if artifact := runner.artifact_service.load_artifact( + if artifact := await runner.artifact_service.load_artifact( app_name=session.app_name, user_id=session.user_id, session_id=session.id, filename=artifact_name, ): - tool_context.save_artifact(filename=artifact_name, artifact=artifact) + await tool_context.save_artifact( + filename=artifact_name, artifact=artifact + ) if ( not last_event diff --git a/src/google/adk/tools/load_artifacts_tool.py b/src/google/adk/tools/load_artifacts_tool.py index bee650f..db28aef 100644 --- a/src/google/adk/tools/load_artifacts_tool.py +++ b/src/google/adk/tools/load_artifacts_tool.py @@ -69,14 +69,14 @@ class LoadArtifactsTool(BaseTool): tool_context=tool_context, llm_request=llm_request, ) - self._append_artifacts_to_llm_request( + await self._append_artifacts_to_llm_request( tool_context=tool_context, llm_request=llm_request ) - def _append_artifacts_to_llm_request( + async def _append_artifacts_to_llm_request( self, *, tool_context: ToolContext, llm_request: LlmRequest ): - artifact_names = tool_context.list_artifacts() + artifact_names = await tool_context.list_artifacts() if not artifact_names: return @@ -96,7 +96,7 @@ class LoadArtifactsTool(BaseTool): if function_response and function_response.name == 'load_artifacts': artifact_names = function_response.response['artifact_names'] for artifact_name in artifact_names: - artifact = tool_context.load_artifact(artifact_name) + artifact = await tool_context.load_artifact(artifact_name) llm_request.contents.append( types.Content( role='user', diff --git a/src/google/adk/tools/tool_context.py b/src/google/adk/tools/tool_context.py index ad8fce8..e99d42c 100644 --- a/src/google/adk/tools/tool_context.py +++ b/src/google/adk/tools/tool_context.py @@ -69,11 +69,11 @@ class ToolContext(CallbackContext): def get_auth_response(self, auth_config: AuthConfig) -> AuthCredential: return AuthHandler(auth_config).get_auth_response(self.state) - def list_artifacts(self) -> list[str]: + async def list_artifacts(self) -> list[str]: """Lists the filenames of the artifacts attached to the current session.""" if self._invocation_context.artifact_service is None: raise ValueError('Artifact service is not initialized.') - return self._invocation_context.artifact_service.list_artifact_keys( + return await self._invocation_context.artifact_service.list_artifact_keys( app_name=self._invocation_context.app_name, user_id=self._invocation_context.user_id, session_id=self._invocation_context.session.id, diff --git a/tests/unittests/artifacts/test_artifact_service.py b/tests/unittests/artifacts/test_artifact_service.py index e8ce497..6f8ef0b 100644 --- a/tests/unittests/artifacts/test_artifact_service.py +++ b/tests/unittests/artifacts/test_artifact_service.py @@ -152,13 +152,14 @@ def get_artifact_service( return InMemoryArtifactService() +@pytest.mark.asyncio @pytest.mark.parametrize( "service_type", [ArtifactServiceType.IN_MEMORY, ArtifactServiceType.GCS] ) -def test_load_empty(service_type): +async def test_load_empty(service_type): """Tests loading an artifact when none exists.""" artifact_service = get_artifact_service(service_type) - assert not artifact_service.load_artifact( + assert not await artifact_service.load_artifact( app_name="test_app", user_id="test_user", session_id="session_id", @@ -166,10 +167,11 @@ def test_load_empty(service_type): ) +@pytest.mark.asyncio @pytest.mark.parametrize( "service_type", [ArtifactServiceType.IN_MEMORY, ArtifactServiceType.GCS] ) -def test_save_load_delete(service_type): +async def test_save_load_delete(service_type): """Tests saving, loading, and deleting an artifact.""" artifact_service = get_artifact_service(service_type) artifact = types.Part.from_bytes(data=b"test_data", mime_type="text/plain") @@ -178,7 +180,7 @@ def test_save_load_delete(service_type): session_id = "123" filename = "file456" - artifact_service.save_artifact( + await artifact_service.save_artifact( app_name=app_name, user_id=user_id, session_id=session_id, @@ -186,7 +188,7 @@ def test_save_load_delete(service_type): artifact=artifact, ) assert ( - artifact_service.load_artifact( + await artifact_service.load_artifact( app_name=app_name, user_id=user_id, session_id=session_id, @@ -195,13 +197,13 @@ def test_save_load_delete(service_type): == artifact ) - artifact_service.delete_artifact( + await artifact_service.delete_artifact( app_name=app_name, user_id=user_id, session_id=session_id, filename=filename, ) - assert not artifact_service.load_artifact( + assert not await artifact_service.load_artifact( app_name=app_name, user_id=user_id, session_id=session_id, @@ -209,10 +211,11 @@ def test_save_load_delete(service_type): ) +@pytest.mark.asyncio @pytest.mark.parametrize( "service_type", [ArtifactServiceType.IN_MEMORY, ArtifactServiceType.GCS] ) -def test_list_keys(service_type): +async def test_list_keys(service_type): """Tests listing keys in the artifact service.""" artifact_service = get_artifact_service(service_type) artifact = types.Part.from_bytes(data=b"test_data", mime_type="text/plain") @@ -223,7 +226,7 @@ def test_list_keys(service_type): filenames = [filename + str(i) for i in range(5)] for f in filenames: - artifact_service.save_artifact( + await artifact_service.save_artifact( app_name=app_name, user_id=user_id, session_id=session_id, @@ -232,17 +235,18 @@ def test_list_keys(service_type): ) assert ( - artifact_service.list_artifact_keys( + await artifact_service.list_artifact_keys( app_name=app_name, user_id=user_id, session_id=session_id ) == filenames ) +@pytest.mark.asyncio @pytest.mark.parametrize( "service_type", [ArtifactServiceType.IN_MEMORY, ArtifactServiceType.GCS] ) -def test_list_versions(service_type): +async def test_list_versions(service_type): """Tests listing versions of an artifact.""" artifact_service = get_artifact_service(service_type) @@ -258,7 +262,7 @@ def test_list_versions(service_type): ] for i in range(3): - artifact_service.save_artifact( + await artifact_service.save_artifact( app_name=app_name, user_id=user_id, session_id=session_id, @@ -266,7 +270,7 @@ def test_list_versions(service_type): artifact=versions[i], ) - response_versions = artifact_service.list_versions( + response_versions = await artifact_service.list_versions( app_name=app_name, user_id=user_id, session_id=session_id,