From c5a04377455ba5aa6371aa82f3b91f2b7301af15 Mon Sep 17 00:00:00 2001 From: Selcuk Gun Date: Wed, 21 May 2025 11:14:41 -0700 Subject: [PATCH] Provide inject_session_state as public util method This is useful for injecting artifacts and session state variable into instruction template typically in instruction providers. PiperOrigin-RevId: 761595473 --- .../adk/flows/llm_flows/instructions.py | 84 +------ tests/unittests/public_utils/__init__.py | 13 ++ .../public_utils/test_instructions_utils.py | 214 ++++++++++++++++++ 3 files changed, 234 insertions(+), 77 deletions(-) create mode 100644 tests/unittests/public_utils/__init__.py create mode 100644 tests/unittests/public_utils/test_instructions_utils.py diff --git a/src/google/adk/flows/llm_flows/instructions.py b/src/google/adk/flows/llm_flows/instructions.py index 2956d6f..77a1afe 100644 --- a/src/google/adk/flows/llm_flows/instructions.py +++ b/src/google/adk/flows/llm_flows/instructions.py @@ -26,6 +26,7 @@ from typing_extensions import override from ...agents.readonly_context import ReadonlyContext from ...events.event import Event from ...sessions.state import State +from ...utils import instructions_utils from ._base_llm_processor import BaseLlmRequestProcessor if TYPE_CHECKING: @@ -60,7 +61,9 @@ class _InstructionsLlmRequestProcessor(BaseLlmRequestProcessor): ) si = raw_si if not bypass_state_injection: - si = await _populate_values(raw_si, invocation_context) + si = await instructions_utils.inject_session_state( + raw_si, ReadonlyContext(invocation_context) + ) llm_request.append_instructions([si]) # Appends agent instructions if set. @@ -70,7 +73,9 @@ class _InstructionsLlmRequestProcessor(BaseLlmRequestProcessor): ) si = raw_si if not bypass_state_injection: - si = await _populate_values(raw_si, invocation_context) + si = await instructions_utils.inject_session_state( + raw_si, ReadonlyContext(invocation_context) + ) llm_request.append_instructions([si]) # Maintain async generator behavior @@ -79,78 +84,3 @@ class _InstructionsLlmRequestProcessor(BaseLlmRequestProcessor): request_processor = _InstructionsLlmRequestProcessor() - - -async def _populate_values( - instruction_template: str, - context: InvocationContext, -) -> str: - """Populates values in the instruction template, e.g. state, artifact, etc.""" - - async def _async_sub(pattern, repl_async_fn, string) -> str: - result = [] - last_end = 0 - for match in re.finditer(pattern, string): - result.append(string[last_end : match.start()]) - replacement = await repl_async_fn(match) - result.append(replacement) - last_end = match.end() - result.append(string[last_end:]) - return ''.join(result) - - async def _replace_match(match) -> str: - var_name = match.group().lstrip('{').rstrip('}').strip() - optional = False - if var_name.endswith('?'): - optional = True - var_name = var_name.removesuffix('?') - if var_name.startswith('artifact.'): - var_name = var_name.removeprefix('artifact.') - if context.artifact_service is None: - raise ValueError('Artifact service is not initialized.') - artifact = await context.artifact_service.load_artifact( - app_name=context.session.app_name, - user_id=context.session.user_id, - session_id=context.session.id, - filename=var_name, - ) - if not var_name: - raise KeyError(f'Artifact {var_name} not found.') - return str(artifact) - else: - if not _is_valid_state_name(var_name): - return match.group() - if var_name in context.session.state: - return str(context.session.state[var_name]) - else: - if optional: - return '' - else: - raise KeyError(f'Context variable not found: `{var_name}`.') - - return await _async_sub(r'{+[^{}]*}+', _replace_match, instruction_template) - - -def _is_valid_state_name(var_name): - """Checks if the variable name is a valid state name. - - Valid state is either: - - Valid identifier - - : - All the others will just return as it is. - - Args: - var_name: The variable name to check. - - Returns: - True if the variable name is a valid state name, False otherwise. - """ - parts = var_name.split(':') - if len(parts) == 1: - return var_name.isidentifier() - - if len(parts) == 2: - prefixes = [State.APP_PREFIX, State.USER_PREFIX, State.TEMP_PREFIX] - if (parts[0] + ':') in prefixes: - return parts[1].isidentifier() - return False diff --git a/tests/unittests/public_utils/__init__.py b/tests/unittests/public_utils/__init__.py new file mode 100644 index 0000000..0a2669d --- /dev/null +++ b/tests/unittests/public_utils/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/unittests/public_utils/test_instructions_utils.py b/tests/unittests/public_utils/test_instructions_utils.py new file mode 100644 index 0000000..8fc7647 --- /dev/null +++ b/tests/unittests/public_utils/test_instructions_utils.py @@ -0,0 +1,214 @@ +from google.adk.agents import Agent +from google.adk.agents.invocation_context import InvocationContext +from google.adk.agents.readonly_context import ReadonlyContext +from google.adk.sessions import Session +from google.adk.utils import instructions_utils +import pytest + +from .. import utils + + +class MockArtifactService: + + def __init__(self, artifacts: dict): + self.artifacts = artifacts + + async def load_artifact(self, app_name, user_id, session_id, filename): + if filename in self.artifacts: + return self.artifacts[filename] + else: + raise KeyError(f"Artifact '{filename}' not found.") + + +async def _create_test_readonly_context( + state: dict = None, + artifact_service: MockArtifactService = None, + app_name: str = "test_app", + user_id: str = "test_user", + session_id: str = "test_session_id", +) -> ReadonlyContext: + agent = Agent( + model="gemini-2.0-flash", + name="agent", + instruction="test", + ) + invocation_context = await utils.create_invocation_context(agent=agent) + invocation_context.session = Session( + state=state if state else {}, + app_name=app_name, + user_id=user_id, + id=session_id, + ) + + invocation_context.artifact_service = artifact_service + return ReadonlyContext(invocation_context) + + +@pytest.mark.asyncio +async def test_inject_session_state(): + instruction_template = "Hello {user_name}, you are in {app_state} state." + invocation_context = await _create_test_readonly_context( + state={"user_name": "Foo", "app_state": "active"} + ) + + populated_instruction = await instructions_utils.inject_session_state( + instruction_template, invocation_context + ) + assert populated_instruction == "Hello Foo, you are in active state." + + +@pytest.mark.asyncio +async def test_inject_session_state_with_artifact(): + instruction_template = "The artifact content is: {artifact.my_file}" + mock_artifact_service = MockArtifactService( + {"my_file": "This is my artifact content."} + ) + invocation_context = await _create_test_readonly_context( + artifact_service=mock_artifact_service + ) + + populated_instruction = await instructions_utils.inject_session_state( + instruction_template, invocation_context + ) + assert ( + populated_instruction + == "The artifact content is: This is my artifact content." + ) + + +@pytest.mark.asyncio +async def test_inject_session_state_with_optional_state(): + instruction_template = "Optional value: {optional_value?}" + invocation_context = await _create_test_readonly_context() + + populated_instruction = await instructions_utils.inject_session_state( + instruction_template, invocation_context + ) + assert populated_instruction == "Optional value: " + + +@pytest.mark.asyncio +async def test_inject_session_state_with_missing_state_raises_key_error(): + instruction_template = "Hello {missing_key}!" + invocation_context = await _create_test_readonly_context( + state={"user_name": "Foo"} + ) + + with pytest.raises( + KeyError, match="Context variable not found: `missing_key`." + ): + await instructions_utils.inject_session_state( + instruction_template, invocation_context + ) + + +@pytest.mark.asyncio +async def test_inject_session_state_with_missing_artifact_raises_key_error(): + instruction_template = "The artifact content is: {artifact.missing_file}" + mock_artifact_service = MockArtifactService( + {"my_file": "This is my artifact content."} + ) + invocation_context = await _create_test_readonly_context( + artifact_service=mock_artifact_service + ) + + with pytest.raises(KeyError, match="Artifact 'missing_file' not found."): + await instructions_utils.inject_session_state( + instruction_template, invocation_context + ) + + +@pytest.mark.asyncio +async def test_inject_session_state_with_invalid_state_name_returns_original(): + instruction_template = "Hello {invalid-key}!" + invocation_context = await _create_test_readonly_context( + state={"user_name": "Foo"} + ) + + populated_instruction = await instructions_utils.inject_session_state( + instruction_template, invocation_context + ) + assert populated_instruction == "Hello {invalid-key}!" + + +@pytest.mark.asyncio +async def test_inject_session_state_with_invalid_prefix_state_name_returns_original(): + instruction_template = "Hello {invalid:key}!" + invocation_context = await _create_test_readonly_context( + state={"user_name": "Foo"} + ) + + populated_instruction = await instructions_utils.inject_session_state( + instruction_template, invocation_context + ) + assert populated_instruction == "Hello {invalid:key}!" + + +@pytest.mark.asyncio +async def test_inject_session_state_with_valid_prefix_state(): + instruction_template = "Hello {app:user_name}!" + invocation_context = await _create_test_readonly_context( + state={"app:user_name": "Foo"} + ) + + populated_instruction = await instructions_utils.inject_session_state( + instruction_template, invocation_context + ) + assert populated_instruction == "Hello Foo!" + + +@pytest.mark.asyncio +async def test_inject_session_state_with_multiple_variables_and_artifacts(): + instruction_template = """ + Hello {user_name}, + You are {user_age} years old. + Your favorite color is {favorite_color?}. + The artifact says: {artifact.my_file} + And another optional artifact: {artifact.other_file} + """ + mock_artifact_service = MockArtifactService({ + "my_file": "This is my artifact content.", + "other_file": "This is another artifact content.", + }) + invocation_context = await _create_test_readonly_context( + state={"user_name": "Foo", "user_age": 30, "favorite_color": "blue"}, + artifact_service=mock_artifact_service, + ) + + populated_instruction = await instructions_utils.inject_session_state( + instruction_template, invocation_context + ) + expected_instruction = """ + Hello Foo, + You are 30 years old. + Your favorite color is blue. + The artifact says: This is my artifact content. + And another optional artifact: This is another artifact content. + """ + assert populated_instruction == expected_instruction + + +@pytest.mark.asyncio +async def test_inject_session_state_with_empty_artifact_name_raises_key_error(): + instruction_template = "The artifact content is: {artifact.}" + mock_artifact_service = MockArtifactService( + {"my_file": "This is my artifact content."} + ) + invocation_context = await _create_test_readonly_context( + artifact_service=mock_artifact_service + ) + + with pytest.raises(KeyError, match="Artifact '' not found."): + await instructions_utils.inject_session_state( + instruction_template, invocation_context + ) + + +@pytest.mark.asyncio +async def test_inject_session_state_artifact_service_not_initialized_raises_value_error(): + instruction_template = "The artifact content is: {artifact.my_file}" + invocation_context = await _create_test_readonly_context() + with pytest.raises(ValueError, match="Artifact service is not initialized."): + await instructions_utils.inject_session_state( + instruction_template, invocation_context + )