mirror of
https://github.com/EvolutionAPI/adk-python.git
synced 2025-07-13 15:14:50 -06:00
Provide inject_session_state as public util method
This is useful for injecting artifacts and session state variable into instruction template typically in instruction providers. PiperOrigin-RevId: 761595473
This commit is contained in:
parent
e060344e39
commit
c5a0437745
@ -26,6 +26,7 @@ from typing_extensions import override
|
||||
from ...agents.readonly_context import ReadonlyContext
|
||||
from ...events.event import Event
|
||||
from ...sessions.state import State
|
||||
from ...utils import instructions_utils
|
||||
from ._base_llm_processor import BaseLlmRequestProcessor
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@ -60,7 +61,9 @@ class _InstructionsLlmRequestProcessor(BaseLlmRequestProcessor):
|
||||
)
|
||||
si = raw_si
|
||||
if not bypass_state_injection:
|
||||
si = await _populate_values(raw_si, invocation_context)
|
||||
si = await instructions_utils.inject_session_state(
|
||||
raw_si, ReadonlyContext(invocation_context)
|
||||
)
|
||||
llm_request.append_instructions([si])
|
||||
|
||||
# Appends agent instructions if set.
|
||||
@ -70,7 +73,9 @@ class _InstructionsLlmRequestProcessor(BaseLlmRequestProcessor):
|
||||
)
|
||||
si = raw_si
|
||||
if not bypass_state_injection:
|
||||
si = await _populate_values(raw_si, invocation_context)
|
||||
si = await instructions_utils.inject_session_state(
|
||||
raw_si, ReadonlyContext(invocation_context)
|
||||
)
|
||||
llm_request.append_instructions([si])
|
||||
|
||||
# Maintain async generator behavior
|
||||
@ -79,78 +84,3 @@ class _InstructionsLlmRequestProcessor(BaseLlmRequestProcessor):
|
||||
|
||||
|
||||
request_processor = _InstructionsLlmRequestProcessor()
|
||||
|
||||
|
||||
async def _populate_values(
|
||||
instruction_template: str,
|
||||
context: InvocationContext,
|
||||
) -> str:
|
||||
"""Populates values in the instruction template, e.g. state, artifact, etc."""
|
||||
|
||||
async def _async_sub(pattern, repl_async_fn, string) -> str:
|
||||
result = []
|
||||
last_end = 0
|
||||
for match in re.finditer(pattern, string):
|
||||
result.append(string[last_end : match.start()])
|
||||
replacement = await repl_async_fn(match)
|
||||
result.append(replacement)
|
||||
last_end = match.end()
|
||||
result.append(string[last_end:])
|
||||
return ''.join(result)
|
||||
|
||||
async def _replace_match(match) -> str:
|
||||
var_name = match.group().lstrip('{').rstrip('}').strip()
|
||||
optional = False
|
||||
if var_name.endswith('?'):
|
||||
optional = True
|
||||
var_name = var_name.removesuffix('?')
|
||||
if var_name.startswith('artifact.'):
|
||||
var_name = var_name.removeprefix('artifact.')
|
||||
if context.artifact_service is None:
|
||||
raise ValueError('Artifact service is not initialized.')
|
||||
artifact = await context.artifact_service.load_artifact(
|
||||
app_name=context.session.app_name,
|
||||
user_id=context.session.user_id,
|
||||
session_id=context.session.id,
|
||||
filename=var_name,
|
||||
)
|
||||
if not var_name:
|
||||
raise KeyError(f'Artifact {var_name} not found.')
|
||||
return str(artifact)
|
||||
else:
|
||||
if not _is_valid_state_name(var_name):
|
||||
return match.group()
|
||||
if var_name in context.session.state:
|
||||
return str(context.session.state[var_name])
|
||||
else:
|
||||
if optional:
|
||||
return ''
|
||||
else:
|
||||
raise KeyError(f'Context variable not found: `{var_name}`.')
|
||||
|
||||
return await _async_sub(r'{+[^{}]*}+', _replace_match, instruction_template)
|
||||
|
||||
|
||||
def _is_valid_state_name(var_name):
|
||||
"""Checks if the variable name is a valid state name.
|
||||
|
||||
Valid state is either:
|
||||
- Valid identifier
|
||||
- <Valid prefix>:<Valid identifier>
|
||||
All the others will just return as it is.
|
||||
|
||||
Args:
|
||||
var_name: The variable name to check.
|
||||
|
||||
Returns:
|
||||
True if the variable name is a valid state name, False otherwise.
|
||||
"""
|
||||
parts = var_name.split(':')
|
||||
if len(parts) == 1:
|
||||
return var_name.isidentifier()
|
||||
|
||||
if len(parts) == 2:
|
||||
prefixes = [State.APP_PREFIX, State.USER_PREFIX, State.TEMP_PREFIX]
|
||||
if (parts[0] + ':') in prefixes:
|
||||
return parts[1].isidentifier()
|
||||
return False
|
||||
|
13
tests/unittests/public_utils/__init__.py
Normal file
13
tests/unittests/public_utils/__init__.py
Normal file
@ -0,0 +1,13 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
214
tests/unittests/public_utils/test_instructions_utils.py
Normal file
214
tests/unittests/public_utils/test_instructions_utils.py
Normal file
@ -0,0 +1,214 @@
|
||||
from google.adk.agents import Agent
|
||||
from google.adk.agents.invocation_context import InvocationContext
|
||||
from google.adk.agents.readonly_context import ReadonlyContext
|
||||
from google.adk.sessions import Session
|
||||
from google.adk.utils import instructions_utils
|
||||
import pytest
|
||||
|
||||
from .. import utils
|
||||
|
||||
|
||||
class MockArtifactService:
|
||||
|
||||
def __init__(self, artifacts: dict):
|
||||
self.artifacts = artifacts
|
||||
|
||||
async def load_artifact(self, app_name, user_id, session_id, filename):
|
||||
if filename in self.artifacts:
|
||||
return self.artifacts[filename]
|
||||
else:
|
||||
raise KeyError(f"Artifact '{filename}' not found.")
|
||||
|
||||
|
||||
async def _create_test_readonly_context(
|
||||
state: dict = None,
|
||||
artifact_service: MockArtifactService = None,
|
||||
app_name: str = "test_app",
|
||||
user_id: str = "test_user",
|
||||
session_id: str = "test_session_id",
|
||||
) -> ReadonlyContext:
|
||||
agent = Agent(
|
||||
model="gemini-2.0-flash",
|
||||
name="agent",
|
||||
instruction="test",
|
||||
)
|
||||
invocation_context = await utils.create_invocation_context(agent=agent)
|
||||
invocation_context.session = Session(
|
||||
state=state if state else {},
|
||||
app_name=app_name,
|
||||
user_id=user_id,
|
||||
id=session_id,
|
||||
)
|
||||
|
||||
invocation_context.artifact_service = artifact_service
|
||||
return ReadonlyContext(invocation_context)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_inject_session_state():
|
||||
instruction_template = "Hello {user_name}, you are in {app_state} state."
|
||||
invocation_context = await _create_test_readonly_context(
|
||||
state={"user_name": "Foo", "app_state": "active"}
|
||||
)
|
||||
|
||||
populated_instruction = await instructions_utils.inject_session_state(
|
||||
instruction_template, invocation_context
|
||||
)
|
||||
assert populated_instruction == "Hello Foo, you are in active state."
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_inject_session_state_with_artifact():
|
||||
instruction_template = "The artifact content is: {artifact.my_file}"
|
||||
mock_artifact_service = MockArtifactService(
|
||||
{"my_file": "This is my artifact content."}
|
||||
)
|
||||
invocation_context = await _create_test_readonly_context(
|
||||
artifact_service=mock_artifact_service
|
||||
)
|
||||
|
||||
populated_instruction = await instructions_utils.inject_session_state(
|
||||
instruction_template, invocation_context
|
||||
)
|
||||
assert (
|
||||
populated_instruction
|
||||
== "The artifact content is: This is my artifact content."
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_inject_session_state_with_optional_state():
|
||||
instruction_template = "Optional value: {optional_value?}"
|
||||
invocation_context = await _create_test_readonly_context()
|
||||
|
||||
populated_instruction = await instructions_utils.inject_session_state(
|
||||
instruction_template, invocation_context
|
||||
)
|
||||
assert populated_instruction == "Optional value: "
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_inject_session_state_with_missing_state_raises_key_error():
|
||||
instruction_template = "Hello {missing_key}!"
|
||||
invocation_context = await _create_test_readonly_context(
|
||||
state={"user_name": "Foo"}
|
||||
)
|
||||
|
||||
with pytest.raises(
|
||||
KeyError, match="Context variable not found: `missing_key`."
|
||||
):
|
||||
await instructions_utils.inject_session_state(
|
||||
instruction_template, invocation_context
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_inject_session_state_with_missing_artifact_raises_key_error():
|
||||
instruction_template = "The artifact content is: {artifact.missing_file}"
|
||||
mock_artifact_service = MockArtifactService(
|
||||
{"my_file": "This is my artifact content."}
|
||||
)
|
||||
invocation_context = await _create_test_readonly_context(
|
||||
artifact_service=mock_artifact_service
|
||||
)
|
||||
|
||||
with pytest.raises(KeyError, match="Artifact 'missing_file' not found."):
|
||||
await instructions_utils.inject_session_state(
|
||||
instruction_template, invocation_context
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_inject_session_state_with_invalid_state_name_returns_original():
|
||||
instruction_template = "Hello {invalid-key}!"
|
||||
invocation_context = await _create_test_readonly_context(
|
||||
state={"user_name": "Foo"}
|
||||
)
|
||||
|
||||
populated_instruction = await instructions_utils.inject_session_state(
|
||||
instruction_template, invocation_context
|
||||
)
|
||||
assert populated_instruction == "Hello {invalid-key}!"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_inject_session_state_with_invalid_prefix_state_name_returns_original():
|
||||
instruction_template = "Hello {invalid:key}!"
|
||||
invocation_context = await _create_test_readonly_context(
|
||||
state={"user_name": "Foo"}
|
||||
)
|
||||
|
||||
populated_instruction = await instructions_utils.inject_session_state(
|
||||
instruction_template, invocation_context
|
||||
)
|
||||
assert populated_instruction == "Hello {invalid:key}!"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_inject_session_state_with_valid_prefix_state():
|
||||
instruction_template = "Hello {app:user_name}!"
|
||||
invocation_context = await _create_test_readonly_context(
|
||||
state={"app:user_name": "Foo"}
|
||||
)
|
||||
|
||||
populated_instruction = await instructions_utils.inject_session_state(
|
||||
instruction_template, invocation_context
|
||||
)
|
||||
assert populated_instruction == "Hello Foo!"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_inject_session_state_with_multiple_variables_and_artifacts():
|
||||
instruction_template = """
|
||||
Hello {user_name},
|
||||
You are {user_age} years old.
|
||||
Your favorite color is {favorite_color?}.
|
||||
The artifact says: {artifact.my_file}
|
||||
And another optional artifact: {artifact.other_file}
|
||||
"""
|
||||
mock_artifact_service = MockArtifactService({
|
||||
"my_file": "This is my artifact content.",
|
||||
"other_file": "This is another artifact content.",
|
||||
})
|
||||
invocation_context = await _create_test_readonly_context(
|
||||
state={"user_name": "Foo", "user_age": 30, "favorite_color": "blue"},
|
||||
artifact_service=mock_artifact_service,
|
||||
)
|
||||
|
||||
populated_instruction = await instructions_utils.inject_session_state(
|
||||
instruction_template, invocation_context
|
||||
)
|
||||
expected_instruction = """
|
||||
Hello Foo,
|
||||
You are 30 years old.
|
||||
Your favorite color is blue.
|
||||
The artifact says: This is my artifact content.
|
||||
And another optional artifact: This is another artifact content.
|
||||
"""
|
||||
assert populated_instruction == expected_instruction
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_inject_session_state_with_empty_artifact_name_raises_key_error():
|
||||
instruction_template = "The artifact content is: {artifact.}"
|
||||
mock_artifact_service = MockArtifactService(
|
||||
{"my_file": "This is my artifact content."}
|
||||
)
|
||||
invocation_context = await _create_test_readonly_context(
|
||||
artifact_service=mock_artifact_service
|
||||
)
|
||||
|
||||
with pytest.raises(KeyError, match="Artifact '' not found."):
|
||||
await instructions_utils.inject_session_state(
|
||||
instruction_template, invocation_context
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_inject_session_state_artifact_service_not_initialized_raises_value_error():
|
||||
instruction_template = "The artifact content is: {artifact.my_file}"
|
||||
invocation_context = await _create_test_readonly_context()
|
||||
with pytest.raises(ValueError, match="Artifact service is not initialized."):
|
||||
await instructions_utils.inject_session_state(
|
||||
instruction_template, invocation_context
|
||||
)
|
Loading…
Reference in New Issue
Block a user