structure saas with tools
This commit is contained in:
@@ -0,0 +1,23 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .base_artifact_service import BaseArtifactService
|
||||
from .gcs_artifact_service import GcsArtifactService
|
||||
from .in_memory_artifact_service import InMemoryArtifactService
|
||||
|
||||
__all__ = [
|
||||
'BaseArtifactService',
|
||||
'GcsArtifactService',
|
||||
'InMemoryArtifactService',
|
||||
]
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,128 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Abstract base class for artifact services."""
|
||||
|
||||
from abc import ABC
|
||||
from abc import abstractmethod
|
||||
from typing import Optional
|
||||
|
||||
from google.genai import types
|
||||
|
||||
|
||||
class BaseArtifactService(ABC):
|
||||
"""Abstract base class for artifact services."""
|
||||
|
||||
@abstractmethod
|
||||
def save_artifact(
|
||||
self,
|
||||
*,
|
||||
app_name: str,
|
||||
user_id: str,
|
||||
session_id: str,
|
||||
filename: str,
|
||||
artifact: types.Part,
|
||||
) -> int:
|
||||
"""Saves an artifact to the artifact service storage.
|
||||
|
||||
The artifact is a file identified by the app name, user ID, session ID, and
|
||||
filename. After saving the artifact, a revision ID is returned to identify
|
||||
the artifact version.
|
||||
|
||||
Args:
|
||||
app_name: The app name.
|
||||
user_id: The user ID.
|
||||
session_id: The session ID.
|
||||
filename: The filename of the artifact.
|
||||
artifact: The artifact to save.
|
||||
|
||||
Returns:
|
||||
The revision ID. The first version of the artifact has a revision ID of 0.
|
||||
This is incremented by 1 after each successful save.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def load_artifact(
|
||||
self,
|
||||
*,
|
||||
app_name: str,
|
||||
user_id: str,
|
||||
session_id: str,
|
||||
filename: str,
|
||||
version: Optional[int] = None,
|
||||
) -> Optional[types.Part]:
|
||||
"""Gets an artifact from the artifact service storage.
|
||||
|
||||
The artifact is a file identified by the app name, user ID, session ID, and
|
||||
filename.
|
||||
|
||||
Args:
|
||||
app_name: The app name.
|
||||
user_id: The user ID.
|
||||
session_id: The session ID.
|
||||
filename: The filename of the artifact.
|
||||
version: The version of the artifact. If None, the latest version will be
|
||||
returned.
|
||||
|
||||
Returns:
|
||||
The artifact or None if not found.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def list_artifact_keys(
|
||||
self, *, app_name: str, user_id: str, session_id: str
|
||||
) -> list[str]:
|
||||
"""Lists all the artifact filenames within a session.
|
||||
|
||||
Args:
|
||||
app_name: The name of the application.
|
||||
user_id: The ID of the user.
|
||||
session_id: The ID of the session.
|
||||
|
||||
Returns:
|
||||
A list of all artifact filenames within a session.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def delete_artifact(
|
||||
self, *, app_name: str, user_id: str, session_id: str, filename: str
|
||||
) -> None:
|
||||
"""Deletes an artifact.
|
||||
|
||||
Args:
|
||||
app_name: The name of the application.
|
||||
user_id: The ID of the user.
|
||||
session_id: The ID of the session.
|
||||
filename: The name of the artifact file.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def list_versions(
|
||||
self, *, app_name: str, user_id: str, session_id: str, filename: str
|
||||
) -> list[int]:
|
||||
"""Lists all versions of an artifact.
|
||||
|
||||
Args:
|
||||
app_name: The name of the application.
|
||||
user_id: The ID of the user.
|
||||
session_id: The ID of the session.
|
||||
filename: The name of the artifact file.
|
||||
|
||||
Returns:
|
||||
A list of all available versions of the artifact.
|
||||
"""
|
||||
pass
|
||||
@@ -0,0 +1,195 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""An artifact service implementation using Google Cloud Storage (GCS)."""
|
||||
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
from google.cloud import storage
|
||||
from google.genai import types
|
||||
from typing_extensions import override
|
||||
|
||||
from .base_artifact_service import BaseArtifactService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class GcsArtifactService(BaseArtifactService):
|
||||
"""An artifact service implementation using Google Cloud Storage (GCS)."""
|
||||
|
||||
def __init__(self, bucket_name: str, **kwargs):
|
||||
"""Initializes the GcsArtifactService.
|
||||
|
||||
Args:
|
||||
bucket_name: The name of the bucket to use.
|
||||
**kwargs: Keyword arguments to pass to the Google Cloud Storage client.
|
||||
"""
|
||||
self.bucket_name = bucket_name
|
||||
self.storage_client = storage.Client(**kwargs)
|
||||
self.bucket = self.storage_client.bucket(self.bucket_name)
|
||||
|
||||
def _file_has_user_namespace(self, filename: str) -> bool:
|
||||
"""Checks if the filename has a user namespace.
|
||||
|
||||
Args:
|
||||
filename: The filename to check.
|
||||
|
||||
Returns:
|
||||
True if the filename has a user namespace (starts with "user:"),
|
||||
False otherwise.
|
||||
"""
|
||||
return filename.startswith("user:")
|
||||
|
||||
def _get_blob_name(
|
||||
self,
|
||||
app_name: str,
|
||||
user_id: str,
|
||||
session_id: str,
|
||||
filename: str,
|
||||
version: int,
|
||||
) -> str:
|
||||
"""Constructs the blob name in GCS.
|
||||
|
||||
Args:
|
||||
app_name: The name of the application.
|
||||
user_id: The ID of the user.
|
||||
session_id: The ID of the session.
|
||||
filename: The name of the artifact file.
|
||||
version: The version of the artifact.
|
||||
|
||||
Returns:
|
||||
The constructed blob name in GCS.
|
||||
"""
|
||||
if self._file_has_user_namespace(filename):
|
||||
return f"{app_name}/{user_id}/user/{filename}/{version}"
|
||||
return f"{app_name}/{user_id}/{session_id}/{filename}/{version}"
|
||||
|
||||
@override
|
||||
def save_artifact(
|
||||
self,
|
||||
*,
|
||||
app_name: str,
|
||||
user_id: str,
|
||||
session_id: str,
|
||||
filename: str,
|
||||
artifact: types.Part,
|
||||
) -> int:
|
||||
versions = self.list_versions(
|
||||
app_name=app_name,
|
||||
user_id=user_id,
|
||||
session_id=session_id,
|
||||
filename=filename,
|
||||
)
|
||||
version = 0 if not versions else max(versions) + 1
|
||||
|
||||
blob_name = self._get_blob_name(
|
||||
app_name, user_id, session_id, filename, version
|
||||
)
|
||||
blob = self.bucket.blob(blob_name)
|
||||
|
||||
blob.upload_from_string(
|
||||
data=artifact.inline_data.data,
|
||||
content_type=artifact.inline_data.mime_type,
|
||||
)
|
||||
|
||||
return version
|
||||
|
||||
@override
|
||||
def load_artifact(
|
||||
self,
|
||||
*,
|
||||
app_name: str,
|
||||
user_id: str,
|
||||
session_id: str,
|
||||
filename: str,
|
||||
version: Optional[int] = None,
|
||||
) -> Optional[types.Part]:
|
||||
if version is None:
|
||||
versions = self.list_versions(
|
||||
app_name=app_name,
|
||||
user_id=user_id,
|
||||
session_id=session_id,
|
||||
filename=filename,
|
||||
)
|
||||
if not versions:
|
||||
return None
|
||||
version = max(versions)
|
||||
|
||||
blob_name = self._get_blob_name(
|
||||
app_name, user_id, session_id, filename, version
|
||||
)
|
||||
blob = self.bucket.blob(blob_name)
|
||||
|
||||
artifact_bytes = blob.download_as_bytes()
|
||||
if not artifact_bytes:
|
||||
return None
|
||||
artifact = types.Part.from_bytes(
|
||||
data=artifact_bytes, mime_type=blob.content_type
|
||||
)
|
||||
return artifact
|
||||
|
||||
@override
|
||||
def list_artifact_keys(
|
||||
self, *, app_name: str, user_id: str, session_id: str
|
||||
) -> list[str]:
|
||||
filenames = set()
|
||||
|
||||
session_prefix = f"{app_name}/{user_id}/{session_id}/"
|
||||
session_blobs = self.storage_client.list_blobs(
|
||||
self.bucket, prefix=session_prefix
|
||||
)
|
||||
for blob in session_blobs:
|
||||
_, _, _, filename, _ = blob.name.split("/")
|
||||
filenames.add(filename)
|
||||
|
||||
user_namespace_prefix = f"{app_name}/{user_id}/user/"
|
||||
user_namespace_blobs = self.storage_client.list_blobs(
|
||||
self.bucket, prefix=user_namespace_prefix
|
||||
)
|
||||
for blob in user_namespace_blobs:
|
||||
_, _, _, filename, _ = blob.name.split("/")
|
||||
filenames.add(filename)
|
||||
|
||||
return sorted(list(filenames))
|
||||
|
||||
@override
|
||||
def delete_artifact(
|
||||
self, *, app_name: str, user_id: str, session_id: str, filename: str
|
||||
) -> None:
|
||||
versions = self.list_versions(
|
||||
app_name=app_name,
|
||||
user_id=user_id,
|
||||
session_id=session_id,
|
||||
filename=filename,
|
||||
)
|
||||
for version in versions:
|
||||
blob_name = self._get_blob_name(
|
||||
app_name, user_id, session_id, filename, version
|
||||
)
|
||||
blob = self.bucket.blob(blob_name)
|
||||
blob.delete()
|
||||
return
|
||||
|
||||
@override
|
||||
def list_versions(
|
||||
self, *, app_name: str, user_id: str, session_id: str, filename: str
|
||||
) -> list[int]:
|
||||
prefix = self._get_blob_name(app_name, user_id, session_id, filename, "")
|
||||
blobs = self.storage_client.list_blobs(self.bucket, prefix=prefix)
|
||||
versions = []
|
||||
for blob in blobs:
|
||||
_, _, _, _, version = blob.name.split("/")
|
||||
versions.append(int(version))
|
||||
return versions
|
||||
@@ -0,0 +1,133 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""An in-memory implementation of the artifact service."""
|
||||
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
from google.genai import types
|
||||
from pydantic import BaseModel
|
||||
from pydantic import Field
|
||||
from typing_extensions import override
|
||||
|
||||
from .base_artifact_service import BaseArtifactService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class InMemoryArtifactService(BaseArtifactService, BaseModel):
|
||||
"""An in-memory implementation of the artifact service."""
|
||||
|
||||
artifacts: dict[str, list[types.Part]] = Field(default_factory=dict)
|
||||
|
||||
def _file_has_user_namespace(self, filename: str) -> bool:
|
||||
"""Checks if the filename has a user namespace.
|
||||
|
||||
Args:
|
||||
filename: The filename to check.
|
||||
|
||||
Returns:
|
||||
True if the filename has a user namespace (starts with "user:"),
|
||||
False otherwise.
|
||||
"""
|
||||
return filename.startswith("user:")
|
||||
|
||||
def _artifact_path(
|
||||
self, app_name: str, user_id: str, session_id: str, filename: str
|
||||
) -> str:
|
||||
"""Constructs the artifact path.
|
||||
|
||||
Args:
|
||||
app_name: The name of the application.
|
||||
user_id: The ID of the user.
|
||||
session_id: The ID of the session.
|
||||
filename: The name of the artifact file.
|
||||
|
||||
Returns:
|
||||
The constructed artifact path.
|
||||
"""
|
||||
if self._file_has_user_namespace(filename):
|
||||
return f"{app_name}/{user_id}/user/{filename}"
|
||||
return f"{app_name}/{user_id}/{session_id}/{filename}"
|
||||
|
||||
@override
|
||||
def save_artifact(
|
||||
self,
|
||||
*,
|
||||
app_name: str,
|
||||
user_id: str,
|
||||
session_id: str,
|
||||
filename: str,
|
||||
artifact: types.Part,
|
||||
) -> int:
|
||||
path = self._artifact_path(app_name, user_id, session_id, filename)
|
||||
if path not in self.artifacts:
|
||||
self.artifacts[path] = []
|
||||
version = len(self.artifacts[path])
|
||||
self.artifacts[path].append(artifact)
|
||||
return version
|
||||
|
||||
@override
|
||||
def load_artifact(
|
||||
self,
|
||||
*,
|
||||
app_name: str,
|
||||
user_id: str,
|
||||
session_id: str,
|
||||
filename: str,
|
||||
version: Optional[int] = None,
|
||||
) -> Optional[types.Part]:
|
||||
path = self._artifact_path(app_name, user_id, session_id, filename)
|
||||
versions = self.artifacts.get(path)
|
||||
if not versions:
|
||||
return None
|
||||
if version is None:
|
||||
version = -1
|
||||
return versions[version]
|
||||
|
||||
@override
|
||||
def list_artifact_keys(
|
||||
self, *, app_name: str, user_id: str, session_id: str
|
||||
) -> list[str]:
|
||||
session_prefix = f"{app_name}/{user_id}/{session_id}/"
|
||||
usernamespace_prefix = f"{app_name}/{user_id}/user/"
|
||||
filenames = []
|
||||
for path in self.artifacts:
|
||||
if path.startswith(session_prefix):
|
||||
filename = path.removeprefix(session_prefix)
|
||||
filenames.append(filename)
|
||||
elif path.startswith(usernamespace_prefix):
|
||||
filename = path.removeprefix(usernamespace_prefix)
|
||||
filenames.append(filename)
|
||||
return sorted(filenames)
|
||||
|
||||
@override
|
||||
def delete_artifact(
|
||||
self, *, app_name: str, user_id: str, session_id: str, filename: str
|
||||
) -> None:
|
||||
path = self._artifact_path(app_name, user_id, session_id, filename)
|
||||
if not self.artifacts.get(path):
|
||||
return None
|
||||
self.artifacts.pop(path, None)
|
||||
|
||||
@override
|
||||
def list_versions(
|
||||
self, *, app_name: str, user_id: str, session_id: str, filename: str
|
||||
) -> list[int]:
|
||||
path = self._artifact_path(app_name, user_id, session_id, filename)
|
||||
versions = self.artifacts.get(path)
|
||||
if not versions:
|
||||
return []
|
||||
return list(range(len(versions)))
|
||||
Reference in New Issue
Block a user