Agent Development Kit(ADK)

An easy-to-use and powerful framework to build AI agents.
This commit is contained in:
hangfei
2025-04-08 17:22:09 +00:00
parent f92478bd5c
commit 9827820143
299 changed files with 44398 additions and 2 deletions

View File

@@ -0,0 +1,41 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from .base_session_service import BaseSessionService
from .in_memory_session_service import InMemorySessionService
from .session import Session
from .state import State
from .vertex_ai_session_service import VertexAiSessionService
logger = logging.getLogger(__name__)
__all__ = [
'BaseSessionService',
'InMemorySessionService',
'Session',
'State',
'VertexAiSessionService',
]
try:
from .database_session_service import DatabaseSessionService
__all__.append('DatabaseSessionService')
except ImportError:
logger.debug(
'DatabaseSessionService require sqlalchemy>=2.0, please ensure it is'
' installed correctly.'
)

View File

@@ -0,0 +1,133 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import abc
from typing import Any
from typing import Optional
from pydantic import BaseModel
from pydantic import Field
from ..events.event import Event
from .session import Session
from .state import State
class GetSessionConfig(BaseModel):
"""The configuration of getting a session."""
num_recent_events: Optional[int] = None
after_timestamp: Optional[float] = None
class ListSessionsResponse(BaseModel):
"""The response of listing sessions.
The events and states are not set within each Session object.
"""
sessions: list[Session] = Field(default_factory=list)
class ListEventsResponse(BaseModel):
"""The response of listing events in a session."""
events: list[Event] = Field(default_factory=list)
next_page_token: Optional[str] = None
class BaseSessionService(abc.ABC):
"""Base class for session services.
The service provides a set of methods for managing sessions and events.
"""
@abc.abstractmethod
def create_session(
self,
*,
app_name: str,
user_id: str,
state: Optional[dict[str, Any]] = None,
session_id: Optional[str] = None,
) -> Session:
"""Creates a new session.
Args:
app_name: the name of the app.
user_id: the id of the user.
state: the initial state of the session.
session_id: the client-provided id of the session. If not provided, a
generated ID will be used.
Returns:
session: The newly created session instance.
"""
pass
@abc.abstractmethod
def get_session(
self,
*,
app_name: str,
user_id: str,
session_id: str,
config: Optional[GetSessionConfig] = None,
) -> Optional[Session]:
"""Gets a session."""
pass
@abc.abstractmethod
def list_sessions(
self, *, app_name: str, user_id: str
) -> ListSessionsResponse:
"""Lists all the sessions."""
pass
@abc.abstractmethod
def delete_session(
self, *, app_name: str, user_id: str, session_id: str
) -> None:
"""Deletes a session."""
pass
@abc.abstractmethod
def list_events(
self,
*,
app_name: str,
user_id: str,
session_id: str,
) -> ListEventsResponse:
"""Lists events in a session."""
pass
def close_session(self, *, session: Session):
"""Closes a session."""
# TODO: determine whether we want to finalize the session here.
pass
def append_event(self, session: Session, event: Event) -> Event:
"""Appends an event to a session object."""
if event.partial:
return event
self.__update_session_state(session, event)
session.events.append(event)
return event
def __update_session_state(self, session: Session, event: Event):
"""Updates the session state based on the event."""
if not event.actions or not event.actions.state_delta:
return
for key, value in event.actions.state_delta.items():
if key.startswith(State.TEMP_PREFIX):
continue
session.state.update({key: value})

View File

@@ -0,0 +1,522 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
from datetime import datetime
import json
import logging
from typing import Any
from typing import Optional
import uuid
from sqlalchemy import delete
from sqlalchemy import Dialect
from sqlalchemy import ForeignKeyConstraint
from sqlalchemy import func
from sqlalchemy import select
from sqlalchemy import Text
from sqlalchemy.dialects import postgresql
from sqlalchemy.engine import create_engine
from sqlalchemy.engine import Engine
from sqlalchemy.ext.mutable import MutableDict
from sqlalchemy.inspection import inspect
from sqlalchemy.orm import DeclarativeBase
from sqlalchemy.orm import Mapped
from sqlalchemy.orm import mapped_column
from sqlalchemy.orm import relationship
from sqlalchemy.orm import Session as DatabaseSessionFactory
from sqlalchemy.orm import sessionmaker
from sqlalchemy.schema import MetaData
from sqlalchemy.types import DateTime
from sqlalchemy.types import PickleType
from sqlalchemy.types import String
from sqlalchemy.types import TypeDecorator
from typing_extensions import override
from tzlocal import get_localzone
from ..events.event import Event
from .base_session_service import BaseSessionService
from .base_session_service import GetSessionConfig
from .base_session_service import ListEventsResponse
from .base_session_service import ListSessionsResponse
from .session import Session
from .state import State
logger = logging.getLogger(__name__)
class DynamicJSON(TypeDecorator):
"""A JSON-like type that uses JSONB on PostgreSQL and TEXT with JSON
serialization for other databases.
"""
impl = Text # Default implementation is TEXT
def load_dialect_impl(self, dialect: Dialect):
if dialect.name == "postgresql":
return dialect.type_descriptor(postgresql.JSONB)
else:
return dialect.type_descriptor(Text) # Default to Text for other dialects
def process_bind_param(self, value, dialect: Dialect):
if value is not None:
if dialect.name == "postgresql":
return value # JSONB handles dict directly
else:
return json.dumps(value) # Serialize to JSON string for TEXT
return value
def process_result_value(self, value, dialect: Dialect):
if value is not None:
if dialect.name == "postgresql":
return value # JSONB returns dict directly
else:
return json.loads(value) # Deserialize from JSON string for TEXT
return value
class Base(DeclarativeBase):
"""Base class for database tables."""
pass
class StorageSession(Base):
"""Represents a session stored in the database."""
__tablename__ = "sessions"
app_name: Mapped[str] = mapped_column(String, primary_key=True)
user_id: Mapped[str] = mapped_column(String, primary_key=True)
id: Mapped[str] = mapped_column(
String, primary_key=True, default=lambda: str(uuid.uuid4())
)
state: Mapped[dict] = mapped_column(
MutableDict.as_mutable(DynamicJSON), default={}
)
create_time: Mapped[DateTime] = mapped_column(DateTime(), default=func.now())
update_time: Mapped[DateTime] = mapped_column(
DateTime(), default=func.now(), onupdate=func.now()
)
storage_events: Mapped[list["StorageEvent"]] = relationship(
"StorageEvent",
back_populates="storage_session",
)
def __repr__(self):
return f"<StorageSession(id={self.id}, update_time={self.update_time})>"
class StorageEvent(Base):
"""Represents an event stored in the database."""
__tablename__ = "events"
id: Mapped[str] = mapped_column(String, primary_key=True)
app_name: Mapped[str] = mapped_column(String, primary_key=True)
user_id: Mapped[str] = mapped_column(String, primary_key=True)
session_id: Mapped[str] = mapped_column(String, primary_key=True)
invocation_id: Mapped[str] = mapped_column(String)
author: Mapped[str] = mapped_column(String)
branch: Mapped[str] = mapped_column(String, nullable=True)
timestamp: Mapped[DateTime] = mapped_column(DateTime(), default=func.now())
content: Mapped[dict] = mapped_column(DynamicJSON)
actions: Mapped[dict] = mapped_column(PickleType)
storage_session: Mapped[StorageSession] = relationship(
"StorageSession",
back_populates="storage_events",
)
__table_args__ = (
ForeignKeyConstraint(
["app_name", "user_id", "session_id"],
["sessions.app_name", "sessions.user_id", "sessions.id"],
ondelete="CASCADE",
),
)
class StorageAppState(Base):
"""Represents an app state stored in the database."""
__tablename__ = "app_states"
app_name: Mapped[str] = mapped_column(String, primary_key=True)
state: Mapped[dict] = mapped_column(
MutableDict.as_mutable(DynamicJSON), default={}
)
update_time: Mapped[DateTime] = mapped_column(
DateTime(), default=func.now(), onupdate=func.now()
)
class StorageUserState(Base):
"""Represents a user state stored in the database."""
__tablename__ = "user_states"
app_name: Mapped[str] = mapped_column(String, primary_key=True)
user_id: Mapped[str] = mapped_column(String, primary_key=True)
state: Mapped[dict] = mapped_column(
MutableDict.as_mutable(DynamicJSON), default={}
)
update_time: Mapped[DateTime] = mapped_column(
DateTime(), default=func.now(), onupdate=func.now()
)
class DatabaseSessionService(BaseSessionService):
"""A session service that uses a database for storage."""
def __init__(self, db_url: str):
"""
Args:
db_url: The database URL to connect to.
"""
# 1. Create DB engine for db connection
# 2. Create all tables based on schema
# 3. Initialize all properies
supported_dialects = ["postgresql", "mysql", "sqlite"]
dialect = db_url.split("://")[0]
if dialect in supported_dialects:
db_engine = create_engine(db_url)
else:
raise ValueError(f"Unsupported database URL: {db_url}")
# Get the local timezone
local_timezone = get_localzone()
logger.info(f"Local timezone: {local_timezone}")
self.db_engine: Engine = db_engine
self.metadata: MetaData = MetaData()
self.inspector = inspect(self.db_engine)
# DB session factory method
self.DatabaseSessionFactory: sessionmaker[DatabaseSessionFactory] = (
sessionmaker(bind=self.db_engine)
)
# Uncomment to recreate DB every time
# Base.metadata.drop_all(self.db_engine)
Base.metadata.create_all(self.db_engine)
@override
def create_session(
self,
*,
app_name: str,
user_id: str,
state: Optional[dict[str, Any]] = None,
session_id: Optional[str] = None,
) -> Session:
# 1. Populate states.
# 2. Build storage session object
# 3. Add the object to the table
# 4. Build the session object with generated id
# 5. Return the session
with self.DatabaseSessionFactory() as sessionFactory:
# Fetch app and user states from storage
storage_app_state = sessionFactory.get(StorageAppState, (app_name))
storage_user_state = sessionFactory.get(
StorageUserState, (app_name, user_id)
)
app_state = storage_app_state.state if storage_app_state else {}
user_state = storage_user_state.state if storage_user_state else {}
# Create state tables if not exist
if not storage_app_state:
storage_app_state = StorageAppState(app_name=app_name, state={})
sessionFactory.add(storage_app_state)
if not storage_user_state:
storage_user_state = StorageUserState(
app_name=app_name, user_id=user_id, state={}
)
sessionFactory.add(storage_user_state)
# Extract state deltas
app_state_delta, user_state_delta, session_state = _extract_state_delta(
state
)
# Apply state delta
app_state.update(app_state_delta)
user_state.update(user_state_delta)
# Store app and user state
if app_state_delta:
storage_app_state.state = app_state
if user_state_delta:
storage_user_state.state = user_state
# Store the session
storage_session = StorageSession(
app_name=app_name,
user_id=user_id,
id=session_id,
state=session_state,
)
sessionFactory.add(storage_session)
sessionFactory.commit()
sessionFactory.refresh(storage_session)
# Merge states for response
merged_state = _merge_state(app_state, user_state, session_state)
session = Session(
app_name=str(storage_session.app_name),
user_id=str(storage_session.user_id),
id=str(storage_session.id),
state=merged_state,
last_update_time=storage_session.update_time.timestamp(),
)
return session
return None
@override
def get_session(
self,
*,
app_name: str,
user_id: str,
session_id: str,
config: Optional[GetSessionConfig] = None,
) -> Optional[Session]:
# 1. Get the storage session entry from session table
# 2. Get all the events based on session id and filtering config
# 3. Convert and return the session
session: Session = None
with self.DatabaseSessionFactory() as sessionFactory:
storage_session = sessionFactory.get(
StorageSession, (app_name, user_id, session_id)
)
if storage_session is None:
return None
storage_events = (
sessionFactory.query(StorageEvent)
.filter(StorageEvent.session_id == storage_session.id)
.filter(
StorageEvent.timestamp < config.after_timestamp
if config
else True
)
.limit(config.num_recent_events if config else None)
.all()
)
# Fetch states from storage
storage_app_state = sessionFactory.get(StorageAppState, (app_name))
storage_user_state = sessionFactory.get(
StorageUserState, (app_name, user_id)
)
app_state = storage_app_state.state if storage_app_state else {}
user_state = storage_user_state.state if storage_user_state else {}
session_state = storage_session.state
# Merge states
merged_state = _merge_state(app_state, user_state, session_state)
# Convert storage session to session
session = Session(
app_name=app_name,
user_id=user_id,
id=session_id,
state=merged_state,
last_update_time=storage_session.update_time.timestamp(),
)
session.events = [
Event(
id=e.id,
author=e.author,
branch=e.branch,
invocation_id=e.invocation_id,
content=e.content,
actions=e.actions,
timestamp=e.timestamp.timestamp(),
)
for e in storage_events
]
return session
@override
def list_sessions(
self, *, app_name: str, user_id: str
) -> ListSessionsResponse:
with self.DatabaseSessionFactory() as sessionFactory:
results = (
sessionFactory.query(StorageSession)
.filter(StorageSession.app_name == app_name)
.filter(StorageSession.user_id == user_id)
.all()
)
sessions = []
for storage_session in results:
session = Session(
app_name=app_name,
user_id=user_id,
id=storage_session.id,
state={},
last_update_time=storage_session.update_time.timestamp(),
)
sessions.append(session)
return ListSessionsResponse(sessions=sessions)
raise ValueError("Failed to retrieve sessions.")
@override
def delete_session(
self, app_name: str, user_id: str, session_id: str
) -> None:
with self.DatabaseSessionFactory() as sessionFactory:
stmt = delete(StorageSession).where(
StorageSession.app_name == app_name,
StorageSession.user_id == user_id,
StorageSession.id == session_id,
)
sessionFactory.execute(stmt)
sessionFactory.commit()
@override
def append_event(self, session: Session, event: Event) -> Event:
logger.info(f"Append event: {event} to session {session.id}")
if event.partial and not event.content:
return event
# 1. Check if timestamp is stale
# 2. Update session attributes based on event config
# 3. Store event to table
with self.DatabaseSessionFactory() as sessionFactory:
storage_session = sessionFactory.get(
StorageSession, (session.app_name, session.user_id, session.id)
)
if storage_session.update_time.timestamp() > session.last_update_time:
raise ValueError(
f"Session last_update_time {session.last_update_time} is later than"
f" the upate_time in storage {storage_session.update_time}"
)
# Fetch states from storage
storage_app_state = sessionFactory.get(
StorageAppState, (session.app_name)
)
storage_user_state = sessionFactory.get(
StorageUserState, (session.app_name, session.user_id)
)
app_state = storage_app_state.state if storage_app_state else {}
user_state = storage_user_state.state if storage_user_state else {}
session_state = storage_session.state
# Extract state delta
app_state_delta = {}
user_state_delta = {}
session_state_delta = {}
if event.actions:
if event.actions.state_delta:
app_state_delta, user_state_delta, session_state_delta = (
_extract_state_delta(event.actions.state_delta)
)
# Merge state
app_state.update(app_state_delta)
user_state.update(user_state_delta)
session_state.update(session_state_delta)
# Update storage
storage_app_state.state = app_state
storage_user_state.state = user_state
storage_session.state = session_state
encoded_content = event.content.model_dump(exclude_none=True)
storage_event = StorageEvent(
id=event.id,
invocation_id=event.invocation_id,
author=event.author,
branch=event.branch,
content=encoded_content,
actions=event.actions,
session_id=session.id,
app_name=session.app_name,
user_id=session.user_id,
timestamp=datetime.fromtimestamp(event.timestamp),
)
sessionFactory.add(storage_event)
sessionFactory.commit()
sessionFactory.refresh(storage_session)
# Update timestamp with commit time
session.last_update_time = storage_session.update_time.timestamp()
# Also update the in-memory session
super().append_event(session=session, event=event)
return event
@override
def list_events(
self,
*,
app_name: str,
user_id: str,
session_id: str,
) -> ListEventsResponse:
pass
def convert_event(event: StorageEvent) -> Event:
"""Converts a storage event to an event."""
return Event(
id=event.id,
author=event.author,
branch=event.branch,
invocation_id=event.invocation_id,
content=event.content,
actions=event.actions,
timestamp=event.timestamp.timestamp(),
)
def _extract_state_delta(state: dict):
app_state_delta = {}
user_state_delta = {}
session_state_delta = {}
if state:
for key in state.keys():
if key.startswith(State.APP_PREFIX):
app_state_delta[key.removeprefix(State.APP_PREFIX)] = state[key]
elif key.startswith(State.USER_PREFIX):
user_state_delta[key.removeprefix(State.USER_PREFIX)] = state[key]
elif not key.startswith(State.TEMP_PREFIX):
session_state_delta[key] = state[key]
return app_state_delta, user_state_delta, session_state_delta
def _merge_state(app_state, user_state, session_state):
# Merge states for response
merged_state = copy.deepcopy(session_state)
for key in app_state.keys():
merged_state[State.APP_PREFIX + key] = app_state[key]
for key in user_state.keys():
merged_state[State.USER_PREFIX + key] = user_state[key]
return merged_state

View File

@@ -0,0 +1,206 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import time
from typing import Any
from typing import Optional
import uuid
from typing_extensions import override
from ..events.event import Event
from .base_session_service import BaseSessionService
from .base_session_service import GetSessionConfig
from .base_session_service import ListEventsResponse
from .base_session_service import ListSessionsResponse
from .session import Session
from .state import State
class InMemorySessionService(BaseSessionService):
"""An in-memory implementation of the session service."""
def __init__(self):
# A map from app name to a map from user ID to a map from session ID to session.
self.sessions: dict[str, dict[str, dict[str, Session]]] = {}
# A map from app name to a map from user ID to a map from key to the value.
self.user_state: dict[str, dict[str, dict[str, Any]]] = {}
# A map from app name to a map from key to the value.
self.app_state: dict[str, dict[str, Any]] = {}
@override
def create_session(
self,
*,
app_name: str,
user_id: str,
state: Optional[dict[str, Any]] = None,
session_id: Optional[str] = None,
) -> Session:
session_id = (
session_id.strip()
if session_id and session_id.strip()
else str(uuid.uuid4())
)
session = Session(
app_name=app_name,
user_id=user_id,
id=session_id,
state=state or {},
last_update_time=time.time(),
)
if app_name not in self.sessions:
self.sessions[app_name] = {}
if user_id not in self.sessions[app_name]:
self.sessions[app_name][user_id] = {}
self.sessions[app_name][user_id][session_id] = session
copied_session = copy.deepcopy(session)
return self._merge_state(app_name, user_id, copied_session)
@override
def get_session(
self,
*,
app_name: str,
user_id: str,
session_id: str,
config: Optional[GetSessionConfig] = None,
) -> Session:
if app_name not in self.sessions:
return None
if user_id not in self.sessions[app_name]:
return None
if session_id not in self.sessions[app_name][user_id]:
return None
session = self.sessions[app_name][user_id].get(session_id)
copied_session = copy.deepcopy(session)
if config:
if config.num_recent_events:
copied_session.events = copied_session.events[
-config.num_recent_events :
]
elif config.after_timestamp:
i = len(session.events) - 1
while i >= 0:
if copied_session.events[i].timestamp < config.after_timestamp:
break
i -= 1
if i >= 0:
copied_session.events = copied_session.events[i:]
return self._merge_state(app_name, user_id, copied_session)
def _merge_state(self, app_name: str, user_id: str, copied_session: Session):
# Merge app state
if app_name in self.app_state:
for key in self.app_state[app_name].keys():
copied_session.state[State.APP_PREFIX + key] = self.app_state[app_name][
key
]
if (
app_name not in self.user_state
or user_id not in self.user_state[app_name]
):
return copied_session
# Merge session state with user state.
for key in self.user_state[app_name][user_id].keys():
copied_session.state[State.USER_PREFIX + key] = self.user_state[app_name][
user_id
][key]
return copied_session
@override
def list_sessions(
self, *, app_name: str, user_id: str
) -> ListSessionsResponse:
empty_response = ListSessionsResponse()
if app_name not in self.sessions:
return empty_response
if user_id not in self.sessions[app_name]:
return empty_response
sessions_without_events = []
for session in self.sessions[app_name][user_id].values():
copied_session = copy.deepcopy(session)
copied_session.events = []
copied_session.state = {}
sessions_without_events.append(copied_session)
return ListSessionsResponse(sessions=sessions_without_events)
@override
def delete_session(
self, *, app_name: str, user_id: str, session_id: str
) -> None:
if (
self.get_session(
app_name=app_name, user_id=user_id, session_id=session_id
)
is None
):
return None
self.sessions[app_name][user_id].pop(session_id)
@override
def append_event(self, session: Session, event: Event) -> Event:
# Update the in-memory session.
super().append_event(session=session, event=event)
session.last_update_time = event.timestamp
# Update the storage session
app_name = session.app_name
user_id = session.user_id
session_id = session.id
if app_name not in self.sessions:
return event
if user_id not in self.sessions[app_name]:
return event
if session_id not in self.sessions[app_name][user_id]:
return event
if event.actions and event.actions.state_delta:
for key in event.actions.state_delta:
if key.startswith(State.APP_PREFIX):
self.app_state.setdefault(app_name, {})[
key.removeprefix(State.APP_PREFIX)
] = event.actions.state_delta[key]
if key.startswith(State.USER_PREFIX):
self.user_state.setdefault(app_name, {}).setdefault(user_id, {})[
key.removeprefix(State.USER_PREFIX)
] = event.actions.state_delta[key]
storage_session = self.sessions[app_name][user_id].get(session_id)
super().append_event(session=storage_session, event=event)
storage_session.last_update_time = event.timestamp
return event
@override
def list_events(
self,
*,
app_name: str,
user_id: str,
session_id: str,
) -> ListEventsResponse:
raise NotImplementedError()

View File

@@ -0,0 +1,54 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any
from pydantic import BaseModel
from pydantic import ConfigDict
from pydantic import Field
from ..events.event import Event
class Session(BaseModel):
"""Represents a series of interactions between a user and agents.
Attributes:
id: The unique identifier of the session.
app_name: The name of the app.
user_id: The id of the user.
state: The state of the session.
events: The events of the session, e.g. user input, model response, function
call/response, etc.
last_update_time: The last update time of the session.
"""
model_config = ConfigDict(
extra='forbid',
arbitrary_types_allowed=True,
)
id: str
"""The unique identifier of the session."""
app_name: str
"""The name of the app."""
user_id: str
"""The id of the user."""
state: dict[str, Any] = Field(default_factory=dict)
"""The state of the session."""
events: list[Event] = Field(default_factory=list)
"""The events of the session, e.g. user input, model response, function
call/response, etc."""
last_update_time: float = 0.0
"""The last update time of the session."""

View File

@@ -0,0 +1,71 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any
class State:
"""A state dict that maintain the current value and the pending-commit delta."""
APP_PREFIX = "app:"
USER_PREFIX = "user:"
TEMP_PREFIX = "temp:"
def __init__(self, value: dict[str, Any], delta: dict[str, Any]):
"""
Args:
value: The current value of the state dict.
delta: The delta change to the current value that hasn't been commited.
"""
self._value = value
self._delta = delta
def __getitem__(self, key: str) -> Any:
"""Returns the value of the state dict for the given key."""
if key in self._delta:
return self._delta[key]
return self._value[key]
def __setitem__(self, key: str, value: Any):
"""Sets the value of the state dict for the given key."""
# TODO: make new change only store in delta, so that self._value is only
# updated at the storage commit time.
self._value[key] = value
self._delta[key] = value
def __contains__(self, key: str) -> bool:
"""Whether the state dict contains the given key."""
return key in self._value or key in self._delta
def has_delta(self) -> bool:
"""Whether the state has pending detla."""
return bool(self._delta)
def get(self, key: str, default: Any = None) -> Any:
"""Returns the value of the state dict for the given key."""
if key not in self:
return default
return self[key]
def update(self, delta: dict[str, Any]):
"""Updates the state dict with the given delta."""
self._value.update(delta)
self._delta.update(delta)
def to_dict(self) -> dict[str, Any]:
"""Returns the state dict."""
result = {}
result.update(self._value)
result.update(self._delta)
return result

View File

@@ -0,0 +1,356 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import re
import time
from typing import Any
from typing import Optional
from dateutil.parser import isoparse
from google import genai
from typing_extensions import override
from ..events.event import Event
from ..events.event_actions import EventActions
from .base_session_service import BaseSessionService
from .base_session_service import GetSessionConfig
from .base_session_service import ListEventsResponse
from .base_session_service import ListSessionsResponse
from .session import Session
logger = logging.getLogger(__name__)
class VertexAiSessionService(BaseSessionService):
"""Connects to the managed Vertex AI Session Service."""
def __init__(
self,
project: str = None,
location: str = None,
):
self.project = project
self.location = location
client = genai.Client(vertexai=True, project=project, location=location)
self.api_client = client._api_client
@override
def create_session(
self,
*,
app_name: str,
user_id: str,
state: Optional[dict[str, Any]] = None,
session_id: Optional[str] = None,
) -> Session:
reasoning_engine_id = _parse_reasoning_engine_id(app_name)
session_json_dict = {'user_id': user_id}
if state:
session_json_dict['session_state'] = state
api_response = self.api_client.request(
http_method='POST',
path=f'reasoningEngines/{reasoning_engine_id}/sessions',
request_dict=session_json_dict,
)
logger.info(f'Create Session response {api_response}')
session_id = api_response['name'].split('/')[-3]
operation_id = api_response['name'].split('/')[-1]
max_retry_attempt = 5
while max_retry_attempt >= 0:
lro_response = self.api_client.request(
http_method='GET',
path=f'operations/{operation_id}',
request_dict={},
)
if lro_response.get('done', None):
break
time.sleep(1)
max_retry_attempt -= 1
# Get session resource
get_session_api_response = self.api_client.request(
http_method='GET',
path=f'reasoningEngines/{reasoning_engine_id}/sessions/{session_id}',
request_dict={},
)
update_timestamp = isoparse(
get_session_api_response['updateTime']
).timestamp()
session = Session(
app_name=str(app_name),
user_id=str(user_id),
id=str(session_id),
state=get_session_api_response.get('sessionState', {}),
last_update_time=update_timestamp,
)
return session
@override
def get_session(
self,
*,
app_name: str,
user_id: str,
session_id: str,
config: Optional[GetSessionConfig] = None,
) -> Session:
reasoning_engine_id = _parse_reasoning_engine_id(app_name)
# Get session resource
get_session_api_response = self.api_client.request(
http_method='GET',
path=f'reasoningEngines/{reasoning_engine_id}/sessions/{session_id}',
request_dict={},
)
session_id = get_session_api_response['name'].split('/')[-1]
update_timestamp = isoparse(
get_session_api_response['updateTime']
).timestamp()
session = Session(
app_name=str(app_name),
user_id=str(user_id),
id=str(session_id),
state=get_session_api_response.get('sessionState', {}),
last_update_time=update_timestamp,
)
list_events_api_response = self.api_client.request(
http_method='GET',
path=f'reasoningEngines/{reasoning_engine_id}/sessions/{session_id}/events',
request_dict={},
)
# Handles empty response case
if list_events_api_response.get('httpHeaders', None):
return session
session.events = [
_from_api_event(event)
for event in list_events_api_response['sessionEvents']
]
session.events = [
event for event in session.events if event.timestamp <= update_timestamp
]
session.events.sort(key=lambda event: event.timestamp)
if config:
if config.num_recent_events:
session.events = session.events[-config.num_recent_events :]
elif config.after_timestamp:
i = len(session.events) - 1
while i >= 0:
if session.events[i].timestamp < config.after_timestamp:
break
i -= 1
if i >= 0:
session.events = session.events[i:]
return session
@override
def list_sessions(
self, *, app_name: str, user_id: str
) -> ListSessionsResponse:
reasoning_engine_id = _parse_reasoning_engine_id(app_name)
api_response = self.api_client.request(
http_method='GET',
path=f'reasoningEngines/{reasoning_engine_id}/sessions?filter=user_id={user_id}',
request_dict={},
)
# Handles empty response case
if api_response.get('httpHeaders', None):
return ListSessionsResponse()
sessions = []
for api_session in api_response['sessions']:
session = Session(
app_name=app_name,
user_id=user_id,
id=api_session['name'].split('/')[-1],
state={},
last_update_time=isoparse(api_session['updateTime']).timestamp(),
)
sessions.append(session)
return ListSessionsResponse(sessions=sessions)
def delete_session(
self, *, app_name: str, user_id: str, session_id: str
) -> None:
reasoning_engine_id = _parse_reasoning_engine_id(app_name)
self.api_client.request(
http_method='DELETE',
path=f'reasoningEngines/{reasoning_engine_id}/sessions/{session_id}',
request_dict={},
)
@override
def list_events(
self,
*,
app_name: str,
user_id: str,
session_id: str,
) -> ListEventsResponse:
reasoning_engine_id = _parse_reasoning_engine_id(app_name)
api_response = self.api_client.request(
http_method='GET',
path=f'reasoningEngines/{reasoning_engine_id}/sessions/{session_id}/events',
request_dict={},
)
logger.info(f'List events response {api_response}')
# Handles empty response case
if api_response.get('httpHeaders', None):
return ListEventsResponse()
session_events = api_response['sessionEvents']
return ListEventsResponse(
events=[_from_api_event(event) for event in session_events]
)
@override
def append_event(self, session: Session, event: Event) -> Event:
# Update the in-memory session.
super().append_event(session=session, event=event)
reasoning_engine_id = _parse_reasoning_engine_id(session.app_name)
self.api_client.request(
http_method='POST',
path=f'reasoningEngines/{reasoning_engine_id}/sessions/{session.id}:appendEvent',
request_dict=_convert_event_to_json(event),
)
return event
def _convert_event_to_json(event: Event):
metadata_json = {
'partial': event.partial,
'turn_complete': event.turn_complete,
'interrupted': event.interrupted,
'branch': event.branch,
'long_running_tool_ids': (
list(event.long_running_tool_ids)
if event.long_running_tool_ids
else None
),
}
if event.grounding_metadata:
metadata_json['grounding_metadata'] = event.grounding_metadata.model_dump(
exclude_none=True
)
event_json = {
'author': event.author,
'invocation_id': event.invocation_id,
'timestamp': {
'seconds': int(event.timestamp),
'nanos': int(
(event.timestamp - int(event.timestamp)) * 1_000_000_000
),
},
'error_code': event.error_code,
'error_message': event.error_message,
'event_metadata': metadata_json,
}
if event.actions:
actions_json = {
'skip_summarization': event.actions.skip_summarization,
'state_delta': event.actions.state_delta,
'artifact_delta': event.actions.artifact_delta,
'transfer_agent': event.actions.transfer_to_agent,
'escalate': event.actions.escalate,
'requested_auth_configs': event.actions.requested_auth_configs,
}
event_json['actions'] = actions_json
if event.content:
event_json['content'] = event.content.model_dump(exclude_none=True)
if event.error_code:
event_json['error_code'] = event.error_code
if event.error_message:
event_json['error_message'] = event.error_message
return event_json
def _from_api_event(api_event: dict) -> Event:
event_actions = EventActions()
if api_event.get('actions', None):
event_actions = EventActions(
skip_summarization=api_event['actions'].get('skipSummarization', None),
state_delta=api_event['actions'].get('stateDelta', {}),
artifact_delta=api_event['actions'].get('artifactDelta', {}),
transfer_to_agent=api_event['actions'].get('transferAgent', None),
escalate=api_event['actions'].get('escalate', None),
requested_auth_configs=api_event['actions'].get(
'requestedAuthConfigs', {}
),
)
event = Event(
id=api_event['name'].split('/')[-1],
invocation_id=api_event['invocationId'],
author=api_event['author'],
actions=event_actions,
content=api_event.get('content', None),
timestamp=isoparse(api_event['timestamp']).timestamp(),
error_code=api_event.get('errorCode', None),
error_message=api_event.get('errorMessage', None),
)
if api_event.get('eventMetadata', None):
long_running_tool_ids_list = api_event['eventMetadata'].get(
'longRunningToolIds', None
)
event.partial = api_event['eventMetadata'].get('partial', None)
event.turn_complete = api_event['eventMetadata'].get('turnComplete', None)
event.interrupted = api_event['eventMetadata'].get('interrupted', None)
event.branch = api_event['eventMetadata'].get('branch', None)
event.grounding_metadata = api_event['eventMetadata'].get(
'groundingMetadata', None
)
event.long_running_tool_ids = (
set(long_running_tool_ids_list) if long_running_tool_ids_list else None
)
return event
def _parse_reasoning_engine_id(app_name: str):
if app_name.isdigit():
return app_name
pattern = r'^projects/([a-zA-Z0-9-_]+)/locations/([a-zA-Z0-9-_]+)/reasoningEngines/(\d+)$'
match = re.fullmatch(pattern, app_name)
if not bool(match):
raise ValueError(
f'App name {app_name} is not valid. It should either be the full'
' ReasoningEngine resource name, or the reasoning engine id.'
)
return match.groups()[-1]