feat: Add an option to use gcs artifact service in adk web.

Resolves https://github.com/google/adk-python/issues/309

PiperOrigin-RevId: 765772763
This commit is contained in:
Shangjie Chen 2025-06-01 00:27:39 -07:00 committed by Copybara-Service
parent 0e72efb439
commit 8d36dbda52
4 changed files with 49 additions and 5 deletions

View File

@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import os
import shutil
@ -86,6 +87,7 @@ def to_cloud_run(
with_ui: bool,
verbosity: str,
session_db_url: str,
artifact_storage_uri: Optional[str],
adk_version: str,
):
"""Deploys an agent to Google Cloud Run.
@ -115,6 +117,7 @@ def to_cloud_run(
with_ui: Whether to deploy with UI.
verbosity: The verbosity level of the CLI.
session_db_url: The database URL to connect the session.
artifact_storage_uri: The artifact storage URI to store the artifacts.
adk_version: The ADK version to use in Cloud Run.
"""
app_name = app_name or os.path.basename(agent_folder)
@ -152,6 +155,9 @@ def to_cloud_run(
session_db_option=f'--session_db_url={session_db_url}'
if session_db_url
else '',
artifact_storage_option=f'--artifact_storage_uri={artifact_storage_uri}'
if artifact_storage_uri
else '',
trace_to_cloud_option='--trace_to_cloud' if trace_to_cloud else '',
adk_version=adk_version,
host_option=host_option,

View File

@ -430,6 +430,15 @@ def fast_api_common_options():
- See https://docs.sqlalchemy.org/en/20/core/engines.html#backend-specific-urls for more details on supported DB URLs."""
),
)
@click.option(
"--artifact_storage_uri",
type=str,
help=(
"Optional. The artifact storage URI to store the artifacts,"
" supported URIs: gs://<bucket name> for GCS artifact service."
),
default=None,
)
@click.option(
"--host",
type=str,
@ -490,6 +499,7 @@ def fast_api_common_options():
def cli_web(
agents_dir: str,
session_db_url: str = "",
artifact_storage_uri: Optional[str] = None,
log_level: str = "INFO",
allow_origins: Optional[list[str]] = None,
host: str = "127.0.0.1",
@ -533,6 +543,7 @@ def cli_web(
app = get_fast_api_app(
agents_dir=agents_dir,
session_db_url=session_db_url,
artifact_storage_uri=artifact_storage_uri,
allow_origins=allow_origins,
web=True,
trace_to_cloud=trace_to_cloud,
@ -563,6 +574,7 @@ def cli_web(
def cli_api_server(
agents_dir: str,
session_db_url: str = "",
artifact_storage_uri: Optional[str] = None,
log_level: str = "INFO",
allow_origins: Optional[list[str]] = None,
host: str = "127.0.0.1",
@ -585,6 +597,7 @@ def cli_api_server(
get_fast_api_app(
agents_dir=agents_dir,
session_db_url=session_db_url,
artifact_storage_uri=artifact_storage_uri,
allow_origins=allow_origins,
web=False,
trace_to_cloud=trace_to_cloud,
@ -688,6 +701,15 @@ def cli_api_server(
- See https://docs.sqlalchemy.org/en/20/core/engines.html#backend-specific-urls for more details on supported DB URLs."""
),
)
@click.option(
"--artifact_storage_uri",
type=str,
help=(
"Optional. The artifact storage URI to store the artifacts, supported"
" URIs: gs://<bucket name> for GCS artifact service."
),
default=None,
)
@click.argument(
"agent",
type=click.Path(
@ -716,6 +738,7 @@ def cli_deploy_cloud_run(
with_ui: bool,
verbosity: str,
session_db_url: str,
artifact_storage_uri: Optional[str],
adk_version: str,
):
"""Deploys an agent to Cloud Run.
@ -739,6 +762,7 @@ def cli_deploy_cloud_run(
with_ui=with_ui,
verbosity=verbosity,
session_db_url=session_db_url,
artifact_storage_uri=artifact_storage_uri,
adk_version=adk_version,
)
except Exception as e:

View File

@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import asyncio
@ -56,6 +55,7 @@ from ..agents.live_request_queue import LiveRequest
from ..agents.live_request_queue import LiveRequestQueue
from ..agents.llm_agent import Agent
from ..agents.run_config import StreamingMode
from ..artifacts.gcs_artifact_service import GcsArtifactService
from ..artifacts.in_memory_artifact_service import InMemoryArtifactService
from ..evaluation.eval_case import EvalCase
from ..evaluation.eval_case import SessionInput
@ -193,6 +193,7 @@ def get_fast_api_app(
*,
agents_dir: str,
session_db_url: str = "",
artifact_storage_uri: Optional[str] = None,
allow_origins: Optional[list[str]] = None,
web: bool,
trace_to_cloud: bool = False,
@ -251,13 +252,12 @@ def get_fast_api_app(
runner_dict = {}
# Build the Artifact service
artifact_service = InMemoryArtifactService()
memory_service = InMemoryMemoryService()
eval_sets_manager = LocalEvalSetsManager(agents_dir=agents_dir)
eval_set_results_manager = LocalEvalSetResultsManager(agents_dir=agents_dir)
# Build the Memory service
memory_service = InMemoryMemoryService()
# Build the Session service
agent_engine_id = ""
if session_db_url:
@ -276,6 +276,18 @@ def get_fast_api_app(
else:
session_service = InMemorySessionService()
# Build the Artifact service
if artifact_storage_uri:
if artifact_storage_uri.startswith("gs://"):
gcs_bucket = artifact_storage_uri.split("://")[1]
artifact_service = GcsArtifactService(bucket_name=gcs_bucket)
else:
raise click.ClickException(
"Unsupported artifact storage URI: %s" % artifact_storage_uri
)
else:
artifact_service = InMemoryArtifactService()
# initialize Agent Loader
agent_loader = AgentLoader(agents_dir)

View File

@ -128,6 +128,7 @@ def test_to_cloud_run_happy_path(
with_ui=True,
verbosity="info",
session_db_url="sqlite://",
artifact_storage_uri="gs://bucket",
adk_version="0.0.5",
)
@ -170,6 +171,7 @@ def test_to_cloud_run_cleans_temp_dir(
with_ui=False,
verbosity="info",
session_db_url=None,
artifact_storage_uri=None,
adk_version="0.0.5",
)