ADK changes

PiperOrigin-RevId: 755201925
This commit is contained in:
Shangjie Chen
2025-05-05 21:57:51 -07:00
committed by Copybara-Service
parent 6dec235c13
commit 905c20dad6
12 changed files with 86 additions and 69 deletions

View File

@@ -152,13 +152,14 @@ def get_artifact_service(
return InMemoryArtifactService()
@pytest.mark.asyncio
@pytest.mark.parametrize(
"service_type", [ArtifactServiceType.IN_MEMORY, ArtifactServiceType.GCS]
)
def test_load_empty(service_type):
async def test_load_empty(service_type):
"""Tests loading an artifact when none exists."""
artifact_service = get_artifact_service(service_type)
assert not artifact_service.load_artifact(
assert not await artifact_service.load_artifact(
app_name="test_app",
user_id="test_user",
session_id="session_id",
@@ -166,10 +167,11 @@ def test_load_empty(service_type):
)
@pytest.mark.asyncio
@pytest.mark.parametrize(
"service_type", [ArtifactServiceType.IN_MEMORY, ArtifactServiceType.GCS]
)
def test_save_load_delete(service_type):
async def test_save_load_delete(service_type):
"""Tests saving, loading, and deleting an artifact."""
artifact_service = get_artifact_service(service_type)
artifact = types.Part.from_bytes(data=b"test_data", mime_type="text/plain")
@@ -178,7 +180,7 @@ def test_save_load_delete(service_type):
session_id = "123"
filename = "file456"
artifact_service.save_artifact(
await artifact_service.save_artifact(
app_name=app_name,
user_id=user_id,
session_id=session_id,
@@ -186,7 +188,7 @@ def test_save_load_delete(service_type):
artifact=artifact,
)
assert (
artifact_service.load_artifact(
await artifact_service.load_artifact(
app_name=app_name,
user_id=user_id,
session_id=session_id,
@@ -195,13 +197,13 @@ def test_save_load_delete(service_type):
== artifact
)
artifact_service.delete_artifact(
await artifact_service.delete_artifact(
app_name=app_name,
user_id=user_id,
session_id=session_id,
filename=filename,
)
assert not artifact_service.load_artifact(
assert not await artifact_service.load_artifact(
app_name=app_name,
user_id=user_id,
session_id=session_id,
@@ -209,10 +211,11 @@ def test_save_load_delete(service_type):
)
@pytest.mark.asyncio
@pytest.mark.parametrize(
"service_type", [ArtifactServiceType.IN_MEMORY, ArtifactServiceType.GCS]
)
def test_list_keys(service_type):
async def test_list_keys(service_type):
"""Tests listing keys in the artifact service."""
artifact_service = get_artifact_service(service_type)
artifact = types.Part.from_bytes(data=b"test_data", mime_type="text/plain")
@@ -223,7 +226,7 @@ def test_list_keys(service_type):
filenames = [filename + str(i) for i in range(5)]
for f in filenames:
artifact_service.save_artifact(
await artifact_service.save_artifact(
app_name=app_name,
user_id=user_id,
session_id=session_id,
@@ -232,17 +235,18 @@ def test_list_keys(service_type):
)
assert (
artifact_service.list_artifact_keys(
await artifact_service.list_artifact_keys(
app_name=app_name, user_id=user_id, session_id=session_id
)
== filenames
)
@pytest.mark.asyncio
@pytest.mark.parametrize(
"service_type", [ArtifactServiceType.IN_MEMORY, ArtifactServiceType.GCS]
)
def test_list_versions(service_type):
async def test_list_versions(service_type):
"""Tests listing versions of an artifact."""
artifact_service = get_artifact_service(service_type)
@@ -258,7 +262,7 @@ def test_list_versions(service_type):
]
for i in range(3):
artifact_service.save_artifact(
await artifact_service.save_artifact(
app_name=app_name,
user_id=user_id,
session_id=session_id,
@@ -266,7 +270,7 @@ def test_list_versions(service_type):
artifact=versions[i],
)
response_versions = artifact_service.list_versions(
response_versions = await artifact_service.list_versions(
app_name=app_name,
user_id=user_id,
session_id=session_id,