adk-python/tests/unittests/artifacts/test_artifact_service.py
Jack Sun 05142a07cc
Moves unittests to root folder and adds github action to run unit tests. (#72)
* Move unit tests to root package.

* Adds deps to "test" extra, and mark two broken tests in tests/unittests/auth/test_auth_handler.py

* Adds github workflow

* minor fix in lite_llm.py for python 3.9.

* format pyproject.toml
2025-04-11 08:25:59 -07:00

277 lines
7.2 KiB
Python

# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for the artifact service."""
import enum
from typing import Optional
from typing import Union
from google.adk.artifacts import GcsArtifactService
from google.adk.artifacts import InMemoryArtifactService
from google.genai import types
import pytest
Enum = enum.Enum
class ArtifactServiceType(Enum):
IN_MEMORY = "IN_MEMORY"
GCS = "GCS"
class MockBlob:
"""Mocks a GCS Blob object.
This class provides mock implementations for a few common GCS Blob methods,
allowing the user to test code that interacts with GCS without actually
connecting to a real bucket.
"""
def __init__(self, name: str) -> None:
"""Initializes a MockBlob.
Args:
name: The name of the blob.
"""
self.name = name
self.content: Optional[bytes] = None
self.content_type: Optional[str] = None
def upload_from_string(
self, data: Union[str, bytes], content_type: Optional[str] = None
) -> None:
"""Mocks uploading data to the blob (from a string or bytes).
Args:
data: The data to upload (string or bytes).
content_type: The content type of the data (optional).
"""
if isinstance(data, str):
self.content = data.encode("utf-8")
elif isinstance(data, bytes):
self.content = data
else:
raise TypeError("data must be str or bytes")
if content_type:
self.content_type = content_type
def download_as_bytes(self) -> bytes:
"""Mocks downloading the blob's content as bytes.
Returns:
bytes: The content of the blob as bytes.
Raises:
Exception: If the blob doesn't exist (hasn't been uploaded to).
"""
if self.content is None:
return b""
return self.content
def delete(self) -> None:
"""Mocks deleting a blob."""
self.content = None
self.content_type = None
class MockBucket:
"""Mocks a GCS Bucket object."""
def __init__(self, name: str) -> None:
"""Initializes a MockBucket.
Args:
name: The name of the bucket.
"""
self.name = name
self.blobs: dict[str, MockBlob] = {}
def blob(self, blob_name: str) -> MockBlob:
"""Mocks getting a Blob object (doesn't create it in storage).
Args:
blob_name: The name of the blob.
Returns:
A MockBlob instance.
"""
if blob_name not in self.blobs:
self.blobs[blob_name] = MockBlob(blob_name)
return self.blobs[blob_name]
class MockClient:
"""Mocks the GCS Client."""
def __init__(self) -> None:
"""Initializes MockClient."""
self.buckets: dict[str, MockBucket] = {}
def bucket(self, bucket_name: str) -> MockBucket:
"""Mocks getting a Bucket object."""
if bucket_name not in self.buckets:
self.buckets[bucket_name] = MockBucket(bucket_name)
return self.buckets[bucket_name]
def list_blobs(self, bucket: MockBucket, prefix: Optional[str] = None):
"""Mocks listing blobs in a bucket, optionally with a prefix."""
if prefix:
return [
blob for name, blob in bucket.blobs.items() if name.startswith(prefix)
]
return list(bucket.blobs.values())
def mock_gcs_artifact_service():
"""Creates a mock GCS artifact service for testing."""
service = GcsArtifactService(bucket_name="test_bucket")
service.storage_client = MockClient()
service.bucket = service.storage_client.bucket("test_bucket")
return service
def get_artifact_service(
service_type: ArtifactServiceType = ArtifactServiceType.IN_MEMORY,
):
"""Creates an artifact service for testing."""
if service_type == ArtifactServiceType.GCS:
return mock_gcs_artifact_service()
return InMemoryArtifactService()
@pytest.mark.parametrize(
"service_type", [ArtifactServiceType.IN_MEMORY, ArtifactServiceType.GCS]
)
def test_load_empty(service_type):
"""Tests loading an artifact when none exists."""
artifact_service = get_artifact_service(service_type)
assert not artifact_service.load_artifact(
app_name="test_app",
user_id="test_user",
session_id="session_id",
filename="filename",
)
@pytest.mark.parametrize(
"service_type", [ArtifactServiceType.IN_MEMORY, ArtifactServiceType.GCS]
)
def test_save_load_delete(service_type):
"""Tests saving, loading, and deleting an artifact."""
artifact_service = get_artifact_service(service_type)
artifact = types.Part.from_bytes(data=b"test_data", mime_type="text/plain")
app_name = "app0"
user_id = "user0"
session_id = "123"
filename = "file456"
artifact_service.save_artifact(
app_name=app_name,
user_id=user_id,
session_id=session_id,
filename=filename,
artifact=artifact,
)
assert (
artifact_service.load_artifact(
app_name=app_name,
user_id=user_id,
session_id=session_id,
filename=filename,
)
== artifact
)
artifact_service.delete_artifact(
app_name=app_name,
user_id=user_id,
session_id=session_id,
filename=filename,
)
assert not artifact_service.load_artifact(
app_name=app_name,
user_id=user_id,
session_id=session_id,
filename=filename,
)
@pytest.mark.parametrize(
"service_type", [ArtifactServiceType.IN_MEMORY, ArtifactServiceType.GCS]
)
def test_list_keys(service_type):
"""Tests listing keys in the artifact service."""
artifact_service = get_artifact_service(service_type)
artifact = types.Part.from_bytes(data=b"test_data", mime_type="text/plain")
app_name = "app0"
user_id = "user0"
session_id = "123"
filename = "filename"
filenames = [filename + str(i) for i in range(5)]
for f in filenames:
artifact_service.save_artifact(
app_name=app_name,
user_id=user_id,
session_id=session_id,
filename=f,
artifact=artifact,
)
assert (
artifact_service.list_artifact_keys(
app_name=app_name, user_id=user_id, session_id=session_id
)
== filenames
)
@pytest.mark.parametrize(
"service_type", [ArtifactServiceType.IN_MEMORY, ArtifactServiceType.GCS]
)
def test_list_versions(service_type):
"""Tests listing versions of an artifact."""
artifact_service = get_artifact_service(service_type)
app_name = "app0"
user_id = "user0"
session_id = "123"
filename = "filename"
versions = [
types.Part.from_bytes(
data=i.to_bytes(2, byteorder="big"), mime_type="text/plain"
)
for i in range(3)
]
for i in range(3):
artifact_service.save_artifact(
app_name=app_name,
user_id=user_id,
session_id=session_id,
filename=filename,
artifact=versions[i],
)
response_versions = artifact_service.list_versions(
app_name=app_name,
user_id=user_id,
session_id=session_id,
filename=filename,
)
assert response_versions == list(range(3))