feat(a2a): add file support and multimodal content processing for A2A protocol
This commit is contained in:
@@ -33,8 +33,11 @@ from collections.abc import AsyncIterable
|
||||
from typing import Dict, Optional
|
||||
from uuid import UUID
|
||||
import json
|
||||
import base64
|
||||
import uuid as uuid_pkg
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
from google.genai.types import Part, Blob
|
||||
|
||||
from src.config.settings import settings
|
||||
from src.services.agent_service import (
|
||||
@@ -76,6 +79,7 @@ from src.schemas.a2a_types import (
|
||||
AgentAuthentication,
|
||||
AgentProvider,
|
||||
)
|
||||
from src.schemas.chat import FileData
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -281,12 +285,29 @@ class A2ATaskManager:
|
||||
all_messages.append(agent_message)
|
||||
|
||||
task_state = self._determine_task_state(result)
|
||||
artifact = Artifact(parts=agent_message.parts, index=0)
|
||||
|
||||
# Create artifacts for any file content
|
||||
artifacts = []
|
||||
# First, add the main response as an artifact
|
||||
artifacts.append(Artifact(parts=agent_message.parts, index=0))
|
||||
|
||||
# Also add any files from the message history
|
||||
for idx, msg in enumerate(all_messages, 1):
|
||||
for part in msg.parts:
|
||||
if hasattr(part, "type") and part.type == "file":
|
||||
artifacts.append(
|
||||
Artifact(
|
||||
parts=[part],
|
||||
index=idx,
|
||||
name=part.file.name,
|
||||
description=f"File from message {idx}",
|
||||
)
|
||||
)
|
||||
|
||||
task = await self.update_store(
|
||||
task_params.id,
|
||||
TaskStatus(state=task_state, message=agent_message),
|
||||
[artifact],
|
||||
artifacts,
|
||||
)
|
||||
|
||||
await self._update_task_history(
|
||||
@@ -400,6 +421,32 @@ class A2ATaskManager:
|
||||
|
||||
final_message = None
|
||||
|
||||
# Check for files in the user message and include them as artifacts
|
||||
user_files = []
|
||||
for part in request.params.message.parts:
|
||||
if (
|
||||
hasattr(part, "type")
|
||||
and part.type == "file"
|
||||
and hasattr(part, "file")
|
||||
):
|
||||
user_files.append(
|
||||
Artifact(
|
||||
parts=[part],
|
||||
index=0,
|
||||
name=part.file.name if part.file.name else "file",
|
||||
description="File from user",
|
||||
)
|
||||
)
|
||||
|
||||
# Send artifacts for any user files
|
||||
for artifact in user_files:
|
||||
yield SendTaskStreamingResponse(
|
||||
id=request.id,
|
||||
result=TaskArtifactUpdateEvent(
|
||||
id=request.params.id, artifact=artifact
|
||||
),
|
||||
)
|
||||
|
||||
async for chunk in run_agent_stream(
|
||||
agent_id=str(agent.id),
|
||||
external_id=external_id,
|
||||
@@ -418,7 +465,48 @@ class A2ATaskManager:
|
||||
parts = content.get("parts", [])
|
||||
|
||||
if parts:
|
||||
update_message = Message(role=role, parts=parts)
|
||||
# Modify to handle file parts as well
|
||||
agent_parts = []
|
||||
for part in parts:
|
||||
# Handle different part types
|
||||
if part.get("type") == "text":
|
||||
agent_parts.append(part)
|
||||
full_response += part.get("text", "")
|
||||
elif part.get("inlineData") and part["inlineData"].get(
|
||||
"data"
|
||||
):
|
||||
# Convert inline data to file part
|
||||
mime_type = part["inlineData"].get(
|
||||
"mimeType", "application/octet-stream"
|
||||
)
|
||||
file_name = f"file_{uuid_pkg.uuid4().hex}{self._get_extension_from_mime(mime_type)}"
|
||||
file_part = {
|
||||
"type": "file",
|
||||
"file": {
|
||||
"name": file_name,
|
||||
"mimeType": mime_type,
|
||||
"bytes": part["inlineData"]["data"],
|
||||
},
|
||||
}
|
||||
agent_parts.append(file_part)
|
||||
|
||||
# Also send as artifact
|
||||
yield SendTaskStreamingResponse(
|
||||
id=request.id,
|
||||
result=TaskArtifactUpdateEvent(
|
||||
id=request.params.id,
|
||||
artifact=Artifact(
|
||||
parts=[file_part],
|
||||
index=0,
|
||||
name=file_name,
|
||||
description=f"Generated {mime_type} file",
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
if agent_parts:
|
||||
update_message = Message(role=role, parts=agent_parts)
|
||||
final_message = update_message
|
||||
|
||||
yield SendTaskStreamingResponse(
|
||||
id=request.id,
|
||||
@@ -431,11 +519,6 @@ class A2ATaskManager:
|
||||
final=False,
|
||||
),
|
||||
)
|
||||
|
||||
for part in parts:
|
||||
if part.get("type") == "text":
|
||||
full_response += part.get("text", "")
|
||||
final_message = update_message
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing chunk: {e}, chunk: {chunk}")
|
||||
continue
|
||||
@@ -485,6 +568,29 @@ class A2ATaskManager:
|
||||
error=InternalError(message=f"Error streaming task process: {str(e)}"),
|
||||
)
|
||||
|
||||
def _get_extension_from_mime(self, mime_type: str) -> str:
|
||||
"""Get a file extension from MIME type."""
|
||||
if not mime_type:
|
||||
return ""
|
||||
|
||||
mime_map = {
|
||||
"image/jpeg": ".jpg",
|
||||
"image/png": ".png",
|
||||
"image/gif": ".gif",
|
||||
"application/pdf": ".pdf",
|
||||
"text/plain": ".txt",
|
||||
"text/html": ".html",
|
||||
"text/csv": ".csv",
|
||||
"application/json": ".json",
|
||||
"application/xml": ".xml",
|
||||
"application/msword": ".doc",
|
||||
"application/vnd.openxmlformats-officedocument.wordprocessingml.document": ".docx",
|
||||
"application/vnd.ms-excel": ".xls",
|
||||
"application/vnd.openxmlformats-officedocument.spreadsheetml.sheet": ".xlsx",
|
||||
}
|
||||
|
||||
return mime_map.get(mime_type, "")
|
||||
|
||||
async def update_store(
|
||||
self,
|
||||
task_id: str,
|
||||
@@ -514,19 +620,193 @@ class A2ATaskManager:
|
||||
return task
|
||||
|
||||
def _extract_user_query(self, task_params: TaskSendParams) -> str:
|
||||
"""Extracts the user query from the task parameters."""
|
||||
"""Extracts the user query from the task parameters and processes any files."""
|
||||
if not task_params.message or not task_params.message.parts:
|
||||
raise ValueError("Message or parts are missing in task parameters")
|
||||
|
||||
part = task_params.message.parts[0]
|
||||
if part.type != "text":
|
||||
raise ValueError("Only text parts are supported")
|
||||
# Process file parts first
|
||||
text_parts = []
|
||||
has_files = False
|
||||
file_parts = []
|
||||
|
||||
return part.text
|
||||
logger.info(
|
||||
f"Extracting query from message with {len(task_params.message.parts)} parts"
|
||||
)
|
||||
|
||||
# Extract text parts and file parts separately
|
||||
for idx, part in enumerate(task_params.message.parts):
|
||||
logger.info(
|
||||
f"Processing part {idx+1}, type: {getattr(part, 'type', 'unknown')}"
|
||||
)
|
||||
if hasattr(part, "type"):
|
||||
if part.type == "text":
|
||||
logger.info(f"Found text part: '{part.text[:50]}...' (truncated)")
|
||||
text_parts.append(part.text)
|
||||
elif part.type == "file":
|
||||
logger.info(
|
||||
f"Found file part: {getattr(getattr(part, 'file', None), 'name', 'unnamed')}"
|
||||
)
|
||||
has_files = True
|
||||
try:
|
||||
processed_file = self._process_file_part(
|
||||
part, task_params.sessionId
|
||||
)
|
||||
if processed_file:
|
||||
file_parts.append(processed_file)
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing file part: {e}")
|
||||
# Continue with other parts even if a file fails
|
||||
else:
|
||||
logger.warning(f"Unknown part type: {part.type}")
|
||||
else:
|
||||
logger.warning(f"Part has no type attribute: {part}")
|
||||
|
||||
# Store the file parts in self for later use
|
||||
self._last_processed_files = file_parts if file_parts else None
|
||||
|
||||
# If we have at least one text part, use that as the query
|
||||
if text_parts:
|
||||
final_query = " ".join(text_parts)
|
||||
logger.info(
|
||||
f"Final query from text parts: '{final_query[:50]}...' (truncated)"
|
||||
)
|
||||
return final_query
|
||||
# If we only have file parts, create a generic query asking for analysis
|
||||
elif has_files:
|
||||
logger.info("No text parts, using generic query for file analysis")
|
||||
return "Analyze the attached files"
|
||||
else:
|
||||
logger.error("No supported content parts found in the message")
|
||||
raise ValueError("No supported content parts found in the message")
|
||||
|
||||
def _process_file_part(self, part, session_id: str):
|
||||
"""Processes a file part and saves it to the artifact service.
|
||||
|
||||
Returns:
|
||||
dict: Processed file information to pass to agent_runner
|
||||
"""
|
||||
if not hasattr(part, "file") or not part.file:
|
||||
logger.warning("File part missing file data")
|
||||
return None
|
||||
|
||||
file_data = part.file
|
||||
|
||||
if not file_data.name:
|
||||
file_data.name = f"file_{uuid_pkg.uuid4().hex}"
|
||||
|
||||
logger.info(f"Processing file {file_data.name} for session {session_id}")
|
||||
|
||||
if file_data.bytes:
|
||||
# Process file data provided as base64 string
|
||||
try:
|
||||
# Convert base64 to bytes
|
||||
logger.info(f"Decoding base64 content for file {file_data.name}")
|
||||
file_bytes = base64.b64decode(file_data.bytes)
|
||||
|
||||
# Determine MIME type based on binary content
|
||||
mime_type = (
|
||||
file_data.mimeType if hasattr(file_data, "mimeType") else None
|
||||
)
|
||||
|
||||
if not mime_type or mime_type == "application/octet-stream":
|
||||
# Detection by byte signature
|
||||
if file_bytes.startswith(b"\xff\xd8\xff"): # JPEG signature
|
||||
mime_type = "image/jpeg"
|
||||
elif file_bytes.startswith(b"\x89PNG\r\n\x1a\n"): # PNG signature
|
||||
mime_type = "image/png"
|
||||
elif file_bytes.startswith(b"GIF87a") or file_bytes.startswith(
|
||||
b"GIF89a"
|
||||
): # GIF
|
||||
mime_type = "image/gif"
|
||||
elif file_bytes.startswith(b"%PDF"): # PDF
|
||||
mime_type = "application/pdf"
|
||||
else:
|
||||
# Fallback to avoid generic type in images
|
||||
if file_data.name.lower().endswith((".jpg", ".jpeg")):
|
||||
mime_type = "image/jpeg"
|
||||
elif file_data.name.lower().endswith(".png"):
|
||||
mime_type = "image/png"
|
||||
elif file_data.name.lower().endswith(".gif"):
|
||||
mime_type = "image/gif"
|
||||
elif file_data.name.lower().endswith(".pdf"):
|
||||
mime_type = "application/pdf"
|
||||
else:
|
||||
mime_type = "application/octet-stream"
|
||||
|
||||
logger.info(
|
||||
f"Decoded file size: {len(file_bytes)} bytes, MIME type: {mime_type}"
|
||||
)
|
||||
|
||||
# Split session_id to get app_name and user_id
|
||||
parts = session_id.split("_")
|
||||
if len(parts) != 2:
|
||||
user_id = session_id
|
||||
app_name = "a2a"
|
||||
else:
|
||||
user_id, app_name = parts
|
||||
|
||||
# Create artifact Part
|
||||
logger.info(f"Creating artifact Part for file {file_data.name}")
|
||||
artifact = Part(inline_data=Blob(mime_type=mime_type, data=file_bytes))
|
||||
|
||||
# Save to artifact service
|
||||
logger.info(
|
||||
f"Saving artifact {file_data.name} to {app_name}/{user_id}/{session_id}"
|
||||
)
|
||||
version = artifacts_service.save_artifact(
|
||||
app_name=app_name,
|
||||
user_id=user_id,
|
||||
session_id=session_id,
|
||||
filename=file_data.name,
|
||||
artifact=artifact,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Successfully saved file {file_data.name} (version {version}) for session {session_id}"
|
||||
)
|
||||
|
||||
# Import the FileData model from the chat schema
|
||||
from src.schemas.chat import FileData
|
||||
|
||||
# Create a FileData object instead of a dictionary
|
||||
# This is compatible with what agent_runner.py expects
|
||||
return FileData(
|
||||
filename=file_data.name,
|
||||
content_type=mime_type,
|
||||
data=file_data.bytes, # Keep the original base64 format
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing file data: {str(e)}")
|
||||
# Log more details about the error to help with debugging
|
||||
import traceback
|
||||
|
||||
logger.error(f"Full traceback: {traceback.format_exc()}")
|
||||
raise
|
||||
|
||||
elif file_data.uri:
|
||||
# Handling URIs would require additional implementation
|
||||
# For now, log that we received a URI but can't process it
|
||||
logger.warning(f"File URI references not yet implemented: {file_data.uri}")
|
||||
# Future enhancement: fetch the file from the URI and save it
|
||||
return None
|
||||
|
||||
return None
|
||||
|
||||
async def _run_agent(self, agent: Agent, query: str, session_id: str) -> dict:
|
||||
"""Executes the agent to process the user query."""
|
||||
try:
|
||||
files = getattr(self, "_last_processed_files", None)
|
||||
|
||||
if files:
|
||||
logger.info(f"Passing {len(files)} files to run_agent")
|
||||
for file_info in files:
|
||||
logger.info(
|
||||
f"File being sent: {file_info.filename} ({file_info.content_type})"
|
||||
)
|
||||
else:
|
||||
logger.info("No files to pass to run_agent")
|
||||
|
||||
# We call the same function used in the chat API
|
||||
return await run_agent(
|
||||
agent_id=str(agent.id),
|
||||
@@ -536,6 +816,7 @@ class A2ATaskManager:
|
||||
artifacts_service=artifacts_service,
|
||||
memory_service=memory_service,
|
||||
db=self.db,
|
||||
files=files,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error running agent: {e}")
|
||||
|
||||
@@ -28,7 +28,7 @@
|
||||
"""
|
||||
|
||||
from google.adk.runners import Runner
|
||||
from google.genai.types import Content, Part
|
||||
from google.genai.types import Content, Part, Blob
|
||||
from google.adk.sessions import DatabaseSessionService
|
||||
from google.adk.memory import InMemoryMemoryService
|
||||
from google.adk.artifacts.in_memory_artifact_service import InMemoryArtifactService
|
||||
@@ -42,6 +42,7 @@ import asyncio
|
||||
import json
|
||||
from src.utils.otel import get_tracer
|
||||
from opentelemetry import trace
|
||||
import base64
|
||||
|
||||
logger = setup_logger(__name__)
|
||||
|
||||
@@ -56,6 +57,7 @@ async def run_agent(
|
||||
db: Session,
|
||||
session_id: Optional[str] = None,
|
||||
timeout: float = 60.0,
|
||||
files: Optional[list] = None,
|
||||
):
|
||||
tracer = get_tracer()
|
||||
with tracer.start_as_current_span(
|
||||
@@ -65,6 +67,7 @@ async def run_agent(
|
||||
"external_id": external_id,
|
||||
"session_id": session_id or f"{external_id}_{agent_id}",
|
||||
"message": message,
|
||||
"has_files": files is not None and len(files) > 0,
|
||||
},
|
||||
):
|
||||
exit_stack = None
|
||||
@@ -74,6 +77,9 @@ async def run_agent(
|
||||
)
|
||||
logger.info(f"Received message: {message}")
|
||||
|
||||
if files and len(files) > 0:
|
||||
logger.info(f"Received {len(files)} files with message")
|
||||
|
||||
get_root_agent = get_agent(db, agent_id)
|
||||
logger.info(
|
||||
f"Root agent found: {get_root_agent.name} (type: {get_root_agent.type})"
|
||||
@@ -113,7 +119,63 @@ async def run_agent(
|
||||
session_id=adk_session_id,
|
||||
)
|
||||
|
||||
content = Content(role="user", parts=[Part(text=message)])
|
||||
file_parts = []
|
||||
if files and len(files) > 0:
|
||||
for file_data in files:
|
||||
try:
|
||||
file_bytes = base64.b64decode(file_data.data)
|
||||
|
||||
logger.info(f"DEBUG - Processing file: {file_data.filename}")
|
||||
logger.info(f"DEBUG - File size: {len(file_bytes)} bytes")
|
||||
logger.info(f"DEBUG - MIME type: '{file_data.content_type}'")
|
||||
logger.info(f"DEBUG - First 20 bytes: {file_bytes[:20]}")
|
||||
|
||||
try:
|
||||
file_part = Part(
|
||||
inline_data=Blob(
|
||||
mime_type=file_data.content_type, data=file_bytes
|
||||
)
|
||||
)
|
||||
logger.info(f"DEBUG - Part created successfully")
|
||||
except Exception as part_error:
|
||||
logger.error(
|
||||
f"DEBUG - Error creating Part: {str(part_error)}"
|
||||
)
|
||||
logger.error(
|
||||
f"DEBUG - Error type: {type(part_error).__name__}"
|
||||
)
|
||||
import traceback
|
||||
|
||||
logger.error(
|
||||
f"DEBUG - Stack trace: {traceback.format_exc()}"
|
||||
)
|
||||
raise
|
||||
|
||||
# Save the file in the ArtifactService
|
||||
version = artifacts_service.save_artifact(
|
||||
app_name=agent_id,
|
||||
user_id=external_id,
|
||||
session_id=adk_session_id,
|
||||
filename=file_data.filename,
|
||||
artifact=file_part,
|
||||
)
|
||||
logger.info(
|
||||
f"Saved file {file_data.filename} as version {version}"
|
||||
)
|
||||
|
||||
# Add the Part to the list of parts for the message content
|
||||
file_parts.append(file_part)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error processing file {file_data.filename}: {str(e)}"
|
||||
)
|
||||
|
||||
# Create the content with the text message and the files
|
||||
parts = [Part(text=message)]
|
||||
if file_parts:
|
||||
parts.extend(file_parts)
|
||||
|
||||
content = Content(role="user", parts=parts)
|
||||
logger.info("Starting agent execution")
|
||||
|
||||
final_response_text = "No final response captured."
|
||||
@@ -256,6 +318,7 @@ async def run_agent_stream(
|
||||
memory_service: InMemoryMemoryService,
|
||||
db: Session,
|
||||
session_id: Optional[str] = None,
|
||||
files: Optional[list] = None,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
tracer = get_tracer()
|
||||
span = tracer.start_span(
|
||||
@@ -265,6 +328,7 @@ async def run_agent_stream(
|
||||
"external_id": external_id,
|
||||
"session_id": session_id or f"{external_id}_{agent_id}",
|
||||
"message": message,
|
||||
"has_files": files is not None and len(files) > 0,
|
||||
},
|
||||
)
|
||||
try:
|
||||
@@ -275,6 +339,9 @@ async def run_agent_stream(
|
||||
)
|
||||
logger.info(f"Received message: {message}")
|
||||
|
||||
if files and len(files) > 0:
|
||||
logger.info(f"Received {len(files)} files with message")
|
||||
|
||||
get_root_agent = get_agent(db, agent_id)
|
||||
logger.info(
|
||||
f"Root agent found: {get_root_agent.name} (type: {get_root_agent.type})"
|
||||
@@ -314,7 +381,72 @@ async def run_agent_stream(
|
||||
session_id=adk_session_id,
|
||||
)
|
||||
|
||||
content = Content(role="user", parts=[Part(text=message)])
|
||||
# Process the received files
|
||||
file_parts = []
|
||||
if files and len(files) > 0:
|
||||
for file_data in files:
|
||||
try:
|
||||
# Decode the base64 file
|
||||
file_bytes = base64.b64decode(file_data.data)
|
||||
|
||||
# Detailed debug
|
||||
logger.info(
|
||||
f"DEBUG - Processing file: {file_data.filename}"
|
||||
)
|
||||
logger.info(f"DEBUG - File size: {len(file_bytes)} bytes")
|
||||
logger.info(
|
||||
f"DEBUG - MIME type: '{file_data.content_type}'"
|
||||
)
|
||||
logger.info(f"DEBUG - First 20 bytes: {file_bytes[:20]}")
|
||||
|
||||
# Create a Part for the file using the default constructor
|
||||
try:
|
||||
file_part = Part(
|
||||
inline_data=Blob(
|
||||
mime_type=file_data.content_type,
|
||||
data=file_bytes,
|
||||
)
|
||||
)
|
||||
logger.info(f"DEBUG - Part created successfully")
|
||||
except Exception as part_error:
|
||||
logger.error(
|
||||
f"DEBUG - Error creating Part: {str(part_error)}"
|
||||
)
|
||||
logger.error(
|
||||
f"DEBUG - Error type: {type(part_error).__name__}"
|
||||
)
|
||||
import traceback
|
||||
|
||||
logger.error(
|
||||
f"DEBUG - Stack trace: {traceback.format_exc()}"
|
||||
)
|
||||
raise
|
||||
|
||||
# Save the file in the ArtifactService
|
||||
version = artifacts_service.save_artifact(
|
||||
app_name=agent_id,
|
||||
user_id=external_id,
|
||||
session_id=adk_session_id,
|
||||
filename=file_data.filename,
|
||||
artifact=file_part,
|
||||
)
|
||||
logger.info(
|
||||
f"Saved file {file_data.filename} as version {version}"
|
||||
)
|
||||
|
||||
# Add the Part to the list of parts for the message content
|
||||
file_parts.append(file_part)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error processing file {file_data.filename}: {str(e)}"
|
||||
)
|
||||
|
||||
# Create the content with the text message and the files
|
||||
parts = [Part(text=message)]
|
||||
if file_parts:
|
||||
parts.extend(file_parts)
|
||||
|
||||
content = Content(role="user", parts=parts)
|
||||
logger.info("Starting agent streaming execution")
|
||||
|
||||
try:
|
||||
|
||||
Reference in New Issue
Block a user