feat(a2a): add file support and multimodal content processing for A2A protocol

This commit is contained in:
Davidson Gomes
2025-05-14 22:15:08 -03:00
parent 958eeec4a6
commit 6bf0ea52e0
8 changed files with 869 additions and 43 deletions

View File

@@ -33,8 +33,11 @@ from collections.abc import AsyncIterable
from typing import Dict, Optional
from uuid import UUID
import json
import base64
import uuid as uuid_pkg
from sqlalchemy.orm import Session
from google.genai.types import Part, Blob
from src.config.settings import settings
from src.services.agent_service import (
@@ -76,6 +79,7 @@ from src.schemas.a2a_types import (
AgentAuthentication,
AgentProvider,
)
from src.schemas.chat import FileData
logger = logging.getLogger(__name__)
@@ -281,12 +285,29 @@ class A2ATaskManager:
all_messages.append(agent_message)
task_state = self._determine_task_state(result)
artifact = Artifact(parts=agent_message.parts, index=0)
# Create artifacts for any file content
artifacts = []
# First, add the main response as an artifact
artifacts.append(Artifact(parts=agent_message.parts, index=0))
# Also add any files from the message history
for idx, msg in enumerate(all_messages, 1):
for part in msg.parts:
if hasattr(part, "type") and part.type == "file":
artifacts.append(
Artifact(
parts=[part],
index=idx,
name=part.file.name,
description=f"File from message {idx}",
)
)
task = await self.update_store(
task_params.id,
TaskStatus(state=task_state, message=agent_message),
[artifact],
artifacts,
)
await self._update_task_history(
@@ -400,6 +421,32 @@ class A2ATaskManager:
final_message = None
# Check for files in the user message and include them as artifacts
user_files = []
for part in request.params.message.parts:
if (
hasattr(part, "type")
and part.type == "file"
and hasattr(part, "file")
):
user_files.append(
Artifact(
parts=[part],
index=0,
name=part.file.name if part.file.name else "file",
description="File from user",
)
)
# Send artifacts for any user files
for artifact in user_files:
yield SendTaskStreamingResponse(
id=request.id,
result=TaskArtifactUpdateEvent(
id=request.params.id, artifact=artifact
),
)
async for chunk in run_agent_stream(
agent_id=str(agent.id),
external_id=external_id,
@@ -418,7 +465,48 @@ class A2ATaskManager:
parts = content.get("parts", [])
if parts:
update_message = Message(role=role, parts=parts)
# Modify to handle file parts as well
agent_parts = []
for part in parts:
# Handle different part types
if part.get("type") == "text":
agent_parts.append(part)
full_response += part.get("text", "")
elif part.get("inlineData") and part["inlineData"].get(
"data"
):
# Convert inline data to file part
mime_type = part["inlineData"].get(
"mimeType", "application/octet-stream"
)
file_name = f"file_{uuid_pkg.uuid4().hex}{self._get_extension_from_mime(mime_type)}"
file_part = {
"type": "file",
"file": {
"name": file_name,
"mimeType": mime_type,
"bytes": part["inlineData"]["data"],
},
}
agent_parts.append(file_part)
# Also send as artifact
yield SendTaskStreamingResponse(
id=request.id,
result=TaskArtifactUpdateEvent(
id=request.params.id,
artifact=Artifact(
parts=[file_part],
index=0,
name=file_name,
description=f"Generated {mime_type} file",
),
),
)
if agent_parts:
update_message = Message(role=role, parts=agent_parts)
final_message = update_message
yield SendTaskStreamingResponse(
id=request.id,
@@ -431,11 +519,6 @@ class A2ATaskManager:
final=False,
),
)
for part in parts:
if part.get("type") == "text":
full_response += part.get("text", "")
final_message = update_message
except Exception as e:
logger.error(f"Error processing chunk: {e}, chunk: {chunk}")
continue
@@ -485,6 +568,29 @@ class A2ATaskManager:
error=InternalError(message=f"Error streaming task process: {str(e)}"),
)
def _get_extension_from_mime(self, mime_type: str) -> str:
"""Get a file extension from MIME type."""
if not mime_type:
return ""
mime_map = {
"image/jpeg": ".jpg",
"image/png": ".png",
"image/gif": ".gif",
"application/pdf": ".pdf",
"text/plain": ".txt",
"text/html": ".html",
"text/csv": ".csv",
"application/json": ".json",
"application/xml": ".xml",
"application/msword": ".doc",
"application/vnd.openxmlformats-officedocument.wordprocessingml.document": ".docx",
"application/vnd.ms-excel": ".xls",
"application/vnd.openxmlformats-officedocument.spreadsheetml.sheet": ".xlsx",
}
return mime_map.get(mime_type, "")
async def update_store(
self,
task_id: str,
@@ -514,19 +620,193 @@ class A2ATaskManager:
return task
def _extract_user_query(self, task_params: TaskSendParams) -> str:
"""Extracts the user query from the task parameters."""
"""Extracts the user query from the task parameters and processes any files."""
if not task_params.message or not task_params.message.parts:
raise ValueError("Message or parts are missing in task parameters")
part = task_params.message.parts[0]
if part.type != "text":
raise ValueError("Only text parts are supported")
# Process file parts first
text_parts = []
has_files = False
file_parts = []
return part.text
logger.info(
f"Extracting query from message with {len(task_params.message.parts)} parts"
)
# Extract text parts and file parts separately
for idx, part in enumerate(task_params.message.parts):
logger.info(
f"Processing part {idx+1}, type: {getattr(part, 'type', 'unknown')}"
)
if hasattr(part, "type"):
if part.type == "text":
logger.info(f"Found text part: '{part.text[:50]}...' (truncated)")
text_parts.append(part.text)
elif part.type == "file":
logger.info(
f"Found file part: {getattr(getattr(part, 'file', None), 'name', 'unnamed')}"
)
has_files = True
try:
processed_file = self._process_file_part(
part, task_params.sessionId
)
if processed_file:
file_parts.append(processed_file)
except Exception as e:
logger.error(f"Error processing file part: {e}")
# Continue with other parts even if a file fails
else:
logger.warning(f"Unknown part type: {part.type}")
else:
logger.warning(f"Part has no type attribute: {part}")
# Store the file parts in self for later use
self._last_processed_files = file_parts if file_parts else None
# If we have at least one text part, use that as the query
if text_parts:
final_query = " ".join(text_parts)
logger.info(
f"Final query from text parts: '{final_query[:50]}...' (truncated)"
)
return final_query
# If we only have file parts, create a generic query asking for analysis
elif has_files:
logger.info("No text parts, using generic query for file analysis")
return "Analyze the attached files"
else:
logger.error("No supported content parts found in the message")
raise ValueError("No supported content parts found in the message")
def _process_file_part(self, part, session_id: str):
"""Processes a file part and saves it to the artifact service.
Returns:
dict: Processed file information to pass to agent_runner
"""
if not hasattr(part, "file") or not part.file:
logger.warning("File part missing file data")
return None
file_data = part.file
if not file_data.name:
file_data.name = f"file_{uuid_pkg.uuid4().hex}"
logger.info(f"Processing file {file_data.name} for session {session_id}")
if file_data.bytes:
# Process file data provided as base64 string
try:
# Convert base64 to bytes
logger.info(f"Decoding base64 content for file {file_data.name}")
file_bytes = base64.b64decode(file_data.bytes)
# Determine MIME type based on binary content
mime_type = (
file_data.mimeType if hasattr(file_data, "mimeType") else None
)
if not mime_type or mime_type == "application/octet-stream":
# Detection by byte signature
if file_bytes.startswith(b"\xff\xd8\xff"): # JPEG signature
mime_type = "image/jpeg"
elif file_bytes.startswith(b"\x89PNG\r\n\x1a\n"): # PNG signature
mime_type = "image/png"
elif file_bytes.startswith(b"GIF87a") or file_bytes.startswith(
b"GIF89a"
): # GIF
mime_type = "image/gif"
elif file_bytes.startswith(b"%PDF"): # PDF
mime_type = "application/pdf"
else:
# Fallback to avoid generic type in images
if file_data.name.lower().endswith((".jpg", ".jpeg")):
mime_type = "image/jpeg"
elif file_data.name.lower().endswith(".png"):
mime_type = "image/png"
elif file_data.name.lower().endswith(".gif"):
mime_type = "image/gif"
elif file_data.name.lower().endswith(".pdf"):
mime_type = "application/pdf"
else:
mime_type = "application/octet-stream"
logger.info(
f"Decoded file size: {len(file_bytes)} bytes, MIME type: {mime_type}"
)
# Split session_id to get app_name and user_id
parts = session_id.split("_")
if len(parts) != 2:
user_id = session_id
app_name = "a2a"
else:
user_id, app_name = parts
# Create artifact Part
logger.info(f"Creating artifact Part for file {file_data.name}")
artifact = Part(inline_data=Blob(mime_type=mime_type, data=file_bytes))
# Save to artifact service
logger.info(
f"Saving artifact {file_data.name} to {app_name}/{user_id}/{session_id}"
)
version = artifacts_service.save_artifact(
app_name=app_name,
user_id=user_id,
session_id=session_id,
filename=file_data.name,
artifact=artifact,
)
logger.info(
f"Successfully saved file {file_data.name} (version {version}) for session {session_id}"
)
# Import the FileData model from the chat schema
from src.schemas.chat import FileData
# Create a FileData object instead of a dictionary
# This is compatible with what agent_runner.py expects
return FileData(
filename=file_data.name,
content_type=mime_type,
data=file_data.bytes, # Keep the original base64 format
)
except Exception as e:
logger.error(f"Error processing file data: {str(e)}")
# Log more details about the error to help with debugging
import traceback
logger.error(f"Full traceback: {traceback.format_exc()}")
raise
elif file_data.uri:
# Handling URIs would require additional implementation
# For now, log that we received a URI but can't process it
logger.warning(f"File URI references not yet implemented: {file_data.uri}")
# Future enhancement: fetch the file from the URI and save it
return None
return None
async def _run_agent(self, agent: Agent, query: str, session_id: str) -> dict:
"""Executes the agent to process the user query."""
try:
files = getattr(self, "_last_processed_files", None)
if files:
logger.info(f"Passing {len(files)} files to run_agent")
for file_info in files:
logger.info(
f"File being sent: {file_info.filename} ({file_info.content_type})"
)
else:
logger.info("No files to pass to run_agent")
# We call the same function used in the chat API
return await run_agent(
agent_id=str(agent.id),
@@ -536,6 +816,7 @@ class A2ATaskManager:
artifacts_service=artifacts_service,
memory_service=memory_service,
db=self.db,
files=files,
)
except Exception as e:
logger.error(f"Error running agent: {e}")

View File

@@ -28,7 +28,7 @@
"""
from google.adk.runners import Runner
from google.genai.types import Content, Part
from google.genai.types import Content, Part, Blob
from google.adk.sessions import DatabaseSessionService
from google.adk.memory import InMemoryMemoryService
from google.adk.artifacts.in_memory_artifact_service import InMemoryArtifactService
@@ -42,6 +42,7 @@ import asyncio
import json
from src.utils.otel import get_tracer
from opentelemetry import trace
import base64
logger = setup_logger(__name__)
@@ -56,6 +57,7 @@ async def run_agent(
db: Session,
session_id: Optional[str] = None,
timeout: float = 60.0,
files: Optional[list] = None,
):
tracer = get_tracer()
with tracer.start_as_current_span(
@@ -65,6 +67,7 @@ async def run_agent(
"external_id": external_id,
"session_id": session_id or f"{external_id}_{agent_id}",
"message": message,
"has_files": files is not None and len(files) > 0,
},
):
exit_stack = None
@@ -74,6 +77,9 @@ async def run_agent(
)
logger.info(f"Received message: {message}")
if files and len(files) > 0:
logger.info(f"Received {len(files)} files with message")
get_root_agent = get_agent(db, agent_id)
logger.info(
f"Root agent found: {get_root_agent.name} (type: {get_root_agent.type})"
@@ -113,7 +119,63 @@ async def run_agent(
session_id=adk_session_id,
)
content = Content(role="user", parts=[Part(text=message)])
file_parts = []
if files and len(files) > 0:
for file_data in files:
try:
file_bytes = base64.b64decode(file_data.data)
logger.info(f"DEBUG - Processing file: {file_data.filename}")
logger.info(f"DEBUG - File size: {len(file_bytes)} bytes")
logger.info(f"DEBUG - MIME type: '{file_data.content_type}'")
logger.info(f"DEBUG - First 20 bytes: {file_bytes[:20]}")
try:
file_part = Part(
inline_data=Blob(
mime_type=file_data.content_type, data=file_bytes
)
)
logger.info(f"DEBUG - Part created successfully")
except Exception as part_error:
logger.error(
f"DEBUG - Error creating Part: {str(part_error)}"
)
logger.error(
f"DEBUG - Error type: {type(part_error).__name__}"
)
import traceback
logger.error(
f"DEBUG - Stack trace: {traceback.format_exc()}"
)
raise
# Save the file in the ArtifactService
version = artifacts_service.save_artifact(
app_name=agent_id,
user_id=external_id,
session_id=adk_session_id,
filename=file_data.filename,
artifact=file_part,
)
logger.info(
f"Saved file {file_data.filename} as version {version}"
)
# Add the Part to the list of parts for the message content
file_parts.append(file_part)
except Exception as e:
logger.error(
f"Error processing file {file_data.filename}: {str(e)}"
)
# Create the content with the text message and the files
parts = [Part(text=message)]
if file_parts:
parts.extend(file_parts)
content = Content(role="user", parts=parts)
logger.info("Starting agent execution")
final_response_text = "No final response captured."
@@ -256,6 +318,7 @@ async def run_agent_stream(
memory_service: InMemoryMemoryService,
db: Session,
session_id: Optional[str] = None,
files: Optional[list] = None,
) -> AsyncGenerator[str, None]:
tracer = get_tracer()
span = tracer.start_span(
@@ -265,6 +328,7 @@ async def run_agent_stream(
"external_id": external_id,
"session_id": session_id or f"{external_id}_{agent_id}",
"message": message,
"has_files": files is not None and len(files) > 0,
},
)
try:
@@ -275,6 +339,9 @@ async def run_agent_stream(
)
logger.info(f"Received message: {message}")
if files and len(files) > 0:
logger.info(f"Received {len(files)} files with message")
get_root_agent = get_agent(db, agent_id)
logger.info(
f"Root agent found: {get_root_agent.name} (type: {get_root_agent.type})"
@@ -314,7 +381,72 @@ async def run_agent_stream(
session_id=adk_session_id,
)
content = Content(role="user", parts=[Part(text=message)])
# Process the received files
file_parts = []
if files and len(files) > 0:
for file_data in files:
try:
# Decode the base64 file
file_bytes = base64.b64decode(file_data.data)
# Detailed debug
logger.info(
f"DEBUG - Processing file: {file_data.filename}"
)
logger.info(f"DEBUG - File size: {len(file_bytes)} bytes")
logger.info(
f"DEBUG - MIME type: '{file_data.content_type}'"
)
logger.info(f"DEBUG - First 20 bytes: {file_bytes[:20]}")
# Create a Part for the file using the default constructor
try:
file_part = Part(
inline_data=Blob(
mime_type=file_data.content_type,
data=file_bytes,
)
)
logger.info(f"DEBUG - Part created successfully")
except Exception as part_error:
logger.error(
f"DEBUG - Error creating Part: {str(part_error)}"
)
logger.error(
f"DEBUG - Error type: {type(part_error).__name__}"
)
import traceback
logger.error(
f"DEBUG - Stack trace: {traceback.format_exc()}"
)
raise
# Save the file in the ArtifactService
version = artifacts_service.save_artifact(
app_name=agent_id,
user_id=external_id,
session_id=adk_session_id,
filename=file_data.filename,
artifact=file_part,
)
logger.info(
f"Saved file {file_data.filename} as version {version}"
)
# Add the Part to the list of parts for the message content
file_parts.append(file_part)
except Exception as e:
logger.error(
f"Error processing file {file_data.filename}: {str(e)}"
)
# Create the content with the text message and the files
parts = [Part(text=message)]
if file_parts:
parts.extend(file_parts)
content = Content(role="user", parts=parts)
logger.info("Starting agent streaming execution")
try: