ADK changes

PiperOrigin-RevId: 755201925
This commit is contained in:
Shangjie Chen
2025-05-05 21:57:51 -07:00
committed by Copybara-Service
parent 6dec235c13
commit 905c20dad6
12 changed files with 86 additions and 69 deletions
@@ -25,7 +25,7 @@ class BaseArtifactService(ABC):
"""Abstract base class for artifact services."""
@abstractmethod
def save_artifact(
async def save_artifact(
self,
*,
app_name: str,
@@ -53,7 +53,7 @@ class BaseArtifactService(ABC):
"""
@abstractmethod
def load_artifact(
async def load_artifact(
self,
*,
app_name: str,
@@ -81,7 +81,7 @@ class BaseArtifactService(ABC):
pass
@abstractmethod
def list_artifact_keys(
async def list_artifact_keys(
self, *, app_name: str, user_id: str, session_id: str
) -> list[str]:
"""Lists all the artifact filenames within a session.
@@ -97,7 +97,7 @@ class BaseArtifactService(ABC):
pass
@abstractmethod
def delete_artifact(
async def delete_artifact(
self, *, app_name: str, user_id: str, session_id: str, filename: str
) -> None:
"""Deletes an artifact.
@@ -111,7 +111,7 @@ class BaseArtifactService(ABC):
pass
@abstractmethod
def list_versions(
async def list_versions(
self, *, app_name: str, user_id: str, session_id: str, filename: str
) -> list[int]:
"""Lists all versions of an artifact.
@@ -77,7 +77,7 @@ class GcsArtifactService(BaseArtifactService):
return f"{app_name}/{user_id}/{session_id}/{filename}/{version}"
@override
def save_artifact(
async def save_artifact(
self,
*,
app_name: str,
@@ -86,7 +86,7 @@ class GcsArtifactService(BaseArtifactService):
filename: str,
artifact: types.Part,
) -> int:
versions = self.list_versions(
versions = await self.list_versions(
app_name=app_name,
user_id=user_id,
session_id=session_id,
@@ -107,7 +107,7 @@ class GcsArtifactService(BaseArtifactService):
return version
@override
def load_artifact(
async def load_artifact(
self,
*,
app_name: str,
@@ -117,7 +117,7 @@ class GcsArtifactService(BaseArtifactService):
version: Optional[int] = None,
) -> Optional[types.Part]:
if version is None:
versions = self.list_versions(
versions = await self.list_versions(
app_name=app_name,
user_id=user_id,
session_id=session_id,
@@ -141,7 +141,7 @@ class GcsArtifactService(BaseArtifactService):
return artifact
@override
def list_artifact_keys(
async def list_artifact_keys(
self, *, app_name: str, user_id: str, session_id: str
) -> list[str]:
filenames = set()
@@ -165,10 +165,10 @@ class GcsArtifactService(BaseArtifactService):
return sorted(list(filenames))
@override
def delete_artifact(
async def delete_artifact(
self, *, app_name: str, user_id: str, session_id: str, filename: str
) -> None:
versions = self.list_versions(
versions = await self.list_versions(
app_name=app_name,
user_id=user_id,
session_id=session_id,
@@ -183,7 +183,7 @@ class GcsArtifactService(BaseArtifactService):
return
@override
def list_versions(
async def list_versions(
self, *, app_name: str, user_id: str, session_id: str, filename: str
) -> list[int]:
prefix = self._get_blob_name(app_name, user_id, session_id, filename, "")
@@ -63,7 +63,7 @@ class InMemoryArtifactService(BaseArtifactService, BaseModel):
return f"{app_name}/{user_id}/{session_id}/{filename}"
@override
def save_artifact(
async def save_artifact(
self,
*,
app_name: str,
@@ -80,7 +80,7 @@ class InMemoryArtifactService(BaseArtifactService, BaseModel):
return version
@override
def load_artifact(
async def load_artifact(
self,
*,
app_name: str,
@@ -98,7 +98,7 @@ class InMemoryArtifactService(BaseArtifactService, BaseModel):
return versions[version]
@override
def list_artifact_keys(
async def list_artifact_keys(
self, *, app_name: str, user_id: str, session_id: str
) -> list[str]:
session_prefix = f"{app_name}/{user_id}/{session_id}/"
@@ -114,7 +114,7 @@ class InMemoryArtifactService(BaseArtifactService, BaseModel):
return sorted(filenames)
@override
def delete_artifact(
async def delete_artifact(
self, *, app_name: str, user_id: str, session_id: str, filename: str
) -> None:
path = self._artifact_path(app_name, user_id, session_id, filename)
@@ -123,7 +123,7 @@ class InMemoryArtifactService(BaseArtifactService, BaseModel):
self.artifacts.pop(path, None)
@override
def list_versions(
async def list_versions(
self, *, app_name: str, user_id: str, session_id: str, filename: str
) -> list[int]:
path = self._artifact_path(app_name, user_id, session_id, filename)