mirror of
https://github.com/EvolutionAPI/adk-python.git
synced 2025-07-14 01:41:25 -06:00
Support async instruction and global instruction provider
PiperOrigin-RevId: 757808335
This commit is contained in:
parent
812485fdfa
commit
4c4cfb74ae
@ -14,6 +14,7 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import inspect
|
||||||
import logging
|
import logging
|
||||||
from typing import (
|
from typing import (
|
||||||
Any,
|
Any,
|
||||||
@ -96,7 +97,9 @@ AfterToolCallback: TypeAlias = Union[
|
|||||||
list[_SingleAfterToolCallback],
|
list[_SingleAfterToolCallback],
|
||||||
]
|
]
|
||||||
|
|
||||||
InstructionProvider: TypeAlias = Callable[[ReadonlyContext], str]
|
InstructionProvider: TypeAlias = Callable[
|
||||||
|
[ReadonlyContext], Union[str, Awaitable[str]]
|
||||||
|
]
|
||||||
|
|
||||||
ToolUnion: TypeAlias = Union[Callable, BaseTool, BaseToolset]
|
ToolUnion: TypeAlias = Union[Callable, BaseTool, BaseToolset]
|
||||||
ExamplesUnion = Union[list[Example], BaseExampleProvider]
|
ExamplesUnion = Union[list[Example], BaseExampleProvider]
|
||||||
@ -302,7 +305,7 @@ class LlmAgent(BaseAgent):
|
|||||||
ancestor_agent = ancestor_agent.parent_agent
|
ancestor_agent = ancestor_agent.parent_agent
|
||||||
raise ValueError(f'No model found for {self.name}.')
|
raise ValueError(f'No model found for {self.name}.')
|
||||||
|
|
||||||
def canonical_instruction(self, ctx: ReadonlyContext) -> str:
|
async def canonical_instruction(self, ctx: ReadonlyContext) -> str:
|
||||||
"""The resolved self.instruction field to construct instruction for this agent.
|
"""The resolved self.instruction field to construct instruction for this agent.
|
||||||
|
|
||||||
This method is only for use by Agent Development Kit.
|
This method is only for use by Agent Development Kit.
|
||||||
@ -310,9 +313,12 @@ class LlmAgent(BaseAgent):
|
|||||||
if isinstance(self.instruction, str):
|
if isinstance(self.instruction, str):
|
||||||
return self.instruction
|
return self.instruction
|
||||||
else:
|
else:
|
||||||
return self.instruction(ctx)
|
instruction = self.instruction(ctx)
|
||||||
|
if inspect.isawaitable(instruction):
|
||||||
|
instruction = await instruction
|
||||||
|
return instruction
|
||||||
|
|
||||||
def canonical_global_instruction(self, ctx: ReadonlyContext) -> str:
|
async def canonical_global_instruction(self, ctx: ReadonlyContext) -> str:
|
||||||
"""The resolved self.instruction field to construct global instruction.
|
"""The resolved self.instruction field to construct global instruction.
|
||||||
|
|
||||||
This method is only for use by Agent Development Kit.
|
This method is only for use by Agent Development Kit.
|
||||||
@ -320,7 +326,10 @@ class LlmAgent(BaseAgent):
|
|||||||
if isinstance(self.global_instruction, str):
|
if isinstance(self.global_instruction, str):
|
||||||
return self.global_instruction
|
return self.global_instruction
|
||||||
else:
|
else:
|
||||||
return self.global_instruction(ctx)
|
global_instruction = self.global_instruction(ctx)
|
||||||
|
if inspect.isawaitable(global_instruction):
|
||||||
|
global_instruction = await global_instruction
|
||||||
|
return global_instruction
|
||||||
|
|
||||||
async def canonical_tools(
|
async def canonical_tools(
|
||||||
self, ctx: ReadonlyContext = None
|
self, ctx: ReadonlyContext = None
|
||||||
|
@ -53,7 +53,7 @@ class _InstructionsLlmRequestProcessor(BaseLlmRequestProcessor):
|
|||||||
if (
|
if (
|
||||||
isinstance(root_agent, LlmAgent) and root_agent.global_instruction
|
isinstance(root_agent, LlmAgent) and root_agent.global_instruction
|
||||||
): # not empty str
|
): # not empty str
|
||||||
raw_si = root_agent.canonical_global_instruction(
|
raw_si = await root_agent.canonical_global_instruction(
|
||||||
ReadonlyContext(invocation_context)
|
ReadonlyContext(invocation_context)
|
||||||
)
|
)
|
||||||
si = await _populate_values(raw_si, invocation_context)
|
si = await _populate_values(raw_si, invocation_context)
|
||||||
@ -61,7 +61,9 @@ class _InstructionsLlmRequestProcessor(BaseLlmRequestProcessor):
|
|||||||
|
|
||||||
# Appends agent instructions if set.
|
# Appends agent instructions if set.
|
||||||
if agent.instruction: # not empty str
|
if agent.instruction: # not empty str
|
||||||
raw_si = agent.canonical_instruction(ReadonlyContext(invocation_context))
|
raw_si = await agent.canonical_instruction(
|
||||||
|
ReadonlyContext(invocation_context)
|
||||||
|
)
|
||||||
si = await _populate_values(raw_si, invocation_context)
|
si = await _populate_values(raw_si, invocation_context)
|
||||||
llm_request.append_instructions([si])
|
llm_request.append_instructions([si])
|
||||||
|
|
||||||
|
@ -92,6 +92,16 @@ def test_canonical_instruction():
|
|||||||
assert agent.canonical_instruction(ctx) == 'instruction: state_value'
|
assert agent.canonical_instruction(ctx) == 'instruction: state_value'
|
||||||
|
|
||||||
|
|
||||||
|
def test_async_canonical_instruction():
|
||||||
|
async def _instruction_provider(ctx: ReadonlyContext) -> str:
|
||||||
|
return f'instruction: {ctx.state["state_var"]}'
|
||||||
|
|
||||||
|
agent = LlmAgent(name='test_agent', instruction=_instruction_provider)
|
||||||
|
ctx = _create_readonly_context(agent, state={'state_var': 'state_value'})
|
||||||
|
|
||||||
|
assert agent.canonical_instruction(ctx) == 'instruction: state_value'
|
||||||
|
|
||||||
|
|
||||||
def test_canonical_global_instruction_str():
|
def test_canonical_global_instruction_str():
|
||||||
agent = LlmAgent(name='test_agent', global_instruction='global instruction')
|
agent = LlmAgent(name='test_agent', global_instruction='global instruction')
|
||||||
ctx = _create_readonly_context(agent)
|
ctx = _create_readonly_context(agent)
|
||||||
@ -114,6 +124,21 @@ def test_canonical_global_instruction():
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_async_canonical_global_instruction():
|
||||||
|
async def _global_instruction_provider(ctx: ReadonlyContext) -> str:
|
||||||
|
return f'global instruction: {ctx.state["state_var"]}'
|
||||||
|
|
||||||
|
agent = LlmAgent(
|
||||||
|
name='test_agent', global_instruction=_global_instruction_provider
|
||||||
|
)
|
||||||
|
ctx = _create_readonly_context(agent, state={'state_var': 'state_value'})
|
||||||
|
|
||||||
|
assert (
|
||||||
|
agent.canonical_global_instruction(ctx)
|
||||||
|
== 'global instruction: state_value'
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_output_schema_will_disable_transfer(caplog: pytest.LogCaptureFixture):
|
def test_output_schema_will_disable_transfer(caplog: pytest.LogCaptureFixture):
|
||||||
with caplog.at_level('WARNING'):
|
with caplog.at_level('WARNING'):
|
||||||
|
|
||||||
|
@ -92,6 +92,44 @@ async def test_function_system_instruction():
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_async_function_system_instruction():
|
||||||
|
async def build_function_instruction(
|
||||||
|
readonly_context: ReadonlyContext,
|
||||||
|
) -> str:
|
||||||
|
return (
|
||||||
|
"This is the function agent instruction for invocation:"
|
||||||
|
f" {readonly_context.invocation_id}."
|
||||||
|
)
|
||||||
|
|
||||||
|
request = LlmRequest(
|
||||||
|
model="gemini-1.5-flash",
|
||||||
|
config=types.GenerateContentConfig(system_instruction=""),
|
||||||
|
)
|
||||||
|
agent = Agent(
|
||||||
|
model="gemini-1.5-flash",
|
||||||
|
name="agent",
|
||||||
|
instruction=build_function_instruction,
|
||||||
|
)
|
||||||
|
invocation_context = utils.create_invocation_context(agent=agent)
|
||||||
|
invocation_context.session = Session(
|
||||||
|
app_name="test_app",
|
||||||
|
user_id="test_user",
|
||||||
|
id="test_id",
|
||||||
|
state={"customerId": "1234567890", "customer_int": 30},
|
||||||
|
)
|
||||||
|
|
||||||
|
async for _ in instructions.request_processor.run_async(
|
||||||
|
invocation_context,
|
||||||
|
request,
|
||||||
|
):
|
||||||
|
pass
|
||||||
|
|
||||||
|
assert request.config.system_instruction == (
|
||||||
|
"This is the function agent instruction for invocation: test_id."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_global_system_instruction():
|
async def test_global_system_instruction():
|
||||||
sub_agent = Agent(
|
sub_agent = Agent(
|
||||||
@ -128,6 +166,90 @@ async def test_global_system_instruction():
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_function_global_system_instruction():
|
||||||
|
def sub_agent_si(readonly_context: ReadonlyContext) -> str:
|
||||||
|
return "This is the sub agent instruction."
|
||||||
|
|
||||||
|
def root_agent_gi(readonly_context: ReadonlyContext) -> str:
|
||||||
|
return "This is the global instruction."
|
||||||
|
|
||||||
|
sub_agent = Agent(
|
||||||
|
model="gemini-1.5-flash",
|
||||||
|
name="sub_agent",
|
||||||
|
instruction=sub_agent_si,
|
||||||
|
)
|
||||||
|
root_agent = Agent(
|
||||||
|
model="gemini-1.5-flash",
|
||||||
|
name="root_agent",
|
||||||
|
global_instruction=root_agent_gi,
|
||||||
|
sub_agents=[sub_agent],
|
||||||
|
)
|
||||||
|
request = LlmRequest(
|
||||||
|
model="gemini-1.5-flash",
|
||||||
|
config=types.GenerateContentConfig(system_instruction=""),
|
||||||
|
)
|
||||||
|
invocation_context = utils.create_invocation_context(agent=sub_agent)
|
||||||
|
invocation_context.session = Session(
|
||||||
|
app_name="test_app",
|
||||||
|
user_id="test_user",
|
||||||
|
id="test_id",
|
||||||
|
state={"customerId": "1234567890", "customer_int": 30},
|
||||||
|
)
|
||||||
|
|
||||||
|
async for _ in instructions.request_processor.run_async(
|
||||||
|
invocation_context,
|
||||||
|
request,
|
||||||
|
):
|
||||||
|
pass
|
||||||
|
|
||||||
|
assert request.config.system_instruction == (
|
||||||
|
"This is the global instruction.\n\nThis is the sub agent instruction."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_async_function_global_system_instruction():
|
||||||
|
async def sub_agent_si(readonly_context: ReadonlyContext) -> str:
|
||||||
|
return "This is the sub agent instruction."
|
||||||
|
|
||||||
|
async def root_agent_gi(readonly_context: ReadonlyContext) -> str:
|
||||||
|
return "This is the global instruction."
|
||||||
|
|
||||||
|
sub_agent = Agent(
|
||||||
|
model="gemini-1.5-flash",
|
||||||
|
name="sub_agent",
|
||||||
|
instruction=sub_agent_si,
|
||||||
|
)
|
||||||
|
root_agent = Agent(
|
||||||
|
model="gemini-1.5-flash",
|
||||||
|
name="root_agent",
|
||||||
|
global_instruction=root_agent_gi,
|
||||||
|
sub_agents=[sub_agent],
|
||||||
|
)
|
||||||
|
request = LlmRequest(
|
||||||
|
model="gemini-1.5-flash",
|
||||||
|
config=types.GenerateContentConfig(system_instruction=""),
|
||||||
|
)
|
||||||
|
invocation_context = utils.create_invocation_context(agent=sub_agent)
|
||||||
|
invocation_context.session = Session(
|
||||||
|
app_name="test_app",
|
||||||
|
user_id="test_user",
|
||||||
|
id="test_id",
|
||||||
|
state={"customerId": "1234567890", "customer_int": 30},
|
||||||
|
)
|
||||||
|
|
||||||
|
async for _ in instructions.request_processor.run_async(
|
||||||
|
invocation_context,
|
||||||
|
request,
|
||||||
|
):
|
||||||
|
pass
|
||||||
|
|
||||||
|
assert request.config.system_instruction == (
|
||||||
|
"This is the global instruction.\n\nThis is the sub agent instruction."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_build_system_instruction_with_namespace():
|
async def test_build_system_instruction_with_namespace():
|
||||||
request = LlmRequest(
|
request = LlmRequest(
|
||||||
|
Loading…
Reference in New Issue
Block a user