feat(a2a): add file support and multimodal content processing for A2A protocol

This commit is contained in:
Davidson Gomes 2025-05-14 22:15:08 -03:00
parent 958eeec4a6
commit 6bf0ea52e0
8 changed files with 869 additions and 43 deletions

View File

@ -11,6 +11,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Add Task Agent for structured single-task execution
- Improve context management in agent execution
- Add file support for A2A protocol (Agent-to-Agent) endpoints
- Implement multimodal content processing in A2A messages
## [0.0.9] - 2025-05-13

View File

@ -31,6 +31,7 @@
Routes for the A2A (Agent-to-Agent) protocol.
This module implements the standard A2A routes according to the specification.
Supports both text messages and file uploads through the message parts mechanism.
"""
import uuid
@ -92,7 +93,39 @@ async def process_a2a_request(
db: Session = Depends(get_db),
a2a_service: A2AService = Depends(get_a2a_service),
):
"""Processes an A2A request."""
"""
Processes an A2A request.
Supports both text messages and file uploads. For file uploads,
include file parts in the message following the A2A protocol format:
{
"jsonrpc": "2.0",
"id": "request-id",
"method": "tasks/send",
"params": {
"id": "task-id",
"sessionId": "session-id",
"message": {
"role": "user",
"parts": [
{
"type": "text",
"text": "Analyze this image"
},
{
"type": "file",
"file": {
"name": "example.jpg",
"mimeType": "image/jpeg",
"bytes": "base64-encoded-content"
}
}
]
}
}
}
"""
# Verify the API key
if not verify_api_key(db, x_api_key):
raise HTTPException(status_code=401, detail="Invalid API key")
@ -100,10 +133,60 @@ async def process_a2a_request(
# Process the request
try:
request_body = await request.json()
debug_request_body = {}
if "method" in request_body:
debug_request_body["method"] = request_body["method"]
if "id" in request_body:
debug_request_body["id"] = request_body["id"]
logger.info(f"A2A request received: {debug_request_body}")
# Log if request contains file parts for debugging
if isinstance(request_body, dict) and "params" in request_body:
params = request_body.get("params", {})
message = params.get("message", {})
parts = message.get("parts", [])
logger.info(f"A2A message contains {len(parts)} parts")
for i, part in enumerate(parts):
if not isinstance(part, dict):
logger.warning(f"Part {i+1} is not a dictionary: {type(part)}")
continue
part_type = part.get("type")
logger.info(f"Part {i+1} type: {part_type}")
if part_type == "file":
file_info = part.get("file", {})
logger.info(
f"File part found: {file_info.get('name')} ({file_info.get('mimeType')})"
)
if "bytes" in file_info:
bytes_data = file_info.get("bytes", "")
bytes_size = len(bytes_data) * 0.75
logger.info(f"File size: ~{bytes_size/1024:.2f} KB")
if bytes_data:
sample = (
bytes_data[:10] + "..."
if len(bytes_data) > 10
else bytes_data
)
logger.info(f"Sample of base64 data: {sample}")
elif part_type == "text":
text_content = part.get("text", "")
preview = (
text_content[:30] + "..."
if len(text_content) > 30
else text_content
)
logger.info(f"Text part found: '{preview}'")
result = await a2a_service.process_request(agent_id, request_body)
# If the response is a streaming response, return as EventSourceResponse
if hasattr(result, "__aiter__"):
logger.info("Returning streaming response")
async def event_generator():
async for item in result:
@ -115,11 +198,15 @@ async def process_a2a_request(
return EventSourceResponse(event_generator())
# Otherwise, return as JSONResponse
logger.info("Returning standard JSON response")
if hasattr(result, "model_dump"):
return JSONResponse(result.model_dump(exclude_none=True))
return JSONResponse(result)
except Exception as e:
logger.error(f"Error processing A2A request: {e}")
import traceback
logger.error(f"Full traceback: {traceback.format_exc()}")
return JSONResponse(
status_code=500,
content={

View File

@ -28,6 +28,7 @@
"""
import uuid
import base64
from fastapi import (
APIRouter,
Depends,
@ -47,7 +48,7 @@ from src.core.jwt_middleware import (
from src.services import (
agent_service,
)
from src.schemas.chat import ChatRequest, ChatResponse, ErrorResponse
from src.schemas.chat import ChatRequest, ChatResponse, ErrorResponse, FileData
from src.services.agent_runner import run_agent, run_agent_stream
from src.core.exceptions import AgentNotFoundError
from src.services.service_providers import (
@ -59,7 +60,7 @@ from src.services.service_providers import (
from datetime import datetime
import logging
import json
from typing import Optional, Dict
from typing import Optional, Dict, List, Any
logger = logging.getLogger(__name__)
@ -195,6 +196,29 @@ async def websocket_chat(
if not message:
continue
files = None
if data.get("files") and isinstance(data.get("files"), list):
try:
files = []
for file_data in data.get("files"):
if (
isinstance(file_data, dict)
and file_data.get("filename")
and file_data.get("content_type")
and file_data.get("data")
):
files.append(
FileData(
filename=file_data.get("filename"),
content_type=file_data.get("content_type"),
data=file_data.get("data"),
)
)
logger.info(f"Processed {len(files)} files via WebSocket")
except Exception as e:
logger.error(f"Error processing files: {str(e)}")
files = None
async for chunk in run_agent_stream(
agent_id=agent_id,
external_id=external_id,
@ -203,6 +227,7 @@ async def websocket_chat(
artifacts_service=artifacts_service,
memory_service=memory_service,
db=db,
files=files,
):
await websocket.send_json(
{"message": json.loads(chunk), "turn_complete": False}
@ -259,6 +284,7 @@ async def chat(
artifacts_service,
memory_service,
db,
files=request.files,
)
return {

View File

@ -30,8 +30,9 @@
from fastapi import APIRouter, Depends, HTTPException, status
from sqlalchemy.orm import Session
from src.config.database import get_db
from typing import List
from typing import List, Optional, Dict, Any
import uuid
import base64
from src.core.jwt_middleware import (
get_jwt_token,
verify_user_client,
@ -48,7 +49,7 @@ from src.services.session_service import (
get_sessions_by_agent,
get_sessions_by_client,
)
from src.services.service_providers import session_service
from src.services.service_providers import session_service, artifacts_service
import logging
logger = logging.getLogger(__name__)
@ -118,13 +119,18 @@ async def get_session(
@router.get(
"/{session_id}/messages",
response_model=List[Event],
)
async def get_agent_messages(
session_id: str,
db: Session = Depends(get_db),
payload: dict = Depends(get_jwt_token),
):
"""
Gets messages from a session with embedded artifacts.
This function loads all messages from a session and processes any references
to artifacts, loading them and converting them to base64 for direct use in the frontend.
"""
# Get the session
session = get_session_by_id(session_service, session_id)
if not session:
@ -139,7 +145,160 @@ async def get_agent_messages(
if agent:
await verify_user_client(payload, db, agent.client_id)
return get_session_events(session_service, session_id)
# Parse session ID para obter app_name e user_id
parts = session_id.split("_")
if len(parts) != 2:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid session ID format"
)
user_id, app_name = parts[0], parts[1]
events = get_session_events(session_service, session_id)
processed_events = []
for event in events:
event_dict = event.dict()
def process_dict(d):
if isinstance(d, dict):
for key, value in list(d.items()):
if isinstance(value, bytes):
try:
d[key] = base64.b64encode(value).decode("utf-8")
logger.debug(f"Converted bytes field to base64: {key}")
except Exception as e:
logger.error(f"Error encoding bytes to base64: {str(e)}")
d[key] = None
elif isinstance(value, dict):
process_dict(value)
elif isinstance(value, list):
for item in value:
if isinstance(item, (dict, list)):
process_dict(item)
elif isinstance(d, list):
for i, item in enumerate(d):
if isinstance(item, bytes):
try:
d[i] = base64.b64encode(item).decode("utf-8")
except Exception as e:
logger.error(
f"Error encoding bytes to base64 in list: {str(e)}"
)
d[i] = None
elif isinstance(item, (dict, list)):
process_dict(item)
return d
# Process all event dictionary
event_dict = process_dict(event_dict)
# Process the content parts specifically
if event_dict.get("content") and event_dict["content"].get("parts"):
for part in event_dict["content"]["parts"]:
# Process inlineData if present
if part and part.get("inlineData") and part["inlineData"].get("data"):
# Check if it's already a string or if it's bytes
if isinstance(part["inlineData"]["data"], bytes):
# Convert bytes to base64 string
part["inlineData"]["data"] = base64.b64encode(
part["inlineData"]["data"]
).decode("utf-8")
logger.debug(
f"Converted binary data to base64 in message {event_dict.get('id')}"
)
# Process fileData if present (reference to an artifact)
if part and part.get("fileData") and part["fileData"].get("fileId"):
try:
# Extract the file name from the fileId
file_id = part["fileData"]["fileId"]
# Load the artifact from the artifacts service
artifact = artifacts_service.load_artifact(
app_name=app_name,
user_id=user_id,
session_id=session_id,
filename=file_id,
)
if artifact and hasattr(artifact, "inline_data"):
# Extract the data and MIME type
file_bytes = artifact.inline_data.data
mime_type = artifact.inline_data.mime_type
# Add inlineData with the artifact data
if not part.get("inlineData"):
part["inlineData"] = {}
# Ensure we're sending a base64 string, not bytes
if isinstance(file_bytes, bytes):
try:
part["inlineData"]["data"] = base64.b64encode(
file_bytes
).decode("utf-8")
except Exception as e:
logger.error(
f"Error encoding artifact to base64: {str(e)}"
)
part["inlineData"]["data"] = None
else:
part["inlineData"]["data"] = str(file_bytes)
part["inlineData"]["mimeType"] = mime_type
logger.debug(
f"Loaded artifact {file_id} for message {event_dict.get('id')}"
)
except Exception as e:
logger.error(f"Error loading artifact: {str(e)}")
# Don't interrupt the flow if an artifact fails
# Check artifact_delta in actions
if event_dict.get("actions") and event_dict["actions"].get("artifact_delta"):
artifact_deltas = event_dict["actions"]["artifact_delta"]
for filename, version in artifact_deltas.items():
try:
# Load the artifact
artifact = artifacts_service.load_artifact(
app_name=app_name,
user_id=user_id,
session_id=session_id,
filename=filename,
version=version,
)
if artifact and hasattr(artifact, "inline_data"):
# If the event doesn't have an artifacts section, create it
if "artifacts" not in event_dict:
event_dict["artifacts"] = {}
# Add the artifact to the event's artifacts list
file_bytes = artifact.inline_data.data
mime_type = artifact.inline_data.mime_type
# Ensure the bytes are converted to base64
event_dict["artifacts"][filename] = {
"data": (
base64.b64encode(file_bytes).decode("utf-8")
if isinstance(file_bytes, bytes)
else str(file_bytes)
),
"mimeType": mime_type,
"version": version,
}
logger.debug(
f"Added artifact {filename} (v{version}) to message {event_dict.get('id')}"
)
except Exception as e:
logger.error(
f"Error processing artifact_delta {filename}: {str(e)}"
)
processed_events.append(event_dict)
return processed_events
@router.delete(

View File

@ -27,36 +27,42 @@
"""
from pydantic import BaseModel, Field
from typing import Dict, Any, Optional
from pydantic import BaseModel, Field, validator
from typing import Dict, List, Optional, Any
from datetime import datetime
class FileData(BaseModel):
"""Model to represent file data sent in a chat request."""
filename: str = Field(..., description="File name")
content_type: str = Field(..., description="File content type")
data: str = Field(..., description="File content encoded in base64")
class ChatRequest(BaseModel):
"""Schema for chat requests"""
"""Model to represent a chat request."""
agent_id: str = Field(
..., description="ID of the agent that will process the message"
agent_id: str = Field(..., description="Agent ID to process the message")
external_id: str = Field(..., description="External ID for user identification")
message: str = Field(..., description="User message to the agent")
files: Optional[List[FileData]] = Field(
None, description="List of files attached to the message"
)
external_id: str = Field(
..., description="ID of the external_id that will process the message"
)
message: str = Field(..., description="User message")
class ChatResponse(BaseModel):
"""Schema for chat responses"""
"""Model to represent a chat response."""
response: str = Field(..., description="Agent response")
status: str = Field(..., description="Operation status")
error: Optional[str] = Field(None, description="Error message, if there is one")
timestamp: str = Field(..., description="Timestamp of the response")
response: str = Field(..., description="Response generated by the agent")
message_history: List[Dict[str, Any]] = Field(
default_factory=list, description="Message history"
)
status: str = Field(..., description="Response status (success/error)")
timestamp: str = Field(..., description="Response timestamp")
class ErrorResponse(BaseModel):
"""Schema for error responses"""
"""Model to represent an error response."""
error: str = Field(..., description="Error message")
status_code: int = Field(..., description="HTTP status code of the error")
details: Optional[Dict[str, Any]] = Field(
None, description="Additional error details"
)
detail: str = Field(..., description="Error details")

View File

@ -33,8 +33,11 @@ from collections.abc import AsyncIterable
from typing import Dict, Optional
from uuid import UUID
import json
import base64
import uuid as uuid_pkg
from sqlalchemy.orm import Session
from google.genai.types import Part, Blob
from src.config.settings import settings
from src.services.agent_service import (
@ -76,6 +79,7 @@ from src.schemas.a2a_types import (
AgentAuthentication,
AgentProvider,
)
from src.schemas.chat import FileData
logger = logging.getLogger(__name__)
@ -281,12 +285,29 @@ class A2ATaskManager:
all_messages.append(agent_message)
task_state = self._determine_task_state(result)
artifact = Artifact(parts=agent_message.parts, index=0)
# Create artifacts for any file content
artifacts = []
# First, add the main response as an artifact
artifacts.append(Artifact(parts=agent_message.parts, index=0))
# Also add any files from the message history
for idx, msg in enumerate(all_messages, 1):
for part in msg.parts:
if hasattr(part, "type") and part.type == "file":
artifacts.append(
Artifact(
parts=[part],
index=idx,
name=part.file.name,
description=f"File from message {idx}",
)
)
task = await self.update_store(
task_params.id,
TaskStatus(state=task_state, message=agent_message),
[artifact],
artifacts,
)
await self._update_task_history(
@ -400,6 +421,32 @@ class A2ATaskManager:
final_message = None
# Check for files in the user message and include them as artifacts
user_files = []
for part in request.params.message.parts:
if (
hasattr(part, "type")
and part.type == "file"
and hasattr(part, "file")
):
user_files.append(
Artifact(
parts=[part],
index=0,
name=part.file.name if part.file.name else "file",
description="File from user",
)
)
# Send artifacts for any user files
for artifact in user_files:
yield SendTaskStreamingResponse(
id=request.id,
result=TaskArtifactUpdateEvent(
id=request.params.id, artifact=artifact
),
)
async for chunk in run_agent_stream(
agent_id=str(agent.id),
external_id=external_id,
@ -418,7 +465,48 @@ class A2ATaskManager:
parts = content.get("parts", [])
if parts:
update_message = Message(role=role, parts=parts)
# Modify to handle file parts as well
agent_parts = []
for part in parts:
# Handle different part types
if part.get("type") == "text":
agent_parts.append(part)
full_response += part.get("text", "")
elif part.get("inlineData") and part["inlineData"].get(
"data"
):
# Convert inline data to file part
mime_type = part["inlineData"].get(
"mimeType", "application/octet-stream"
)
file_name = f"file_{uuid_pkg.uuid4().hex}{self._get_extension_from_mime(mime_type)}"
file_part = {
"type": "file",
"file": {
"name": file_name,
"mimeType": mime_type,
"bytes": part["inlineData"]["data"],
},
}
agent_parts.append(file_part)
# Also send as artifact
yield SendTaskStreamingResponse(
id=request.id,
result=TaskArtifactUpdateEvent(
id=request.params.id,
artifact=Artifact(
parts=[file_part],
index=0,
name=file_name,
description=f"Generated {mime_type} file",
),
),
)
if agent_parts:
update_message = Message(role=role, parts=agent_parts)
final_message = update_message
yield SendTaskStreamingResponse(
id=request.id,
@ -431,11 +519,6 @@ class A2ATaskManager:
final=False,
),
)
for part in parts:
if part.get("type") == "text":
full_response += part.get("text", "")
final_message = update_message
except Exception as e:
logger.error(f"Error processing chunk: {e}, chunk: {chunk}")
continue
@ -485,6 +568,29 @@ class A2ATaskManager:
error=InternalError(message=f"Error streaming task process: {str(e)}"),
)
def _get_extension_from_mime(self, mime_type: str) -> str:
"""Get a file extension from MIME type."""
if not mime_type:
return ""
mime_map = {
"image/jpeg": ".jpg",
"image/png": ".png",
"image/gif": ".gif",
"application/pdf": ".pdf",
"text/plain": ".txt",
"text/html": ".html",
"text/csv": ".csv",
"application/json": ".json",
"application/xml": ".xml",
"application/msword": ".doc",
"application/vnd.openxmlformats-officedocument.wordprocessingml.document": ".docx",
"application/vnd.ms-excel": ".xls",
"application/vnd.openxmlformats-officedocument.spreadsheetml.sheet": ".xlsx",
}
return mime_map.get(mime_type, "")
async def update_store(
self,
task_id: str,
@ -514,19 +620,193 @@ class A2ATaskManager:
return task
def _extract_user_query(self, task_params: TaskSendParams) -> str:
"""Extracts the user query from the task parameters."""
"""Extracts the user query from the task parameters and processes any files."""
if not task_params.message or not task_params.message.parts:
raise ValueError("Message or parts are missing in task parameters")
part = task_params.message.parts[0]
if part.type != "text":
raise ValueError("Only text parts are supported")
# Process file parts first
text_parts = []
has_files = False
file_parts = []
return part.text
logger.info(
f"Extracting query from message with {len(task_params.message.parts)} parts"
)
# Extract text parts and file parts separately
for idx, part in enumerate(task_params.message.parts):
logger.info(
f"Processing part {idx+1}, type: {getattr(part, 'type', 'unknown')}"
)
if hasattr(part, "type"):
if part.type == "text":
logger.info(f"Found text part: '{part.text[:50]}...' (truncated)")
text_parts.append(part.text)
elif part.type == "file":
logger.info(
f"Found file part: {getattr(getattr(part, 'file', None), 'name', 'unnamed')}"
)
has_files = True
try:
processed_file = self._process_file_part(
part, task_params.sessionId
)
if processed_file:
file_parts.append(processed_file)
except Exception as e:
logger.error(f"Error processing file part: {e}")
# Continue with other parts even if a file fails
else:
logger.warning(f"Unknown part type: {part.type}")
else:
logger.warning(f"Part has no type attribute: {part}")
# Store the file parts in self for later use
self._last_processed_files = file_parts if file_parts else None
# If we have at least one text part, use that as the query
if text_parts:
final_query = " ".join(text_parts)
logger.info(
f"Final query from text parts: '{final_query[:50]}...' (truncated)"
)
return final_query
# If we only have file parts, create a generic query asking for analysis
elif has_files:
logger.info("No text parts, using generic query for file analysis")
return "Analyze the attached files"
else:
logger.error("No supported content parts found in the message")
raise ValueError("No supported content parts found in the message")
def _process_file_part(self, part, session_id: str):
"""Processes a file part and saves it to the artifact service.
Returns:
dict: Processed file information to pass to agent_runner
"""
if not hasattr(part, "file") or not part.file:
logger.warning("File part missing file data")
return None
file_data = part.file
if not file_data.name:
file_data.name = f"file_{uuid_pkg.uuid4().hex}"
logger.info(f"Processing file {file_data.name} for session {session_id}")
if file_data.bytes:
# Process file data provided as base64 string
try:
# Convert base64 to bytes
logger.info(f"Decoding base64 content for file {file_data.name}")
file_bytes = base64.b64decode(file_data.bytes)
# Determine MIME type based on binary content
mime_type = (
file_data.mimeType if hasattr(file_data, "mimeType") else None
)
if not mime_type or mime_type == "application/octet-stream":
# Detection by byte signature
if file_bytes.startswith(b"\xff\xd8\xff"): # JPEG signature
mime_type = "image/jpeg"
elif file_bytes.startswith(b"\x89PNG\r\n\x1a\n"): # PNG signature
mime_type = "image/png"
elif file_bytes.startswith(b"GIF87a") or file_bytes.startswith(
b"GIF89a"
): # GIF
mime_type = "image/gif"
elif file_bytes.startswith(b"%PDF"): # PDF
mime_type = "application/pdf"
else:
# Fallback to avoid generic type in images
if file_data.name.lower().endswith((".jpg", ".jpeg")):
mime_type = "image/jpeg"
elif file_data.name.lower().endswith(".png"):
mime_type = "image/png"
elif file_data.name.lower().endswith(".gif"):
mime_type = "image/gif"
elif file_data.name.lower().endswith(".pdf"):
mime_type = "application/pdf"
else:
mime_type = "application/octet-stream"
logger.info(
f"Decoded file size: {len(file_bytes)} bytes, MIME type: {mime_type}"
)
# Split session_id to get app_name and user_id
parts = session_id.split("_")
if len(parts) != 2:
user_id = session_id
app_name = "a2a"
else:
user_id, app_name = parts
# Create artifact Part
logger.info(f"Creating artifact Part for file {file_data.name}")
artifact = Part(inline_data=Blob(mime_type=mime_type, data=file_bytes))
# Save to artifact service
logger.info(
f"Saving artifact {file_data.name} to {app_name}/{user_id}/{session_id}"
)
version = artifacts_service.save_artifact(
app_name=app_name,
user_id=user_id,
session_id=session_id,
filename=file_data.name,
artifact=artifact,
)
logger.info(
f"Successfully saved file {file_data.name} (version {version}) for session {session_id}"
)
# Import the FileData model from the chat schema
from src.schemas.chat import FileData
# Create a FileData object instead of a dictionary
# This is compatible with what agent_runner.py expects
return FileData(
filename=file_data.name,
content_type=mime_type,
data=file_data.bytes, # Keep the original base64 format
)
except Exception as e:
logger.error(f"Error processing file data: {str(e)}")
# Log more details about the error to help with debugging
import traceback
logger.error(f"Full traceback: {traceback.format_exc()}")
raise
elif file_data.uri:
# Handling URIs would require additional implementation
# For now, log that we received a URI but can't process it
logger.warning(f"File URI references not yet implemented: {file_data.uri}")
# Future enhancement: fetch the file from the URI and save it
return None
return None
async def _run_agent(self, agent: Agent, query: str, session_id: str) -> dict:
"""Executes the agent to process the user query."""
try:
files = getattr(self, "_last_processed_files", None)
if files:
logger.info(f"Passing {len(files)} files to run_agent")
for file_info in files:
logger.info(
f"File being sent: {file_info.filename} ({file_info.content_type})"
)
else:
logger.info("No files to pass to run_agent")
# We call the same function used in the chat API
return await run_agent(
agent_id=str(agent.id),
@ -536,6 +816,7 @@ class A2ATaskManager:
artifacts_service=artifacts_service,
memory_service=memory_service,
db=self.db,
files=files,
)
except Exception as e:
logger.error(f"Error running agent: {e}")

View File

@ -28,7 +28,7 @@
"""
from google.adk.runners import Runner
from google.genai.types import Content, Part
from google.genai.types import Content, Part, Blob
from google.adk.sessions import DatabaseSessionService
from google.adk.memory import InMemoryMemoryService
from google.adk.artifacts.in_memory_artifact_service import InMemoryArtifactService
@ -42,6 +42,7 @@ import asyncio
import json
from src.utils.otel import get_tracer
from opentelemetry import trace
import base64
logger = setup_logger(__name__)
@ -56,6 +57,7 @@ async def run_agent(
db: Session,
session_id: Optional[str] = None,
timeout: float = 60.0,
files: Optional[list] = None,
):
tracer = get_tracer()
with tracer.start_as_current_span(
@ -65,6 +67,7 @@ async def run_agent(
"external_id": external_id,
"session_id": session_id or f"{external_id}_{agent_id}",
"message": message,
"has_files": files is not None and len(files) > 0,
},
):
exit_stack = None
@ -74,6 +77,9 @@ async def run_agent(
)
logger.info(f"Received message: {message}")
if files and len(files) > 0:
logger.info(f"Received {len(files)} files with message")
get_root_agent = get_agent(db, agent_id)
logger.info(
f"Root agent found: {get_root_agent.name} (type: {get_root_agent.type})"
@ -113,7 +119,63 @@ async def run_agent(
session_id=adk_session_id,
)
content = Content(role="user", parts=[Part(text=message)])
file_parts = []
if files and len(files) > 0:
for file_data in files:
try:
file_bytes = base64.b64decode(file_data.data)
logger.info(f"DEBUG - Processing file: {file_data.filename}")
logger.info(f"DEBUG - File size: {len(file_bytes)} bytes")
logger.info(f"DEBUG - MIME type: '{file_data.content_type}'")
logger.info(f"DEBUG - First 20 bytes: {file_bytes[:20]}")
try:
file_part = Part(
inline_data=Blob(
mime_type=file_data.content_type, data=file_bytes
)
)
logger.info(f"DEBUG - Part created successfully")
except Exception as part_error:
logger.error(
f"DEBUG - Error creating Part: {str(part_error)}"
)
logger.error(
f"DEBUG - Error type: {type(part_error).__name__}"
)
import traceback
logger.error(
f"DEBUG - Stack trace: {traceback.format_exc()}"
)
raise
# Save the file in the ArtifactService
version = artifacts_service.save_artifact(
app_name=agent_id,
user_id=external_id,
session_id=adk_session_id,
filename=file_data.filename,
artifact=file_part,
)
logger.info(
f"Saved file {file_data.filename} as version {version}"
)
# Add the Part to the list of parts for the message content
file_parts.append(file_part)
except Exception as e:
logger.error(
f"Error processing file {file_data.filename}: {str(e)}"
)
# Create the content with the text message and the files
parts = [Part(text=message)]
if file_parts:
parts.extend(file_parts)
content = Content(role="user", parts=parts)
logger.info("Starting agent execution")
final_response_text = "No final response captured."
@ -256,6 +318,7 @@ async def run_agent_stream(
memory_service: InMemoryMemoryService,
db: Session,
session_id: Optional[str] = None,
files: Optional[list] = None,
) -> AsyncGenerator[str, None]:
tracer = get_tracer()
span = tracer.start_span(
@ -265,6 +328,7 @@ async def run_agent_stream(
"external_id": external_id,
"session_id": session_id or f"{external_id}_{agent_id}",
"message": message,
"has_files": files is not None and len(files) > 0,
},
)
try:
@ -275,6 +339,9 @@ async def run_agent_stream(
)
logger.info(f"Received message: {message}")
if files and len(files) > 0:
logger.info(f"Received {len(files)} files with message")
get_root_agent = get_agent(db, agent_id)
logger.info(
f"Root agent found: {get_root_agent.name} (type: {get_root_agent.type})"
@ -314,7 +381,72 @@ async def run_agent_stream(
session_id=adk_session_id,
)
content = Content(role="user", parts=[Part(text=message)])
# Process the received files
file_parts = []
if files and len(files) > 0:
for file_data in files:
try:
# Decode the base64 file
file_bytes = base64.b64decode(file_data.data)
# Detailed debug
logger.info(
f"DEBUG - Processing file: {file_data.filename}"
)
logger.info(f"DEBUG - File size: {len(file_bytes)} bytes")
logger.info(
f"DEBUG - MIME type: '{file_data.content_type}'"
)
logger.info(f"DEBUG - First 20 bytes: {file_bytes[:20]}")
# Create a Part for the file using the default constructor
try:
file_part = Part(
inline_data=Blob(
mime_type=file_data.content_type,
data=file_bytes,
)
)
logger.info(f"DEBUG - Part created successfully")
except Exception as part_error:
logger.error(
f"DEBUG - Error creating Part: {str(part_error)}"
)
logger.error(
f"DEBUG - Error type: {type(part_error).__name__}"
)
import traceback
logger.error(
f"DEBUG - Stack trace: {traceback.format_exc()}"
)
raise
# Save the file in the ArtifactService
version = artifacts_service.save_artifact(
app_name=agent_id,
user_id=external_id,
session_id=adk_session_id,
filename=file_data.filename,
artifact=file_part,
)
logger.info(
f"Saved file {file_data.filename} as version {version}"
)
# Add the Part to the list of parts for the message content
file_parts.append(file_part)
except Exception as e:
logger.error(
f"Error processing file {file_data.filename}: {str(e)}"
)
# Create the content with the text message and the files
parts = [Part(text=message)]
if file_parts:
parts.extend(file_parts)
content = Content(role="user", parts=parts)
logger.info("Starting agent streaming execution")
try:

View File

@ -27,10 +27,16 @@
"""
import base64
import uuid
from typing import Dict, List, Any, Optional
from google.genai.types import Part, Blob
from src.schemas.a2a_types import (
ContentTypeNotSupportedError,
JSONRPCResponse,
UnsupportedOperationError,
Message,
)
@ -55,3 +61,130 @@ def new_incompatible_types_error(request_id):
def new_not_implemented_error(request_id):
return JSONRPCResponse(id=request_id, error=UnsupportedOperationError())
def extract_files_from_message(message: Message) -> List[Dict[str, Any]]:
"""
Extract file parts from an A2A message.
Args:
message: An A2A Message object
Returns:
List of file parts extracted from the message
"""
if not message or not message.parts:
return []
files = []
for part in message.parts:
if hasattr(part, "type") and part.type == "file" and hasattr(part, "file"):
files.append(part)
return files
def a2a_part_to_adk_part(a2a_part: Dict[str, Any]) -> Optional[Part]:
"""
Convert an A2A protocol part to an ADK Part object.
Args:
a2a_part: An A2A part dictionary
Returns:
Converted ADK Part object or None if conversion not possible
"""
part_type = a2a_part.get("type")
if part_type == "file" and "file" in a2a_part:
file_data = a2a_part["file"]
if "bytes" in file_data:
try:
# Convert base64 to bytes
file_bytes = base64.b64decode(file_data["bytes"])
mime_type = file_data.get("mimeType", "application/octet-stream")
# Create ADK Part
return Part(inline_data=Blob(mime_type=mime_type, data=file_bytes))
except Exception:
return None
elif part_type == "text" and "text" in a2a_part:
# For text parts, we could create a text blob if needed
return None
return None
def adk_part_to_a2a_part(
adk_part: Part, filename: Optional[str] = None
) -> Optional[Dict[str, Any]]:
"""
Convert an ADK Part object to an A2A protocol part.
Args:
adk_part: An ADK Part object
filename: Optional filename to use
Returns:
Converted A2A Part dictionary or None if conversion not possible
"""
if hasattr(adk_part, "inline_data") and adk_part.inline_data:
if adk_part.inline_data.data and adk_part.inline_data.mime_type:
# Convert binary data to base64
file_bytes = adk_part.inline_data.data
mime_type = adk_part.inline_data.mime_type
# Generate filename if not provided
if not filename:
ext = get_extension_from_mime(mime_type)
filename = f"file_{uuid.uuid4().hex}{ext}"
# Convert to A2A FilePart dict
return {
"type": "file",
"file": {
"name": filename,
"mimeType": mime_type,
"bytes": (
base64.b64encode(file_bytes).decode("utf-8")
if isinstance(file_bytes, bytes)
else str(file_bytes)
),
},
}
elif hasattr(adk_part, "text") and adk_part.text:
# Convert text part
return {"type": "text", "text": adk_part.text}
return None
def get_extension_from_mime(mime_type: str) -> str:
"""
Get a file extension from MIME type.
Args:
mime_type: MIME type string
Returns:
Appropriate file extension with leading dot
"""
if not mime_type:
return ""
mime_map = {
"image/jpeg": ".jpg",
"image/png": ".png",
"image/gif": ".gif",
"application/pdf": ".pdf",
"text/plain": ".txt",
"text/html": ".html",
"text/csv": ".csv",
"application/json": ".json",
"application/xml": ".xml",
"application/msword": ".doc",
"application/vnd.openxmlformats-officedocument.wordprocessingml.document": ".docx",
"application/vnd.ms-excel": ".xls",
"application/vnd.openxmlformats-officedocument.spreadsheetml.sheet": ".xlsx",
}
return mime_map.get(mime_type, "")