mirror of
https://github.com/EvolutionAPI/adk-python.git
synced 2025-12-24 14:17:45 -06:00
No public description
PiperOrigin-RevId: 748777998
This commit is contained in:
committed by
hangfei
parent
290058eb05
commit
61d4be2d76
@@ -12,6 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import base64
|
||||
import copy
|
||||
from datetime import datetime
|
||||
import json
|
||||
@@ -20,17 +21,17 @@ from typing import Any
|
||||
from typing import Optional
|
||||
import uuid
|
||||
|
||||
from sqlalchemy import Boolean
|
||||
from sqlalchemy import delete
|
||||
from sqlalchemy import Dialect
|
||||
from sqlalchemy import ForeignKeyConstraint
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy import Text
|
||||
from sqlalchemy.dialects import postgresql
|
||||
from sqlalchemy.engine import create_engine
|
||||
from sqlalchemy.engine import Engine
|
||||
from sqlalchemy.ext.mutable import MutableDict
|
||||
from sqlalchemy.exc import ArgumentError
|
||||
from sqlalchemy.ext.mutable import MutableDict
|
||||
from sqlalchemy.inspection import inspect
|
||||
from sqlalchemy.orm import DeclarativeBase
|
||||
from sqlalchemy.orm import Mapped
|
||||
@@ -54,6 +55,7 @@ from .base_session_service import ListSessionsResponse
|
||||
from .session import Session
|
||||
from .state import State
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -103,7 +105,7 @@ class StorageSession(Base):
|
||||
String, primary_key=True, default=lambda: str(uuid.uuid4())
|
||||
)
|
||||
|
||||
state: Mapped[dict] = mapped_column(
|
||||
state: Mapped[MutableDict[str, Any]] = mapped_column(
|
||||
MutableDict.as_mutable(DynamicJSON), default={}
|
||||
)
|
||||
|
||||
@@ -134,8 +136,20 @@ class StorageEvent(Base):
|
||||
author: Mapped[str] = mapped_column(String)
|
||||
branch: Mapped[str] = mapped_column(String, nullable=True)
|
||||
timestamp: Mapped[DateTime] = mapped_column(DateTime(), default=func.now())
|
||||
content: Mapped[dict] = mapped_column(DynamicJSON)
|
||||
actions: Mapped[dict] = mapped_column(PickleType)
|
||||
content: Mapped[dict[str, Any]] = mapped_column(DynamicJSON)
|
||||
actions: Mapped[MutableDict[str, Any]] = mapped_column(PickleType)
|
||||
|
||||
long_running_tool_ids_json: Mapped[Optional[str]] = mapped_column(
|
||||
Text, nullable=True
|
||||
)
|
||||
grounding_metadata: Mapped[dict[str, Any]] = mapped_column(
|
||||
DynamicJSON, nullable=True
|
||||
)
|
||||
partial: Mapped[bool] = mapped_column(Boolean, nullable=True)
|
||||
turn_complete: Mapped[bool] = mapped_column(Boolean, nullable=True)
|
||||
error_code: Mapped[str] = mapped_column(String, nullable=True)
|
||||
error_message: Mapped[str] = mapped_column(String, nullable=True)
|
||||
interrupted: Mapped[bool] = mapped_column(Boolean, nullable=True)
|
||||
|
||||
storage_session: Mapped[StorageSession] = relationship(
|
||||
"StorageSession",
|
||||
@@ -150,13 +164,28 @@ class StorageEvent(Base):
|
||||
),
|
||||
)
|
||||
|
||||
@property
|
||||
def long_running_tool_ids(self) -> set[str]:
|
||||
return (
|
||||
set(json.loads(self.long_running_tool_ids_json))
|
||||
if self.long_running_tool_ids_json
|
||||
else set()
|
||||
)
|
||||
|
||||
@long_running_tool_ids.setter
|
||||
def long_running_tool_ids(self, value: set[str]):
|
||||
if value is None:
|
||||
self.long_running_tool_ids_json = None
|
||||
else:
|
||||
self.long_running_tool_ids_json = json.dumps(list(value))
|
||||
|
||||
|
||||
class StorageAppState(Base):
|
||||
"""Represents an app state stored in the database."""
|
||||
__tablename__ = "app_states"
|
||||
|
||||
app_name: Mapped[str] = mapped_column(String, primary_key=True)
|
||||
state: Mapped[dict] = mapped_column(
|
||||
state: Mapped[MutableDict[str, Any]] = mapped_column(
|
||||
MutableDict.as_mutable(DynamicJSON), default={}
|
||||
)
|
||||
update_time: Mapped[DateTime] = mapped_column(
|
||||
@@ -170,7 +199,7 @@ class StorageUserState(Base):
|
||||
|
||||
app_name: Mapped[str] = mapped_column(String, primary_key=True)
|
||||
user_id: Mapped[str] = mapped_column(String, primary_key=True)
|
||||
state: Mapped[dict] = mapped_column(
|
||||
state: Mapped[MutableDict[str, Any]] = mapped_column(
|
||||
MutableDict.as_mutable(DynamicJSON), default={}
|
||||
)
|
||||
update_time: Mapped[DateTime] = mapped_column(
|
||||
@@ -295,7 +324,6 @@ class DatabaseSessionService(BaseSessionService):
|
||||
last_update_time=storage_session.update_time.timestamp(),
|
||||
)
|
||||
return session
|
||||
return None
|
||||
|
||||
@override
|
||||
def get_session(
|
||||
@@ -309,7 +337,6 @@ class DatabaseSessionService(BaseSessionService):
|
||||
# 1. Get the storage session entry from session table
|
||||
# 2. Get all the events based on session id and filtering config
|
||||
# 3. Convert and return the session
|
||||
session: Session = None
|
||||
with self.DatabaseSessionFactory() as sessionFactory:
|
||||
storage_session = sessionFactory.get(
|
||||
StorageSession, (app_name, user_id, session_id)
|
||||
@@ -356,13 +383,19 @@ class DatabaseSessionService(BaseSessionService):
|
||||
author=e.author,
|
||||
branch=e.branch,
|
||||
invocation_id=e.invocation_id,
|
||||
content=e.content,
|
||||
content=_decode_content(e.content),
|
||||
actions=e.actions,
|
||||
timestamp=e.timestamp.timestamp(),
|
||||
long_running_tool_ids=e.long_running_tool_ids,
|
||||
grounding_metadata=e.grounding_metadata,
|
||||
partial=e.partial,
|
||||
turn_complete=e.turn_complete,
|
||||
error_code=e.error_code,
|
||||
error_message=e.error_message,
|
||||
interrupted=e.interrupted,
|
||||
)
|
||||
for e in storage_events
|
||||
]
|
||||
|
||||
return session
|
||||
|
||||
@override
|
||||
@@ -387,7 +420,6 @@ class DatabaseSessionService(BaseSessionService):
|
||||
)
|
||||
sessions.append(session)
|
||||
return ListSessionsResponse(sessions=sessions)
|
||||
raise ValueError("Failed to retrieve sessions.")
|
||||
|
||||
@override
|
||||
def delete_session(
|
||||
@@ -406,7 +438,7 @@ class DatabaseSessionService(BaseSessionService):
|
||||
def append_event(self, session: Session, event: Event) -> Event:
|
||||
logger.info(f"Append event: {event} to session {session.id}")
|
||||
|
||||
if event.partial and not event.content:
|
||||
if event.partial:
|
||||
return event
|
||||
|
||||
# 1. Check if timestamp is stale
|
||||
@@ -455,19 +487,34 @@ class DatabaseSessionService(BaseSessionService):
|
||||
storage_user_state.state = user_state
|
||||
storage_session.state = session_state
|
||||
|
||||
encoded_content = event.content.model_dump(exclude_none=True)
|
||||
storage_event = StorageEvent(
|
||||
id=event.id,
|
||||
invocation_id=event.invocation_id,
|
||||
author=event.author,
|
||||
branch=event.branch,
|
||||
content=encoded_content,
|
||||
actions=event.actions,
|
||||
session_id=session.id,
|
||||
app_name=session.app_name,
|
||||
user_id=session.user_id,
|
||||
timestamp=datetime.fromtimestamp(event.timestamp),
|
||||
long_running_tool_ids=event.long_running_tool_ids,
|
||||
grounding_metadata=event.grounding_metadata,
|
||||
partial=event.partial,
|
||||
turn_complete=event.turn_complete,
|
||||
error_code=event.error_code,
|
||||
error_message=event.error_message,
|
||||
interrupted=event.interrupted,
|
||||
)
|
||||
if event.content:
|
||||
encoded_content = event.content.model_dump(exclude_none=True)
|
||||
# Workaround for multimodal Content throwing JSON not serializable
|
||||
# error with SQLAlchemy.
|
||||
for p in encoded_content["parts"]:
|
||||
if "inline_data" in p:
|
||||
p["inline_data"]["data"] = (
|
||||
base64.b64encode(p["inline_data"]["data"]).decode("utf-8"),
|
||||
)
|
||||
storage_event.content = encoded_content
|
||||
|
||||
sessionFactory.add(storage_event)
|
||||
|
||||
@@ -489,8 +536,7 @@ class DatabaseSessionService(BaseSessionService):
|
||||
user_id: str,
|
||||
session_id: str,
|
||||
) -> ListEventsResponse:
|
||||
pass
|
||||
|
||||
raise NotImplementedError()
|
||||
|
||||
def convert_event(event: StorageEvent) -> Event:
|
||||
"""Converts a storage event to an event."""
|
||||
@@ -505,7 +551,7 @@ def convert_event(event: StorageEvent) -> Event:
|
||||
)
|
||||
|
||||
|
||||
def _extract_state_delta(state: dict):
|
||||
def _extract_state_delta(state: dict[str, Any]):
|
||||
app_state_delta = {}
|
||||
user_state_delta = {}
|
||||
session_state_delta = {}
|
||||
@@ -528,3 +574,10 @@ def _merge_state(app_state, user_state, session_state):
|
||||
for key in user_state.keys():
|
||||
merged_state[State.USER_PREFIX + key] = user_state[key]
|
||||
return merged_state
|
||||
|
||||
|
||||
def _decode_content(content: dict[str, Any]) -> dict[str, Any]:
|
||||
for p in content["parts"]:
|
||||
if "inline_data" in p:
|
||||
p["inline_data"]["data"] = base64.b64decode(p["inline_data"]["data"][0])
|
||||
return content
|
||||
|
||||
Reference in New Issue
Block a user