feat(a2a): add file support and multimodal content processing for A2A protocol
This commit is contained in:
parent
958eeec4a6
commit
6bf0ea52e0
@ -11,6 +11,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
||||
|
||||
- Add Task Agent for structured single-task execution
|
||||
- Improve context management in agent execution
|
||||
- Add file support for A2A protocol (Agent-to-Agent) endpoints
|
||||
- Implement multimodal content processing in A2A messages
|
||||
|
||||
## [0.0.9] - 2025-05-13
|
||||
|
||||
|
@ -31,6 +31,7 @@
|
||||
Routes for the A2A (Agent-to-Agent) protocol.
|
||||
|
||||
This module implements the standard A2A routes according to the specification.
|
||||
Supports both text messages and file uploads through the message parts mechanism.
|
||||
"""
|
||||
|
||||
import uuid
|
||||
@ -92,7 +93,39 @@ async def process_a2a_request(
|
||||
db: Session = Depends(get_db),
|
||||
a2a_service: A2AService = Depends(get_a2a_service),
|
||||
):
|
||||
"""Processes an A2A request."""
|
||||
"""
|
||||
Processes an A2A request.
|
||||
|
||||
Supports both text messages and file uploads. For file uploads,
|
||||
include file parts in the message following the A2A protocol format:
|
||||
|
||||
{
|
||||
"jsonrpc": "2.0",
|
||||
"id": "request-id",
|
||||
"method": "tasks/send",
|
||||
"params": {
|
||||
"id": "task-id",
|
||||
"sessionId": "session-id",
|
||||
"message": {
|
||||
"role": "user",
|
||||
"parts": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": "Analyze this image"
|
||||
},
|
||||
{
|
||||
"type": "file",
|
||||
"file": {
|
||||
"name": "example.jpg",
|
||||
"mimeType": "image/jpeg",
|
||||
"bytes": "base64-encoded-content"
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
"""
|
||||
# Verify the API key
|
||||
if not verify_api_key(db, x_api_key):
|
||||
raise HTTPException(status_code=401, detail="Invalid API key")
|
||||
@ -100,10 +133,60 @@ async def process_a2a_request(
|
||||
# Process the request
|
||||
try:
|
||||
request_body = await request.json()
|
||||
|
||||
debug_request_body = {}
|
||||
if "method" in request_body:
|
||||
debug_request_body["method"] = request_body["method"]
|
||||
if "id" in request_body:
|
||||
debug_request_body["id"] = request_body["id"]
|
||||
|
||||
logger.info(f"A2A request received: {debug_request_body}")
|
||||
|
||||
# Log if request contains file parts for debugging
|
||||
if isinstance(request_body, dict) and "params" in request_body:
|
||||
params = request_body.get("params", {})
|
||||
message = params.get("message", {})
|
||||
parts = message.get("parts", [])
|
||||
|
||||
logger.info(f"A2A message contains {len(parts)} parts")
|
||||
for i, part in enumerate(parts):
|
||||
if not isinstance(part, dict):
|
||||
logger.warning(f"Part {i+1} is not a dictionary: {type(part)}")
|
||||
continue
|
||||
|
||||
part_type = part.get("type")
|
||||
logger.info(f"Part {i+1} type: {part_type}")
|
||||
|
||||
if part_type == "file":
|
||||
file_info = part.get("file", {})
|
||||
logger.info(
|
||||
f"File part found: {file_info.get('name')} ({file_info.get('mimeType')})"
|
||||
)
|
||||
if "bytes" in file_info:
|
||||
bytes_data = file_info.get("bytes", "")
|
||||
bytes_size = len(bytes_data) * 0.75
|
||||
logger.info(f"File size: ~{bytes_size/1024:.2f} KB")
|
||||
if bytes_data:
|
||||
sample = (
|
||||
bytes_data[:10] + "..."
|
||||
if len(bytes_data) > 10
|
||||
else bytes_data
|
||||
)
|
||||
logger.info(f"Sample of base64 data: {sample}")
|
||||
elif part_type == "text":
|
||||
text_content = part.get("text", "")
|
||||
preview = (
|
||||
text_content[:30] + "..."
|
||||
if len(text_content) > 30
|
||||
else text_content
|
||||
)
|
||||
logger.info(f"Text part found: '{preview}'")
|
||||
|
||||
result = await a2a_service.process_request(agent_id, request_body)
|
||||
|
||||
# If the response is a streaming response, return as EventSourceResponse
|
||||
if hasattr(result, "__aiter__"):
|
||||
logger.info("Returning streaming response")
|
||||
|
||||
async def event_generator():
|
||||
async for item in result:
|
||||
@ -115,11 +198,15 @@ async def process_a2a_request(
|
||||
return EventSourceResponse(event_generator())
|
||||
|
||||
# Otherwise, return as JSONResponse
|
||||
logger.info("Returning standard JSON response")
|
||||
if hasattr(result, "model_dump"):
|
||||
return JSONResponse(result.model_dump(exclude_none=True))
|
||||
return JSONResponse(result)
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing A2A request: {e}")
|
||||
import traceback
|
||||
|
||||
logger.error(f"Full traceback: {traceback.format_exc()}")
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content={
|
||||
|
@ -28,6 +28,7 @@
|
||||
"""
|
||||
|
||||
import uuid
|
||||
import base64
|
||||
from fastapi import (
|
||||
APIRouter,
|
||||
Depends,
|
||||
@ -47,7 +48,7 @@ from src.core.jwt_middleware import (
|
||||
from src.services import (
|
||||
agent_service,
|
||||
)
|
||||
from src.schemas.chat import ChatRequest, ChatResponse, ErrorResponse
|
||||
from src.schemas.chat import ChatRequest, ChatResponse, ErrorResponse, FileData
|
||||
from src.services.agent_runner import run_agent, run_agent_stream
|
||||
from src.core.exceptions import AgentNotFoundError
|
||||
from src.services.service_providers import (
|
||||
@ -59,7 +60,7 @@ from src.services.service_providers import (
|
||||
from datetime import datetime
|
||||
import logging
|
||||
import json
|
||||
from typing import Optional, Dict
|
||||
from typing import Optional, Dict, List, Any
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -195,6 +196,29 @@ async def websocket_chat(
|
||||
if not message:
|
||||
continue
|
||||
|
||||
files = None
|
||||
if data.get("files") and isinstance(data.get("files"), list):
|
||||
try:
|
||||
files = []
|
||||
for file_data in data.get("files"):
|
||||
if (
|
||||
isinstance(file_data, dict)
|
||||
and file_data.get("filename")
|
||||
and file_data.get("content_type")
|
||||
and file_data.get("data")
|
||||
):
|
||||
files.append(
|
||||
FileData(
|
||||
filename=file_data.get("filename"),
|
||||
content_type=file_data.get("content_type"),
|
||||
data=file_data.get("data"),
|
||||
)
|
||||
)
|
||||
logger.info(f"Processed {len(files)} files via WebSocket")
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing files: {str(e)}")
|
||||
files = None
|
||||
|
||||
async for chunk in run_agent_stream(
|
||||
agent_id=agent_id,
|
||||
external_id=external_id,
|
||||
@ -203,6 +227,7 @@ async def websocket_chat(
|
||||
artifacts_service=artifacts_service,
|
||||
memory_service=memory_service,
|
||||
db=db,
|
||||
files=files,
|
||||
):
|
||||
await websocket.send_json(
|
||||
{"message": json.loads(chunk), "turn_complete": False}
|
||||
@ -259,6 +284,7 @@ async def chat(
|
||||
artifacts_service,
|
||||
memory_service,
|
||||
db,
|
||||
files=request.files,
|
||||
)
|
||||
|
||||
return {
|
||||
|
@ -30,8 +30,9 @@
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from sqlalchemy.orm import Session
|
||||
from src.config.database import get_db
|
||||
from typing import List
|
||||
from typing import List, Optional, Dict, Any
|
||||
import uuid
|
||||
import base64
|
||||
from src.core.jwt_middleware import (
|
||||
get_jwt_token,
|
||||
verify_user_client,
|
||||
@ -48,7 +49,7 @@ from src.services.session_service import (
|
||||
get_sessions_by_agent,
|
||||
get_sessions_by_client,
|
||||
)
|
||||
from src.services.service_providers import session_service
|
||||
from src.services.service_providers import session_service, artifacts_service
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -118,13 +119,18 @@ async def get_session(
|
||||
|
||||
@router.get(
|
||||
"/{session_id}/messages",
|
||||
response_model=List[Event],
|
||||
)
|
||||
async def get_agent_messages(
|
||||
session_id: str,
|
||||
db: Session = Depends(get_db),
|
||||
payload: dict = Depends(get_jwt_token),
|
||||
):
|
||||
"""
|
||||
Gets messages from a session with embedded artifacts.
|
||||
|
||||
This function loads all messages from a session and processes any references
|
||||
to artifacts, loading them and converting them to base64 for direct use in the frontend.
|
||||
"""
|
||||
# Get the session
|
||||
session = get_session_by_id(session_service, session_id)
|
||||
if not session:
|
||||
@ -139,7 +145,160 @@ async def get_agent_messages(
|
||||
if agent:
|
||||
await verify_user_client(payload, db, agent.client_id)
|
||||
|
||||
return get_session_events(session_service, session_id)
|
||||
# Parse session ID para obter app_name e user_id
|
||||
parts = session_id.split("_")
|
||||
if len(parts) != 2:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid session ID format"
|
||||
)
|
||||
|
||||
user_id, app_name = parts[0], parts[1]
|
||||
|
||||
events = get_session_events(session_service, session_id)
|
||||
|
||||
processed_events = []
|
||||
for event in events:
|
||||
event_dict = event.dict()
|
||||
|
||||
def process_dict(d):
|
||||
if isinstance(d, dict):
|
||||
for key, value in list(d.items()):
|
||||
if isinstance(value, bytes):
|
||||
try:
|
||||
d[key] = base64.b64encode(value).decode("utf-8")
|
||||
logger.debug(f"Converted bytes field to base64: {key}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error encoding bytes to base64: {str(e)}")
|
||||
d[key] = None
|
||||
elif isinstance(value, dict):
|
||||
process_dict(value)
|
||||
elif isinstance(value, list):
|
||||
for item in value:
|
||||
if isinstance(item, (dict, list)):
|
||||
process_dict(item)
|
||||
elif isinstance(d, list):
|
||||
for i, item in enumerate(d):
|
||||
if isinstance(item, bytes):
|
||||
try:
|
||||
d[i] = base64.b64encode(item).decode("utf-8")
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error encoding bytes to base64 in list: {str(e)}"
|
||||
)
|
||||
d[i] = None
|
||||
elif isinstance(item, (dict, list)):
|
||||
process_dict(item)
|
||||
return d
|
||||
|
||||
# Process all event dictionary
|
||||
event_dict = process_dict(event_dict)
|
||||
|
||||
# Process the content parts specifically
|
||||
if event_dict.get("content") and event_dict["content"].get("parts"):
|
||||
for part in event_dict["content"]["parts"]:
|
||||
# Process inlineData if present
|
||||
if part and part.get("inlineData") and part["inlineData"].get("data"):
|
||||
# Check if it's already a string or if it's bytes
|
||||
if isinstance(part["inlineData"]["data"], bytes):
|
||||
# Convert bytes to base64 string
|
||||
part["inlineData"]["data"] = base64.b64encode(
|
||||
part["inlineData"]["data"]
|
||||
).decode("utf-8")
|
||||
logger.debug(
|
||||
f"Converted binary data to base64 in message {event_dict.get('id')}"
|
||||
)
|
||||
|
||||
# Process fileData if present (reference to an artifact)
|
||||
if part and part.get("fileData") and part["fileData"].get("fileId"):
|
||||
try:
|
||||
# Extract the file name from the fileId
|
||||
file_id = part["fileData"]["fileId"]
|
||||
|
||||
# Load the artifact from the artifacts service
|
||||
artifact = artifacts_service.load_artifact(
|
||||
app_name=app_name,
|
||||
user_id=user_id,
|
||||
session_id=session_id,
|
||||
filename=file_id,
|
||||
)
|
||||
|
||||
if artifact and hasattr(artifact, "inline_data"):
|
||||
# Extract the data and MIME type
|
||||
file_bytes = artifact.inline_data.data
|
||||
mime_type = artifact.inline_data.mime_type
|
||||
|
||||
# Add inlineData with the artifact data
|
||||
if not part.get("inlineData"):
|
||||
part["inlineData"] = {}
|
||||
|
||||
# Ensure we're sending a base64 string, not bytes
|
||||
if isinstance(file_bytes, bytes):
|
||||
try:
|
||||
part["inlineData"]["data"] = base64.b64encode(
|
||||
file_bytes
|
||||
).decode("utf-8")
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error encoding artifact to base64: {str(e)}"
|
||||
)
|
||||
part["inlineData"]["data"] = None
|
||||
else:
|
||||
part["inlineData"]["data"] = str(file_bytes)
|
||||
|
||||
part["inlineData"]["mimeType"] = mime_type
|
||||
|
||||
logger.debug(
|
||||
f"Loaded artifact {file_id} for message {event_dict.get('id')}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading artifact: {str(e)}")
|
||||
# Don't interrupt the flow if an artifact fails
|
||||
|
||||
# Check artifact_delta in actions
|
||||
if event_dict.get("actions") and event_dict["actions"].get("artifact_delta"):
|
||||
artifact_deltas = event_dict["actions"]["artifact_delta"]
|
||||
for filename, version in artifact_deltas.items():
|
||||
try:
|
||||
# Load the artifact
|
||||
artifact = artifacts_service.load_artifact(
|
||||
app_name=app_name,
|
||||
user_id=user_id,
|
||||
session_id=session_id,
|
||||
filename=filename,
|
||||
version=version,
|
||||
)
|
||||
|
||||
if artifact and hasattr(artifact, "inline_data"):
|
||||
# If the event doesn't have an artifacts section, create it
|
||||
if "artifacts" not in event_dict:
|
||||
event_dict["artifacts"] = {}
|
||||
|
||||
# Add the artifact to the event's artifacts list
|
||||
file_bytes = artifact.inline_data.data
|
||||
mime_type = artifact.inline_data.mime_type
|
||||
|
||||
# Ensure the bytes are converted to base64
|
||||
event_dict["artifacts"][filename] = {
|
||||
"data": (
|
||||
base64.b64encode(file_bytes).decode("utf-8")
|
||||
if isinstance(file_bytes, bytes)
|
||||
else str(file_bytes)
|
||||
),
|
||||
"mimeType": mime_type,
|
||||
"version": version,
|
||||
}
|
||||
|
||||
logger.debug(
|
||||
f"Added artifact {filename} (v{version}) to message {event_dict.get('id')}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error processing artifact_delta {filename}: {str(e)}"
|
||||
)
|
||||
|
||||
processed_events.append(event_dict)
|
||||
|
||||
return processed_events
|
||||
|
||||
|
||||
@router.delete(
|
||||
|
@ -27,36 +27,42 @@
|
||||
└──────────────────────────────────────────────────────────────────────────────┘
|
||||
"""
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import Dict, Any, Optional
|
||||
from pydantic import BaseModel, Field, validator
|
||||
from typing import Dict, List, Optional, Any
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
class FileData(BaseModel):
|
||||
"""Model to represent file data sent in a chat request."""
|
||||
|
||||
filename: str = Field(..., description="File name")
|
||||
content_type: str = Field(..., description="File content type")
|
||||
data: str = Field(..., description="File content encoded in base64")
|
||||
|
||||
|
||||
class ChatRequest(BaseModel):
|
||||
"""Schema for chat requests"""
|
||||
"""Model to represent a chat request."""
|
||||
|
||||
agent_id: str = Field(
|
||||
..., description="ID of the agent that will process the message"
|
||||
agent_id: str = Field(..., description="Agent ID to process the message")
|
||||
external_id: str = Field(..., description="External ID for user identification")
|
||||
message: str = Field(..., description="User message to the agent")
|
||||
files: Optional[List[FileData]] = Field(
|
||||
None, description="List of files attached to the message"
|
||||
)
|
||||
external_id: str = Field(
|
||||
..., description="ID of the external_id that will process the message"
|
||||
)
|
||||
message: str = Field(..., description="User message")
|
||||
|
||||
|
||||
class ChatResponse(BaseModel):
|
||||
"""Schema for chat responses"""
|
||||
"""Model to represent a chat response."""
|
||||
|
||||
response: str = Field(..., description="Agent response")
|
||||
status: str = Field(..., description="Operation status")
|
||||
error: Optional[str] = Field(None, description="Error message, if there is one")
|
||||
timestamp: str = Field(..., description="Timestamp of the response")
|
||||
response: str = Field(..., description="Response generated by the agent")
|
||||
message_history: List[Dict[str, Any]] = Field(
|
||||
default_factory=list, description="Message history"
|
||||
)
|
||||
status: str = Field(..., description="Response status (success/error)")
|
||||
timestamp: str = Field(..., description="Response timestamp")
|
||||
|
||||
|
||||
class ErrorResponse(BaseModel):
|
||||
"""Schema for error responses"""
|
||||
"""Model to represent an error response."""
|
||||
|
||||
error: str = Field(..., description="Error message")
|
||||
status_code: int = Field(..., description="HTTP status code of the error")
|
||||
details: Optional[Dict[str, Any]] = Field(
|
||||
None, description="Additional error details"
|
||||
)
|
||||
detail: str = Field(..., description="Error details")
|
||||
|
@ -33,8 +33,11 @@ from collections.abc import AsyncIterable
|
||||
from typing import Dict, Optional
|
||||
from uuid import UUID
|
||||
import json
|
||||
import base64
|
||||
import uuid as uuid_pkg
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
from google.genai.types import Part, Blob
|
||||
|
||||
from src.config.settings import settings
|
||||
from src.services.agent_service import (
|
||||
@ -76,6 +79,7 @@ from src.schemas.a2a_types import (
|
||||
AgentAuthentication,
|
||||
AgentProvider,
|
||||
)
|
||||
from src.schemas.chat import FileData
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -281,12 +285,29 @@ class A2ATaskManager:
|
||||
all_messages.append(agent_message)
|
||||
|
||||
task_state = self._determine_task_state(result)
|
||||
artifact = Artifact(parts=agent_message.parts, index=0)
|
||||
|
||||
# Create artifacts for any file content
|
||||
artifacts = []
|
||||
# First, add the main response as an artifact
|
||||
artifacts.append(Artifact(parts=agent_message.parts, index=0))
|
||||
|
||||
# Also add any files from the message history
|
||||
for idx, msg in enumerate(all_messages, 1):
|
||||
for part in msg.parts:
|
||||
if hasattr(part, "type") and part.type == "file":
|
||||
artifacts.append(
|
||||
Artifact(
|
||||
parts=[part],
|
||||
index=idx,
|
||||
name=part.file.name,
|
||||
description=f"File from message {idx}",
|
||||
)
|
||||
)
|
||||
|
||||
task = await self.update_store(
|
||||
task_params.id,
|
||||
TaskStatus(state=task_state, message=agent_message),
|
||||
[artifact],
|
||||
artifacts,
|
||||
)
|
||||
|
||||
await self._update_task_history(
|
||||
@ -400,6 +421,32 @@ class A2ATaskManager:
|
||||
|
||||
final_message = None
|
||||
|
||||
# Check for files in the user message and include them as artifacts
|
||||
user_files = []
|
||||
for part in request.params.message.parts:
|
||||
if (
|
||||
hasattr(part, "type")
|
||||
and part.type == "file"
|
||||
and hasattr(part, "file")
|
||||
):
|
||||
user_files.append(
|
||||
Artifact(
|
||||
parts=[part],
|
||||
index=0,
|
||||
name=part.file.name if part.file.name else "file",
|
||||
description="File from user",
|
||||
)
|
||||
)
|
||||
|
||||
# Send artifacts for any user files
|
||||
for artifact in user_files:
|
||||
yield SendTaskStreamingResponse(
|
||||
id=request.id,
|
||||
result=TaskArtifactUpdateEvent(
|
||||
id=request.params.id, artifact=artifact
|
||||
),
|
||||
)
|
||||
|
||||
async for chunk in run_agent_stream(
|
||||
agent_id=str(agent.id),
|
||||
external_id=external_id,
|
||||
@ -418,7 +465,48 @@ class A2ATaskManager:
|
||||
parts = content.get("parts", [])
|
||||
|
||||
if parts:
|
||||
update_message = Message(role=role, parts=parts)
|
||||
# Modify to handle file parts as well
|
||||
agent_parts = []
|
||||
for part in parts:
|
||||
# Handle different part types
|
||||
if part.get("type") == "text":
|
||||
agent_parts.append(part)
|
||||
full_response += part.get("text", "")
|
||||
elif part.get("inlineData") and part["inlineData"].get(
|
||||
"data"
|
||||
):
|
||||
# Convert inline data to file part
|
||||
mime_type = part["inlineData"].get(
|
||||
"mimeType", "application/octet-stream"
|
||||
)
|
||||
file_name = f"file_{uuid_pkg.uuid4().hex}{self._get_extension_from_mime(mime_type)}"
|
||||
file_part = {
|
||||
"type": "file",
|
||||
"file": {
|
||||
"name": file_name,
|
||||
"mimeType": mime_type,
|
||||
"bytes": part["inlineData"]["data"],
|
||||
},
|
||||
}
|
||||
agent_parts.append(file_part)
|
||||
|
||||
# Also send as artifact
|
||||
yield SendTaskStreamingResponse(
|
||||
id=request.id,
|
||||
result=TaskArtifactUpdateEvent(
|
||||
id=request.params.id,
|
||||
artifact=Artifact(
|
||||
parts=[file_part],
|
||||
index=0,
|
||||
name=file_name,
|
||||
description=f"Generated {mime_type} file",
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
if agent_parts:
|
||||
update_message = Message(role=role, parts=agent_parts)
|
||||
final_message = update_message
|
||||
|
||||
yield SendTaskStreamingResponse(
|
||||
id=request.id,
|
||||
@ -431,11 +519,6 @@ class A2ATaskManager:
|
||||
final=False,
|
||||
),
|
||||
)
|
||||
|
||||
for part in parts:
|
||||
if part.get("type") == "text":
|
||||
full_response += part.get("text", "")
|
||||
final_message = update_message
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing chunk: {e}, chunk: {chunk}")
|
||||
continue
|
||||
@ -485,6 +568,29 @@ class A2ATaskManager:
|
||||
error=InternalError(message=f"Error streaming task process: {str(e)}"),
|
||||
)
|
||||
|
||||
def _get_extension_from_mime(self, mime_type: str) -> str:
|
||||
"""Get a file extension from MIME type."""
|
||||
if not mime_type:
|
||||
return ""
|
||||
|
||||
mime_map = {
|
||||
"image/jpeg": ".jpg",
|
||||
"image/png": ".png",
|
||||
"image/gif": ".gif",
|
||||
"application/pdf": ".pdf",
|
||||
"text/plain": ".txt",
|
||||
"text/html": ".html",
|
||||
"text/csv": ".csv",
|
||||
"application/json": ".json",
|
||||
"application/xml": ".xml",
|
||||
"application/msword": ".doc",
|
||||
"application/vnd.openxmlformats-officedocument.wordprocessingml.document": ".docx",
|
||||
"application/vnd.ms-excel": ".xls",
|
||||
"application/vnd.openxmlformats-officedocument.spreadsheetml.sheet": ".xlsx",
|
||||
}
|
||||
|
||||
return mime_map.get(mime_type, "")
|
||||
|
||||
async def update_store(
|
||||
self,
|
||||
task_id: str,
|
||||
@ -514,19 +620,193 @@ class A2ATaskManager:
|
||||
return task
|
||||
|
||||
def _extract_user_query(self, task_params: TaskSendParams) -> str:
|
||||
"""Extracts the user query from the task parameters."""
|
||||
"""Extracts the user query from the task parameters and processes any files."""
|
||||
if not task_params.message or not task_params.message.parts:
|
||||
raise ValueError("Message or parts are missing in task parameters")
|
||||
|
||||
part = task_params.message.parts[0]
|
||||
if part.type != "text":
|
||||
raise ValueError("Only text parts are supported")
|
||||
# Process file parts first
|
||||
text_parts = []
|
||||
has_files = False
|
||||
file_parts = []
|
||||
|
||||
return part.text
|
||||
logger.info(
|
||||
f"Extracting query from message with {len(task_params.message.parts)} parts"
|
||||
)
|
||||
|
||||
# Extract text parts and file parts separately
|
||||
for idx, part in enumerate(task_params.message.parts):
|
||||
logger.info(
|
||||
f"Processing part {idx+1}, type: {getattr(part, 'type', 'unknown')}"
|
||||
)
|
||||
if hasattr(part, "type"):
|
||||
if part.type == "text":
|
||||
logger.info(f"Found text part: '{part.text[:50]}...' (truncated)")
|
||||
text_parts.append(part.text)
|
||||
elif part.type == "file":
|
||||
logger.info(
|
||||
f"Found file part: {getattr(getattr(part, 'file', None), 'name', 'unnamed')}"
|
||||
)
|
||||
has_files = True
|
||||
try:
|
||||
processed_file = self._process_file_part(
|
||||
part, task_params.sessionId
|
||||
)
|
||||
if processed_file:
|
||||
file_parts.append(processed_file)
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing file part: {e}")
|
||||
# Continue with other parts even if a file fails
|
||||
else:
|
||||
logger.warning(f"Unknown part type: {part.type}")
|
||||
else:
|
||||
logger.warning(f"Part has no type attribute: {part}")
|
||||
|
||||
# Store the file parts in self for later use
|
||||
self._last_processed_files = file_parts if file_parts else None
|
||||
|
||||
# If we have at least one text part, use that as the query
|
||||
if text_parts:
|
||||
final_query = " ".join(text_parts)
|
||||
logger.info(
|
||||
f"Final query from text parts: '{final_query[:50]}...' (truncated)"
|
||||
)
|
||||
return final_query
|
||||
# If we only have file parts, create a generic query asking for analysis
|
||||
elif has_files:
|
||||
logger.info("No text parts, using generic query for file analysis")
|
||||
return "Analyze the attached files"
|
||||
else:
|
||||
logger.error("No supported content parts found in the message")
|
||||
raise ValueError("No supported content parts found in the message")
|
||||
|
||||
def _process_file_part(self, part, session_id: str):
|
||||
"""Processes a file part and saves it to the artifact service.
|
||||
|
||||
Returns:
|
||||
dict: Processed file information to pass to agent_runner
|
||||
"""
|
||||
if not hasattr(part, "file") or not part.file:
|
||||
logger.warning("File part missing file data")
|
||||
return None
|
||||
|
||||
file_data = part.file
|
||||
|
||||
if not file_data.name:
|
||||
file_data.name = f"file_{uuid_pkg.uuid4().hex}"
|
||||
|
||||
logger.info(f"Processing file {file_data.name} for session {session_id}")
|
||||
|
||||
if file_data.bytes:
|
||||
# Process file data provided as base64 string
|
||||
try:
|
||||
# Convert base64 to bytes
|
||||
logger.info(f"Decoding base64 content for file {file_data.name}")
|
||||
file_bytes = base64.b64decode(file_data.bytes)
|
||||
|
||||
# Determine MIME type based on binary content
|
||||
mime_type = (
|
||||
file_data.mimeType if hasattr(file_data, "mimeType") else None
|
||||
)
|
||||
|
||||
if not mime_type or mime_type == "application/octet-stream":
|
||||
# Detection by byte signature
|
||||
if file_bytes.startswith(b"\xff\xd8\xff"): # JPEG signature
|
||||
mime_type = "image/jpeg"
|
||||
elif file_bytes.startswith(b"\x89PNG\r\n\x1a\n"): # PNG signature
|
||||
mime_type = "image/png"
|
||||
elif file_bytes.startswith(b"GIF87a") or file_bytes.startswith(
|
||||
b"GIF89a"
|
||||
): # GIF
|
||||
mime_type = "image/gif"
|
||||
elif file_bytes.startswith(b"%PDF"): # PDF
|
||||
mime_type = "application/pdf"
|
||||
else:
|
||||
# Fallback to avoid generic type in images
|
||||
if file_data.name.lower().endswith((".jpg", ".jpeg")):
|
||||
mime_type = "image/jpeg"
|
||||
elif file_data.name.lower().endswith(".png"):
|
||||
mime_type = "image/png"
|
||||
elif file_data.name.lower().endswith(".gif"):
|
||||
mime_type = "image/gif"
|
||||
elif file_data.name.lower().endswith(".pdf"):
|
||||
mime_type = "application/pdf"
|
||||
else:
|
||||
mime_type = "application/octet-stream"
|
||||
|
||||
logger.info(
|
||||
f"Decoded file size: {len(file_bytes)} bytes, MIME type: {mime_type}"
|
||||
)
|
||||
|
||||
# Split session_id to get app_name and user_id
|
||||
parts = session_id.split("_")
|
||||
if len(parts) != 2:
|
||||
user_id = session_id
|
||||
app_name = "a2a"
|
||||
else:
|
||||
user_id, app_name = parts
|
||||
|
||||
# Create artifact Part
|
||||
logger.info(f"Creating artifact Part for file {file_data.name}")
|
||||
artifact = Part(inline_data=Blob(mime_type=mime_type, data=file_bytes))
|
||||
|
||||
# Save to artifact service
|
||||
logger.info(
|
||||
f"Saving artifact {file_data.name} to {app_name}/{user_id}/{session_id}"
|
||||
)
|
||||
version = artifacts_service.save_artifact(
|
||||
app_name=app_name,
|
||||
user_id=user_id,
|
||||
session_id=session_id,
|
||||
filename=file_data.name,
|
||||
artifact=artifact,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Successfully saved file {file_data.name} (version {version}) for session {session_id}"
|
||||
)
|
||||
|
||||
# Import the FileData model from the chat schema
|
||||
from src.schemas.chat import FileData
|
||||
|
||||
# Create a FileData object instead of a dictionary
|
||||
# This is compatible with what agent_runner.py expects
|
||||
return FileData(
|
||||
filename=file_data.name,
|
||||
content_type=mime_type,
|
||||
data=file_data.bytes, # Keep the original base64 format
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing file data: {str(e)}")
|
||||
# Log more details about the error to help with debugging
|
||||
import traceback
|
||||
|
||||
logger.error(f"Full traceback: {traceback.format_exc()}")
|
||||
raise
|
||||
|
||||
elif file_data.uri:
|
||||
# Handling URIs would require additional implementation
|
||||
# For now, log that we received a URI but can't process it
|
||||
logger.warning(f"File URI references not yet implemented: {file_data.uri}")
|
||||
# Future enhancement: fetch the file from the URI and save it
|
||||
return None
|
||||
|
||||
return None
|
||||
|
||||
async def _run_agent(self, agent: Agent, query: str, session_id: str) -> dict:
|
||||
"""Executes the agent to process the user query."""
|
||||
try:
|
||||
files = getattr(self, "_last_processed_files", None)
|
||||
|
||||
if files:
|
||||
logger.info(f"Passing {len(files)} files to run_agent")
|
||||
for file_info in files:
|
||||
logger.info(
|
||||
f"File being sent: {file_info.filename} ({file_info.content_type})"
|
||||
)
|
||||
else:
|
||||
logger.info("No files to pass to run_agent")
|
||||
|
||||
# We call the same function used in the chat API
|
||||
return await run_agent(
|
||||
agent_id=str(agent.id),
|
||||
@ -536,6 +816,7 @@ class A2ATaskManager:
|
||||
artifacts_service=artifacts_service,
|
||||
memory_service=memory_service,
|
||||
db=self.db,
|
||||
files=files,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error running agent: {e}")
|
||||
|
@ -28,7 +28,7 @@
|
||||
"""
|
||||
|
||||
from google.adk.runners import Runner
|
||||
from google.genai.types import Content, Part
|
||||
from google.genai.types import Content, Part, Blob
|
||||
from google.adk.sessions import DatabaseSessionService
|
||||
from google.adk.memory import InMemoryMemoryService
|
||||
from google.adk.artifacts.in_memory_artifact_service import InMemoryArtifactService
|
||||
@ -42,6 +42,7 @@ import asyncio
|
||||
import json
|
||||
from src.utils.otel import get_tracer
|
||||
from opentelemetry import trace
|
||||
import base64
|
||||
|
||||
logger = setup_logger(__name__)
|
||||
|
||||
@ -56,6 +57,7 @@ async def run_agent(
|
||||
db: Session,
|
||||
session_id: Optional[str] = None,
|
||||
timeout: float = 60.0,
|
||||
files: Optional[list] = None,
|
||||
):
|
||||
tracer = get_tracer()
|
||||
with tracer.start_as_current_span(
|
||||
@ -65,6 +67,7 @@ async def run_agent(
|
||||
"external_id": external_id,
|
||||
"session_id": session_id or f"{external_id}_{agent_id}",
|
||||
"message": message,
|
||||
"has_files": files is not None and len(files) > 0,
|
||||
},
|
||||
):
|
||||
exit_stack = None
|
||||
@ -74,6 +77,9 @@ async def run_agent(
|
||||
)
|
||||
logger.info(f"Received message: {message}")
|
||||
|
||||
if files and len(files) > 0:
|
||||
logger.info(f"Received {len(files)} files with message")
|
||||
|
||||
get_root_agent = get_agent(db, agent_id)
|
||||
logger.info(
|
||||
f"Root agent found: {get_root_agent.name} (type: {get_root_agent.type})"
|
||||
@ -113,7 +119,63 @@ async def run_agent(
|
||||
session_id=adk_session_id,
|
||||
)
|
||||
|
||||
content = Content(role="user", parts=[Part(text=message)])
|
||||
file_parts = []
|
||||
if files and len(files) > 0:
|
||||
for file_data in files:
|
||||
try:
|
||||
file_bytes = base64.b64decode(file_data.data)
|
||||
|
||||
logger.info(f"DEBUG - Processing file: {file_data.filename}")
|
||||
logger.info(f"DEBUG - File size: {len(file_bytes)} bytes")
|
||||
logger.info(f"DEBUG - MIME type: '{file_data.content_type}'")
|
||||
logger.info(f"DEBUG - First 20 bytes: {file_bytes[:20]}")
|
||||
|
||||
try:
|
||||
file_part = Part(
|
||||
inline_data=Blob(
|
||||
mime_type=file_data.content_type, data=file_bytes
|
||||
)
|
||||
)
|
||||
logger.info(f"DEBUG - Part created successfully")
|
||||
except Exception as part_error:
|
||||
logger.error(
|
||||
f"DEBUG - Error creating Part: {str(part_error)}"
|
||||
)
|
||||
logger.error(
|
||||
f"DEBUG - Error type: {type(part_error).__name__}"
|
||||
)
|
||||
import traceback
|
||||
|
||||
logger.error(
|
||||
f"DEBUG - Stack trace: {traceback.format_exc()}"
|
||||
)
|
||||
raise
|
||||
|
||||
# Save the file in the ArtifactService
|
||||
version = artifacts_service.save_artifact(
|
||||
app_name=agent_id,
|
||||
user_id=external_id,
|
||||
session_id=adk_session_id,
|
||||
filename=file_data.filename,
|
||||
artifact=file_part,
|
||||
)
|
||||
logger.info(
|
||||
f"Saved file {file_data.filename} as version {version}"
|
||||
)
|
||||
|
||||
# Add the Part to the list of parts for the message content
|
||||
file_parts.append(file_part)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error processing file {file_data.filename}: {str(e)}"
|
||||
)
|
||||
|
||||
# Create the content with the text message and the files
|
||||
parts = [Part(text=message)]
|
||||
if file_parts:
|
||||
parts.extend(file_parts)
|
||||
|
||||
content = Content(role="user", parts=parts)
|
||||
logger.info("Starting agent execution")
|
||||
|
||||
final_response_text = "No final response captured."
|
||||
@ -256,6 +318,7 @@ async def run_agent_stream(
|
||||
memory_service: InMemoryMemoryService,
|
||||
db: Session,
|
||||
session_id: Optional[str] = None,
|
||||
files: Optional[list] = None,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
tracer = get_tracer()
|
||||
span = tracer.start_span(
|
||||
@ -265,6 +328,7 @@ async def run_agent_stream(
|
||||
"external_id": external_id,
|
||||
"session_id": session_id or f"{external_id}_{agent_id}",
|
||||
"message": message,
|
||||
"has_files": files is not None and len(files) > 0,
|
||||
},
|
||||
)
|
||||
try:
|
||||
@ -275,6 +339,9 @@ async def run_agent_stream(
|
||||
)
|
||||
logger.info(f"Received message: {message}")
|
||||
|
||||
if files and len(files) > 0:
|
||||
logger.info(f"Received {len(files)} files with message")
|
||||
|
||||
get_root_agent = get_agent(db, agent_id)
|
||||
logger.info(
|
||||
f"Root agent found: {get_root_agent.name} (type: {get_root_agent.type})"
|
||||
@ -314,7 +381,72 @@ async def run_agent_stream(
|
||||
session_id=adk_session_id,
|
||||
)
|
||||
|
||||
content = Content(role="user", parts=[Part(text=message)])
|
||||
# Process the received files
|
||||
file_parts = []
|
||||
if files and len(files) > 0:
|
||||
for file_data in files:
|
||||
try:
|
||||
# Decode the base64 file
|
||||
file_bytes = base64.b64decode(file_data.data)
|
||||
|
||||
# Detailed debug
|
||||
logger.info(
|
||||
f"DEBUG - Processing file: {file_data.filename}"
|
||||
)
|
||||
logger.info(f"DEBUG - File size: {len(file_bytes)} bytes")
|
||||
logger.info(
|
||||
f"DEBUG - MIME type: '{file_data.content_type}'"
|
||||
)
|
||||
logger.info(f"DEBUG - First 20 bytes: {file_bytes[:20]}")
|
||||
|
||||
# Create a Part for the file using the default constructor
|
||||
try:
|
||||
file_part = Part(
|
||||
inline_data=Blob(
|
||||
mime_type=file_data.content_type,
|
||||
data=file_bytes,
|
||||
)
|
||||
)
|
||||
logger.info(f"DEBUG - Part created successfully")
|
||||
except Exception as part_error:
|
||||
logger.error(
|
||||
f"DEBUG - Error creating Part: {str(part_error)}"
|
||||
)
|
||||
logger.error(
|
||||
f"DEBUG - Error type: {type(part_error).__name__}"
|
||||
)
|
||||
import traceback
|
||||
|
||||
logger.error(
|
||||
f"DEBUG - Stack trace: {traceback.format_exc()}"
|
||||
)
|
||||
raise
|
||||
|
||||
# Save the file in the ArtifactService
|
||||
version = artifacts_service.save_artifact(
|
||||
app_name=agent_id,
|
||||
user_id=external_id,
|
||||
session_id=adk_session_id,
|
||||
filename=file_data.filename,
|
||||
artifact=file_part,
|
||||
)
|
||||
logger.info(
|
||||
f"Saved file {file_data.filename} as version {version}"
|
||||
)
|
||||
|
||||
# Add the Part to the list of parts for the message content
|
||||
file_parts.append(file_part)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error processing file {file_data.filename}: {str(e)}"
|
||||
)
|
||||
|
||||
# Create the content with the text message and the files
|
||||
parts = [Part(text=message)]
|
||||
if file_parts:
|
||||
parts.extend(file_parts)
|
||||
|
||||
content = Content(role="user", parts=parts)
|
||||
logger.info("Starting agent streaming execution")
|
||||
|
||||
try:
|
||||
|
@ -27,10 +27,16 @@
|
||||
└──────────────────────────────────────────────────────────────────────────────┘
|
||||
"""
|
||||
|
||||
import base64
|
||||
import uuid
|
||||
from typing import Dict, List, Any, Optional
|
||||
from google.genai.types import Part, Blob
|
||||
|
||||
from src.schemas.a2a_types import (
|
||||
ContentTypeNotSupportedError,
|
||||
JSONRPCResponse,
|
||||
UnsupportedOperationError,
|
||||
Message,
|
||||
)
|
||||
|
||||
|
||||
@ -55,3 +61,130 @@ def new_incompatible_types_error(request_id):
|
||||
|
||||
def new_not_implemented_error(request_id):
|
||||
return JSONRPCResponse(id=request_id, error=UnsupportedOperationError())
|
||||
|
||||
|
||||
def extract_files_from_message(message: Message) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Extract file parts from an A2A message.
|
||||
|
||||
Args:
|
||||
message: An A2A Message object
|
||||
|
||||
Returns:
|
||||
List of file parts extracted from the message
|
||||
"""
|
||||
if not message or not message.parts:
|
||||
return []
|
||||
|
||||
files = []
|
||||
for part in message.parts:
|
||||
if hasattr(part, "type") and part.type == "file" and hasattr(part, "file"):
|
||||
files.append(part)
|
||||
|
||||
return files
|
||||
|
||||
|
||||
def a2a_part_to_adk_part(a2a_part: Dict[str, Any]) -> Optional[Part]:
|
||||
"""
|
||||
Convert an A2A protocol part to an ADK Part object.
|
||||
|
||||
Args:
|
||||
a2a_part: An A2A part dictionary
|
||||
|
||||
Returns:
|
||||
Converted ADK Part object or None if conversion not possible
|
||||
"""
|
||||
part_type = a2a_part.get("type")
|
||||
if part_type == "file" and "file" in a2a_part:
|
||||
file_data = a2a_part["file"]
|
||||
if "bytes" in file_data:
|
||||
try:
|
||||
# Convert base64 to bytes
|
||||
file_bytes = base64.b64decode(file_data["bytes"])
|
||||
mime_type = file_data.get("mimeType", "application/octet-stream")
|
||||
|
||||
# Create ADK Part
|
||||
return Part(inline_data=Blob(mime_type=mime_type, data=file_bytes))
|
||||
except Exception:
|
||||
return None
|
||||
elif part_type == "text" and "text" in a2a_part:
|
||||
# For text parts, we could create a text blob if needed
|
||||
return None
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def adk_part_to_a2a_part(
|
||||
adk_part: Part, filename: Optional[str] = None
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Convert an ADK Part object to an A2A protocol part.
|
||||
|
||||
Args:
|
||||
adk_part: An ADK Part object
|
||||
filename: Optional filename to use
|
||||
|
||||
Returns:
|
||||
Converted A2A Part dictionary or None if conversion not possible
|
||||
"""
|
||||
if hasattr(adk_part, "inline_data") and adk_part.inline_data:
|
||||
if adk_part.inline_data.data and adk_part.inline_data.mime_type:
|
||||
# Convert binary data to base64
|
||||
file_bytes = adk_part.inline_data.data
|
||||
mime_type = adk_part.inline_data.mime_type
|
||||
|
||||
# Generate filename if not provided
|
||||
if not filename:
|
||||
ext = get_extension_from_mime(mime_type)
|
||||
filename = f"file_{uuid.uuid4().hex}{ext}"
|
||||
|
||||
# Convert to A2A FilePart dict
|
||||
return {
|
||||
"type": "file",
|
||||
"file": {
|
||||
"name": filename,
|
||||
"mimeType": mime_type,
|
||||
"bytes": (
|
||||
base64.b64encode(file_bytes).decode("utf-8")
|
||||
if isinstance(file_bytes, bytes)
|
||||
else str(file_bytes)
|
||||
),
|
||||
},
|
||||
}
|
||||
elif hasattr(adk_part, "text") and adk_part.text:
|
||||
# Convert text part
|
||||
return {"type": "text", "text": adk_part.text}
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def get_extension_from_mime(mime_type: str) -> str:
|
||||
"""
|
||||
Get a file extension from MIME type.
|
||||
|
||||
Args:
|
||||
mime_type: MIME type string
|
||||
|
||||
Returns:
|
||||
Appropriate file extension with leading dot
|
||||
"""
|
||||
if not mime_type:
|
||||
return ""
|
||||
|
||||
mime_map = {
|
||||
"image/jpeg": ".jpg",
|
||||
"image/png": ".png",
|
||||
"image/gif": ".gif",
|
||||
"application/pdf": ".pdf",
|
||||
"text/plain": ".txt",
|
||||
"text/html": ".html",
|
||||
"text/csv": ".csv",
|
||||
"application/json": ".json",
|
||||
"application/xml": ".xml",
|
||||
"application/msword": ".doc",
|
||||
"application/vnd.openxmlformats-officedocument.wordprocessingml.document": ".docx",
|
||||
"application/vnd.ms-excel": ".xls",
|
||||
"application/vnd.openxmlformats-officedocument.spreadsheetml.sheet": ".xlsx",
|
||||
}
|
||||
|
||||
return mime_map.get(mime_type, "")
|
||||
|
Loading…
Reference in New Issue
Block a user