chore: fix python format.

PiperOrigin-RevId: 759674648
This commit is contained in:
Shangjie Chen 2025-05-16 10:48:35 -07:00 committed by Copybara-Service
parent d0f117ebbc
commit 2f006264ce

View File

@ -61,10 +61,7 @@ DEFAULT_MAX_VARCHAR_LENGTH = 256
class DynamicJSON(TypeDecorator):
"""A JSON-like type that uses JSONB on PostgreSQL and TEXT with JSON
serialization for other databases.
"""
"""A JSON-like type that uses JSONB on PostgreSQL and TEXT with JSON serialization for other databases."""
impl = Text # Default implementation is TEXT
@ -242,10 +239,7 @@ class DatabaseSessionService(BaseSessionService):
"""A session service that uses a database for storage."""
def __init__(self, db_url: str):
"""
Args:
db_url: The database URL to connect to.
"""
"""Initializes the database session service with a database URL."""
# 1. Create DB engine for db connection
# 2. Create all tables based on schema
# 3. Initialize all properties
@ -274,7 +268,7 @@ class DatabaseSessionService(BaseSessionService):
self.inspector = inspect(self.db_engine)
# DB session factory method
self.DatabaseSessionFactory: sessionmaker[DatabaseSessionFactory] = (
self.database_session_factory: sessionmaker[DatabaseSessionFactory] = (
sessionmaker(bind=self.db_engine)
)
@ -297,11 +291,11 @@ class DatabaseSessionService(BaseSessionService):
# 4. Build the session object with generated id
# 5. Return the session
with self.DatabaseSessionFactory() as sessionFactory:
with self.database_session_factory() as session_factory:
# Fetch app and user states from storage
storage_app_state = sessionFactory.get(StorageAppState, (app_name))
storage_user_state = sessionFactory.get(
storage_app_state = session_factory.get(StorageAppState, (app_name))
storage_user_state = session_factory.get(
StorageUserState, (app_name, user_id)
)
@ -311,12 +305,12 @@ class DatabaseSessionService(BaseSessionService):
# Create state tables if not exist
if not storage_app_state:
storage_app_state = StorageAppState(app_name=app_name, state={})
sessionFactory.add(storage_app_state)
session_factory.add(storage_app_state)
if not storage_user_state:
storage_user_state = StorageUserState(
app_name=app_name, user_id=user_id, state={}
)
sessionFactory.add(storage_user_state)
session_factory.add(storage_user_state)
# Extract state deltas
app_state_delta, user_state_delta, session_state = _extract_state_delta(
@ -340,10 +334,10 @@ class DatabaseSessionService(BaseSessionService):
id=session_id,
state=session_state,
)
sessionFactory.add(storage_session)
sessionFactory.commit()
session_factory.add(storage_session)
session_factory.commit()
sessionFactory.refresh(storage_session)
session_factory.refresh(storage_session)
# Merge states for response
merged_state = _merge_state(app_state, user_state, session_state)
@ -368,31 +362,37 @@ class DatabaseSessionService(BaseSessionService):
# 1. Get the storage session entry from session table
# 2. Get all the events based on session id and filtering config
# 3. Convert and return the session
with self.DatabaseSessionFactory() as sessionFactory:
storage_session = sessionFactory.get(
with self.database_session_factory() as session_factory:
storage_session = session_factory.get(
StorageSession, (app_name, user_id, session_id)
)
if storage_session is None:
return None
if config and config.after_timestamp:
after_dt = datetime.fromtimestamp(config.after_timestamp, tz=timezone.utc)
after_dt = datetime.fromtimestamp(
config.after_timestamp, tz=timezone.utc
)
timestamp_filter = StorageEvent.timestamp > after_dt
else:
timestamp_filter = True
storage_events = (
sessionFactory.query(StorageEvent)
session_factory.query(StorageEvent)
.filter(StorageEvent.session_id == storage_session.id)
.filter(timestamp_filter)
.filter(timestamp_filter)
.order_by(StorageEvent.timestamp.asc())
.limit(config.num_recent_events if config and config.num_recent_events else None)
.limit(
config.num_recent_events
if config and config.num_recent_events
else None
)
.all()
)
# Fetch states from storage
storage_app_state = sessionFactory.get(StorageAppState, (app_name))
storage_user_state = sessionFactory.get(
storage_app_state = session_factory.get(StorageAppState, (app_name))
storage_user_state = session_factory.get(
StorageUserState, (app_name, user_id)
)
@ -436,9 +436,9 @@ class DatabaseSessionService(BaseSessionService):
async def list_sessions(
self, *, app_name: str, user_id: str
) -> ListSessionsResponse:
with self.DatabaseSessionFactory() as sessionFactory:
with self.database_session_factory() as session_factory:
results = (
sessionFactory.query(StorageSession)
session_factory.query(StorageSession)
.filter(StorageSession.app_name == app_name)
.filter(StorageSession.user_id == user_id)
.all()
@ -459,14 +459,14 @@ class DatabaseSessionService(BaseSessionService):
async def delete_session(
self, app_name: str, user_id: str, session_id: str
) -> None:
with self.DatabaseSessionFactory() as sessionFactory:
with self.database_session_factory() as session_factory:
stmt = delete(StorageSession).where(
StorageSession.app_name == app_name,
StorageSession.user_id == user_id,
StorageSession.id == session_id,
)
sessionFactory.execute(stmt)
sessionFactory.commit()
session_factory.execute(stmt)
session_factory.commit()
@override
async def append_event(self, session: Session, event: Event) -> Event:
@ -478,8 +478,8 @@ class DatabaseSessionService(BaseSessionService):
# 1. Check if timestamp is stale
# 2. Update session attributes based on event config
# 3. Store event to table
with self.DatabaseSessionFactory() as sessionFactory:
storage_session = sessionFactory.get(
with self.database_session_factory() as session_factory:
storage_session = session_factory.get(
StorageSession, (session.app_name, session.user_id, session.id)
)
@ -493,10 +493,10 @@ class DatabaseSessionService(BaseSessionService):
)
# Fetch states from storage
storage_app_state = sessionFactory.get(
storage_app_state = session_factory.get(
StorageAppState, (session.app_name)
)
storage_user_state = sessionFactory.get(
storage_user_state = session_factory.get(
StorageUserState, (session.app_name, session.user_id)
)
@ -545,10 +545,10 @@ class DatabaseSessionService(BaseSessionService):
if event.content:
storage_event.content = _session_util.encode_content(event.content)
sessionFactory.add(storage_event)
session_factory.add(storage_event)
sessionFactory.commit()
sessionFactory.refresh(storage_session)
session_factory.commit()
session_factory.refresh(storage_session)
# Update timestamp with commit time
session.last_update_time = storage_session.update_time.timestamp()