ADK changes

PiperOrigin-RevId: 755201925
This commit is contained in:
Shangjie Chen 2025-05-05 21:57:51 -07:00 committed by Copybara-Service
parent 6dec235c13
commit 905c20dad6
12 changed files with 86 additions and 69 deletions

View File

@ -65,7 +65,7 @@ class CallbackContext(ReadonlyContext):
"""The user content that started this invocation. READONLY field.""" """The user content that started this invocation. READONLY field."""
return self._invocation_context.user_content return self._invocation_context.user_content
def load_artifact( async def load_artifact(
self, filename: str, version: Optional[int] = None self, filename: str, version: Optional[int] = None
) -> Optional[types.Part]: ) -> Optional[types.Part]:
"""Loads an artifact attached to the current session. """Loads an artifact attached to the current session.
@ -80,7 +80,7 @@ class CallbackContext(ReadonlyContext):
""" """
if self._invocation_context.artifact_service is None: if self._invocation_context.artifact_service is None:
raise ValueError("Artifact service is not initialized.") raise ValueError("Artifact service is not initialized.")
return self._invocation_context.artifact_service.load_artifact( return await self._invocation_context.artifact_service.load_artifact(
app_name=self._invocation_context.app_name, app_name=self._invocation_context.app_name,
user_id=self._invocation_context.user_id, user_id=self._invocation_context.user_id,
session_id=self._invocation_context.session.id, session_id=self._invocation_context.session.id,
@ -88,7 +88,7 @@ class CallbackContext(ReadonlyContext):
version=version, version=version,
) )
def save_artifact(self, filename: str, artifact: types.Part) -> int: async def save_artifact(self, filename: str, artifact: types.Part) -> int:
"""Saves an artifact and records it as delta for the current session. """Saves an artifact and records it as delta for the current session.
Args: Args:
@ -100,7 +100,7 @@ class CallbackContext(ReadonlyContext):
""" """
if self._invocation_context.artifact_service is None: if self._invocation_context.artifact_service is None:
raise ValueError("Artifact service is not initialized.") raise ValueError("Artifact service is not initialized.")
version = self._invocation_context.artifact_service.save_artifact( version = await self._invocation_context.artifact_service.save_artifact(
app_name=self._invocation_context.app_name, app_name=self._invocation_context.app_name,
user_id=self._invocation_context.user_id, user_id=self._invocation_context.user_id,
session_id=self._invocation_context.session.id, session_id=self._invocation_context.session.id,

View File

@ -25,7 +25,7 @@ class BaseArtifactService(ABC):
"""Abstract base class for artifact services.""" """Abstract base class for artifact services."""
@abstractmethod @abstractmethod
def save_artifact( async def save_artifact(
self, self,
*, *,
app_name: str, app_name: str,
@ -53,7 +53,7 @@ class BaseArtifactService(ABC):
""" """
@abstractmethod @abstractmethod
def load_artifact( async def load_artifact(
self, self,
*, *,
app_name: str, app_name: str,
@ -81,7 +81,7 @@ class BaseArtifactService(ABC):
pass pass
@abstractmethod @abstractmethod
def list_artifact_keys( async def list_artifact_keys(
self, *, app_name: str, user_id: str, session_id: str self, *, app_name: str, user_id: str, session_id: str
) -> list[str]: ) -> list[str]:
"""Lists all the artifact filenames within a session. """Lists all the artifact filenames within a session.
@ -97,7 +97,7 @@ class BaseArtifactService(ABC):
pass pass
@abstractmethod @abstractmethod
def delete_artifact( async def delete_artifact(
self, *, app_name: str, user_id: str, session_id: str, filename: str self, *, app_name: str, user_id: str, session_id: str, filename: str
) -> None: ) -> None:
"""Deletes an artifact. """Deletes an artifact.
@ -111,7 +111,7 @@ class BaseArtifactService(ABC):
pass pass
@abstractmethod @abstractmethod
def list_versions( async def list_versions(
self, *, app_name: str, user_id: str, session_id: str, filename: str self, *, app_name: str, user_id: str, session_id: str, filename: str
) -> list[int]: ) -> list[int]:
"""Lists all versions of an artifact. """Lists all versions of an artifact.

View File

@ -77,7 +77,7 @@ class GcsArtifactService(BaseArtifactService):
return f"{app_name}/{user_id}/{session_id}/{filename}/{version}" return f"{app_name}/{user_id}/{session_id}/{filename}/{version}"
@override @override
def save_artifact( async def save_artifact(
self, self,
*, *,
app_name: str, app_name: str,
@ -86,7 +86,7 @@ class GcsArtifactService(BaseArtifactService):
filename: str, filename: str,
artifact: types.Part, artifact: types.Part,
) -> int: ) -> int:
versions = self.list_versions( versions = await self.list_versions(
app_name=app_name, app_name=app_name,
user_id=user_id, user_id=user_id,
session_id=session_id, session_id=session_id,
@ -107,7 +107,7 @@ class GcsArtifactService(BaseArtifactService):
return version return version
@override @override
def load_artifact( async def load_artifact(
self, self,
*, *,
app_name: str, app_name: str,
@ -117,7 +117,7 @@ class GcsArtifactService(BaseArtifactService):
version: Optional[int] = None, version: Optional[int] = None,
) -> Optional[types.Part]: ) -> Optional[types.Part]:
if version is None: if version is None:
versions = self.list_versions( versions = await self.list_versions(
app_name=app_name, app_name=app_name,
user_id=user_id, user_id=user_id,
session_id=session_id, session_id=session_id,
@ -141,7 +141,7 @@ class GcsArtifactService(BaseArtifactService):
return artifact return artifact
@override @override
def list_artifact_keys( async def list_artifact_keys(
self, *, app_name: str, user_id: str, session_id: str self, *, app_name: str, user_id: str, session_id: str
) -> list[str]: ) -> list[str]:
filenames = set() filenames = set()
@ -165,10 +165,10 @@ class GcsArtifactService(BaseArtifactService):
return sorted(list(filenames)) return sorted(list(filenames))
@override @override
def delete_artifact( async def delete_artifact(
self, *, app_name: str, user_id: str, session_id: str, filename: str self, *, app_name: str, user_id: str, session_id: str, filename: str
) -> None: ) -> None:
versions = self.list_versions( versions = await self.list_versions(
app_name=app_name, app_name=app_name,
user_id=user_id, user_id=user_id,
session_id=session_id, session_id=session_id,
@ -183,7 +183,7 @@ class GcsArtifactService(BaseArtifactService):
return return
@override @override
def list_versions( async def list_versions(
self, *, app_name: str, user_id: str, session_id: str, filename: str self, *, app_name: str, user_id: str, session_id: str, filename: str
) -> list[int]: ) -> list[int]:
prefix = self._get_blob_name(app_name, user_id, session_id, filename, "") prefix = self._get_blob_name(app_name, user_id, session_id, filename, "")

View File

@ -63,7 +63,7 @@ class InMemoryArtifactService(BaseArtifactService, BaseModel):
return f"{app_name}/{user_id}/{session_id}/{filename}" return f"{app_name}/{user_id}/{session_id}/{filename}"
@override @override
def save_artifact( async def save_artifact(
self, self,
*, *,
app_name: str, app_name: str,
@ -80,7 +80,7 @@ class InMemoryArtifactService(BaseArtifactService, BaseModel):
return version return version
@override @override
def load_artifact( async def load_artifact(
self, self,
*, *,
app_name: str, app_name: str,
@ -98,7 +98,7 @@ class InMemoryArtifactService(BaseArtifactService, BaseModel):
return versions[version] return versions[version]
@override @override
def list_artifact_keys( async def list_artifact_keys(
self, *, app_name: str, user_id: str, session_id: str self, *, app_name: str, user_id: str, session_id: str
) -> list[str]: ) -> list[str]:
session_prefix = f"{app_name}/{user_id}/{session_id}/" session_prefix = f"{app_name}/{user_id}/{session_id}/"
@ -114,7 +114,7 @@ class InMemoryArtifactService(BaseArtifactService, BaseModel):
return sorted(filenames) return sorted(filenames)
@override @override
def delete_artifact( async def delete_artifact(
self, *, app_name: str, user_id: str, session_id: str, filename: str self, *, app_name: str, user_id: str, session_id: str, filename: str
) -> None: ) -> None:
path = self._artifact_path(app_name, user_id, session_id, filename) path = self._artifact_path(app_name, user_id, session_id, filename)
@ -123,7 +123,7 @@ class InMemoryArtifactService(BaseArtifactService, BaseModel):
self.artifacts.pop(path, None) self.artifacts.pop(path, None)
@override @override
def list_versions( async def list_versions(
self, *, app_name: str, user_id: str, session_id: str, filename: str self, *, app_name: str, user_id: str, session_id: str, filename: str
) -> list[int]: ) -> list[int]:
path = self._artifact_path(app_name, user_id, session_id, filename) path = self._artifact_path(app_name, user_id, session_id, filename)

View File

@ -503,7 +503,7 @@ def get_fast_api_app(
"/apps/{app_name}/users/{user_id}/sessions/{session_id}/artifacts/{artifact_name}", "/apps/{app_name}/users/{user_id}/sessions/{session_id}/artifacts/{artifact_name}",
response_model_exclude_none=True, response_model_exclude_none=True,
) )
def load_artifact( async def load_artifact(
app_name: str, app_name: str,
user_id: str, user_id: str,
session_id: str, session_id: str,
@ -511,7 +511,7 @@ def get_fast_api_app(
version: Optional[int] = Query(None), version: Optional[int] = Query(None),
) -> Optional[types.Part]: ) -> Optional[types.Part]:
app_name = agent_engine_id if agent_engine_id else app_name app_name = agent_engine_id if agent_engine_id else app_name
artifact = artifact_service.load_artifact( artifact = await artifact_service.load_artifact(
app_name=app_name, app_name=app_name,
user_id=user_id, user_id=user_id,
session_id=session_id, session_id=session_id,
@ -526,7 +526,7 @@ def get_fast_api_app(
"/apps/{app_name}/users/{user_id}/sessions/{session_id}/artifacts/{artifact_name}/versions/{version_id}", "/apps/{app_name}/users/{user_id}/sessions/{session_id}/artifacts/{artifact_name}/versions/{version_id}",
response_model_exclude_none=True, response_model_exclude_none=True,
) )
def load_artifact_version( async def load_artifact_version(
app_name: str, app_name: str,
user_id: str, user_id: str,
session_id: str, session_id: str,
@ -534,7 +534,7 @@ def get_fast_api_app(
version_id: int, version_id: int,
) -> Optional[types.Part]: ) -> Optional[types.Part]:
app_name = agent_engine_id if agent_engine_id else app_name app_name = agent_engine_id if agent_engine_id else app_name
artifact = artifact_service.load_artifact( artifact = await artifact_service.load_artifact(
app_name=app_name, app_name=app_name,
user_id=user_id, user_id=user_id,
session_id=session_id, session_id=session_id,
@ -549,11 +549,11 @@ def get_fast_api_app(
"/apps/{app_name}/users/{user_id}/sessions/{session_id}/artifacts", "/apps/{app_name}/users/{user_id}/sessions/{session_id}/artifacts",
response_model_exclude_none=True, response_model_exclude_none=True,
) )
def list_artifact_names( async def list_artifact_names(
app_name: str, user_id: str, session_id: str app_name: str, user_id: str, session_id: str
) -> list[str]: ) -> list[str]:
app_name = agent_engine_id if agent_engine_id else app_name app_name = agent_engine_id if agent_engine_id else app_name
return artifact_service.list_artifact_keys( return await artifact_service.list_artifact_keys(
app_name=app_name, user_id=user_id, session_id=session_id app_name=app_name, user_id=user_id, session_id=session_id
) )
@ -561,11 +561,11 @@ def get_fast_api_app(
"/apps/{app_name}/users/{user_id}/sessions/{session_id}/artifacts/{artifact_name}/versions", "/apps/{app_name}/users/{user_id}/sessions/{session_id}/artifacts/{artifact_name}/versions",
response_model_exclude_none=True, response_model_exclude_none=True,
) )
def list_artifact_versions( async def list_artifact_versions(
app_name: str, user_id: str, session_id: str, artifact_name: str app_name: str, user_id: str, session_id: str, artifact_name: str
) -> list[int]: ) -> list[int]:
app_name = agent_engine_id if agent_engine_id else app_name app_name = agent_engine_id if agent_engine_id else app_name
return artifact_service.list_versions( return await artifact_service.list_versions(
app_name=app_name, app_name=app_name,
user_id=user_id, user_id=user_id,
session_id=session_id, session_id=session_id,
@ -575,11 +575,11 @@ def get_fast_api_app(
@app.delete( @app.delete(
"/apps/{app_name}/users/{user_id}/sessions/{session_id}/artifacts/{artifact_name}", "/apps/{app_name}/users/{user_id}/sessions/{session_id}/artifacts/{artifact_name}",
) )
def delete_artifact( async def delete_artifact(
app_name: str, user_id: str, session_id: str, artifact_name: str app_name: str, user_id: str, session_id: str, artifact_name: str
): ):
app_name = agent_engine_id if agent_engine_id else app_name app_name = agent_engine_id if agent_engine_id else app_name
artifact_service.delete_artifact( await artifact_service.delete_artifact(
app_name=app_name, app_name=app_name,
user_id=user_id, user_id=user_id,
session_id=session_id, session_id=session_id,

View File

@ -122,7 +122,7 @@ class _CodeExecutionRequestProcessor(BaseLlmRequestProcessor):
if not invocation_context.agent.code_executor: if not invocation_context.agent.code_executor:
return return
for event in _run_pre_processor(invocation_context, llm_request): async for event in _run_pre_processor(invocation_context, llm_request):
yield event yield event
# Convert the code execution parts to text parts. # Convert the code execution parts to text parts.
@ -159,10 +159,10 @@ class _CodeExecutionResponseProcessor(BaseLlmResponseProcessor):
response_processor = _CodeExecutionResponseProcessor() response_processor = _CodeExecutionResponseProcessor()
def _run_pre_processor( async def _run_pre_processor(
invocation_context: InvocationContext, invocation_context: InvocationContext,
llm_request: LlmRequest, llm_request: LlmRequest,
) -> Generator[Event, None, None]: ) -> AsyncGenerator[Event, None]:
"""Pre-process the user message by adding the user message to the Colab notebook.""" """Pre-process the user message by adding the user message to the Colab notebook."""
from ...agents.llm_agent import LlmAgent from ...agents.llm_agent import LlmAgent
@ -242,7 +242,7 @@ def _run_pre_processor(
code_executor_context.add_processed_file_names([file.name]) code_executor_context.add_processed_file_names([file.name])
# Emit the execution result, and add it to the LLM request. # Emit the execution result, and add it to the LLM request.
execution_result_event = _post_process_code_execution_result( execution_result_event = await _post_process_code_execution_result(
invocation_context, code_executor_context, code_execution_result invocation_context, code_executor_context, code_execution_result
) )
yield execution_result_event yield execution_result_event
@ -375,7 +375,7 @@ def _get_or_set_execution_id(
return execution_id return execution_id
def _post_process_code_execution_result( async def _post_process_code_execution_result(
invocation_context: InvocationContext, invocation_context: InvocationContext,
code_executor_context: CodeExecutorContext, code_executor_context: CodeExecutorContext,
code_execution_result: CodeExecutionResult, code_execution_result: CodeExecutionResult,
@ -406,7 +406,7 @@ def _post_process_code_execution_result(
# Handle output files. # Handle output files.
for output_file in code_execution_result.output_files: for output_file in code_execution_result.output_files:
version = invocation_context.artifact_service.save_artifact( version = await invocation_context.artifact_service.save_artifact(
app_name=invocation_context.app_name, app_name=invocation_context.app_name,
user_id=invocation_context.user_id, user_id=invocation_context.user_id,
session_id=invocation_context.session.id, session_id=invocation_context.session.id,

View File

@ -56,13 +56,13 @@ class _InstructionsLlmRequestProcessor(BaseLlmRequestProcessor):
raw_si = root_agent.canonical_global_instruction( raw_si = root_agent.canonical_global_instruction(
ReadonlyContext(invocation_context) ReadonlyContext(invocation_context)
) )
si = _populate_values(raw_si, invocation_context) si = await _populate_values(raw_si, invocation_context)
llm_request.append_instructions([si]) llm_request.append_instructions([si])
# Appends agent instructions if set. # Appends agent instructions if set.
if agent.instruction: # not empty str if agent.instruction: # not empty str
raw_si = agent.canonical_instruction(ReadonlyContext(invocation_context)) raw_si = agent.canonical_instruction(ReadonlyContext(invocation_context))
si = _populate_values(raw_si, invocation_context) si = await _populate_values(raw_si, invocation_context)
llm_request.append_instructions([si]) llm_request.append_instructions([si])
# Maintain async generator behavior # Maintain async generator behavior
@ -73,13 +73,24 @@ class _InstructionsLlmRequestProcessor(BaseLlmRequestProcessor):
request_processor = _InstructionsLlmRequestProcessor() request_processor = _InstructionsLlmRequestProcessor()
def _populate_values( async def _populate_values(
instruction_template: str, instruction_template: str,
context: InvocationContext, context: InvocationContext,
) -> str: ) -> str:
"""Populates values in the instruction template, e.g. state, artifact, etc.""" """Populates values in the instruction template, e.g. state, artifact, etc."""
def _replace_match(match) -> str: async def _async_sub(pattern, repl_async_fn, string) -> str:
result = []
last_end = 0
for match in re.finditer(pattern, string):
result.append(string[last_end : match.start()])
replacement = await repl_async_fn(match)
result.append(replacement)
last_end = match.end()
result.append(string[last_end:])
return ''.join(result)
async def _replace_match(match) -> str:
var_name = match.group().lstrip('{').rstrip('}').strip() var_name = match.group().lstrip('{').rstrip('}').strip()
optional = False optional = False
if var_name.endswith('?'): if var_name.endswith('?'):
@ -89,7 +100,7 @@ def _populate_values(
var_name = var_name.removeprefix('artifact.') var_name = var_name.removeprefix('artifact.')
if context.artifact_service is None: if context.artifact_service is None:
raise ValueError('Artifact service is not initialized.') raise ValueError('Artifact service is not initialized.')
artifact = context.artifact_service.load_artifact( artifact = await context.artifact_service.load_artifact(
app_name=context.session.app_name, app_name=context.session.app_name,
user_id=context.session.user_id, user_id=context.session.user_id,
session_id=context.session.id, session_id=context.session.id,
@ -109,7 +120,7 @@ def _populate_values(
else: else:
raise KeyError(f'Context variable not found: `{var_name}`.') raise KeyError(f'Context variable not found: `{var_name}`.')
return re.sub(r'{+[^{}]*}+', _replace_match, instruction_template) return await _async_sub(r'{+[^{}]*}+', _replace_match, instruction_template)
def _is_valid_state_name(var_name): def _is_valid_state_name(var_name):

View File

@ -186,7 +186,7 @@ class Runner:
root_agent = self.agent root_agent = self.agent
if new_message: if new_message:
self._append_new_message_to_session( await self._append_new_message_to_session(
session, session,
new_message, new_message,
invocation_context, invocation_context,
@ -199,7 +199,7 @@ class Runner:
self.session_service.append_event(session=session, event=event) self.session_service.append_event(session=session, event=event)
yield event yield event
def _append_new_message_to_session( async def _append_new_message_to_session(
self, self,
session: Session, session: Session,
new_message: types.Content, new_message: types.Content,
@ -225,7 +225,7 @@ class Runner:
if part.inline_data is None: if part.inline_data is None:
continue continue
file_name = f'artifact_{invocation_context.invocation_id}_{i}' file_name = f'artifact_{invocation_context.invocation_id}_{i}'
self.artifact_service.save_artifact( await self.artifact_service.save_artifact(
app_name=self.app_name, app_name=self.app_name,
user_id=session.user_id, user_id=session.user_id,
session_id=session.id, session_id=session.id,

View File

@ -146,18 +146,20 @@ class AgentTool(BaseTool):
if runner.artifact_service: if runner.artifact_service:
# Forward all artifacts to parent session. # Forward all artifacts to parent session.
for artifact_name in runner.artifact_service.list_artifact_keys( async for artifact_name in runner.artifact_service.list_artifact_keys(
app_name=session.app_name, app_name=session.app_name,
user_id=session.user_id, user_id=session.user_id,
session_id=session.id, session_id=session.id,
): ):
if artifact := runner.artifact_service.load_artifact( if artifact := await runner.artifact_service.load_artifact(
app_name=session.app_name, app_name=session.app_name,
user_id=session.user_id, user_id=session.user_id,
session_id=session.id, session_id=session.id,
filename=artifact_name, filename=artifact_name,
): ):
tool_context.save_artifact(filename=artifact_name, artifact=artifact) await tool_context.save_artifact(
filename=artifact_name, artifact=artifact
)
if ( if (
not last_event not last_event

View File

@ -69,14 +69,14 @@ class LoadArtifactsTool(BaseTool):
tool_context=tool_context, tool_context=tool_context,
llm_request=llm_request, llm_request=llm_request,
) )
self._append_artifacts_to_llm_request( await self._append_artifacts_to_llm_request(
tool_context=tool_context, llm_request=llm_request tool_context=tool_context, llm_request=llm_request
) )
def _append_artifacts_to_llm_request( async def _append_artifacts_to_llm_request(
self, *, tool_context: ToolContext, llm_request: LlmRequest self, *, tool_context: ToolContext, llm_request: LlmRequest
): ):
artifact_names = tool_context.list_artifacts() artifact_names = await tool_context.list_artifacts()
if not artifact_names: if not artifact_names:
return return
@ -96,7 +96,7 @@ class LoadArtifactsTool(BaseTool):
if function_response and function_response.name == 'load_artifacts': if function_response and function_response.name == 'load_artifacts':
artifact_names = function_response.response['artifact_names'] artifact_names = function_response.response['artifact_names']
for artifact_name in artifact_names: for artifact_name in artifact_names:
artifact = tool_context.load_artifact(artifact_name) artifact = await tool_context.load_artifact(artifact_name)
llm_request.contents.append( llm_request.contents.append(
types.Content( types.Content(
role='user', role='user',

View File

@ -69,11 +69,11 @@ class ToolContext(CallbackContext):
def get_auth_response(self, auth_config: AuthConfig) -> AuthCredential: def get_auth_response(self, auth_config: AuthConfig) -> AuthCredential:
return AuthHandler(auth_config).get_auth_response(self.state) return AuthHandler(auth_config).get_auth_response(self.state)
def list_artifacts(self) -> list[str]: async def list_artifacts(self) -> list[str]:
"""Lists the filenames of the artifacts attached to the current session.""" """Lists the filenames of the artifacts attached to the current session."""
if self._invocation_context.artifact_service is None: if self._invocation_context.artifact_service is None:
raise ValueError('Artifact service is not initialized.') raise ValueError('Artifact service is not initialized.')
return self._invocation_context.artifact_service.list_artifact_keys( return await self._invocation_context.artifact_service.list_artifact_keys(
app_name=self._invocation_context.app_name, app_name=self._invocation_context.app_name,
user_id=self._invocation_context.user_id, user_id=self._invocation_context.user_id,
session_id=self._invocation_context.session.id, session_id=self._invocation_context.session.id,

View File

@ -152,13 +152,14 @@ def get_artifact_service(
return InMemoryArtifactService() return InMemoryArtifactService()
@pytest.mark.asyncio
@pytest.mark.parametrize( @pytest.mark.parametrize(
"service_type", [ArtifactServiceType.IN_MEMORY, ArtifactServiceType.GCS] "service_type", [ArtifactServiceType.IN_MEMORY, ArtifactServiceType.GCS]
) )
def test_load_empty(service_type): async def test_load_empty(service_type):
"""Tests loading an artifact when none exists.""" """Tests loading an artifact when none exists."""
artifact_service = get_artifact_service(service_type) artifact_service = get_artifact_service(service_type)
assert not artifact_service.load_artifact( assert not await artifact_service.load_artifact(
app_name="test_app", app_name="test_app",
user_id="test_user", user_id="test_user",
session_id="session_id", session_id="session_id",
@ -166,10 +167,11 @@ def test_load_empty(service_type):
) )
@pytest.mark.asyncio
@pytest.mark.parametrize( @pytest.mark.parametrize(
"service_type", [ArtifactServiceType.IN_MEMORY, ArtifactServiceType.GCS] "service_type", [ArtifactServiceType.IN_MEMORY, ArtifactServiceType.GCS]
) )
def test_save_load_delete(service_type): async def test_save_load_delete(service_type):
"""Tests saving, loading, and deleting an artifact.""" """Tests saving, loading, and deleting an artifact."""
artifact_service = get_artifact_service(service_type) artifact_service = get_artifact_service(service_type)
artifact = types.Part.from_bytes(data=b"test_data", mime_type="text/plain") artifact = types.Part.from_bytes(data=b"test_data", mime_type="text/plain")
@ -178,7 +180,7 @@ def test_save_load_delete(service_type):
session_id = "123" session_id = "123"
filename = "file456" filename = "file456"
artifact_service.save_artifact( await artifact_service.save_artifact(
app_name=app_name, app_name=app_name,
user_id=user_id, user_id=user_id,
session_id=session_id, session_id=session_id,
@ -186,7 +188,7 @@ def test_save_load_delete(service_type):
artifact=artifact, artifact=artifact,
) )
assert ( assert (
artifact_service.load_artifact( await artifact_service.load_artifact(
app_name=app_name, app_name=app_name,
user_id=user_id, user_id=user_id,
session_id=session_id, session_id=session_id,
@ -195,13 +197,13 @@ def test_save_load_delete(service_type):
== artifact == artifact
) )
artifact_service.delete_artifact( await artifact_service.delete_artifact(
app_name=app_name, app_name=app_name,
user_id=user_id, user_id=user_id,
session_id=session_id, session_id=session_id,
filename=filename, filename=filename,
) )
assert not artifact_service.load_artifact( assert not await artifact_service.load_artifact(
app_name=app_name, app_name=app_name,
user_id=user_id, user_id=user_id,
session_id=session_id, session_id=session_id,
@ -209,10 +211,11 @@ def test_save_load_delete(service_type):
) )
@pytest.mark.asyncio
@pytest.mark.parametrize( @pytest.mark.parametrize(
"service_type", [ArtifactServiceType.IN_MEMORY, ArtifactServiceType.GCS] "service_type", [ArtifactServiceType.IN_MEMORY, ArtifactServiceType.GCS]
) )
def test_list_keys(service_type): async def test_list_keys(service_type):
"""Tests listing keys in the artifact service.""" """Tests listing keys in the artifact service."""
artifact_service = get_artifact_service(service_type) artifact_service = get_artifact_service(service_type)
artifact = types.Part.from_bytes(data=b"test_data", mime_type="text/plain") artifact = types.Part.from_bytes(data=b"test_data", mime_type="text/plain")
@ -223,7 +226,7 @@ def test_list_keys(service_type):
filenames = [filename + str(i) for i in range(5)] filenames = [filename + str(i) for i in range(5)]
for f in filenames: for f in filenames:
artifact_service.save_artifact( await artifact_service.save_artifact(
app_name=app_name, app_name=app_name,
user_id=user_id, user_id=user_id,
session_id=session_id, session_id=session_id,
@ -232,17 +235,18 @@ def test_list_keys(service_type):
) )
assert ( assert (
artifact_service.list_artifact_keys( await artifact_service.list_artifact_keys(
app_name=app_name, user_id=user_id, session_id=session_id app_name=app_name, user_id=user_id, session_id=session_id
) )
== filenames == filenames
) )
@pytest.mark.asyncio
@pytest.mark.parametrize( @pytest.mark.parametrize(
"service_type", [ArtifactServiceType.IN_MEMORY, ArtifactServiceType.GCS] "service_type", [ArtifactServiceType.IN_MEMORY, ArtifactServiceType.GCS]
) )
def test_list_versions(service_type): async def test_list_versions(service_type):
"""Tests listing versions of an artifact.""" """Tests listing versions of an artifact."""
artifact_service = get_artifact_service(service_type) artifact_service = get_artifact_service(service_type)
@ -258,7 +262,7 @@ def test_list_versions(service_type):
] ]
for i in range(3): for i in range(3):
artifact_service.save_artifact( await artifact_service.save_artifact(
app_name=app_name, app_name=app_name,
user_id=user_id, user_id=user_id,
session_id=session_id, session_id=session_id,
@ -266,7 +270,7 @@ def test_list_versions(service_type):
artifact=versions[i], artifact=versions[i],
) )
response_versions = artifact_service.list_versions( response_versions = await artifact_service.list_versions(
app_name=app_name, app_name=app_name,
user_id=user_id, user_id=user_id,
session_id=session_id, session_id=session_id,