ADK changes

PiperOrigin-RevId: 761650284
This commit is contained in:
Xiang (Sean) Zhou 2025-05-21 13:38:27 -07:00 committed by Copybara-Service
parent 1773cdab96
commit b2a2b11776
2 changed files with 145 additions and 0 deletions

View File

@ -0,0 +1,13 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

View File

@ -0,0 +1,132 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import re
from ..agents.invocation_context import InvocationContext
from ..agents.readonly_context import ReadonlyContext
from ..sessions.state import State
__all__ = [
'inject_session_state',
]
async def inject_session_state(
template: str,
readonly_context: ReadonlyContext,
) -> str:
"""Populates values in the instruction template, e.g. state, artifact, etc.
This method is intended to be used in InstructionProvider based instruction
and global_instruction which are called with readonly_context.
e.g.
```
...
from google.adk.utils import instructions_utils
async def build_instruction(
readonly_context: ReadonlyContext,
) -> str:
return await instructions_utils.inject_session_state(
'You can inject a state variable like {var_name} or an artifact '
'{artifact.file_name} into the instruction template.',
readonly_context,
)
agent = Agent(
model="gemini-2.0-flash",
name="agent",
instruction=build_instruction,
)
```
Args:
template: The instruction template.
readonly_context: The read-only context
Returns:
The instruction template with values populated.
"""
invocation_context = readonly_context._invocation_context
async def _async_sub(pattern, repl_async_fn, string) -> str:
result = []
last_end = 0
for match in re.finditer(pattern, string):
result.append(string[last_end : match.start()])
replacement = await repl_async_fn(match)
result.append(replacement)
last_end = match.end()
result.append(string[last_end:])
return ''.join(result)
async def _replace_match(match) -> str:
var_name = match.group().lstrip('{').rstrip('}').strip()
optional = False
if var_name.endswith('?'):
optional = True
var_name = var_name.removesuffix('?')
if var_name.startswith('artifact.'):
var_name = var_name.removeprefix('artifact.')
if invocation_context.artifact_service is None:
raise ValueError('Artifact service is not initialized.')
artifact = await invocation_context.artifact_service.load_artifact(
app_name=invocation_context.session.app_name,
user_id=invocation_context.session.user_id,
session_id=invocation_context.session.id,
filename=var_name,
)
if not var_name:
raise KeyError(f'Artifact {var_name} not found.')
return str(artifact)
else:
if not _is_valid_state_name(var_name):
return match.group()
if var_name in invocation_context.session.state:
return str(invocation_context.session.state[var_name])
else:
if optional:
return ''
else:
raise KeyError(f'Context variable not found: `{var_name}`.')
return await _async_sub(r'{+[^{}]*}+', _replace_match, template)
def _is_valid_state_name(var_name):
"""Checks if the variable name is a valid state name.
Valid state is either:
- Valid identifier
- <Valid prefix>:<Valid identifier>
All the others will just return as it is.
Args:
var_name: The variable name to check.
Returns:
True if the variable name is a valid state name, False otherwise.
"""
parts = var_name.split(':')
if len(parts) == 1:
return var_name.isidentifier()
if len(parts) == 2:
prefixes = [State.APP_PREFIX, State.USER_PREFIX, State.TEMP_PREFIX]
if (parts[0] + ':') in prefixes:
return parts[1].isidentifier()
return False