mirror of
https://github.com/EvolutionAPI/adk-python.git
synced 2025-07-14 01:41:25 -06:00
ADK changes
PiperOrigin-RevId: 755201925
This commit is contained in:
parent
6dec235c13
commit
905c20dad6
@ -65,7 +65,7 @@ class CallbackContext(ReadonlyContext):
|
|||||||
"""The user content that started this invocation. READONLY field."""
|
"""The user content that started this invocation. READONLY field."""
|
||||||
return self._invocation_context.user_content
|
return self._invocation_context.user_content
|
||||||
|
|
||||||
def load_artifact(
|
async def load_artifact(
|
||||||
self, filename: str, version: Optional[int] = None
|
self, filename: str, version: Optional[int] = None
|
||||||
) -> Optional[types.Part]:
|
) -> Optional[types.Part]:
|
||||||
"""Loads an artifact attached to the current session.
|
"""Loads an artifact attached to the current session.
|
||||||
@ -80,7 +80,7 @@ class CallbackContext(ReadonlyContext):
|
|||||||
"""
|
"""
|
||||||
if self._invocation_context.artifact_service is None:
|
if self._invocation_context.artifact_service is None:
|
||||||
raise ValueError("Artifact service is not initialized.")
|
raise ValueError("Artifact service is not initialized.")
|
||||||
return self._invocation_context.artifact_service.load_artifact(
|
return await self._invocation_context.artifact_service.load_artifact(
|
||||||
app_name=self._invocation_context.app_name,
|
app_name=self._invocation_context.app_name,
|
||||||
user_id=self._invocation_context.user_id,
|
user_id=self._invocation_context.user_id,
|
||||||
session_id=self._invocation_context.session.id,
|
session_id=self._invocation_context.session.id,
|
||||||
@ -88,7 +88,7 @@ class CallbackContext(ReadonlyContext):
|
|||||||
version=version,
|
version=version,
|
||||||
)
|
)
|
||||||
|
|
||||||
def save_artifact(self, filename: str, artifact: types.Part) -> int:
|
async def save_artifact(self, filename: str, artifact: types.Part) -> int:
|
||||||
"""Saves an artifact and records it as delta for the current session.
|
"""Saves an artifact and records it as delta for the current session.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -100,7 +100,7 @@ class CallbackContext(ReadonlyContext):
|
|||||||
"""
|
"""
|
||||||
if self._invocation_context.artifact_service is None:
|
if self._invocation_context.artifact_service is None:
|
||||||
raise ValueError("Artifact service is not initialized.")
|
raise ValueError("Artifact service is not initialized.")
|
||||||
version = self._invocation_context.artifact_service.save_artifact(
|
version = await self._invocation_context.artifact_service.save_artifact(
|
||||||
app_name=self._invocation_context.app_name,
|
app_name=self._invocation_context.app_name,
|
||||||
user_id=self._invocation_context.user_id,
|
user_id=self._invocation_context.user_id,
|
||||||
session_id=self._invocation_context.session.id,
|
session_id=self._invocation_context.session.id,
|
||||||
|
@ -25,7 +25,7 @@ class BaseArtifactService(ABC):
|
|||||||
"""Abstract base class for artifact services."""
|
"""Abstract base class for artifact services."""
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def save_artifact(
|
async def save_artifact(
|
||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
app_name: str,
|
app_name: str,
|
||||||
@ -53,7 +53,7 @@ class BaseArtifactService(ABC):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def load_artifact(
|
async def load_artifact(
|
||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
app_name: str,
|
app_name: str,
|
||||||
@ -81,7 +81,7 @@ class BaseArtifactService(ABC):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def list_artifact_keys(
|
async def list_artifact_keys(
|
||||||
self, *, app_name: str, user_id: str, session_id: str
|
self, *, app_name: str, user_id: str, session_id: str
|
||||||
) -> list[str]:
|
) -> list[str]:
|
||||||
"""Lists all the artifact filenames within a session.
|
"""Lists all the artifact filenames within a session.
|
||||||
@ -97,7 +97,7 @@ class BaseArtifactService(ABC):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def delete_artifact(
|
async def delete_artifact(
|
||||||
self, *, app_name: str, user_id: str, session_id: str, filename: str
|
self, *, app_name: str, user_id: str, session_id: str, filename: str
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Deletes an artifact.
|
"""Deletes an artifact.
|
||||||
@ -111,7 +111,7 @@ class BaseArtifactService(ABC):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def list_versions(
|
async def list_versions(
|
||||||
self, *, app_name: str, user_id: str, session_id: str, filename: str
|
self, *, app_name: str, user_id: str, session_id: str, filename: str
|
||||||
) -> list[int]:
|
) -> list[int]:
|
||||||
"""Lists all versions of an artifact.
|
"""Lists all versions of an artifact.
|
||||||
|
@ -77,7 +77,7 @@ class GcsArtifactService(BaseArtifactService):
|
|||||||
return f"{app_name}/{user_id}/{session_id}/{filename}/{version}"
|
return f"{app_name}/{user_id}/{session_id}/{filename}/{version}"
|
||||||
|
|
||||||
@override
|
@override
|
||||||
def save_artifact(
|
async def save_artifact(
|
||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
app_name: str,
|
app_name: str,
|
||||||
@ -86,7 +86,7 @@ class GcsArtifactService(BaseArtifactService):
|
|||||||
filename: str,
|
filename: str,
|
||||||
artifact: types.Part,
|
artifact: types.Part,
|
||||||
) -> int:
|
) -> int:
|
||||||
versions = self.list_versions(
|
versions = await self.list_versions(
|
||||||
app_name=app_name,
|
app_name=app_name,
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
@ -107,7 +107,7 @@ class GcsArtifactService(BaseArtifactService):
|
|||||||
return version
|
return version
|
||||||
|
|
||||||
@override
|
@override
|
||||||
def load_artifact(
|
async def load_artifact(
|
||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
app_name: str,
|
app_name: str,
|
||||||
@ -117,7 +117,7 @@ class GcsArtifactService(BaseArtifactService):
|
|||||||
version: Optional[int] = None,
|
version: Optional[int] = None,
|
||||||
) -> Optional[types.Part]:
|
) -> Optional[types.Part]:
|
||||||
if version is None:
|
if version is None:
|
||||||
versions = self.list_versions(
|
versions = await self.list_versions(
|
||||||
app_name=app_name,
|
app_name=app_name,
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
@ -141,7 +141,7 @@ class GcsArtifactService(BaseArtifactService):
|
|||||||
return artifact
|
return artifact
|
||||||
|
|
||||||
@override
|
@override
|
||||||
def list_artifact_keys(
|
async def list_artifact_keys(
|
||||||
self, *, app_name: str, user_id: str, session_id: str
|
self, *, app_name: str, user_id: str, session_id: str
|
||||||
) -> list[str]:
|
) -> list[str]:
|
||||||
filenames = set()
|
filenames = set()
|
||||||
@ -165,10 +165,10 @@ class GcsArtifactService(BaseArtifactService):
|
|||||||
return sorted(list(filenames))
|
return sorted(list(filenames))
|
||||||
|
|
||||||
@override
|
@override
|
||||||
def delete_artifact(
|
async def delete_artifact(
|
||||||
self, *, app_name: str, user_id: str, session_id: str, filename: str
|
self, *, app_name: str, user_id: str, session_id: str, filename: str
|
||||||
) -> None:
|
) -> None:
|
||||||
versions = self.list_versions(
|
versions = await self.list_versions(
|
||||||
app_name=app_name,
|
app_name=app_name,
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
@ -183,7 +183,7 @@ class GcsArtifactService(BaseArtifactService):
|
|||||||
return
|
return
|
||||||
|
|
||||||
@override
|
@override
|
||||||
def list_versions(
|
async def list_versions(
|
||||||
self, *, app_name: str, user_id: str, session_id: str, filename: str
|
self, *, app_name: str, user_id: str, session_id: str, filename: str
|
||||||
) -> list[int]:
|
) -> list[int]:
|
||||||
prefix = self._get_blob_name(app_name, user_id, session_id, filename, "")
|
prefix = self._get_blob_name(app_name, user_id, session_id, filename, "")
|
||||||
|
@ -63,7 +63,7 @@ class InMemoryArtifactService(BaseArtifactService, BaseModel):
|
|||||||
return f"{app_name}/{user_id}/{session_id}/{filename}"
|
return f"{app_name}/{user_id}/{session_id}/{filename}"
|
||||||
|
|
||||||
@override
|
@override
|
||||||
def save_artifact(
|
async def save_artifact(
|
||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
app_name: str,
|
app_name: str,
|
||||||
@ -80,7 +80,7 @@ class InMemoryArtifactService(BaseArtifactService, BaseModel):
|
|||||||
return version
|
return version
|
||||||
|
|
||||||
@override
|
@override
|
||||||
def load_artifact(
|
async def load_artifact(
|
||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
app_name: str,
|
app_name: str,
|
||||||
@ -98,7 +98,7 @@ class InMemoryArtifactService(BaseArtifactService, BaseModel):
|
|||||||
return versions[version]
|
return versions[version]
|
||||||
|
|
||||||
@override
|
@override
|
||||||
def list_artifact_keys(
|
async def list_artifact_keys(
|
||||||
self, *, app_name: str, user_id: str, session_id: str
|
self, *, app_name: str, user_id: str, session_id: str
|
||||||
) -> list[str]:
|
) -> list[str]:
|
||||||
session_prefix = f"{app_name}/{user_id}/{session_id}/"
|
session_prefix = f"{app_name}/{user_id}/{session_id}/"
|
||||||
@ -114,7 +114,7 @@ class InMemoryArtifactService(BaseArtifactService, BaseModel):
|
|||||||
return sorted(filenames)
|
return sorted(filenames)
|
||||||
|
|
||||||
@override
|
@override
|
||||||
def delete_artifact(
|
async def delete_artifact(
|
||||||
self, *, app_name: str, user_id: str, session_id: str, filename: str
|
self, *, app_name: str, user_id: str, session_id: str, filename: str
|
||||||
) -> None:
|
) -> None:
|
||||||
path = self._artifact_path(app_name, user_id, session_id, filename)
|
path = self._artifact_path(app_name, user_id, session_id, filename)
|
||||||
@ -123,7 +123,7 @@ class InMemoryArtifactService(BaseArtifactService, BaseModel):
|
|||||||
self.artifacts.pop(path, None)
|
self.artifacts.pop(path, None)
|
||||||
|
|
||||||
@override
|
@override
|
||||||
def list_versions(
|
async def list_versions(
|
||||||
self, *, app_name: str, user_id: str, session_id: str, filename: str
|
self, *, app_name: str, user_id: str, session_id: str, filename: str
|
||||||
) -> list[int]:
|
) -> list[int]:
|
||||||
path = self._artifact_path(app_name, user_id, session_id, filename)
|
path = self._artifact_path(app_name, user_id, session_id, filename)
|
||||||
|
@ -503,7 +503,7 @@ def get_fast_api_app(
|
|||||||
"/apps/{app_name}/users/{user_id}/sessions/{session_id}/artifacts/{artifact_name}",
|
"/apps/{app_name}/users/{user_id}/sessions/{session_id}/artifacts/{artifact_name}",
|
||||||
response_model_exclude_none=True,
|
response_model_exclude_none=True,
|
||||||
)
|
)
|
||||||
def load_artifact(
|
async def load_artifact(
|
||||||
app_name: str,
|
app_name: str,
|
||||||
user_id: str,
|
user_id: str,
|
||||||
session_id: str,
|
session_id: str,
|
||||||
@ -511,7 +511,7 @@ def get_fast_api_app(
|
|||||||
version: Optional[int] = Query(None),
|
version: Optional[int] = Query(None),
|
||||||
) -> Optional[types.Part]:
|
) -> Optional[types.Part]:
|
||||||
app_name = agent_engine_id if agent_engine_id else app_name
|
app_name = agent_engine_id if agent_engine_id else app_name
|
||||||
artifact = artifact_service.load_artifact(
|
artifact = await artifact_service.load_artifact(
|
||||||
app_name=app_name,
|
app_name=app_name,
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
@ -526,7 +526,7 @@ def get_fast_api_app(
|
|||||||
"/apps/{app_name}/users/{user_id}/sessions/{session_id}/artifacts/{artifact_name}/versions/{version_id}",
|
"/apps/{app_name}/users/{user_id}/sessions/{session_id}/artifacts/{artifact_name}/versions/{version_id}",
|
||||||
response_model_exclude_none=True,
|
response_model_exclude_none=True,
|
||||||
)
|
)
|
||||||
def load_artifact_version(
|
async def load_artifact_version(
|
||||||
app_name: str,
|
app_name: str,
|
||||||
user_id: str,
|
user_id: str,
|
||||||
session_id: str,
|
session_id: str,
|
||||||
@ -534,7 +534,7 @@ def get_fast_api_app(
|
|||||||
version_id: int,
|
version_id: int,
|
||||||
) -> Optional[types.Part]:
|
) -> Optional[types.Part]:
|
||||||
app_name = agent_engine_id if agent_engine_id else app_name
|
app_name = agent_engine_id if agent_engine_id else app_name
|
||||||
artifact = artifact_service.load_artifact(
|
artifact = await artifact_service.load_artifact(
|
||||||
app_name=app_name,
|
app_name=app_name,
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
@ -549,11 +549,11 @@ def get_fast_api_app(
|
|||||||
"/apps/{app_name}/users/{user_id}/sessions/{session_id}/artifacts",
|
"/apps/{app_name}/users/{user_id}/sessions/{session_id}/artifacts",
|
||||||
response_model_exclude_none=True,
|
response_model_exclude_none=True,
|
||||||
)
|
)
|
||||||
def list_artifact_names(
|
async def list_artifact_names(
|
||||||
app_name: str, user_id: str, session_id: str
|
app_name: str, user_id: str, session_id: str
|
||||||
) -> list[str]:
|
) -> list[str]:
|
||||||
app_name = agent_engine_id if agent_engine_id else app_name
|
app_name = agent_engine_id if agent_engine_id else app_name
|
||||||
return artifact_service.list_artifact_keys(
|
return await artifact_service.list_artifact_keys(
|
||||||
app_name=app_name, user_id=user_id, session_id=session_id
|
app_name=app_name, user_id=user_id, session_id=session_id
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -561,11 +561,11 @@ def get_fast_api_app(
|
|||||||
"/apps/{app_name}/users/{user_id}/sessions/{session_id}/artifacts/{artifact_name}/versions",
|
"/apps/{app_name}/users/{user_id}/sessions/{session_id}/artifacts/{artifact_name}/versions",
|
||||||
response_model_exclude_none=True,
|
response_model_exclude_none=True,
|
||||||
)
|
)
|
||||||
def list_artifact_versions(
|
async def list_artifact_versions(
|
||||||
app_name: str, user_id: str, session_id: str, artifact_name: str
|
app_name: str, user_id: str, session_id: str, artifact_name: str
|
||||||
) -> list[int]:
|
) -> list[int]:
|
||||||
app_name = agent_engine_id if agent_engine_id else app_name
|
app_name = agent_engine_id if agent_engine_id else app_name
|
||||||
return artifact_service.list_versions(
|
return await artifact_service.list_versions(
|
||||||
app_name=app_name,
|
app_name=app_name,
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
@ -575,11 +575,11 @@ def get_fast_api_app(
|
|||||||
@app.delete(
|
@app.delete(
|
||||||
"/apps/{app_name}/users/{user_id}/sessions/{session_id}/artifacts/{artifact_name}",
|
"/apps/{app_name}/users/{user_id}/sessions/{session_id}/artifacts/{artifact_name}",
|
||||||
)
|
)
|
||||||
def delete_artifact(
|
async def delete_artifact(
|
||||||
app_name: str, user_id: str, session_id: str, artifact_name: str
|
app_name: str, user_id: str, session_id: str, artifact_name: str
|
||||||
):
|
):
|
||||||
app_name = agent_engine_id if agent_engine_id else app_name
|
app_name = agent_engine_id if agent_engine_id else app_name
|
||||||
artifact_service.delete_artifact(
|
await artifact_service.delete_artifact(
|
||||||
app_name=app_name,
|
app_name=app_name,
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
|
@ -122,7 +122,7 @@ class _CodeExecutionRequestProcessor(BaseLlmRequestProcessor):
|
|||||||
if not invocation_context.agent.code_executor:
|
if not invocation_context.agent.code_executor:
|
||||||
return
|
return
|
||||||
|
|
||||||
for event in _run_pre_processor(invocation_context, llm_request):
|
async for event in _run_pre_processor(invocation_context, llm_request):
|
||||||
yield event
|
yield event
|
||||||
|
|
||||||
# Convert the code execution parts to text parts.
|
# Convert the code execution parts to text parts.
|
||||||
@ -159,10 +159,10 @@ class _CodeExecutionResponseProcessor(BaseLlmResponseProcessor):
|
|||||||
response_processor = _CodeExecutionResponseProcessor()
|
response_processor = _CodeExecutionResponseProcessor()
|
||||||
|
|
||||||
|
|
||||||
def _run_pre_processor(
|
async def _run_pre_processor(
|
||||||
invocation_context: InvocationContext,
|
invocation_context: InvocationContext,
|
||||||
llm_request: LlmRequest,
|
llm_request: LlmRequest,
|
||||||
) -> Generator[Event, None, None]:
|
) -> AsyncGenerator[Event, None]:
|
||||||
"""Pre-process the user message by adding the user message to the Colab notebook."""
|
"""Pre-process the user message by adding the user message to the Colab notebook."""
|
||||||
from ...agents.llm_agent import LlmAgent
|
from ...agents.llm_agent import LlmAgent
|
||||||
|
|
||||||
@ -242,7 +242,7 @@ def _run_pre_processor(
|
|||||||
code_executor_context.add_processed_file_names([file.name])
|
code_executor_context.add_processed_file_names([file.name])
|
||||||
|
|
||||||
# Emit the execution result, and add it to the LLM request.
|
# Emit the execution result, and add it to the LLM request.
|
||||||
execution_result_event = _post_process_code_execution_result(
|
execution_result_event = await _post_process_code_execution_result(
|
||||||
invocation_context, code_executor_context, code_execution_result
|
invocation_context, code_executor_context, code_execution_result
|
||||||
)
|
)
|
||||||
yield execution_result_event
|
yield execution_result_event
|
||||||
@ -375,7 +375,7 @@ def _get_or_set_execution_id(
|
|||||||
return execution_id
|
return execution_id
|
||||||
|
|
||||||
|
|
||||||
def _post_process_code_execution_result(
|
async def _post_process_code_execution_result(
|
||||||
invocation_context: InvocationContext,
|
invocation_context: InvocationContext,
|
||||||
code_executor_context: CodeExecutorContext,
|
code_executor_context: CodeExecutorContext,
|
||||||
code_execution_result: CodeExecutionResult,
|
code_execution_result: CodeExecutionResult,
|
||||||
@ -406,7 +406,7 @@ def _post_process_code_execution_result(
|
|||||||
|
|
||||||
# Handle output files.
|
# Handle output files.
|
||||||
for output_file in code_execution_result.output_files:
|
for output_file in code_execution_result.output_files:
|
||||||
version = invocation_context.artifact_service.save_artifact(
|
version = await invocation_context.artifact_service.save_artifact(
|
||||||
app_name=invocation_context.app_name,
|
app_name=invocation_context.app_name,
|
||||||
user_id=invocation_context.user_id,
|
user_id=invocation_context.user_id,
|
||||||
session_id=invocation_context.session.id,
|
session_id=invocation_context.session.id,
|
||||||
|
@ -56,13 +56,13 @@ class _InstructionsLlmRequestProcessor(BaseLlmRequestProcessor):
|
|||||||
raw_si = root_agent.canonical_global_instruction(
|
raw_si = root_agent.canonical_global_instruction(
|
||||||
ReadonlyContext(invocation_context)
|
ReadonlyContext(invocation_context)
|
||||||
)
|
)
|
||||||
si = _populate_values(raw_si, invocation_context)
|
si = await _populate_values(raw_si, invocation_context)
|
||||||
llm_request.append_instructions([si])
|
llm_request.append_instructions([si])
|
||||||
|
|
||||||
# Appends agent instructions if set.
|
# Appends agent instructions if set.
|
||||||
if agent.instruction: # not empty str
|
if agent.instruction: # not empty str
|
||||||
raw_si = agent.canonical_instruction(ReadonlyContext(invocation_context))
|
raw_si = agent.canonical_instruction(ReadonlyContext(invocation_context))
|
||||||
si = _populate_values(raw_si, invocation_context)
|
si = await _populate_values(raw_si, invocation_context)
|
||||||
llm_request.append_instructions([si])
|
llm_request.append_instructions([si])
|
||||||
|
|
||||||
# Maintain async generator behavior
|
# Maintain async generator behavior
|
||||||
@ -73,13 +73,24 @@ class _InstructionsLlmRequestProcessor(BaseLlmRequestProcessor):
|
|||||||
request_processor = _InstructionsLlmRequestProcessor()
|
request_processor = _InstructionsLlmRequestProcessor()
|
||||||
|
|
||||||
|
|
||||||
def _populate_values(
|
async def _populate_values(
|
||||||
instruction_template: str,
|
instruction_template: str,
|
||||||
context: InvocationContext,
|
context: InvocationContext,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Populates values in the instruction template, e.g. state, artifact, etc."""
|
"""Populates values in the instruction template, e.g. state, artifact, etc."""
|
||||||
|
|
||||||
def _replace_match(match) -> str:
|
async def _async_sub(pattern, repl_async_fn, string) -> str:
|
||||||
|
result = []
|
||||||
|
last_end = 0
|
||||||
|
for match in re.finditer(pattern, string):
|
||||||
|
result.append(string[last_end : match.start()])
|
||||||
|
replacement = await repl_async_fn(match)
|
||||||
|
result.append(replacement)
|
||||||
|
last_end = match.end()
|
||||||
|
result.append(string[last_end:])
|
||||||
|
return ''.join(result)
|
||||||
|
|
||||||
|
async def _replace_match(match) -> str:
|
||||||
var_name = match.group().lstrip('{').rstrip('}').strip()
|
var_name = match.group().lstrip('{').rstrip('}').strip()
|
||||||
optional = False
|
optional = False
|
||||||
if var_name.endswith('?'):
|
if var_name.endswith('?'):
|
||||||
@ -89,7 +100,7 @@ def _populate_values(
|
|||||||
var_name = var_name.removeprefix('artifact.')
|
var_name = var_name.removeprefix('artifact.')
|
||||||
if context.artifact_service is None:
|
if context.artifact_service is None:
|
||||||
raise ValueError('Artifact service is not initialized.')
|
raise ValueError('Artifact service is not initialized.')
|
||||||
artifact = context.artifact_service.load_artifact(
|
artifact = await context.artifact_service.load_artifact(
|
||||||
app_name=context.session.app_name,
|
app_name=context.session.app_name,
|
||||||
user_id=context.session.user_id,
|
user_id=context.session.user_id,
|
||||||
session_id=context.session.id,
|
session_id=context.session.id,
|
||||||
@ -109,7 +120,7 @@ def _populate_values(
|
|||||||
else:
|
else:
|
||||||
raise KeyError(f'Context variable not found: `{var_name}`.')
|
raise KeyError(f'Context variable not found: `{var_name}`.')
|
||||||
|
|
||||||
return re.sub(r'{+[^{}]*}+', _replace_match, instruction_template)
|
return await _async_sub(r'{+[^{}]*}+', _replace_match, instruction_template)
|
||||||
|
|
||||||
|
|
||||||
def _is_valid_state_name(var_name):
|
def _is_valid_state_name(var_name):
|
||||||
|
@ -186,7 +186,7 @@ class Runner:
|
|||||||
root_agent = self.agent
|
root_agent = self.agent
|
||||||
|
|
||||||
if new_message:
|
if new_message:
|
||||||
self._append_new_message_to_session(
|
await self._append_new_message_to_session(
|
||||||
session,
|
session,
|
||||||
new_message,
|
new_message,
|
||||||
invocation_context,
|
invocation_context,
|
||||||
@ -199,7 +199,7 @@ class Runner:
|
|||||||
self.session_service.append_event(session=session, event=event)
|
self.session_service.append_event(session=session, event=event)
|
||||||
yield event
|
yield event
|
||||||
|
|
||||||
def _append_new_message_to_session(
|
async def _append_new_message_to_session(
|
||||||
self,
|
self,
|
||||||
session: Session,
|
session: Session,
|
||||||
new_message: types.Content,
|
new_message: types.Content,
|
||||||
@ -225,7 +225,7 @@ class Runner:
|
|||||||
if part.inline_data is None:
|
if part.inline_data is None:
|
||||||
continue
|
continue
|
||||||
file_name = f'artifact_{invocation_context.invocation_id}_{i}'
|
file_name = f'artifact_{invocation_context.invocation_id}_{i}'
|
||||||
self.artifact_service.save_artifact(
|
await self.artifact_service.save_artifact(
|
||||||
app_name=self.app_name,
|
app_name=self.app_name,
|
||||||
user_id=session.user_id,
|
user_id=session.user_id,
|
||||||
session_id=session.id,
|
session_id=session.id,
|
||||||
|
@ -146,18 +146,20 @@ class AgentTool(BaseTool):
|
|||||||
|
|
||||||
if runner.artifact_service:
|
if runner.artifact_service:
|
||||||
# Forward all artifacts to parent session.
|
# Forward all artifacts to parent session.
|
||||||
for artifact_name in runner.artifact_service.list_artifact_keys(
|
async for artifact_name in runner.artifact_service.list_artifact_keys(
|
||||||
app_name=session.app_name,
|
app_name=session.app_name,
|
||||||
user_id=session.user_id,
|
user_id=session.user_id,
|
||||||
session_id=session.id,
|
session_id=session.id,
|
||||||
):
|
):
|
||||||
if artifact := runner.artifact_service.load_artifact(
|
if artifact := await runner.artifact_service.load_artifact(
|
||||||
app_name=session.app_name,
|
app_name=session.app_name,
|
||||||
user_id=session.user_id,
|
user_id=session.user_id,
|
||||||
session_id=session.id,
|
session_id=session.id,
|
||||||
filename=artifact_name,
|
filename=artifact_name,
|
||||||
):
|
):
|
||||||
tool_context.save_artifact(filename=artifact_name, artifact=artifact)
|
await tool_context.save_artifact(
|
||||||
|
filename=artifact_name, artifact=artifact
|
||||||
|
)
|
||||||
|
|
||||||
if (
|
if (
|
||||||
not last_event
|
not last_event
|
||||||
|
@ -69,14 +69,14 @@ class LoadArtifactsTool(BaseTool):
|
|||||||
tool_context=tool_context,
|
tool_context=tool_context,
|
||||||
llm_request=llm_request,
|
llm_request=llm_request,
|
||||||
)
|
)
|
||||||
self._append_artifacts_to_llm_request(
|
await self._append_artifacts_to_llm_request(
|
||||||
tool_context=tool_context, llm_request=llm_request
|
tool_context=tool_context, llm_request=llm_request
|
||||||
)
|
)
|
||||||
|
|
||||||
def _append_artifacts_to_llm_request(
|
async def _append_artifacts_to_llm_request(
|
||||||
self, *, tool_context: ToolContext, llm_request: LlmRequest
|
self, *, tool_context: ToolContext, llm_request: LlmRequest
|
||||||
):
|
):
|
||||||
artifact_names = tool_context.list_artifacts()
|
artifact_names = await tool_context.list_artifacts()
|
||||||
if not artifact_names:
|
if not artifact_names:
|
||||||
return
|
return
|
||||||
|
|
||||||
@ -96,7 +96,7 @@ class LoadArtifactsTool(BaseTool):
|
|||||||
if function_response and function_response.name == 'load_artifacts':
|
if function_response and function_response.name == 'load_artifacts':
|
||||||
artifact_names = function_response.response['artifact_names']
|
artifact_names = function_response.response['artifact_names']
|
||||||
for artifact_name in artifact_names:
|
for artifact_name in artifact_names:
|
||||||
artifact = tool_context.load_artifact(artifact_name)
|
artifact = await tool_context.load_artifact(artifact_name)
|
||||||
llm_request.contents.append(
|
llm_request.contents.append(
|
||||||
types.Content(
|
types.Content(
|
||||||
role='user',
|
role='user',
|
||||||
|
@ -69,11 +69,11 @@ class ToolContext(CallbackContext):
|
|||||||
def get_auth_response(self, auth_config: AuthConfig) -> AuthCredential:
|
def get_auth_response(self, auth_config: AuthConfig) -> AuthCredential:
|
||||||
return AuthHandler(auth_config).get_auth_response(self.state)
|
return AuthHandler(auth_config).get_auth_response(self.state)
|
||||||
|
|
||||||
def list_artifacts(self) -> list[str]:
|
async def list_artifacts(self) -> list[str]:
|
||||||
"""Lists the filenames of the artifacts attached to the current session."""
|
"""Lists the filenames of the artifacts attached to the current session."""
|
||||||
if self._invocation_context.artifact_service is None:
|
if self._invocation_context.artifact_service is None:
|
||||||
raise ValueError('Artifact service is not initialized.')
|
raise ValueError('Artifact service is not initialized.')
|
||||||
return self._invocation_context.artifact_service.list_artifact_keys(
|
return await self._invocation_context.artifact_service.list_artifact_keys(
|
||||||
app_name=self._invocation_context.app_name,
|
app_name=self._invocation_context.app_name,
|
||||||
user_id=self._invocation_context.user_id,
|
user_id=self._invocation_context.user_id,
|
||||||
session_id=self._invocation_context.session.id,
|
session_id=self._invocation_context.session.id,
|
||||||
|
@ -152,13 +152,14 @@ def get_artifact_service(
|
|||||||
return InMemoryArtifactService()
|
return InMemoryArtifactService()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"service_type", [ArtifactServiceType.IN_MEMORY, ArtifactServiceType.GCS]
|
"service_type", [ArtifactServiceType.IN_MEMORY, ArtifactServiceType.GCS]
|
||||||
)
|
)
|
||||||
def test_load_empty(service_type):
|
async def test_load_empty(service_type):
|
||||||
"""Tests loading an artifact when none exists."""
|
"""Tests loading an artifact when none exists."""
|
||||||
artifact_service = get_artifact_service(service_type)
|
artifact_service = get_artifact_service(service_type)
|
||||||
assert not artifact_service.load_artifact(
|
assert not await artifact_service.load_artifact(
|
||||||
app_name="test_app",
|
app_name="test_app",
|
||||||
user_id="test_user",
|
user_id="test_user",
|
||||||
session_id="session_id",
|
session_id="session_id",
|
||||||
@ -166,10 +167,11 @@ def test_load_empty(service_type):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"service_type", [ArtifactServiceType.IN_MEMORY, ArtifactServiceType.GCS]
|
"service_type", [ArtifactServiceType.IN_MEMORY, ArtifactServiceType.GCS]
|
||||||
)
|
)
|
||||||
def test_save_load_delete(service_type):
|
async def test_save_load_delete(service_type):
|
||||||
"""Tests saving, loading, and deleting an artifact."""
|
"""Tests saving, loading, and deleting an artifact."""
|
||||||
artifact_service = get_artifact_service(service_type)
|
artifact_service = get_artifact_service(service_type)
|
||||||
artifact = types.Part.from_bytes(data=b"test_data", mime_type="text/plain")
|
artifact = types.Part.from_bytes(data=b"test_data", mime_type="text/plain")
|
||||||
@ -178,7 +180,7 @@ def test_save_load_delete(service_type):
|
|||||||
session_id = "123"
|
session_id = "123"
|
||||||
filename = "file456"
|
filename = "file456"
|
||||||
|
|
||||||
artifact_service.save_artifact(
|
await artifact_service.save_artifact(
|
||||||
app_name=app_name,
|
app_name=app_name,
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
@ -186,7 +188,7 @@ def test_save_load_delete(service_type):
|
|||||||
artifact=artifact,
|
artifact=artifact,
|
||||||
)
|
)
|
||||||
assert (
|
assert (
|
||||||
artifact_service.load_artifact(
|
await artifact_service.load_artifact(
|
||||||
app_name=app_name,
|
app_name=app_name,
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
@ -195,13 +197,13 @@ def test_save_load_delete(service_type):
|
|||||||
== artifact
|
== artifact
|
||||||
)
|
)
|
||||||
|
|
||||||
artifact_service.delete_artifact(
|
await artifact_service.delete_artifact(
|
||||||
app_name=app_name,
|
app_name=app_name,
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
filename=filename,
|
filename=filename,
|
||||||
)
|
)
|
||||||
assert not artifact_service.load_artifact(
|
assert not await artifact_service.load_artifact(
|
||||||
app_name=app_name,
|
app_name=app_name,
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
@ -209,10 +211,11 @@ def test_save_load_delete(service_type):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"service_type", [ArtifactServiceType.IN_MEMORY, ArtifactServiceType.GCS]
|
"service_type", [ArtifactServiceType.IN_MEMORY, ArtifactServiceType.GCS]
|
||||||
)
|
)
|
||||||
def test_list_keys(service_type):
|
async def test_list_keys(service_type):
|
||||||
"""Tests listing keys in the artifact service."""
|
"""Tests listing keys in the artifact service."""
|
||||||
artifact_service = get_artifact_service(service_type)
|
artifact_service = get_artifact_service(service_type)
|
||||||
artifact = types.Part.from_bytes(data=b"test_data", mime_type="text/plain")
|
artifact = types.Part.from_bytes(data=b"test_data", mime_type="text/plain")
|
||||||
@ -223,7 +226,7 @@ def test_list_keys(service_type):
|
|||||||
filenames = [filename + str(i) for i in range(5)]
|
filenames = [filename + str(i) for i in range(5)]
|
||||||
|
|
||||||
for f in filenames:
|
for f in filenames:
|
||||||
artifact_service.save_artifact(
|
await artifact_service.save_artifact(
|
||||||
app_name=app_name,
|
app_name=app_name,
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
@ -232,17 +235,18 @@ def test_list_keys(service_type):
|
|||||||
)
|
)
|
||||||
|
|
||||||
assert (
|
assert (
|
||||||
artifact_service.list_artifact_keys(
|
await artifact_service.list_artifact_keys(
|
||||||
app_name=app_name, user_id=user_id, session_id=session_id
|
app_name=app_name, user_id=user_id, session_id=session_id
|
||||||
)
|
)
|
||||||
== filenames
|
== filenames
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"service_type", [ArtifactServiceType.IN_MEMORY, ArtifactServiceType.GCS]
|
"service_type", [ArtifactServiceType.IN_MEMORY, ArtifactServiceType.GCS]
|
||||||
)
|
)
|
||||||
def test_list_versions(service_type):
|
async def test_list_versions(service_type):
|
||||||
"""Tests listing versions of an artifact."""
|
"""Tests listing versions of an artifact."""
|
||||||
artifact_service = get_artifact_service(service_type)
|
artifact_service = get_artifact_service(service_type)
|
||||||
|
|
||||||
@ -258,7 +262,7 @@ def test_list_versions(service_type):
|
|||||||
]
|
]
|
||||||
|
|
||||||
for i in range(3):
|
for i in range(3):
|
||||||
artifact_service.save_artifact(
|
await artifact_service.save_artifact(
|
||||||
app_name=app_name,
|
app_name=app_name,
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
@ -266,7 +270,7 @@ def test_list_versions(service_type):
|
|||||||
artifact=versions[i],
|
artifact=versions[i],
|
||||||
)
|
)
|
||||||
|
|
||||||
response_versions = artifact_service.list_versions(
|
response_versions = await artifact_service.list_versions(
|
||||||
app_name=app_name,
|
app_name=app_name,
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
|
Loading…
Reference in New Issue
Block a user