feat(a2a): add file support and multimodal content processing for A2A protocol

This commit is contained in:
Davidson Gomes 2025-05-14 22:15:08 -03:00
parent 958eeec4a6
commit 6bf0ea52e0
8 changed files with 869 additions and 43 deletions

View File

@ -11,6 +11,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Add Task Agent for structured single-task execution - Add Task Agent for structured single-task execution
- Improve context management in agent execution - Improve context management in agent execution
- Add file support for A2A protocol (Agent-to-Agent) endpoints
- Implement multimodal content processing in A2A messages
## [0.0.9] - 2025-05-13 ## [0.0.9] - 2025-05-13

View File

@ -31,6 +31,7 @@
Routes for the A2A (Agent-to-Agent) protocol. Routes for the A2A (Agent-to-Agent) protocol.
This module implements the standard A2A routes according to the specification. This module implements the standard A2A routes according to the specification.
Supports both text messages and file uploads through the message parts mechanism.
""" """
import uuid import uuid
@ -92,7 +93,39 @@ async def process_a2a_request(
db: Session = Depends(get_db), db: Session = Depends(get_db),
a2a_service: A2AService = Depends(get_a2a_service), a2a_service: A2AService = Depends(get_a2a_service),
): ):
"""Processes an A2A request.""" """
Processes an A2A request.
Supports both text messages and file uploads. For file uploads,
include file parts in the message following the A2A protocol format:
{
"jsonrpc": "2.0",
"id": "request-id",
"method": "tasks/send",
"params": {
"id": "task-id",
"sessionId": "session-id",
"message": {
"role": "user",
"parts": [
{
"type": "text",
"text": "Analyze this image"
},
{
"type": "file",
"file": {
"name": "example.jpg",
"mimeType": "image/jpeg",
"bytes": "base64-encoded-content"
}
}
]
}
}
}
"""
# Verify the API key # Verify the API key
if not verify_api_key(db, x_api_key): if not verify_api_key(db, x_api_key):
raise HTTPException(status_code=401, detail="Invalid API key") raise HTTPException(status_code=401, detail="Invalid API key")
@ -100,10 +133,60 @@ async def process_a2a_request(
# Process the request # Process the request
try: try:
request_body = await request.json() request_body = await request.json()
debug_request_body = {}
if "method" in request_body:
debug_request_body["method"] = request_body["method"]
if "id" in request_body:
debug_request_body["id"] = request_body["id"]
logger.info(f"A2A request received: {debug_request_body}")
# Log if request contains file parts for debugging
if isinstance(request_body, dict) and "params" in request_body:
params = request_body.get("params", {})
message = params.get("message", {})
parts = message.get("parts", [])
logger.info(f"A2A message contains {len(parts)} parts")
for i, part in enumerate(parts):
if not isinstance(part, dict):
logger.warning(f"Part {i+1} is not a dictionary: {type(part)}")
continue
part_type = part.get("type")
logger.info(f"Part {i+1} type: {part_type}")
if part_type == "file":
file_info = part.get("file", {})
logger.info(
f"File part found: {file_info.get('name')} ({file_info.get('mimeType')})"
)
if "bytes" in file_info:
bytes_data = file_info.get("bytes", "")
bytes_size = len(bytes_data) * 0.75
logger.info(f"File size: ~{bytes_size/1024:.2f} KB")
if bytes_data:
sample = (
bytes_data[:10] + "..."
if len(bytes_data) > 10
else bytes_data
)
logger.info(f"Sample of base64 data: {sample}")
elif part_type == "text":
text_content = part.get("text", "")
preview = (
text_content[:30] + "..."
if len(text_content) > 30
else text_content
)
logger.info(f"Text part found: '{preview}'")
result = await a2a_service.process_request(agent_id, request_body) result = await a2a_service.process_request(agent_id, request_body)
# If the response is a streaming response, return as EventSourceResponse # If the response is a streaming response, return as EventSourceResponse
if hasattr(result, "__aiter__"): if hasattr(result, "__aiter__"):
logger.info("Returning streaming response")
async def event_generator(): async def event_generator():
async for item in result: async for item in result:
@ -115,11 +198,15 @@ async def process_a2a_request(
return EventSourceResponse(event_generator()) return EventSourceResponse(event_generator())
# Otherwise, return as JSONResponse # Otherwise, return as JSONResponse
logger.info("Returning standard JSON response")
if hasattr(result, "model_dump"): if hasattr(result, "model_dump"):
return JSONResponse(result.model_dump(exclude_none=True)) return JSONResponse(result.model_dump(exclude_none=True))
return JSONResponse(result) return JSONResponse(result)
except Exception as e: except Exception as e:
logger.error(f"Error processing A2A request: {e}") logger.error(f"Error processing A2A request: {e}")
import traceback
logger.error(f"Full traceback: {traceback.format_exc()}")
return JSONResponse( return JSONResponse(
status_code=500, status_code=500,
content={ content={

View File

@ -28,6 +28,7 @@
""" """
import uuid import uuid
import base64
from fastapi import ( from fastapi import (
APIRouter, APIRouter,
Depends, Depends,
@ -47,7 +48,7 @@ from src.core.jwt_middleware import (
from src.services import ( from src.services import (
agent_service, agent_service,
) )
from src.schemas.chat import ChatRequest, ChatResponse, ErrorResponse from src.schemas.chat import ChatRequest, ChatResponse, ErrorResponse, FileData
from src.services.agent_runner import run_agent, run_agent_stream from src.services.agent_runner import run_agent, run_agent_stream
from src.core.exceptions import AgentNotFoundError from src.core.exceptions import AgentNotFoundError
from src.services.service_providers import ( from src.services.service_providers import (
@ -59,7 +60,7 @@ from src.services.service_providers import (
from datetime import datetime from datetime import datetime
import logging import logging
import json import json
from typing import Optional, Dict from typing import Optional, Dict, List, Any
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -195,6 +196,29 @@ async def websocket_chat(
if not message: if not message:
continue continue
files = None
if data.get("files") and isinstance(data.get("files"), list):
try:
files = []
for file_data in data.get("files"):
if (
isinstance(file_data, dict)
and file_data.get("filename")
and file_data.get("content_type")
and file_data.get("data")
):
files.append(
FileData(
filename=file_data.get("filename"),
content_type=file_data.get("content_type"),
data=file_data.get("data"),
)
)
logger.info(f"Processed {len(files)} files via WebSocket")
except Exception as e:
logger.error(f"Error processing files: {str(e)}")
files = None
async for chunk in run_agent_stream( async for chunk in run_agent_stream(
agent_id=agent_id, agent_id=agent_id,
external_id=external_id, external_id=external_id,
@ -203,6 +227,7 @@ async def websocket_chat(
artifacts_service=artifacts_service, artifacts_service=artifacts_service,
memory_service=memory_service, memory_service=memory_service,
db=db, db=db,
files=files,
): ):
await websocket.send_json( await websocket.send_json(
{"message": json.loads(chunk), "turn_complete": False} {"message": json.loads(chunk), "turn_complete": False}
@ -259,6 +284,7 @@ async def chat(
artifacts_service, artifacts_service,
memory_service, memory_service,
db, db,
files=request.files,
) )
return { return {

View File

@ -30,8 +30,9 @@
from fastapi import APIRouter, Depends, HTTPException, status from fastapi import APIRouter, Depends, HTTPException, status
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from src.config.database import get_db from src.config.database import get_db
from typing import List from typing import List, Optional, Dict, Any
import uuid import uuid
import base64
from src.core.jwt_middleware import ( from src.core.jwt_middleware import (
get_jwt_token, get_jwt_token,
verify_user_client, verify_user_client,
@ -48,7 +49,7 @@ from src.services.session_service import (
get_sessions_by_agent, get_sessions_by_agent,
get_sessions_by_client, get_sessions_by_client,
) )
from src.services.service_providers import session_service from src.services.service_providers import session_service, artifacts_service
import logging import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -118,13 +119,18 @@ async def get_session(
@router.get( @router.get(
"/{session_id}/messages", "/{session_id}/messages",
response_model=List[Event],
) )
async def get_agent_messages( async def get_agent_messages(
session_id: str, session_id: str,
db: Session = Depends(get_db), db: Session = Depends(get_db),
payload: dict = Depends(get_jwt_token), payload: dict = Depends(get_jwt_token),
): ):
"""
Gets messages from a session with embedded artifacts.
This function loads all messages from a session and processes any references
to artifacts, loading them and converting them to base64 for direct use in the frontend.
"""
# Get the session # Get the session
session = get_session_by_id(session_service, session_id) session = get_session_by_id(session_service, session_id)
if not session: if not session:
@ -139,7 +145,160 @@ async def get_agent_messages(
if agent: if agent:
await verify_user_client(payload, db, agent.client_id) await verify_user_client(payload, db, agent.client_id)
return get_session_events(session_service, session_id) # Parse session ID para obter app_name e user_id
parts = session_id.split("_")
if len(parts) != 2:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid session ID format"
)
user_id, app_name = parts[0], parts[1]
events = get_session_events(session_service, session_id)
processed_events = []
for event in events:
event_dict = event.dict()
def process_dict(d):
if isinstance(d, dict):
for key, value in list(d.items()):
if isinstance(value, bytes):
try:
d[key] = base64.b64encode(value).decode("utf-8")
logger.debug(f"Converted bytes field to base64: {key}")
except Exception as e:
logger.error(f"Error encoding bytes to base64: {str(e)}")
d[key] = None
elif isinstance(value, dict):
process_dict(value)
elif isinstance(value, list):
for item in value:
if isinstance(item, (dict, list)):
process_dict(item)
elif isinstance(d, list):
for i, item in enumerate(d):
if isinstance(item, bytes):
try:
d[i] = base64.b64encode(item).decode("utf-8")
except Exception as e:
logger.error(
f"Error encoding bytes to base64 in list: {str(e)}"
)
d[i] = None
elif isinstance(item, (dict, list)):
process_dict(item)
return d
# Process all event dictionary
event_dict = process_dict(event_dict)
# Process the content parts specifically
if event_dict.get("content") and event_dict["content"].get("parts"):
for part in event_dict["content"]["parts"]:
# Process inlineData if present
if part and part.get("inlineData") and part["inlineData"].get("data"):
# Check if it's already a string or if it's bytes
if isinstance(part["inlineData"]["data"], bytes):
# Convert bytes to base64 string
part["inlineData"]["data"] = base64.b64encode(
part["inlineData"]["data"]
).decode("utf-8")
logger.debug(
f"Converted binary data to base64 in message {event_dict.get('id')}"
)
# Process fileData if present (reference to an artifact)
if part and part.get("fileData") and part["fileData"].get("fileId"):
try:
# Extract the file name from the fileId
file_id = part["fileData"]["fileId"]
# Load the artifact from the artifacts service
artifact = artifacts_service.load_artifact(
app_name=app_name,
user_id=user_id,
session_id=session_id,
filename=file_id,
)
if artifact and hasattr(artifact, "inline_data"):
# Extract the data and MIME type
file_bytes = artifact.inline_data.data
mime_type = artifact.inline_data.mime_type
# Add inlineData with the artifact data
if not part.get("inlineData"):
part["inlineData"] = {}
# Ensure we're sending a base64 string, not bytes
if isinstance(file_bytes, bytes):
try:
part["inlineData"]["data"] = base64.b64encode(
file_bytes
).decode("utf-8")
except Exception as e:
logger.error(
f"Error encoding artifact to base64: {str(e)}"
)
part["inlineData"]["data"] = None
else:
part["inlineData"]["data"] = str(file_bytes)
part["inlineData"]["mimeType"] = mime_type
logger.debug(
f"Loaded artifact {file_id} for message {event_dict.get('id')}"
)
except Exception as e:
logger.error(f"Error loading artifact: {str(e)}")
# Don't interrupt the flow if an artifact fails
# Check artifact_delta in actions
if event_dict.get("actions") and event_dict["actions"].get("artifact_delta"):
artifact_deltas = event_dict["actions"]["artifact_delta"]
for filename, version in artifact_deltas.items():
try:
# Load the artifact
artifact = artifacts_service.load_artifact(
app_name=app_name,
user_id=user_id,
session_id=session_id,
filename=filename,
version=version,
)
if artifact and hasattr(artifact, "inline_data"):
# If the event doesn't have an artifacts section, create it
if "artifacts" not in event_dict:
event_dict["artifacts"] = {}
# Add the artifact to the event's artifacts list
file_bytes = artifact.inline_data.data
mime_type = artifact.inline_data.mime_type
# Ensure the bytes are converted to base64
event_dict["artifacts"][filename] = {
"data": (
base64.b64encode(file_bytes).decode("utf-8")
if isinstance(file_bytes, bytes)
else str(file_bytes)
),
"mimeType": mime_type,
"version": version,
}
logger.debug(
f"Added artifact {filename} (v{version}) to message {event_dict.get('id')}"
)
except Exception as e:
logger.error(
f"Error processing artifact_delta {filename}: {str(e)}"
)
processed_events.append(event_dict)
return processed_events
@router.delete( @router.delete(

View File

@ -27,36 +27,42 @@
""" """
from pydantic import BaseModel, Field from pydantic import BaseModel, Field, validator
from typing import Dict, Any, Optional from typing import Dict, List, Optional, Any
from datetime import datetime
class FileData(BaseModel):
"""Model to represent file data sent in a chat request."""
filename: str = Field(..., description="File name")
content_type: str = Field(..., description="File content type")
data: str = Field(..., description="File content encoded in base64")
class ChatRequest(BaseModel): class ChatRequest(BaseModel):
"""Schema for chat requests""" """Model to represent a chat request."""
agent_id: str = Field( agent_id: str = Field(..., description="Agent ID to process the message")
..., description="ID of the agent that will process the message" external_id: str = Field(..., description="External ID for user identification")
message: str = Field(..., description="User message to the agent")
files: Optional[List[FileData]] = Field(
None, description="List of files attached to the message"
) )
external_id: str = Field(
..., description="ID of the external_id that will process the message"
)
message: str = Field(..., description="User message")
class ChatResponse(BaseModel): class ChatResponse(BaseModel):
"""Schema for chat responses""" """Model to represent a chat response."""
response: str = Field(..., description="Agent response") response: str = Field(..., description="Response generated by the agent")
status: str = Field(..., description="Operation status") message_history: List[Dict[str, Any]] = Field(
error: Optional[str] = Field(None, description="Error message, if there is one") default_factory=list, description="Message history"
timestamp: str = Field(..., description="Timestamp of the response") )
status: str = Field(..., description="Response status (success/error)")
timestamp: str = Field(..., description="Response timestamp")
class ErrorResponse(BaseModel): class ErrorResponse(BaseModel):
"""Schema for error responses""" """Model to represent an error response."""
error: str = Field(..., description="Error message") detail: str = Field(..., description="Error details")
status_code: int = Field(..., description="HTTP status code of the error")
details: Optional[Dict[str, Any]] = Field(
None, description="Additional error details"
)

View File

@ -33,8 +33,11 @@ from collections.abc import AsyncIterable
from typing import Dict, Optional from typing import Dict, Optional
from uuid import UUID from uuid import UUID
import json import json
import base64
import uuid as uuid_pkg
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from google.genai.types import Part, Blob
from src.config.settings import settings from src.config.settings import settings
from src.services.agent_service import ( from src.services.agent_service import (
@ -76,6 +79,7 @@ from src.schemas.a2a_types import (
AgentAuthentication, AgentAuthentication,
AgentProvider, AgentProvider,
) )
from src.schemas.chat import FileData
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -281,12 +285,29 @@ class A2ATaskManager:
all_messages.append(agent_message) all_messages.append(agent_message)
task_state = self._determine_task_state(result) task_state = self._determine_task_state(result)
artifact = Artifact(parts=agent_message.parts, index=0)
# Create artifacts for any file content
artifacts = []
# First, add the main response as an artifact
artifacts.append(Artifact(parts=agent_message.parts, index=0))
# Also add any files from the message history
for idx, msg in enumerate(all_messages, 1):
for part in msg.parts:
if hasattr(part, "type") and part.type == "file":
artifacts.append(
Artifact(
parts=[part],
index=idx,
name=part.file.name,
description=f"File from message {idx}",
)
)
task = await self.update_store( task = await self.update_store(
task_params.id, task_params.id,
TaskStatus(state=task_state, message=agent_message), TaskStatus(state=task_state, message=agent_message),
[artifact], artifacts,
) )
await self._update_task_history( await self._update_task_history(
@ -400,6 +421,32 @@ class A2ATaskManager:
final_message = None final_message = None
# Check for files in the user message and include them as artifacts
user_files = []
for part in request.params.message.parts:
if (
hasattr(part, "type")
and part.type == "file"
and hasattr(part, "file")
):
user_files.append(
Artifact(
parts=[part],
index=0,
name=part.file.name if part.file.name else "file",
description="File from user",
)
)
# Send artifacts for any user files
for artifact in user_files:
yield SendTaskStreamingResponse(
id=request.id,
result=TaskArtifactUpdateEvent(
id=request.params.id, artifact=artifact
),
)
async for chunk in run_agent_stream( async for chunk in run_agent_stream(
agent_id=str(agent.id), agent_id=str(agent.id),
external_id=external_id, external_id=external_id,
@ -418,7 +465,48 @@ class A2ATaskManager:
parts = content.get("parts", []) parts = content.get("parts", [])
if parts: if parts:
update_message = Message(role=role, parts=parts) # Modify to handle file parts as well
agent_parts = []
for part in parts:
# Handle different part types
if part.get("type") == "text":
agent_parts.append(part)
full_response += part.get("text", "")
elif part.get("inlineData") and part["inlineData"].get(
"data"
):
# Convert inline data to file part
mime_type = part["inlineData"].get(
"mimeType", "application/octet-stream"
)
file_name = f"file_{uuid_pkg.uuid4().hex}{self._get_extension_from_mime(mime_type)}"
file_part = {
"type": "file",
"file": {
"name": file_name,
"mimeType": mime_type,
"bytes": part["inlineData"]["data"],
},
}
agent_parts.append(file_part)
# Also send as artifact
yield SendTaskStreamingResponse(
id=request.id,
result=TaskArtifactUpdateEvent(
id=request.params.id,
artifact=Artifact(
parts=[file_part],
index=0,
name=file_name,
description=f"Generated {mime_type} file",
),
),
)
if agent_parts:
update_message = Message(role=role, parts=agent_parts)
final_message = update_message
yield SendTaskStreamingResponse( yield SendTaskStreamingResponse(
id=request.id, id=request.id,
@ -431,11 +519,6 @@ class A2ATaskManager:
final=False, final=False,
), ),
) )
for part in parts:
if part.get("type") == "text":
full_response += part.get("text", "")
final_message = update_message
except Exception as e: except Exception as e:
logger.error(f"Error processing chunk: {e}, chunk: {chunk}") logger.error(f"Error processing chunk: {e}, chunk: {chunk}")
continue continue
@ -485,6 +568,29 @@ class A2ATaskManager:
error=InternalError(message=f"Error streaming task process: {str(e)}"), error=InternalError(message=f"Error streaming task process: {str(e)}"),
) )
def _get_extension_from_mime(self, mime_type: str) -> str:
"""Get a file extension from MIME type."""
if not mime_type:
return ""
mime_map = {
"image/jpeg": ".jpg",
"image/png": ".png",
"image/gif": ".gif",
"application/pdf": ".pdf",
"text/plain": ".txt",
"text/html": ".html",
"text/csv": ".csv",
"application/json": ".json",
"application/xml": ".xml",
"application/msword": ".doc",
"application/vnd.openxmlformats-officedocument.wordprocessingml.document": ".docx",
"application/vnd.ms-excel": ".xls",
"application/vnd.openxmlformats-officedocument.spreadsheetml.sheet": ".xlsx",
}
return mime_map.get(mime_type, "")
async def update_store( async def update_store(
self, self,
task_id: str, task_id: str,
@ -514,19 +620,193 @@ class A2ATaskManager:
return task return task
def _extract_user_query(self, task_params: TaskSendParams) -> str: def _extract_user_query(self, task_params: TaskSendParams) -> str:
"""Extracts the user query from the task parameters.""" """Extracts the user query from the task parameters and processes any files."""
if not task_params.message or not task_params.message.parts: if not task_params.message or not task_params.message.parts:
raise ValueError("Message or parts are missing in task parameters") raise ValueError("Message or parts are missing in task parameters")
part = task_params.message.parts[0] # Process file parts first
if part.type != "text": text_parts = []
raise ValueError("Only text parts are supported") has_files = False
file_parts = []
return part.text logger.info(
f"Extracting query from message with {len(task_params.message.parts)} parts"
)
# Extract text parts and file parts separately
for idx, part in enumerate(task_params.message.parts):
logger.info(
f"Processing part {idx+1}, type: {getattr(part, 'type', 'unknown')}"
)
if hasattr(part, "type"):
if part.type == "text":
logger.info(f"Found text part: '{part.text[:50]}...' (truncated)")
text_parts.append(part.text)
elif part.type == "file":
logger.info(
f"Found file part: {getattr(getattr(part, 'file', None), 'name', 'unnamed')}"
)
has_files = True
try:
processed_file = self._process_file_part(
part, task_params.sessionId
)
if processed_file:
file_parts.append(processed_file)
except Exception as e:
logger.error(f"Error processing file part: {e}")
# Continue with other parts even if a file fails
else:
logger.warning(f"Unknown part type: {part.type}")
else:
logger.warning(f"Part has no type attribute: {part}")
# Store the file parts in self for later use
self._last_processed_files = file_parts if file_parts else None
# If we have at least one text part, use that as the query
if text_parts:
final_query = " ".join(text_parts)
logger.info(
f"Final query from text parts: '{final_query[:50]}...' (truncated)"
)
return final_query
# If we only have file parts, create a generic query asking for analysis
elif has_files:
logger.info("No text parts, using generic query for file analysis")
return "Analyze the attached files"
else:
logger.error("No supported content parts found in the message")
raise ValueError("No supported content parts found in the message")
def _process_file_part(self, part, session_id: str):
"""Processes a file part and saves it to the artifact service.
Returns:
dict: Processed file information to pass to agent_runner
"""
if not hasattr(part, "file") or not part.file:
logger.warning("File part missing file data")
return None
file_data = part.file
if not file_data.name:
file_data.name = f"file_{uuid_pkg.uuid4().hex}"
logger.info(f"Processing file {file_data.name} for session {session_id}")
if file_data.bytes:
# Process file data provided as base64 string
try:
# Convert base64 to bytes
logger.info(f"Decoding base64 content for file {file_data.name}")
file_bytes = base64.b64decode(file_data.bytes)
# Determine MIME type based on binary content
mime_type = (
file_data.mimeType if hasattr(file_data, "mimeType") else None
)
if not mime_type or mime_type == "application/octet-stream":
# Detection by byte signature
if file_bytes.startswith(b"\xff\xd8\xff"): # JPEG signature
mime_type = "image/jpeg"
elif file_bytes.startswith(b"\x89PNG\r\n\x1a\n"): # PNG signature
mime_type = "image/png"
elif file_bytes.startswith(b"GIF87a") or file_bytes.startswith(
b"GIF89a"
): # GIF
mime_type = "image/gif"
elif file_bytes.startswith(b"%PDF"): # PDF
mime_type = "application/pdf"
else:
# Fallback to avoid generic type in images
if file_data.name.lower().endswith((".jpg", ".jpeg")):
mime_type = "image/jpeg"
elif file_data.name.lower().endswith(".png"):
mime_type = "image/png"
elif file_data.name.lower().endswith(".gif"):
mime_type = "image/gif"
elif file_data.name.lower().endswith(".pdf"):
mime_type = "application/pdf"
else:
mime_type = "application/octet-stream"
logger.info(
f"Decoded file size: {len(file_bytes)} bytes, MIME type: {mime_type}"
)
# Split session_id to get app_name and user_id
parts = session_id.split("_")
if len(parts) != 2:
user_id = session_id
app_name = "a2a"
else:
user_id, app_name = parts
# Create artifact Part
logger.info(f"Creating artifact Part for file {file_data.name}")
artifact = Part(inline_data=Blob(mime_type=mime_type, data=file_bytes))
# Save to artifact service
logger.info(
f"Saving artifact {file_data.name} to {app_name}/{user_id}/{session_id}"
)
version = artifacts_service.save_artifact(
app_name=app_name,
user_id=user_id,
session_id=session_id,
filename=file_data.name,
artifact=artifact,
)
logger.info(
f"Successfully saved file {file_data.name} (version {version}) for session {session_id}"
)
# Import the FileData model from the chat schema
from src.schemas.chat import FileData
# Create a FileData object instead of a dictionary
# This is compatible with what agent_runner.py expects
return FileData(
filename=file_data.name,
content_type=mime_type,
data=file_data.bytes, # Keep the original base64 format
)
except Exception as e:
logger.error(f"Error processing file data: {str(e)}")
# Log more details about the error to help with debugging
import traceback
logger.error(f"Full traceback: {traceback.format_exc()}")
raise
elif file_data.uri:
# Handling URIs would require additional implementation
# For now, log that we received a URI but can't process it
logger.warning(f"File URI references not yet implemented: {file_data.uri}")
# Future enhancement: fetch the file from the URI and save it
return None
return None
async def _run_agent(self, agent: Agent, query: str, session_id: str) -> dict: async def _run_agent(self, agent: Agent, query: str, session_id: str) -> dict:
"""Executes the agent to process the user query.""" """Executes the agent to process the user query."""
try: try:
files = getattr(self, "_last_processed_files", None)
if files:
logger.info(f"Passing {len(files)} files to run_agent")
for file_info in files:
logger.info(
f"File being sent: {file_info.filename} ({file_info.content_type})"
)
else:
logger.info("No files to pass to run_agent")
# We call the same function used in the chat API # We call the same function used in the chat API
return await run_agent( return await run_agent(
agent_id=str(agent.id), agent_id=str(agent.id),
@ -536,6 +816,7 @@ class A2ATaskManager:
artifacts_service=artifacts_service, artifacts_service=artifacts_service,
memory_service=memory_service, memory_service=memory_service,
db=self.db, db=self.db,
files=files,
) )
except Exception as e: except Exception as e:
logger.error(f"Error running agent: {e}") logger.error(f"Error running agent: {e}")

View File

@ -28,7 +28,7 @@
""" """
from google.adk.runners import Runner from google.adk.runners import Runner
from google.genai.types import Content, Part from google.genai.types import Content, Part, Blob
from google.adk.sessions import DatabaseSessionService from google.adk.sessions import DatabaseSessionService
from google.adk.memory import InMemoryMemoryService from google.adk.memory import InMemoryMemoryService
from google.adk.artifacts.in_memory_artifact_service import InMemoryArtifactService from google.adk.artifacts.in_memory_artifact_service import InMemoryArtifactService
@ -42,6 +42,7 @@ import asyncio
import json import json
from src.utils.otel import get_tracer from src.utils.otel import get_tracer
from opentelemetry import trace from opentelemetry import trace
import base64
logger = setup_logger(__name__) logger = setup_logger(__name__)
@ -56,6 +57,7 @@ async def run_agent(
db: Session, db: Session,
session_id: Optional[str] = None, session_id: Optional[str] = None,
timeout: float = 60.0, timeout: float = 60.0,
files: Optional[list] = None,
): ):
tracer = get_tracer() tracer = get_tracer()
with tracer.start_as_current_span( with tracer.start_as_current_span(
@ -65,6 +67,7 @@ async def run_agent(
"external_id": external_id, "external_id": external_id,
"session_id": session_id or f"{external_id}_{agent_id}", "session_id": session_id or f"{external_id}_{agent_id}",
"message": message, "message": message,
"has_files": files is not None and len(files) > 0,
}, },
): ):
exit_stack = None exit_stack = None
@ -74,6 +77,9 @@ async def run_agent(
) )
logger.info(f"Received message: {message}") logger.info(f"Received message: {message}")
if files and len(files) > 0:
logger.info(f"Received {len(files)} files with message")
get_root_agent = get_agent(db, agent_id) get_root_agent = get_agent(db, agent_id)
logger.info( logger.info(
f"Root agent found: {get_root_agent.name} (type: {get_root_agent.type})" f"Root agent found: {get_root_agent.name} (type: {get_root_agent.type})"
@ -113,7 +119,63 @@ async def run_agent(
session_id=adk_session_id, session_id=adk_session_id,
) )
content = Content(role="user", parts=[Part(text=message)]) file_parts = []
if files and len(files) > 0:
for file_data in files:
try:
file_bytes = base64.b64decode(file_data.data)
logger.info(f"DEBUG - Processing file: {file_data.filename}")
logger.info(f"DEBUG - File size: {len(file_bytes)} bytes")
logger.info(f"DEBUG - MIME type: '{file_data.content_type}'")
logger.info(f"DEBUG - First 20 bytes: {file_bytes[:20]}")
try:
file_part = Part(
inline_data=Blob(
mime_type=file_data.content_type, data=file_bytes
)
)
logger.info(f"DEBUG - Part created successfully")
except Exception as part_error:
logger.error(
f"DEBUG - Error creating Part: {str(part_error)}"
)
logger.error(
f"DEBUG - Error type: {type(part_error).__name__}"
)
import traceback
logger.error(
f"DEBUG - Stack trace: {traceback.format_exc()}"
)
raise
# Save the file in the ArtifactService
version = artifacts_service.save_artifact(
app_name=agent_id,
user_id=external_id,
session_id=adk_session_id,
filename=file_data.filename,
artifact=file_part,
)
logger.info(
f"Saved file {file_data.filename} as version {version}"
)
# Add the Part to the list of parts for the message content
file_parts.append(file_part)
except Exception as e:
logger.error(
f"Error processing file {file_data.filename}: {str(e)}"
)
# Create the content with the text message and the files
parts = [Part(text=message)]
if file_parts:
parts.extend(file_parts)
content = Content(role="user", parts=parts)
logger.info("Starting agent execution") logger.info("Starting agent execution")
final_response_text = "No final response captured." final_response_text = "No final response captured."
@ -256,6 +318,7 @@ async def run_agent_stream(
memory_service: InMemoryMemoryService, memory_service: InMemoryMemoryService,
db: Session, db: Session,
session_id: Optional[str] = None, session_id: Optional[str] = None,
files: Optional[list] = None,
) -> AsyncGenerator[str, None]: ) -> AsyncGenerator[str, None]:
tracer = get_tracer() tracer = get_tracer()
span = tracer.start_span( span = tracer.start_span(
@ -265,6 +328,7 @@ async def run_agent_stream(
"external_id": external_id, "external_id": external_id,
"session_id": session_id or f"{external_id}_{agent_id}", "session_id": session_id or f"{external_id}_{agent_id}",
"message": message, "message": message,
"has_files": files is not None and len(files) > 0,
}, },
) )
try: try:
@ -275,6 +339,9 @@ async def run_agent_stream(
) )
logger.info(f"Received message: {message}") logger.info(f"Received message: {message}")
if files and len(files) > 0:
logger.info(f"Received {len(files)} files with message")
get_root_agent = get_agent(db, agent_id) get_root_agent = get_agent(db, agent_id)
logger.info( logger.info(
f"Root agent found: {get_root_agent.name} (type: {get_root_agent.type})" f"Root agent found: {get_root_agent.name} (type: {get_root_agent.type})"
@ -314,7 +381,72 @@ async def run_agent_stream(
session_id=adk_session_id, session_id=adk_session_id,
) )
content = Content(role="user", parts=[Part(text=message)]) # Process the received files
file_parts = []
if files and len(files) > 0:
for file_data in files:
try:
# Decode the base64 file
file_bytes = base64.b64decode(file_data.data)
# Detailed debug
logger.info(
f"DEBUG - Processing file: {file_data.filename}"
)
logger.info(f"DEBUG - File size: {len(file_bytes)} bytes")
logger.info(
f"DEBUG - MIME type: '{file_data.content_type}'"
)
logger.info(f"DEBUG - First 20 bytes: {file_bytes[:20]}")
# Create a Part for the file using the default constructor
try:
file_part = Part(
inline_data=Blob(
mime_type=file_data.content_type,
data=file_bytes,
)
)
logger.info(f"DEBUG - Part created successfully")
except Exception as part_error:
logger.error(
f"DEBUG - Error creating Part: {str(part_error)}"
)
logger.error(
f"DEBUG - Error type: {type(part_error).__name__}"
)
import traceback
logger.error(
f"DEBUG - Stack trace: {traceback.format_exc()}"
)
raise
# Save the file in the ArtifactService
version = artifacts_service.save_artifact(
app_name=agent_id,
user_id=external_id,
session_id=adk_session_id,
filename=file_data.filename,
artifact=file_part,
)
logger.info(
f"Saved file {file_data.filename} as version {version}"
)
# Add the Part to the list of parts for the message content
file_parts.append(file_part)
except Exception as e:
logger.error(
f"Error processing file {file_data.filename}: {str(e)}"
)
# Create the content with the text message and the files
parts = [Part(text=message)]
if file_parts:
parts.extend(file_parts)
content = Content(role="user", parts=parts)
logger.info("Starting agent streaming execution") logger.info("Starting agent streaming execution")
try: try:

View File

@ -27,10 +27,16 @@
""" """
import base64
import uuid
from typing import Dict, List, Any, Optional
from google.genai.types import Part, Blob
from src.schemas.a2a_types import ( from src.schemas.a2a_types import (
ContentTypeNotSupportedError, ContentTypeNotSupportedError,
JSONRPCResponse, JSONRPCResponse,
UnsupportedOperationError, UnsupportedOperationError,
Message,
) )
@ -55,3 +61,130 @@ def new_incompatible_types_error(request_id):
def new_not_implemented_error(request_id): def new_not_implemented_error(request_id):
return JSONRPCResponse(id=request_id, error=UnsupportedOperationError()) return JSONRPCResponse(id=request_id, error=UnsupportedOperationError())
def extract_files_from_message(message: Message) -> List[Dict[str, Any]]:
"""
Extract file parts from an A2A message.
Args:
message: An A2A Message object
Returns:
List of file parts extracted from the message
"""
if not message or not message.parts:
return []
files = []
for part in message.parts:
if hasattr(part, "type") and part.type == "file" and hasattr(part, "file"):
files.append(part)
return files
def a2a_part_to_adk_part(a2a_part: Dict[str, Any]) -> Optional[Part]:
"""
Convert an A2A protocol part to an ADK Part object.
Args:
a2a_part: An A2A part dictionary
Returns:
Converted ADK Part object or None if conversion not possible
"""
part_type = a2a_part.get("type")
if part_type == "file" and "file" in a2a_part:
file_data = a2a_part["file"]
if "bytes" in file_data:
try:
# Convert base64 to bytes
file_bytes = base64.b64decode(file_data["bytes"])
mime_type = file_data.get("mimeType", "application/octet-stream")
# Create ADK Part
return Part(inline_data=Blob(mime_type=mime_type, data=file_bytes))
except Exception:
return None
elif part_type == "text" and "text" in a2a_part:
# For text parts, we could create a text blob if needed
return None
return None
def adk_part_to_a2a_part(
adk_part: Part, filename: Optional[str] = None
) -> Optional[Dict[str, Any]]:
"""
Convert an ADK Part object to an A2A protocol part.
Args:
adk_part: An ADK Part object
filename: Optional filename to use
Returns:
Converted A2A Part dictionary or None if conversion not possible
"""
if hasattr(adk_part, "inline_data") and adk_part.inline_data:
if adk_part.inline_data.data and adk_part.inline_data.mime_type:
# Convert binary data to base64
file_bytes = adk_part.inline_data.data
mime_type = adk_part.inline_data.mime_type
# Generate filename if not provided
if not filename:
ext = get_extension_from_mime(mime_type)
filename = f"file_{uuid.uuid4().hex}{ext}"
# Convert to A2A FilePart dict
return {
"type": "file",
"file": {
"name": filename,
"mimeType": mime_type,
"bytes": (
base64.b64encode(file_bytes).decode("utf-8")
if isinstance(file_bytes, bytes)
else str(file_bytes)
),
},
}
elif hasattr(adk_part, "text") and adk_part.text:
# Convert text part
return {"type": "text", "text": adk_part.text}
return None
def get_extension_from_mime(mime_type: str) -> str:
"""
Get a file extension from MIME type.
Args:
mime_type: MIME type string
Returns:
Appropriate file extension with leading dot
"""
if not mime_type:
return ""
mime_map = {
"image/jpeg": ".jpg",
"image/png": ".png",
"image/gif": ".gif",
"application/pdf": ".pdf",
"text/plain": ".txt",
"text/html": ".html",
"text/csv": ".csv",
"application/json": ".json",
"application/xml": ".xml",
"application/msword": ".doc",
"application/vnd.openxmlformats-officedocument.wordprocessingml.document": ".docx",
"application/vnd.ms-excel": ".xls",
"application/vnd.openxmlformats-officedocument.spreadsheetml.sheet": ".xlsx",
}
return mime_map.get(mime_type, "")