From 8d36dbda520b1c0dec148e1e1d84e36ddcb9cb95 Mon Sep 17 00:00:00 2001 From: Shangjie Chen Date: Sun, 1 Jun 2025 00:27:39 -0700 Subject: [PATCH] feat: Add an option to use gcs artifact service in adk web. Resolves https://github.com/google/adk-python/issues/309 PiperOrigin-RevId: 765772763 --- src/google/adk/cli/cli_deploy.py | 6 +++++ src/google/adk/cli/cli_tools_click.py | 24 ++++++++++++++++++++ src/google/adk/cli/fast_api.py | 22 ++++++++++++++---- tests/unittests/cli/utils/test_cli_deploy.py | 2 ++ 4 files changed, 49 insertions(+), 5 deletions(-) diff --git a/src/google/adk/cli/cli_deploy.py b/src/google/adk/cli/cli_deploy.py index 9175d11..a478799 100644 --- a/src/google/adk/cli/cli_deploy.py +++ b/src/google/adk/cli/cli_deploy.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations import os import shutil @@ -86,6 +87,7 @@ def to_cloud_run( with_ui: bool, verbosity: str, session_db_url: str, + artifact_storage_uri: Optional[str], adk_version: str, ): """Deploys an agent to Google Cloud Run. @@ -115,6 +117,7 @@ def to_cloud_run( with_ui: Whether to deploy with UI. verbosity: The verbosity level of the CLI. session_db_url: The database URL to connect the session. + artifact_storage_uri: The artifact storage URI to store the artifacts. adk_version: The ADK version to use in Cloud Run. """ app_name = app_name or os.path.basename(agent_folder) @@ -152,6 +155,9 @@ def to_cloud_run( session_db_option=f'--session_db_url={session_db_url}' if session_db_url else '', + artifact_storage_option=f'--artifact_storage_uri={artifact_storage_uri}' + if artifact_storage_uri + else '', trace_to_cloud_option='--trace_to_cloud' if trace_to_cloud else '', adk_version=adk_version, host_option=host_option, diff --git a/src/google/adk/cli/cli_tools_click.py b/src/google/adk/cli/cli_tools_click.py index 3fcf03a..3d2e5d3 100644 --- a/src/google/adk/cli/cli_tools_click.py +++ b/src/google/adk/cli/cli_tools_click.py @@ -430,6 +430,15 @@ def fast_api_common_options(): - See https://docs.sqlalchemy.org/en/20/core/engines.html#backend-specific-urls for more details on supported DB URLs.""" ), ) + @click.option( + "--artifact_storage_uri", + type=str, + help=( + "Optional. The artifact storage URI to store the artifacts," + " supported URIs: gs:// for GCS artifact service." + ), + default=None, + ) @click.option( "--host", type=str, @@ -490,6 +499,7 @@ def fast_api_common_options(): def cli_web( agents_dir: str, session_db_url: str = "", + artifact_storage_uri: Optional[str] = None, log_level: str = "INFO", allow_origins: Optional[list[str]] = None, host: str = "127.0.0.1", @@ -533,6 +543,7 @@ def cli_web( app = get_fast_api_app( agents_dir=agents_dir, session_db_url=session_db_url, + artifact_storage_uri=artifact_storage_uri, allow_origins=allow_origins, web=True, trace_to_cloud=trace_to_cloud, @@ -563,6 +574,7 @@ def cli_web( def cli_api_server( agents_dir: str, session_db_url: str = "", + artifact_storage_uri: Optional[str] = None, log_level: str = "INFO", allow_origins: Optional[list[str]] = None, host: str = "127.0.0.1", @@ -585,6 +597,7 @@ def cli_api_server( get_fast_api_app( agents_dir=agents_dir, session_db_url=session_db_url, + artifact_storage_uri=artifact_storage_uri, allow_origins=allow_origins, web=False, trace_to_cloud=trace_to_cloud, @@ -688,6 +701,15 @@ def cli_api_server( - See https://docs.sqlalchemy.org/en/20/core/engines.html#backend-specific-urls for more details on supported DB URLs.""" ), ) +@click.option( + "--artifact_storage_uri", + type=str, + help=( + "Optional. The artifact storage URI to store the artifacts, supported" + " URIs: gs:// for GCS artifact service." + ), + default=None, +) @click.argument( "agent", type=click.Path( @@ -716,6 +738,7 @@ def cli_deploy_cloud_run( with_ui: bool, verbosity: str, session_db_url: str, + artifact_storage_uri: Optional[str], adk_version: str, ): """Deploys an agent to Cloud Run. @@ -739,6 +762,7 @@ def cli_deploy_cloud_run( with_ui=with_ui, verbosity=verbosity, session_db_url=session_db_url, + artifact_storage_uri=artifact_storage_uri, adk_version=adk_version, ) except Exception as e: diff --git a/src/google/adk/cli/fast_api.py b/src/google/adk/cli/fast_api.py index f9f89bf..5291dd0 100644 --- a/src/google/adk/cli/fast_api.py +++ b/src/google/adk/cli/fast_api.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. - from __future__ import annotations import asyncio @@ -56,6 +55,7 @@ from ..agents.live_request_queue import LiveRequest from ..agents.live_request_queue import LiveRequestQueue from ..agents.llm_agent import Agent from ..agents.run_config import StreamingMode +from ..artifacts.gcs_artifact_service import GcsArtifactService from ..artifacts.in_memory_artifact_service import InMemoryArtifactService from ..evaluation.eval_case import EvalCase from ..evaluation.eval_case import SessionInput @@ -193,6 +193,7 @@ def get_fast_api_app( *, agents_dir: str, session_db_url: str = "", + artifact_storage_uri: Optional[str] = None, allow_origins: Optional[list[str]] = None, web: bool, trace_to_cloud: bool = False, @@ -251,13 +252,12 @@ def get_fast_api_app( runner_dict = {} - # Build the Artifact service - artifact_service = InMemoryArtifactService() - memory_service = InMemoryMemoryService() - eval_sets_manager = LocalEvalSetsManager(agents_dir=agents_dir) eval_set_results_manager = LocalEvalSetResultsManager(agents_dir=agents_dir) + # Build the Memory service + memory_service = InMemoryMemoryService() + # Build the Session service agent_engine_id = "" if session_db_url: @@ -276,6 +276,18 @@ def get_fast_api_app( else: session_service = InMemorySessionService() + # Build the Artifact service + if artifact_storage_uri: + if artifact_storage_uri.startswith("gs://"): + gcs_bucket = artifact_storage_uri.split("://")[1] + artifact_service = GcsArtifactService(bucket_name=gcs_bucket) + else: + raise click.ClickException( + "Unsupported artifact storage URI: %s" % artifact_storage_uri + ) + else: + artifact_service = InMemoryArtifactService() + # initialize Agent Loader agent_loader = AgentLoader(agents_dir) diff --git a/tests/unittests/cli/utils/test_cli_deploy.py b/tests/unittests/cli/utils/test_cli_deploy.py index 6f8c291..316aa04 100644 --- a/tests/unittests/cli/utils/test_cli_deploy.py +++ b/tests/unittests/cli/utils/test_cli_deploy.py @@ -128,6 +128,7 @@ def test_to_cloud_run_happy_path( with_ui=True, verbosity="info", session_db_url="sqlite://", + artifact_storage_uri="gs://bucket", adk_version="0.0.5", ) @@ -170,6 +171,7 @@ def test_to_cloud_run_cleans_temp_dir( with_ui=False, verbosity="info", session_db_url=None, + artifact_storage_uri=None, adk_version="0.0.5", )