mirror of
https://github.com/EvolutionAPI/adk-python.git
synced 2025-12-24 14:17:45 -06:00
Agent Development Kit(ADK)
An easy-to-use and powerful framework to build AI agents.
This commit is contained in:
41
src/google/adk/sessions/__init__.py
Normal file
41
src/google/adk/sessions/__init__.py
Normal file
@@ -0,0 +1,41 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import logging
|
||||
|
||||
from .base_session_service import BaseSessionService
|
||||
from .in_memory_session_service import InMemorySessionService
|
||||
from .session import Session
|
||||
from .state import State
|
||||
from .vertex_ai_session_service import VertexAiSessionService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
__all__ = [
|
||||
'BaseSessionService',
|
||||
'InMemorySessionService',
|
||||
'Session',
|
||||
'State',
|
||||
'VertexAiSessionService',
|
||||
]
|
||||
|
||||
try:
|
||||
from .database_session_service import DatabaseSessionService
|
||||
|
||||
__all__.append('DatabaseSessionService')
|
||||
except ImportError:
|
||||
logger.debug(
|
||||
'DatabaseSessionService require sqlalchemy>=2.0, please ensure it is'
|
||||
' installed correctly.'
|
||||
)
|
||||
133
src/google/adk/sessions/base_session_service.py
Normal file
133
src/google/adk/sessions/base_session_service.py
Normal file
@@ -0,0 +1,133 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import abc
|
||||
from typing import Any
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import Field
|
||||
|
||||
from ..events.event import Event
|
||||
from .session import Session
|
||||
from .state import State
|
||||
|
||||
|
||||
class GetSessionConfig(BaseModel):
|
||||
"""The configuration of getting a session."""
|
||||
num_recent_events: Optional[int] = None
|
||||
after_timestamp: Optional[float] = None
|
||||
|
||||
|
||||
class ListSessionsResponse(BaseModel):
|
||||
"""The response of listing sessions.
|
||||
|
||||
The events and states are not set within each Session object.
|
||||
"""
|
||||
sessions: list[Session] = Field(default_factory=list)
|
||||
|
||||
|
||||
class ListEventsResponse(BaseModel):
|
||||
"""The response of listing events in a session."""
|
||||
events: list[Event] = Field(default_factory=list)
|
||||
next_page_token: Optional[str] = None
|
||||
|
||||
|
||||
class BaseSessionService(abc.ABC):
|
||||
"""Base class for session services.
|
||||
|
||||
The service provides a set of methods for managing sessions and events.
|
||||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
def create_session(
|
||||
self,
|
||||
*,
|
||||
app_name: str,
|
||||
user_id: str,
|
||||
state: Optional[dict[str, Any]] = None,
|
||||
session_id: Optional[str] = None,
|
||||
) -> Session:
|
||||
"""Creates a new session.
|
||||
|
||||
Args:
|
||||
app_name: the name of the app.
|
||||
user_id: the id of the user.
|
||||
state: the initial state of the session.
|
||||
session_id: the client-provided id of the session. If not provided, a
|
||||
generated ID will be used.
|
||||
|
||||
Returns:
|
||||
session: The newly created session instance.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_session(
|
||||
self,
|
||||
*,
|
||||
app_name: str,
|
||||
user_id: str,
|
||||
session_id: str,
|
||||
config: Optional[GetSessionConfig] = None,
|
||||
) -> Optional[Session]:
|
||||
"""Gets a session."""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def list_sessions(
|
||||
self, *, app_name: str, user_id: str
|
||||
) -> ListSessionsResponse:
|
||||
"""Lists all the sessions."""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def delete_session(
|
||||
self, *, app_name: str, user_id: str, session_id: str
|
||||
) -> None:
|
||||
"""Deletes a session."""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def list_events(
|
||||
self,
|
||||
*,
|
||||
app_name: str,
|
||||
user_id: str,
|
||||
session_id: str,
|
||||
) -> ListEventsResponse:
|
||||
"""Lists events in a session."""
|
||||
pass
|
||||
|
||||
def close_session(self, *, session: Session):
|
||||
"""Closes a session."""
|
||||
# TODO: determine whether we want to finalize the session here.
|
||||
pass
|
||||
|
||||
def append_event(self, session: Session, event: Event) -> Event:
|
||||
"""Appends an event to a session object."""
|
||||
if event.partial:
|
||||
return event
|
||||
self.__update_session_state(session, event)
|
||||
session.events.append(event)
|
||||
return event
|
||||
|
||||
def __update_session_state(self, session: Session, event: Event):
|
||||
"""Updates the session state based on the event."""
|
||||
if not event.actions or not event.actions.state_delta:
|
||||
return
|
||||
for key, value in event.actions.state_delta.items():
|
||||
if key.startswith(State.TEMP_PREFIX):
|
||||
continue
|
||||
session.state.update({key: value})
|
||||
522
src/google/adk/sessions/database_session_service.py
Normal file
522
src/google/adk/sessions/database_session_service.py
Normal file
@@ -0,0 +1,522 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import copy
|
||||
from datetime import datetime
|
||||
import json
|
||||
import logging
|
||||
from typing import Any
|
||||
from typing import Optional
|
||||
import uuid
|
||||
|
||||
from sqlalchemy import delete
|
||||
from sqlalchemy import Dialect
|
||||
from sqlalchemy import ForeignKeyConstraint
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy import Text
|
||||
from sqlalchemy.dialects import postgresql
|
||||
from sqlalchemy.engine import create_engine
|
||||
from sqlalchemy.engine import Engine
|
||||
from sqlalchemy.ext.mutable import MutableDict
|
||||
from sqlalchemy.inspection import inspect
|
||||
from sqlalchemy.orm import DeclarativeBase
|
||||
from sqlalchemy.orm import Mapped
|
||||
from sqlalchemy.orm import mapped_column
|
||||
from sqlalchemy.orm import relationship
|
||||
from sqlalchemy.orm import Session as DatabaseSessionFactory
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from sqlalchemy.schema import MetaData
|
||||
from sqlalchemy.types import DateTime
|
||||
from sqlalchemy.types import PickleType
|
||||
from sqlalchemy.types import String
|
||||
from sqlalchemy.types import TypeDecorator
|
||||
from typing_extensions import override
|
||||
from tzlocal import get_localzone
|
||||
|
||||
from ..events.event import Event
|
||||
from .base_session_service import BaseSessionService
|
||||
from .base_session_service import GetSessionConfig
|
||||
from .base_session_service import ListEventsResponse
|
||||
from .base_session_service import ListSessionsResponse
|
||||
from .session import Session
|
||||
from .state import State
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DynamicJSON(TypeDecorator):
|
||||
"""A JSON-like type that uses JSONB on PostgreSQL and TEXT with JSON
|
||||
|
||||
serialization for other databases.
|
||||
"""
|
||||
|
||||
impl = Text # Default implementation is TEXT
|
||||
|
||||
def load_dialect_impl(self, dialect: Dialect):
|
||||
if dialect.name == "postgresql":
|
||||
return dialect.type_descriptor(postgresql.JSONB)
|
||||
else:
|
||||
return dialect.type_descriptor(Text) # Default to Text for other dialects
|
||||
|
||||
def process_bind_param(self, value, dialect: Dialect):
|
||||
if value is not None:
|
||||
if dialect.name == "postgresql":
|
||||
return value # JSONB handles dict directly
|
||||
else:
|
||||
return json.dumps(value) # Serialize to JSON string for TEXT
|
||||
return value
|
||||
|
||||
def process_result_value(self, value, dialect: Dialect):
|
||||
if value is not None:
|
||||
if dialect.name == "postgresql":
|
||||
return value # JSONB returns dict directly
|
||||
else:
|
||||
return json.loads(value) # Deserialize from JSON string for TEXT
|
||||
return value
|
||||
|
||||
|
||||
class Base(DeclarativeBase):
|
||||
"""Base class for database tables."""
|
||||
pass
|
||||
|
||||
|
||||
class StorageSession(Base):
|
||||
"""Represents a session stored in the database."""
|
||||
__tablename__ = "sessions"
|
||||
|
||||
app_name: Mapped[str] = mapped_column(String, primary_key=True)
|
||||
user_id: Mapped[str] = mapped_column(String, primary_key=True)
|
||||
id: Mapped[str] = mapped_column(
|
||||
String, primary_key=True, default=lambda: str(uuid.uuid4())
|
||||
)
|
||||
|
||||
state: Mapped[dict] = mapped_column(
|
||||
MutableDict.as_mutable(DynamicJSON), default={}
|
||||
)
|
||||
|
||||
create_time: Mapped[DateTime] = mapped_column(DateTime(), default=func.now())
|
||||
update_time: Mapped[DateTime] = mapped_column(
|
||||
DateTime(), default=func.now(), onupdate=func.now()
|
||||
)
|
||||
|
||||
storage_events: Mapped[list["StorageEvent"]] = relationship(
|
||||
"StorageEvent",
|
||||
back_populates="storage_session",
|
||||
)
|
||||
|
||||
def __repr__(self):
|
||||
return f"<StorageSession(id={self.id}, update_time={self.update_time})>"
|
||||
|
||||
|
||||
class StorageEvent(Base):
|
||||
"""Represents an event stored in the database."""
|
||||
__tablename__ = "events"
|
||||
|
||||
id: Mapped[str] = mapped_column(String, primary_key=True)
|
||||
app_name: Mapped[str] = mapped_column(String, primary_key=True)
|
||||
user_id: Mapped[str] = mapped_column(String, primary_key=True)
|
||||
session_id: Mapped[str] = mapped_column(String, primary_key=True)
|
||||
|
||||
invocation_id: Mapped[str] = mapped_column(String)
|
||||
author: Mapped[str] = mapped_column(String)
|
||||
branch: Mapped[str] = mapped_column(String, nullable=True)
|
||||
timestamp: Mapped[DateTime] = mapped_column(DateTime(), default=func.now())
|
||||
content: Mapped[dict] = mapped_column(DynamicJSON)
|
||||
actions: Mapped[dict] = mapped_column(PickleType)
|
||||
|
||||
storage_session: Mapped[StorageSession] = relationship(
|
||||
"StorageSession",
|
||||
back_populates="storage_events",
|
||||
)
|
||||
|
||||
__table_args__ = (
|
||||
ForeignKeyConstraint(
|
||||
["app_name", "user_id", "session_id"],
|
||||
["sessions.app_name", "sessions.user_id", "sessions.id"],
|
||||
ondelete="CASCADE",
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class StorageAppState(Base):
|
||||
"""Represents an app state stored in the database."""
|
||||
__tablename__ = "app_states"
|
||||
|
||||
app_name: Mapped[str] = mapped_column(String, primary_key=True)
|
||||
state: Mapped[dict] = mapped_column(
|
||||
MutableDict.as_mutable(DynamicJSON), default={}
|
||||
)
|
||||
update_time: Mapped[DateTime] = mapped_column(
|
||||
DateTime(), default=func.now(), onupdate=func.now()
|
||||
)
|
||||
|
||||
|
||||
class StorageUserState(Base):
|
||||
"""Represents a user state stored in the database."""
|
||||
__tablename__ = "user_states"
|
||||
|
||||
app_name: Mapped[str] = mapped_column(String, primary_key=True)
|
||||
user_id: Mapped[str] = mapped_column(String, primary_key=True)
|
||||
state: Mapped[dict] = mapped_column(
|
||||
MutableDict.as_mutable(DynamicJSON), default={}
|
||||
)
|
||||
update_time: Mapped[DateTime] = mapped_column(
|
||||
DateTime(), default=func.now(), onupdate=func.now()
|
||||
)
|
||||
|
||||
|
||||
class DatabaseSessionService(BaseSessionService):
|
||||
"""A session service that uses a database for storage."""
|
||||
|
||||
def __init__(self, db_url: str):
|
||||
"""
|
||||
Args:
|
||||
db_url: The database URL to connect to.
|
||||
"""
|
||||
# 1. Create DB engine for db connection
|
||||
# 2. Create all tables based on schema
|
||||
# 3. Initialize all properies
|
||||
|
||||
supported_dialects = ["postgresql", "mysql", "sqlite"]
|
||||
dialect = db_url.split("://")[0]
|
||||
|
||||
if dialect in supported_dialects:
|
||||
db_engine = create_engine(db_url)
|
||||
else:
|
||||
raise ValueError(f"Unsupported database URL: {db_url}")
|
||||
|
||||
# Get the local timezone
|
||||
local_timezone = get_localzone()
|
||||
logger.info(f"Local timezone: {local_timezone}")
|
||||
|
||||
self.db_engine: Engine = db_engine
|
||||
self.metadata: MetaData = MetaData()
|
||||
self.inspector = inspect(self.db_engine)
|
||||
|
||||
# DB session factory method
|
||||
self.DatabaseSessionFactory: sessionmaker[DatabaseSessionFactory] = (
|
||||
sessionmaker(bind=self.db_engine)
|
||||
)
|
||||
|
||||
# Uncomment to recreate DB every time
|
||||
# Base.metadata.drop_all(self.db_engine)
|
||||
Base.metadata.create_all(self.db_engine)
|
||||
|
||||
@override
|
||||
def create_session(
|
||||
self,
|
||||
*,
|
||||
app_name: str,
|
||||
user_id: str,
|
||||
state: Optional[dict[str, Any]] = None,
|
||||
session_id: Optional[str] = None,
|
||||
) -> Session:
|
||||
# 1. Populate states.
|
||||
# 2. Build storage session object
|
||||
# 3. Add the object to the table
|
||||
# 4. Build the session object with generated id
|
||||
# 5. Return the session
|
||||
|
||||
with self.DatabaseSessionFactory() as sessionFactory:
|
||||
|
||||
# Fetch app and user states from storage
|
||||
storage_app_state = sessionFactory.get(StorageAppState, (app_name))
|
||||
storage_user_state = sessionFactory.get(
|
||||
StorageUserState, (app_name, user_id)
|
||||
)
|
||||
|
||||
app_state = storage_app_state.state if storage_app_state else {}
|
||||
user_state = storage_user_state.state if storage_user_state else {}
|
||||
|
||||
# Create state tables if not exist
|
||||
if not storage_app_state:
|
||||
storage_app_state = StorageAppState(app_name=app_name, state={})
|
||||
sessionFactory.add(storage_app_state)
|
||||
if not storage_user_state:
|
||||
storage_user_state = StorageUserState(
|
||||
app_name=app_name, user_id=user_id, state={}
|
||||
)
|
||||
sessionFactory.add(storage_user_state)
|
||||
|
||||
# Extract state deltas
|
||||
app_state_delta, user_state_delta, session_state = _extract_state_delta(
|
||||
state
|
||||
)
|
||||
|
||||
# Apply state delta
|
||||
app_state.update(app_state_delta)
|
||||
user_state.update(user_state_delta)
|
||||
|
||||
# Store app and user state
|
||||
if app_state_delta:
|
||||
storage_app_state.state = app_state
|
||||
if user_state_delta:
|
||||
storage_user_state.state = user_state
|
||||
|
||||
# Store the session
|
||||
storage_session = StorageSession(
|
||||
app_name=app_name,
|
||||
user_id=user_id,
|
||||
id=session_id,
|
||||
state=session_state,
|
||||
)
|
||||
sessionFactory.add(storage_session)
|
||||
sessionFactory.commit()
|
||||
|
||||
sessionFactory.refresh(storage_session)
|
||||
|
||||
# Merge states for response
|
||||
merged_state = _merge_state(app_state, user_state, session_state)
|
||||
session = Session(
|
||||
app_name=str(storage_session.app_name),
|
||||
user_id=str(storage_session.user_id),
|
||||
id=str(storage_session.id),
|
||||
state=merged_state,
|
||||
last_update_time=storage_session.update_time.timestamp(),
|
||||
)
|
||||
return session
|
||||
return None
|
||||
|
||||
@override
|
||||
def get_session(
|
||||
self,
|
||||
*,
|
||||
app_name: str,
|
||||
user_id: str,
|
||||
session_id: str,
|
||||
config: Optional[GetSessionConfig] = None,
|
||||
) -> Optional[Session]:
|
||||
# 1. Get the storage session entry from session table
|
||||
# 2. Get all the events based on session id and filtering config
|
||||
# 3. Convert and return the session
|
||||
session: Session = None
|
||||
with self.DatabaseSessionFactory() as sessionFactory:
|
||||
storage_session = sessionFactory.get(
|
||||
StorageSession, (app_name, user_id, session_id)
|
||||
)
|
||||
if storage_session is None:
|
||||
return None
|
||||
|
||||
storage_events = (
|
||||
sessionFactory.query(StorageEvent)
|
||||
.filter(StorageEvent.session_id == storage_session.id)
|
||||
.filter(
|
||||
StorageEvent.timestamp < config.after_timestamp
|
||||
if config
|
||||
else True
|
||||
)
|
||||
.limit(config.num_recent_events if config else None)
|
||||
.all()
|
||||
)
|
||||
|
||||
# Fetch states from storage
|
||||
storage_app_state = sessionFactory.get(StorageAppState, (app_name))
|
||||
storage_user_state = sessionFactory.get(
|
||||
StorageUserState, (app_name, user_id)
|
||||
)
|
||||
|
||||
app_state = storage_app_state.state if storage_app_state else {}
|
||||
user_state = storage_user_state.state if storage_user_state else {}
|
||||
session_state = storage_session.state
|
||||
|
||||
# Merge states
|
||||
merged_state = _merge_state(app_state, user_state, session_state)
|
||||
|
||||
# Convert storage session to session
|
||||
session = Session(
|
||||
app_name=app_name,
|
||||
user_id=user_id,
|
||||
id=session_id,
|
||||
state=merged_state,
|
||||
last_update_time=storage_session.update_time.timestamp(),
|
||||
)
|
||||
session.events = [
|
||||
Event(
|
||||
id=e.id,
|
||||
author=e.author,
|
||||
branch=e.branch,
|
||||
invocation_id=e.invocation_id,
|
||||
content=e.content,
|
||||
actions=e.actions,
|
||||
timestamp=e.timestamp.timestamp(),
|
||||
)
|
||||
for e in storage_events
|
||||
]
|
||||
|
||||
return session
|
||||
|
||||
@override
|
||||
def list_sessions(
|
||||
self, *, app_name: str, user_id: str
|
||||
) -> ListSessionsResponse:
|
||||
with self.DatabaseSessionFactory() as sessionFactory:
|
||||
results = (
|
||||
sessionFactory.query(StorageSession)
|
||||
.filter(StorageSession.app_name == app_name)
|
||||
.filter(StorageSession.user_id == user_id)
|
||||
.all()
|
||||
)
|
||||
sessions = []
|
||||
for storage_session in results:
|
||||
session = Session(
|
||||
app_name=app_name,
|
||||
user_id=user_id,
|
||||
id=storage_session.id,
|
||||
state={},
|
||||
last_update_time=storage_session.update_time.timestamp(),
|
||||
)
|
||||
sessions.append(session)
|
||||
return ListSessionsResponse(sessions=sessions)
|
||||
raise ValueError("Failed to retrieve sessions.")
|
||||
|
||||
@override
|
||||
def delete_session(
|
||||
self, app_name: str, user_id: str, session_id: str
|
||||
) -> None:
|
||||
with self.DatabaseSessionFactory() as sessionFactory:
|
||||
stmt = delete(StorageSession).where(
|
||||
StorageSession.app_name == app_name,
|
||||
StorageSession.user_id == user_id,
|
||||
StorageSession.id == session_id,
|
||||
)
|
||||
sessionFactory.execute(stmt)
|
||||
sessionFactory.commit()
|
||||
|
||||
@override
|
||||
def append_event(self, session: Session, event: Event) -> Event:
|
||||
logger.info(f"Append event: {event} to session {session.id}")
|
||||
|
||||
if event.partial and not event.content:
|
||||
return event
|
||||
|
||||
# 1. Check if timestamp is stale
|
||||
# 2. Update session attributes based on event config
|
||||
# 3. Store event to table
|
||||
with self.DatabaseSessionFactory() as sessionFactory:
|
||||
storage_session = sessionFactory.get(
|
||||
StorageSession, (session.app_name, session.user_id, session.id)
|
||||
)
|
||||
|
||||
if storage_session.update_time.timestamp() > session.last_update_time:
|
||||
raise ValueError(
|
||||
f"Session last_update_time {session.last_update_time} is later than"
|
||||
f" the upate_time in storage {storage_session.update_time}"
|
||||
)
|
||||
|
||||
# Fetch states from storage
|
||||
storage_app_state = sessionFactory.get(
|
||||
StorageAppState, (session.app_name)
|
||||
)
|
||||
storage_user_state = sessionFactory.get(
|
||||
StorageUserState, (session.app_name, session.user_id)
|
||||
)
|
||||
|
||||
app_state = storage_app_state.state if storage_app_state else {}
|
||||
user_state = storage_user_state.state if storage_user_state else {}
|
||||
session_state = storage_session.state
|
||||
|
||||
# Extract state delta
|
||||
app_state_delta = {}
|
||||
user_state_delta = {}
|
||||
session_state_delta = {}
|
||||
if event.actions:
|
||||
if event.actions.state_delta:
|
||||
app_state_delta, user_state_delta, session_state_delta = (
|
||||
_extract_state_delta(event.actions.state_delta)
|
||||
)
|
||||
|
||||
# Merge state
|
||||
app_state.update(app_state_delta)
|
||||
user_state.update(user_state_delta)
|
||||
session_state.update(session_state_delta)
|
||||
|
||||
# Update storage
|
||||
storage_app_state.state = app_state
|
||||
storage_user_state.state = user_state
|
||||
storage_session.state = session_state
|
||||
|
||||
encoded_content = event.content.model_dump(exclude_none=True)
|
||||
storage_event = StorageEvent(
|
||||
id=event.id,
|
||||
invocation_id=event.invocation_id,
|
||||
author=event.author,
|
||||
branch=event.branch,
|
||||
content=encoded_content,
|
||||
actions=event.actions,
|
||||
session_id=session.id,
|
||||
app_name=session.app_name,
|
||||
user_id=session.user_id,
|
||||
timestamp=datetime.fromtimestamp(event.timestamp),
|
||||
)
|
||||
|
||||
sessionFactory.add(storage_event)
|
||||
|
||||
sessionFactory.commit()
|
||||
sessionFactory.refresh(storage_session)
|
||||
|
||||
# Update timestamp with commit time
|
||||
session.last_update_time = storage_session.update_time.timestamp()
|
||||
|
||||
# Also update the in-memory session
|
||||
super().append_event(session=session, event=event)
|
||||
return event
|
||||
|
||||
@override
|
||||
def list_events(
|
||||
self,
|
||||
*,
|
||||
app_name: str,
|
||||
user_id: str,
|
||||
session_id: str,
|
||||
) -> ListEventsResponse:
|
||||
pass
|
||||
|
||||
|
||||
def convert_event(event: StorageEvent) -> Event:
|
||||
"""Converts a storage event to an event."""
|
||||
return Event(
|
||||
id=event.id,
|
||||
author=event.author,
|
||||
branch=event.branch,
|
||||
invocation_id=event.invocation_id,
|
||||
content=event.content,
|
||||
actions=event.actions,
|
||||
timestamp=event.timestamp.timestamp(),
|
||||
)
|
||||
|
||||
|
||||
def _extract_state_delta(state: dict):
|
||||
app_state_delta = {}
|
||||
user_state_delta = {}
|
||||
session_state_delta = {}
|
||||
if state:
|
||||
for key in state.keys():
|
||||
if key.startswith(State.APP_PREFIX):
|
||||
app_state_delta[key.removeprefix(State.APP_PREFIX)] = state[key]
|
||||
elif key.startswith(State.USER_PREFIX):
|
||||
user_state_delta[key.removeprefix(State.USER_PREFIX)] = state[key]
|
||||
elif not key.startswith(State.TEMP_PREFIX):
|
||||
session_state_delta[key] = state[key]
|
||||
return app_state_delta, user_state_delta, session_state_delta
|
||||
|
||||
|
||||
def _merge_state(app_state, user_state, session_state):
|
||||
# Merge states for response
|
||||
merged_state = copy.deepcopy(session_state)
|
||||
for key in app_state.keys():
|
||||
merged_state[State.APP_PREFIX + key] = app_state[key]
|
||||
for key in user_state.keys():
|
||||
merged_state[State.USER_PREFIX + key] = user_state[key]
|
||||
return merged_state
|
||||
206
src/google/adk/sessions/in_memory_session_service.py
Normal file
206
src/google/adk/sessions/in_memory_session_service.py
Normal file
@@ -0,0 +1,206 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import copy
|
||||
import time
|
||||
from typing import Any
|
||||
from typing import Optional
|
||||
import uuid
|
||||
|
||||
from typing_extensions import override
|
||||
|
||||
from ..events.event import Event
|
||||
from .base_session_service import BaseSessionService
|
||||
from .base_session_service import GetSessionConfig
|
||||
from .base_session_service import ListEventsResponse
|
||||
from .base_session_service import ListSessionsResponse
|
||||
from .session import Session
|
||||
from .state import State
|
||||
|
||||
|
||||
class InMemorySessionService(BaseSessionService):
|
||||
"""An in-memory implementation of the session service."""
|
||||
|
||||
def __init__(self):
|
||||
# A map from app name to a map from user ID to a map from session ID to session.
|
||||
self.sessions: dict[str, dict[str, dict[str, Session]]] = {}
|
||||
# A map from app name to a map from user ID to a map from key to the value.
|
||||
self.user_state: dict[str, dict[str, dict[str, Any]]] = {}
|
||||
# A map from app name to a map from key to the value.
|
||||
self.app_state: dict[str, dict[str, Any]] = {}
|
||||
|
||||
@override
|
||||
def create_session(
|
||||
self,
|
||||
*,
|
||||
app_name: str,
|
||||
user_id: str,
|
||||
state: Optional[dict[str, Any]] = None,
|
||||
session_id: Optional[str] = None,
|
||||
) -> Session:
|
||||
session_id = (
|
||||
session_id.strip()
|
||||
if session_id and session_id.strip()
|
||||
else str(uuid.uuid4())
|
||||
)
|
||||
session = Session(
|
||||
app_name=app_name,
|
||||
user_id=user_id,
|
||||
id=session_id,
|
||||
state=state or {},
|
||||
last_update_time=time.time(),
|
||||
)
|
||||
|
||||
if app_name not in self.sessions:
|
||||
self.sessions[app_name] = {}
|
||||
if user_id not in self.sessions[app_name]:
|
||||
self.sessions[app_name][user_id] = {}
|
||||
self.sessions[app_name][user_id][session_id] = session
|
||||
|
||||
copied_session = copy.deepcopy(session)
|
||||
return self._merge_state(app_name, user_id, copied_session)
|
||||
|
||||
@override
|
||||
def get_session(
|
||||
self,
|
||||
*,
|
||||
app_name: str,
|
||||
user_id: str,
|
||||
session_id: str,
|
||||
config: Optional[GetSessionConfig] = None,
|
||||
) -> Session:
|
||||
if app_name not in self.sessions:
|
||||
return None
|
||||
if user_id not in self.sessions[app_name]:
|
||||
return None
|
||||
if session_id not in self.sessions[app_name][user_id]:
|
||||
return None
|
||||
|
||||
session = self.sessions[app_name][user_id].get(session_id)
|
||||
copied_session = copy.deepcopy(session)
|
||||
|
||||
if config:
|
||||
if config.num_recent_events:
|
||||
copied_session.events = copied_session.events[
|
||||
-config.num_recent_events :
|
||||
]
|
||||
elif config.after_timestamp:
|
||||
i = len(session.events) - 1
|
||||
while i >= 0:
|
||||
if copied_session.events[i].timestamp < config.after_timestamp:
|
||||
break
|
||||
i -= 1
|
||||
if i >= 0:
|
||||
copied_session.events = copied_session.events[i:]
|
||||
|
||||
return self._merge_state(app_name, user_id, copied_session)
|
||||
|
||||
def _merge_state(self, app_name: str, user_id: str, copied_session: Session):
|
||||
# Merge app state
|
||||
if app_name in self.app_state:
|
||||
for key in self.app_state[app_name].keys():
|
||||
copied_session.state[State.APP_PREFIX + key] = self.app_state[app_name][
|
||||
key
|
||||
]
|
||||
|
||||
if (
|
||||
app_name not in self.user_state
|
||||
or user_id not in self.user_state[app_name]
|
||||
):
|
||||
return copied_session
|
||||
|
||||
# Merge session state with user state.
|
||||
for key in self.user_state[app_name][user_id].keys():
|
||||
copied_session.state[State.USER_PREFIX + key] = self.user_state[app_name][
|
||||
user_id
|
||||
][key]
|
||||
return copied_session
|
||||
|
||||
@override
|
||||
def list_sessions(
|
||||
self, *, app_name: str, user_id: str
|
||||
) -> ListSessionsResponse:
|
||||
empty_response = ListSessionsResponse()
|
||||
if app_name not in self.sessions:
|
||||
return empty_response
|
||||
if user_id not in self.sessions[app_name]:
|
||||
return empty_response
|
||||
|
||||
sessions_without_events = []
|
||||
for session in self.sessions[app_name][user_id].values():
|
||||
copied_session = copy.deepcopy(session)
|
||||
copied_session.events = []
|
||||
copied_session.state = {}
|
||||
sessions_without_events.append(copied_session)
|
||||
return ListSessionsResponse(sessions=sessions_without_events)
|
||||
|
||||
@override
|
||||
def delete_session(
|
||||
self, *, app_name: str, user_id: str, session_id: str
|
||||
) -> None:
|
||||
if (
|
||||
self.get_session(
|
||||
app_name=app_name, user_id=user_id, session_id=session_id
|
||||
)
|
||||
is None
|
||||
):
|
||||
return None
|
||||
|
||||
self.sessions[app_name][user_id].pop(session_id)
|
||||
|
||||
@override
|
||||
def append_event(self, session: Session, event: Event) -> Event:
|
||||
# Update the in-memory session.
|
||||
super().append_event(session=session, event=event)
|
||||
session.last_update_time = event.timestamp
|
||||
|
||||
# Update the storage session
|
||||
app_name = session.app_name
|
||||
user_id = session.user_id
|
||||
session_id = session.id
|
||||
if app_name not in self.sessions:
|
||||
return event
|
||||
if user_id not in self.sessions[app_name]:
|
||||
return event
|
||||
if session_id not in self.sessions[app_name][user_id]:
|
||||
return event
|
||||
|
||||
if event.actions and event.actions.state_delta:
|
||||
for key in event.actions.state_delta:
|
||||
if key.startswith(State.APP_PREFIX):
|
||||
self.app_state.setdefault(app_name, {})[
|
||||
key.removeprefix(State.APP_PREFIX)
|
||||
] = event.actions.state_delta[key]
|
||||
|
||||
if key.startswith(State.USER_PREFIX):
|
||||
self.user_state.setdefault(app_name, {}).setdefault(user_id, {})[
|
||||
key.removeprefix(State.USER_PREFIX)
|
||||
] = event.actions.state_delta[key]
|
||||
|
||||
storage_session = self.sessions[app_name][user_id].get(session_id)
|
||||
super().append_event(session=storage_session, event=event)
|
||||
|
||||
storage_session.last_update_time = event.timestamp
|
||||
|
||||
return event
|
||||
|
||||
@override
|
||||
def list_events(
|
||||
self,
|
||||
*,
|
||||
app_name: str,
|
||||
user_id: str,
|
||||
session_id: str,
|
||||
) -> ListEventsResponse:
|
||||
raise NotImplementedError()
|
||||
54
src/google/adk/sessions/session.py
Normal file
54
src/google/adk/sessions/session.py
Normal file
@@ -0,0 +1,54 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import ConfigDict
|
||||
from pydantic import Field
|
||||
|
||||
from ..events.event import Event
|
||||
|
||||
|
||||
class Session(BaseModel):
|
||||
"""Represents a series of interactions between a user and agents.
|
||||
|
||||
Attributes:
|
||||
id: The unique identifier of the session.
|
||||
app_name: The name of the app.
|
||||
user_id: The id of the user.
|
||||
state: The state of the session.
|
||||
events: The events of the session, e.g. user input, model response, function
|
||||
call/response, etc.
|
||||
last_update_time: The last update time of the session.
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(
|
||||
extra='forbid',
|
||||
arbitrary_types_allowed=True,
|
||||
)
|
||||
|
||||
id: str
|
||||
"""The unique identifier of the session."""
|
||||
app_name: str
|
||||
"""The name of the app."""
|
||||
user_id: str
|
||||
"""The id of the user."""
|
||||
state: dict[str, Any] = Field(default_factory=dict)
|
||||
"""The state of the session."""
|
||||
events: list[Event] = Field(default_factory=list)
|
||||
"""The events of the session, e.g. user input, model response, function
|
||||
call/response, etc."""
|
||||
last_update_time: float = 0.0
|
||||
"""The last update time of the session."""
|
||||
71
src/google/adk/sessions/state.py
Normal file
71
src/google/adk/sessions/state.py
Normal file
@@ -0,0 +1,71 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Any
|
||||
|
||||
|
||||
class State:
|
||||
"""A state dict that maintain the current value and the pending-commit delta."""
|
||||
|
||||
APP_PREFIX = "app:"
|
||||
USER_PREFIX = "user:"
|
||||
TEMP_PREFIX = "temp:"
|
||||
|
||||
def __init__(self, value: dict[str, Any], delta: dict[str, Any]):
|
||||
"""
|
||||
Args:
|
||||
value: The current value of the state dict.
|
||||
delta: The delta change to the current value that hasn't been commited.
|
||||
"""
|
||||
self._value = value
|
||||
self._delta = delta
|
||||
|
||||
def __getitem__(self, key: str) -> Any:
|
||||
"""Returns the value of the state dict for the given key."""
|
||||
if key in self._delta:
|
||||
return self._delta[key]
|
||||
return self._value[key]
|
||||
|
||||
def __setitem__(self, key: str, value: Any):
|
||||
"""Sets the value of the state dict for the given key."""
|
||||
# TODO: make new change only store in delta, so that self._value is only
|
||||
# updated at the storage commit time.
|
||||
self._value[key] = value
|
||||
self._delta[key] = value
|
||||
|
||||
def __contains__(self, key: str) -> bool:
|
||||
"""Whether the state dict contains the given key."""
|
||||
return key in self._value or key in self._delta
|
||||
|
||||
def has_delta(self) -> bool:
|
||||
"""Whether the state has pending detla."""
|
||||
return bool(self._delta)
|
||||
|
||||
def get(self, key: str, default: Any = None) -> Any:
|
||||
"""Returns the value of the state dict for the given key."""
|
||||
if key not in self:
|
||||
return default
|
||||
return self[key]
|
||||
|
||||
def update(self, delta: dict[str, Any]):
|
||||
"""Updates the state dict with the given delta."""
|
||||
self._value.update(delta)
|
||||
self._delta.update(delta)
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""Returns the state dict."""
|
||||
result = {}
|
||||
result.update(self._value)
|
||||
result.update(self._delta)
|
||||
return result
|
||||
356
src/google/adk/sessions/vertex_ai_session_service.py
Normal file
356
src/google/adk/sessions/vertex_ai_session_service.py
Normal file
@@ -0,0 +1,356 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import logging
|
||||
import re
|
||||
import time
|
||||
from typing import Any
|
||||
from typing import Optional
|
||||
|
||||
from dateutil.parser import isoparse
|
||||
from google import genai
|
||||
from typing_extensions import override
|
||||
|
||||
from ..events.event import Event
|
||||
from ..events.event_actions import EventActions
|
||||
from .base_session_service import BaseSessionService
|
||||
from .base_session_service import GetSessionConfig
|
||||
from .base_session_service import ListEventsResponse
|
||||
from .base_session_service import ListSessionsResponse
|
||||
from .session import Session
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class VertexAiSessionService(BaseSessionService):
|
||||
"""Connects to the managed Vertex AI Session Service."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
project: str = None,
|
||||
location: str = None,
|
||||
):
|
||||
self.project = project
|
||||
self.location = location
|
||||
|
||||
client = genai.Client(vertexai=True, project=project, location=location)
|
||||
self.api_client = client._api_client
|
||||
|
||||
@override
|
||||
def create_session(
|
||||
self,
|
||||
*,
|
||||
app_name: str,
|
||||
user_id: str,
|
||||
state: Optional[dict[str, Any]] = None,
|
||||
session_id: Optional[str] = None,
|
||||
) -> Session:
|
||||
reasoning_engine_id = _parse_reasoning_engine_id(app_name)
|
||||
|
||||
session_json_dict = {'user_id': user_id}
|
||||
if state:
|
||||
session_json_dict['session_state'] = state
|
||||
|
||||
api_response = self.api_client.request(
|
||||
http_method='POST',
|
||||
path=f'reasoningEngines/{reasoning_engine_id}/sessions',
|
||||
request_dict=session_json_dict,
|
||||
)
|
||||
logger.info(f'Create Session response {api_response}')
|
||||
|
||||
session_id = api_response['name'].split('/')[-3]
|
||||
operation_id = api_response['name'].split('/')[-1]
|
||||
|
||||
max_retry_attempt = 5
|
||||
while max_retry_attempt >= 0:
|
||||
lro_response = self.api_client.request(
|
||||
http_method='GET',
|
||||
path=f'operations/{operation_id}',
|
||||
request_dict={},
|
||||
)
|
||||
|
||||
if lro_response.get('done', None):
|
||||
break
|
||||
|
||||
time.sleep(1)
|
||||
max_retry_attempt -= 1
|
||||
|
||||
# Get session resource
|
||||
get_session_api_response = self.api_client.request(
|
||||
http_method='GET',
|
||||
path=f'reasoningEngines/{reasoning_engine_id}/sessions/{session_id}',
|
||||
request_dict={},
|
||||
)
|
||||
|
||||
update_timestamp = isoparse(
|
||||
get_session_api_response['updateTime']
|
||||
).timestamp()
|
||||
session = Session(
|
||||
app_name=str(app_name),
|
||||
user_id=str(user_id),
|
||||
id=str(session_id),
|
||||
state=get_session_api_response.get('sessionState', {}),
|
||||
last_update_time=update_timestamp,
|
||||
)
|
||||
return session
|
||||
|
||||
@override
|
||||
def get_session(
|
||||
self,
|
||||
*,
|
||||
app_name: str,
|
||||
user_id: str,
|
||||
session_id: str,
|
||||
config: Optional[GetSessionConfig] = None,
|
||||
) -> Session:
|
||||
reasoning_engine_id = _parse_reasoning_engine_id(app_name)
|
||||
|
||||
# Get session resource
|
||||
get_session_api_response = self.api_client.request(
|
||||
http_method='GET',
|
||||
path=f'reasoningEngines/{reasoning_engine_id}/sessions/{session_id}',
|
||||
request_dict={},
|
||||
)
|
||||
|
||||
session_id = get_session_api_response['name'].split('/')[-1]
|
||||
update_timestamp = isoparse(
|
||||
get_session_api_response['updateTime']
|
||||
).timestamp()
|
||||
session = Session(
|
||||
app_name=str(app_name),
|
||||
user_id=str(user_id),
|
||||
id=str(session_id),
|
||||
state=get_session_api_response.get('sessionState', {}),
|
||||
last_update_time=update_timestamp,
|
||||
)
|
||||
|
||||
list_events_api_response = self.api_client.request(
|
||||
http_method='GET',
|
||||
path=f'reasoningEngines/{reasoning_engine_id}/sessions/{session_id}/events',
|
||||
request_dict={},
|
||||
)
|
||||
|
||||
# Handles empty response case
|
||||
if list_events_api_response.get('httpHeaders', None):
|
||||
return session
|
||||
|
||||
session.events = [
|
||||
_from_api_event(event)
|
||||
for event in list_events_api_response['sessionEvents']
|
||||
]
|
||||
session.events = [
|
||||
event for event in session.events if event.timestamp <= update_timestamp
|
||||
]
|
||||
session.events.sort(key=lambda event: event.timestamp)
|
||||
|
||||
if config:
|
||||
if config.num_recent_events:
|
||||
session.events = session.events[-config.num_recent_events :]
|
||||
elif config.after_timestamp:
|
||||
i = len(session.events) - 1
|
||||
while i >= 0:
|
||||
if session.events[i].timestamp < config.after_timestamp:
|
||||
break
|
||||
i -= 1
|
||||
if i >= 0:
|
||||
session.events = session.events[i:]
|
||||
|
||||
return session
|
||||
|
||||
@override
|
||||
def list_sessions(
|
||||
self, *, app_name: str, user_id: str
|
||||
) -> ListSessionsResponse:
|
||||
reasoning_engine_id = _parse_reasoning_engine_id(app_name)
|
||||
|
||||
api_response = self.api_client.request(
|
||||
http_method='GET',
|
||||
path=f'reasoningEngines/{reasoning_engine_id}/sessions?filter=user_id={user_id}',
|
||||
request_dict={},
|
||||
)
|
||||
|
||||
# Handles empty response case
|
||||
if api_response.get('httpHeaders', None):
|
||||
return ListSessionsResponse()
|
||||
|
||||
sessions = []
|
||||
for api_session in api_response['sessions']:
|
||||
session = Session(
|
||||
app_name=app_name,
|
||||
user_id=user_id,
|
||||
id=api_session['name'].split('/')[-1],
|
||||
state={},
|
||||
last_update_time=isoparse(api_session['updateTime']).timestamp(),
|
||||
)
|
||||
sessions.append(session)
|
||||
return ListSessionsResponse(sessions=sessions)
|
||||
|
||||
def delete_session(
|
||||
self, *, app_name: str, user_id: str, session_id: str
|
||||
) -> None:
|
||||
reasoning_engine_id = _parse_reasoning_engine_id(app_name)
|
||||
self.api_client.request(
|
||||
http_method='DELETE',
|
||||
path=f'reasoningEngines/{reasoning_engine_id}/sessions/{session_id}',
|
||||
request_dict={},
|
||||
)
|
||||
|
||||
@override
|
||||
def list_events(
|
||||
self,
|
||||
*,
|
||||
app_name: str,
|
||||
user_id: str,
|
||||
session_id: str,
|
||||
) -> ListEventsResponse:
|
||||
reasoning_engine_id = _parse_reasoning_engine_id(app_name)
|
||||
api_response = self.api_client.request(
|
||||
http_method='GET',
|
||||
path=f'reasoningEngines/{reasoning_engine_id}/sessions/{session_id}/events',
|
||||
request_dict={},
|
||||
)
|
||||
|
||||
logger.info(f'List events response {api_response}')
|
||||
|
||||
# Handles empty response case
|
||||
if api_response.get('httpHeaders', None):
|
||||
return ListEventsResponse()
|
||||
|
||||
session_events = api_response['sessionEvents']
|
||||
|
||||
return ListEventsResponse(
|
||||
events=[_from_api_event(event) for event in session_events]
|
||||
)
|
||||
|
||||
@override
|
||||
def append_event(self, session: Session, event: Event) -> Event:
|
||||
# Update the in-memory session.
|
||||
super().append_event(session=session, event=event)
|
||||
|
||||
reasoning_engine_id = _parse_reasoning_engine_id(session.app_name)
|
||||
self.api_client.request(
|
||||
http_method='POST',
|
||||
path=f'reasoningEngines/{reasoning_engine_id}/sessions/{session.id}:appendEvent',
|
||||
request_dict=_convert_event_to_json(event),
|
||||
)
|
||||
|
||||
return event
|
||||
|
||||
|
||||
def _convert_event_to_json(event: Event):
|
||||
metadata_json = {
|
||||
'partial': event.partial,
|
||||
'turn_complete': event.turn_complete,
|
||||
'interrupted': event.interrupted,
|
||||
'branch': event.branch,
|
||||
'long_running_tool_ids': (
|
||||
list(event.long_running_tool_ids)
|
||||
if event.long_running_tool_ids
|
||||
else None
|
||||
),
|
||||
}
|
||||
if event.grounding_metadata:
|
||||
metadata_json['grounding_metadata'] = event.grounding_metadata.model_dump(
|
||||
exclude_none=True
|
||||
)
|
||||
|
||||
event_json = {
|
||||
'author': event.author,
|
||||
'invocation_id': event.invocation_id,
|
||||
'timestamp': {
|
||||
'seconds': int(event.timestamp),
|
||||
'nanos': int(
|
||||
(event.timestamp - int(event.timestamp)) * 1_000_000_000
|
||||
),
|
||||
},
|
||||
'error_code': event.error_code,
|
||||
'error_message': event.error_message,
|
||||
'event_metadata': metadata_json,
|
||||
}
|
||||
|
||||
if event.actions:
|
||||
actions_json = {
|
||||
'skip_summarization': event.actions.skip_summarization,
|
||||
'state_delta': event.actions.state_delta,
|
||||
'artifact_delta': event.actions.artifact_delta,
|
||||
'transfer_agent': event.actions.transfer_to_agent,
|
||||
'escalate': event.actions.escalate,
|
||||
'requested_auth_configs': event.actions.requested_auth_configs,
|
||||
}
|
||||
event_json['actions'] = actions_json
|
||||
if event.content:
|
||||
event_json['content'] = event.content.model_dump(exclude_none=True)
|
||||
if event.error_code:
|
||||
event_json['error_code'] = event.error_code
|
||||
if event.error_message:
|
||||
event_json['error_message'] = event.error_message
|
||||
return event_json
|
||||
|
||||
|
||||
def _from_api_event(api_event: dict) -> Event:
|
||||
event_actions = EventActions()
|
||||
if api_event.get('actions', None):
|
||||
event_actions = EventActions(
|
||||
skip_summarization=api_event['actions'].get('skipSummarization', None),
|
||||
state_delta=api_event['actions'].get('stateDelta', {}),
|
||||
artifact_delta=api_event['actions'].get('artifactDelta', {}),
|
||||
transfer_to_agent=api_event['actions'].get('transferAgent', None),
|
||||
escalate=api_event['actions'].get('escalate', None),
|
||||
requested_auth_configs=api_event['actions'].get(
|
||||
'requestedAuthConfigs', {}
|
||||
),
|
||||
)
|
||||
|
||||
event = Event(
|
||||
id=api_event['name'].split('/')[-1],
|
||||
invocation_id=api_event['invocationId'],
|
||||
author=api_event['author'],
|
||||
actions=event_actions,
|
||||
content=api_event.get('content', None),
|
||||
timestamp=isoparse(api_event['timestamp']).timestamp(),
|
||||
error_code=api_event.get('errorCode', None),
|
||||
error_message=api_event.get('errorMessage', None),
|
||||
)
|
||||
|
||||
if api_event.get('eventMetadata', None):
|
||||
long_running_tool_ids_list = api_event['eventMetadata'].get(
|
||||
'longRunningToolIds', None
|
||||
)
|
||||
event.partial = api_event['eventMetadata'].get('partial', None)
|
||||
event.turn_complete = api_event['eventMetadata'].get('turnComplete', None)
|
||||
event.interrupted = api_event['eventMetadata'].get('interrupted', None)
|
||||
event.branch = api_event['eventMetadata'].get('branch', None)
|
||||
event.grounding_metadata = api_event['eventMetadata'].get(
|
||||
'groundingMetadata', None
|
||||
)
|
||||
event.long_running_tool_ids = (
|
||||
set(long_running_tool_ids_list) if long_running_tool_ids_list else None
|
||||
)
|
||||
|
||||
return event
|
||||
|
||||
|
||||
def _parse_reasoning_engine_id(app_name: str):
|
||||
if app_name.isdigit():
|
||||
return app_name
|
||||
|
||||
pattern = r'^projects/([a-zA-Z0-9-_]+)/locations/([a-zA-Z0-9-_]+)/reasoningEngines/(\d+)$'
|
||||
match = re.fullmatch(pattern, app_name)
|
||||
|
||||
if not bool(match):
|
||||
raise ValueError(
|
||||
f'App name {app_name} is not valid. It should either be the full'
|
||||
' ReasoningEngine resource name, or the reasoning engine id.'
|
||||
)
|
||||
|
||||
return match.groups()[-1]
|
||||
Reference in New Issue
Block a user