From b2a2b11776c56142d69eae27509cbfbacf5d4007 Mon Sep 17 00:00:00 2001 From: "Xiang (Sean) Zhou" Date: Wed, 21 May 2025 13:38:27 -0700 Subject: [PATCH] ADK changes PiperOrigin-RevId: 761650284 --- src/google/adk/utils/__init__.py | 13 ++ src/google/adk/utils/instructions_utils.py | 132 +++++++++++++++++++++ 2 files changed, 145 insertions(+) create mode 100644 src/google/adk/utils/__init__.py create mode 100644 src/google/adk/utils/instructions_utils.py diff --git a/src/google/adk/utils/__init__.py b/src/google/adk/utils/__init__.py new file mode 100644 index 0000000..0a2669d --- /dev/null +++ b/src/google/adk/utils/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/src/google/adk/utils/instructions_utils.py b/src/google/adk/utils/instructions_utils.py new file mode 100644 index 0000000..5c63332 --- /dev/null +++ b/src/google/adk/utils/instructions_utils.py @@ -0,0 +1,132 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import re + +from ..agents.invocation_context import InvocationContext +from ..agents.readonly_context import ReadonlyContext +from ..sessions.state import State + +__all__ = [ + 'inject_session_state', +] + + +async def inject_session_state( + template: str, + readonly_context: ReadonlyContext, +) -> str: + """Populates values in the instruction template, e.g. state, artifact, etc. + + This method is intended to be used in InstructionProvider based instruction + and global_instruction which are called with readonly_context. + + e.g. + ``` + ... + from google.adk.utils import instructions_utils + + async def build_instruction( + readonly_context: ReadonlyContext, + ) -> str: + return await instructions_utils.inject_session_state( + 'You can inject a state variable like {var_name} or an artifact ' + '{artifact.file_name} into the instruction template.', + readonly_context, + ) + + agent = Agent( + model="gemini-2.0-flash", + name="agent", + instruction=build_instruction, + ) + ``` + + Args: + template: The instruction template. + readonly_context: The read-only context + + Returns: + The instruction template with values populated. + """ + + invocation_context = readonly_context._invocation_context + + async def _async_sub(pattern, repl_async_fn, string) -> str: + result = [] + last_end = 0 + for match in re.finditer(pattern, string): + result.append(string[last_end : match.start()]) + replacement = await repl_async_fn(match) + result.append(replacement) + last_end = match.end() + result.append(string[last_end:]) + return ''.join(result) + + async def _replace_match(match) -> str: + var_name = match.group().lstrip('{').rstrip('}').strip() + optional = False + if var_name.endswith('?'): + optional = True + var_name = var_name.removesuffix('?') + if var_name.startswith('artifact.'): + var_name = var_name.removeprefix('artifact.') + if invocation_context.artifact_service is None: + raise ValueError('Artifact service is not initialized.') + artifact = await invocation_context.artifact_service.load_artifact( + app_name=invocation_context.session.app_name, + user_id=invocation_context.session.user_id, + session_id=invocation_context.session.id, + filename=var_name, + ) + if not var_name: + raise KeyError(f'Artifact {var_name} not found.') + return str(artifact) + else: + if not _is_valid_state_name(var_name): + return match.group() + if var_name in invocation_context.session.state: + return str(invocation_context.session.state[var_name]) + else: + if optional: + return '' + else: + raise KeyError(f'Context variable not found: `{var_name}`.') + + return await _async_sub(r'{+[^{}]*}+', _replace_match, template) + + +def _is_valid_state_name(var_name): + """Checks if the variable name is a valid state name. + + Valid state is either: + - Valid identifier + - : + All the others will just return as it is. + + Args: + var_name: The variable name to check. + + Returns: + True if the variable name is a valid state name, False otherwise. + """ + parts = var_name.split(':') + if len(parts) == 1: + return var_name.isidentifier() + + if len(parts) == 2: + prefixes = [State.APP_PREFIX, State.USER_PREFIX, State.TEMP_PREFIX] + if (parts[0] + ':') in prefixes: + return parts[1].isidentifier() + return False