diff --git a/CHANGELOG.md b/CHANGELOG.md index 389534c3..cdd00c64 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,6 +11,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Add Task Agent for structured single-task execution - Improve context management in agent execution +- Add file support for A2A protocol (Agent-to-Agent) endpoints +- Implement multimodal content processing in A2A messages ## [0.0.9] - 2025-05-13 diff --git a/src/api/a2a_routes.py b/src/api/a2a_routes.py index b3d7d6d4..712a1b2e 100644 --- a/src/api/a2a_routes.py +++ b/src/api/a2a_routes.py @@ -31,6 +31,7 @@ Routes for the A2A (Agent-to-Agent) protocol. This module implements the standard A2A routes according to the specification. +Supports both text messages and file uploads through the message parts mechanism. """ import uuid @@ -92,7 +93,39 @@ async def process_a2a_request( db: Session = Depends(get_db), a2a_service: A2AService = Depends(get_a2a_service), ): - """Processes an A2A request.""" + """ + Processes an A2A request. + + Supports both text messages and file uploads. For file uploads, + include file parts in the message following the A2A protocol format: + + { + "jsonrpc": "2.0", + "id": "request-id", + "method": "tasks/send", + "params": { + "id": "task-id", + "sessionId": "session-id", + "message": { + "role": "user", + "parts": [ + { + "type": "text", + "text": "Analyze this image" + }, + { + "type": "file", + "file": { + "name": "example.jpg", + "mimeType": "image/jpeg", + "bytes": "base64-encoded-content" + } + } + ] + } + } + } + """ # Verify the API key if not verify_api_key(db, x_api_key): raise HTTPException(status_code=401, detail="Invalid API key") @@ -100,10 +133,60 @@ async def process_a2a_request( # Process the request try: request_body = await request.json() + + debug_request_body = {} + if "method" in request_body: + debug_request_body["method"] = request_body["method"] + if "id" in request_body: + debug_request_body["id"] = request_body["id"] + + logger.info(f"A2A request received: {debug_request_body}") + + # Log if request contains file parts for debugging + if isinstance(request_body, dict) and "params" in request_body: + params = request_body.get("params", {}) + message = params.get("message", {}) + parts = message.get("parts", []) + + logger.info(f"A2A message contains {len(parts)} parts") + for i, part in enumerate(parts): + if not isinstance(part, dict): + logger.warning(f"Part {i+1} is not a dictionary: {type(part)}") + continue + + part_type = part.get("type") + logger.info(f"Part {i+1} type: {part_type}") + + if part_type == "file": + file_info = part.get("file", {}) + logger.info( + f"File part found: {file_info.get('name')} ({file_info.get('mimeType')})" + ) + if "bytes" in file_info: + bytes_data = file_info.get("bytes", "") + bytes_size = len(bytes_data) * 0.75 + logger.info(f"File size: ~{bytes_size/1024:.2f} KB") + if bytes_data: + sample = ( + bytes_data[:10] + "..." + if len(bytes_data) > 10 + else bytes_data + ) + logger.info(f"Sample of base64 data: {sample}") + elif part_type == "text": + text_content = part.get("text", "") + preview = ( + text_content[:30] + "..." + if len(text_content) > 30 + else text_content + ) + logger.info(f"Text part found: '{preview}'") + result = await a2a_service.process_request(agent_id, request_body) # If the response is a streaming response, return as EventSourceResponse if hasattr(result, "__aiter__"): + logger.info("Returning streaming response") async def event_generator(): async for item in result: @@ -115,11 +198,15 @@ async def process_a2a_request( return EventSourceResponse(event_generator()) # Otherwise, return as JSONResponse + logger.info("Returning standard JSON response") if hasattr(result, "model_dump"): return JSONResponse(result.model_dump(exclude_none=True)) return JSONResponse(result) except Exception as e: logger.error(f"Error processing A2A request: {e}") + import traceback + + logger.error(f"Full traceback: {traceback.format_exc()}") return JSONResponse( status_code=500, content={ diff --git a/src/api/chat_routes.py b/src/api/chat_routes.py index 8846080b..bf7cd0a2 100644 --- a/src/api/chat_routes.py +++ b/src/api/chat_routes.py @@ -28,6 +28,7 @@ """ import uuid +import base64 from fastapi import ( APIRouter, Depends, @@ -47,7 +48,7 @@ from src.core.jwt_middleware import ( from src.services import ( agent_service, ) -from src.schemas.chat import ChatRequest, ChatResponse, ErrorResponse +from src.schemas.chat import ChatRequest, ChatResponse, ErrorResponse, FileData from src.services.agent_runner import run_agent, run_agent_stream from src.core.exceptions import AgentNotFoundError from src.services.service_providers import ( @@ -59,7 +60,7 @@ from src.services.service_providers import ( from datetime import datetime import logging import json -from typing import Optional, Dict +from typing import Optional, Dict, List, Any logger = logging.getLogger(__name__) @@ -195,6 +196,29 @@ async def websocket_chat( if not message: continue + files = None + if data.get("files") and isinstance(data.get("files"), list): + try: + files = [] + for file_data in data.get("files"): + if ( + isinstance(file_data, dict) + and file_data.get("filename") + and file_data.get("content_type") + and file_data.get("data") + ): + files.append( + FileData( + filename=file_data.get("filename"), + content_type=file_data.get("content_type"), + data=file_data.get("data"), + ) + ) + logger.info(f"Processed {len(files)} files via WebSocket") + except Exception as e: + logger.error(f"Error processing files: {str(e)}") + files = None + async for chunk in run_agent_stream( agent_id=agent_id, external_id=external_id, @@ -203,6 +227,7 @@ async def websocket_chat( artifacts_service=artifacts_service, memory_service=memory_service, db=db, + files=files, ): await websocket.send_json( {"message": json.loads(chunk), "turn_complete": False} @@ -259,6 +284,7 @@ async def chat( artifacts_service, memory_service, db, + files=request.files, ) return { diff --git a/src/api/session_routes.py b/src/api/session_routes.py index 2961293f..500993ae 100644 --- a/src/api/session_routes.py +++ b/src/api/session_routes.py @@ -30,8 +30,9 @@ from fastapi import APIRouter, Depends, HTTPException, status from sqlalchemy.orm import Session from src.config.database import get_db -from typing import List +from typing import List, Optional, Dict, Any import uuid +import base64 from src.core.jwt_middleware import ( get_jwt_token, verify_user_client, @@ -48,7 +49,7 @@ from src.services.session_service import ( get_sessions_by_agent, get_sessions_by_client, ) -from src.services.service_providers import session_service +from src.services.service_providers import session_service, artifacts_service import logging logger = logging.getLogger(__name__) @@ -118,13 +119,18 @@ async def get_session( @router.get( "/{session_id}/messages", - response_model=List[Event], ) async def get_agent_messages( session_id: str, db: Session = Depends(get_db), payload: dict = Depends(get_jwt_token), ): + """ + Gets messages from a session with embedded artifacts. + + This function loads all messages from a session and processes any references + to artifacts, loading them and converting them to base64 for direct use in the frontend. + """ # Get the session session = get_session_by_id(session_service, session_id) if not session: @@ -139,7 +145,160 @@ async def get_agent_messages( if agent: await verify_user_client(payload, db, agent.client_id) - return get_session_events(session_service, session_id) + # Parse session ID para obter app_name e user_id + parts = session_id.split("_") + if len(parts) != 2: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid session ID format" + ) + + user_id, app_name = parts[0], parts[1] + + events = get_session_events(session_service, session_id) + + processed_events = [] + for event in events: + event_dict = event.dict() + + def process_dict(d): + if isinstance(d, dict): + for key, value in list(d.items()): + if isinstance(value, bytes): + try: + d[key] = base64.b64encode(value).decode("utf-8") + logger.debug(f"Converted bytes field to base64: {key}") + except Exception as e: + logger.error(f"Error encoding bytes to base64: {str(e)}") + d[key] = None + elif isinstance(value, dict): + process_dict(value) + elif isinstance(value, list): + for item in value: + if isinstance(item, (dict, list)): + process_dict(item) + elif isinstance(d, list): + for i, item in enumerate(d): + if isinstance(item, bytes): + try: + d[i] = base64.b64encode(item).decode("utf-8") + except Exception as e: + logger.error( + f"Error encoding bytes to base64 in list: {str(e)}" + ) + d[i] = None + elif isinstance(item, (dict, list)): + process_dict(item) + return d + + # Process all event dictionary + event_dict = process_dict(event_dict) + + # Process the content parts specifically + if event_dict.get("content") and event_dict["content"].get("parts"): + for part in event_dict["content"]["parts"]: + # Process inlineData if present + if part and part.get("inlineData") and part["inlineData"].get("data"): + # Check if it's already a string or if it's bytes + if isinstance(part["inlineData"]["data"], bytes): + # Convert bytes to base64 string + part["inlineData"]["data"] = base64.b64encode( + part["inlineData"]["data"] + ).decode("utf-8") + logger.debug( + f"Converted binary data to base64 in message {event_dict.get('id')}" + ) + + # Process fileData if present (reference to an artifact) + if part and part.get("fileData") and part["fileData"].get("fileId"): + try: + # Extract the file name from the fileId + file_id = part["fileData"]["fileId"] + + # Load the artifact from the artifacts service + artifact = artifacts_service.load_artifact( + app_name=app_name, + user_id=user_id, + session_id=session_id, + filename=file_id, + ) + + if artifact and hasattr(artifact, "inline_data"): + # Extract the data and MIME type + file_bytes = artifact.inline_data.data + mime_type = artifact.inline_data.mime_type + + # Add inlineData with the artifact data + if not part.get("inlineData"): + part["inlineData"] = {} + + # Ensure we're sending a base64 string, not bytes + if isinstance(file_bytes, bytes): + try: + part["inlineData"]["data"] = base64.b64encode( + file_bytes + ).decode("utf-8") + except Exception as e: + logger.error( + f"Error encoding artifact to base64: {str(e)}" + ) + part["inlineData"]["data"] = None + else: + part["inlineData"]["data"] = str(file_bytes) + + part["inlineData"]["mimeType"] = mime_type + + logger.debug( + f"Loaded artifact {file_id} for message {event_dict.get('id')}" + ) + except Exception as e: + logger.error(f"Error loading artifact: {str(e)}") + # Don't interrupt the flow if an artifact fails + + # Check artifact_delta in actions + if event_dict.get("actions") and event_dict["actions"].get("artifact_delta"): + artifact_deltas = event_dict["actions"]["artifact_delta"] + for filename, version in artifact_deltas.items(): + try: + # Load the artifact + artifact = artifacts_service.load_artifact( + app_name=app_name, + user_id=user_id, + session_id=session_id, + filename=filename, + version=version, + ) + + if artifact and hasattr(artifact, "inline_data"): + # If the event doesn't have an artifacts section, create it + if "artifacts" not in event_dict: + event_dict["artifacts"] = {} + + # Add the artifact to the event's artifacts list + file_bytes = artifact.inline_data.data + mime_type = artifact.inline_data.mime_type + + # Ensure the bytes are converted to base64 + event_dict["artifacts"][filename] = { + "data": ( + base64.b64encode(file_bytes).decode("utf-8") + if isinstance(file_bytes, bytes) + else str(file_bytes) + ), + "mimeType": mime_type, + "version": version, + } + + logger.debug( + f"Added artifact {filename} (v{version}) to message {event_dict.get('id')}" + ) + except Exception as e: + logger.error( + f"Error processing artifact_delta {filename}: {str(e)}" + ) + + processed_events.append(event_dict) + + return processed_events @router.delete( diff --git a/src/schemas/chat.py b/src/schemas/chat.py index 696860a8..6188ca7a 100644 --- a/src/schemas/chat.py +++ b/src/schemas/chat.py @@ -27,36 +27,42 @@ └──────────────────────────────────────────────────────────────────────────────┘ """ -from pydantic import BaseModel, Field -from typing import Dict, Any, Optional +from pydantic import BaseModel, Field, validator +from typing import Dict, List, Optional, Any +from datetime import datetime + + +class FileData(BaseModel): + """Model to represent file data sent in a chat request.""" + + filename: str = Field(..., description="File name") + content_type: str = Field(..., description="File content type") + data: str = Field(..., description="File content encoded in base64") class ChatRequest(BaseModel): - """Schema for chat requests""" + """Model to represent a chat request.""" - agent_id: str = Field( - ..., description="ID of the agent that will process the message" + agent_id: str = Field(..., description="Agent ID to process the message") + external_id: str = Field(..., description="External ID for user identification") + message: str = Field(..., description="User message to the agent") + files: Optional[List[FileData]] = Field( + None, description="List of files attached to the message" ) - external_id: str = Field( - ..., description="ID of the external_id that will process the message" - ) - message: str = Field(..., description="User message") class ChatResponse(BaseModel): - """Schema for chat responses""" + """Model to represent a chat response.""" - response: str = Field(..., description="Agent response") - status: str = Field(..., description="Operation status") - error: Optional[str] = Field(None, description="Error message, if there is one") - timestamp: str = Field(..., description="Timestamp of the response") + response: str = Field(..., description="Response generated by the agent") + message_history: List[Dict[str, Any]] = Field( + default_factory=list, description="Message history" + ) + status: str = Field(..., description="Response status (success/error)") + timestamp: str = Field(..., description="Response timestamp") class ErrorResponse(BaseModel): - """Schema for error responses""" + """Model to represent an error response.""" - error: str = Field(..., description="Error message") - status_code: int = Field(..., description="HTTP status code of the error") - details: Optional[Dict[str, Any]] = Field( - None, description="Additional error details" - ) + detail: str = Field(..., description="Error details") diff --git a/src/services/a2a_task_manager.py b/src/services/a2a_task_manager.py index ce86676d..d019cfe7 100644 --- a/src/services/a2a_task_manager.py +++ b/src/services/a2a_task_manager.py @@ -33,8 +33,11 @@ from collections.abc import AsyncIterable from typing import Dict, Optional from uuid import UUID import json +import base64 +import uuid as uuid_pkg from sqlalchemy.orm import Session +from google.genai.types import Part, Blob from src.config.settings import settings from src.services.agent_service import ( @@ -76,6 +79,7 @@ from src.schemas.a2a_types import ( AgentAuthentication, AgentProvider, ) +from src.schemas.chat import FileData logger = logging.getLogger(__name__) @@ -281,12 +285,29 @@ class A2ATaskManager: all_messages.append(agent_message) task_state = self._determine_task_state(result) - artifact = Artifact(parts=agent_message.parts, index=0) + + # Create artifacts for any file content + artifacts = [] + # First, add the main response as an artifact + artifacts.append(Artifact(parts=agent_message.parts, index=0)) + + # Also add any files from the message history + for idx, msg in enumerate(all_messages, 1): + for part in msg.parts: + if hasattr(part, "type") and part.type == "file": + artifacts.append( + Artifact( + parts=[part], + index=idx, + name=part.file.name, + description=f"File from message {idx}", + ) + ) task = await self.update_store( task_params.id, TaskStatus(state=task_state, message=agent_message), - [artifact], + artifacts, ) await self._update_task_history( @@ -400,6 +421,32 @@ class A2ATaskManager: final_message = None + # Check for files in the user message and include them as artifacts + user_files = [] + for part in request.params.message.parts: + if ( + hasattr(part, "type") + and part.type == "file" + and hasattr(part, "file") + ): + user_files.append( + Artifact( + parts=[part], + index=0, + name=part.file.name if part.file.name else "file", + description="File from user", + ) + ) + + # Send artifacts for any user files + for artifact in user_files: + yield SendTaskStreamingResponse( + id=request.id, + result=TaskArtifactUpdateEvent( + id=request.params.id, artifact=artifact + ), + ) + async for chunk in run_agent_stream( agent_id=str(agent.id), external_id=external_id, @@ -418,7 +465,48 @@ class A2ATaskManager: parts = content.get("parts", []) if parts: - update_message = Message(role=role, parts=parts) + # Modify to handle file parts as well + agent_parts = [] + for part in parts: + # Handle different part types + if part.get("type") == "text": + agent_parts.append(part) + full_response += part.get("text", "") + elif part.get("inlineData") and part["inlineData"].get( + "data" + ): + # Convert inline data to file part + mime_type = part["inlineData"].get( + "mimeType", "application/octet-stream" + ) + file_name = f"file_{uuid_pkg.uuid4().hex}{self._get_extension_from_mime(mime_type)}" + file_part = { + "type": "file", + "file": { + "name": file_name, + "mimeType": mime_type, + "bytes": part["inlineData"]["data"], + }, + } + agent_parts.append(file_part) + + # Also send as artifact + yield SendTaskStreamingResponse( + id=request.id, + result=TaskArtifactUpdateEvent( + id=request.params.id, + artifact=Artifact( + parts=[file_part], + index=0, + name=file_name, + description=f"Generated {mime_type} file", + ), + ), + ) + + if agent_parts: + update_message = Message(role=role, parts=agent_parts) + final_message = update_message yield SendTaskStreamingResponse( id=request.id, @@ -431,11 +519,6 @@ class A2ATaskManager: final=False, ), ) - - for part in parts: - if part.get("type") == "text": - full_response += part.get("text", "") - final_message = update_message except Exception as e: logger.error(f"Error processing chunk: {e}, chunk: {chunk}") continue @@ -485,6 +568,29 @@ class A2ATaskManager: error=InternalError(message=f"Error streaming task process: {str(e)}"), ) + def _get_extension_from_mime(self, mime_type: str) -> str: + """Get a file extension from MIME type.""" + if not mime_type: + return "" + + mime_map = { + "image/jpeg": ".jpg", + "image/png": ".png", + "image/gif": ".gif", + "application/pdf": ".pdf", + "text/plain": ".txt", + "text/html": ".html", + "text/csv": ".csv", + "application/json": ".json", + "application/xml": ".xml", + "application/msword": ".doc", + "application/vnd.openxmlformats-officedocument.wordprocessingml.document": ".docx", + "application/vnd.ms-excel": ".xls", + "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet": ".xlsx", + } + + return mime_map.get(mime_type, "") + async def update_store( self, task_id: str, @@ -514,19 +620,193 @@ class A2ATaskManager: return task def _extract_user_query(self, task_params: TaskSendParams) -> str: - """Extracts the user query from the task parameters.""" + """Extracts the user query from the task parameters and processes any files.""" if not task_params.message or not task_params.message.parts: raise ValueError("Message or parts are missing in task parameters") - part = task_params.message.parts[0] - if part.type != "text": - raise ValueError("Only text parts are supported") + # Process file parts first + text_parts = [] + has_files = False + file_parts = [] - return part.text + logger.info( + f"Extracting query from message with {len(task_params.message.parts)} parts" + ) + + # Extract text parts and file parts separately + for idx, part in enumerate(task_params.message.parts): + logger.info( + f"Processing part {idx+1}, type: {getattr(part, 'type', 'unknown')}" + ) + if hasattr(part, "type"): + if part.type == "text": + logger.info(f"Found text part: '{part.text[:50]}...' (truncated)") + text_parts.append(part.text) + elif part.type == "file": + logger.info( + f"Found file part: {getattr(getattr(part, 'file', None), 'name', 'unnamed')}" + ) + has_files = True + try: + processed_file = self._process_file_part( + part, task_params.sessionId + ) + if processed_file: + file_parts.append(processed_file) + except Exception as e: + logger.error(f"Error processing file part: {e}") + # Continue with other parts even if a file fails + else: + logger.warning(f"Unknown part type: {part.type}") + else: + logger.warning(f"Part has no type attribute: {part}") + + # Store the file parts in self for later use + self._last_processed_files = file_parts if file_parts else None + + # If we have at least one text part, use that as the query + if text_parts: + final_query = " ".join(text_parts) + logger.info( + f"Final query from text parts: '{final_query[:50]}...' (truncated)" + ) + return final_query + # If we only have file parts, create a generic query asking for analysis + elif has_files: + logger.info("No text parts, using generic query for file analysis") + return "Analyze the attached files" + else: + logger.error("No supported content parts found in the message") + raise ValueError("No supported content parts found in the message") + + def _process_file_part(self, part, session_id: str): + """Processes a file part and saves it to the artifact service. + + Returns: + dict: Processed file information to pass to agent_runner + """ + if not hasattr(part, "file") or not part.file: + logger.warning("File part missing file data") + return None + + file_data = part.file + + if not file_data.name: + file_data.name = f"file_{uuid_pkg.uuid4().hex}" + + logger.info(f"Processing file {file_data.name} for session {session_id}") + + if file_data.bytes: + # Process file data provided as base64 string + try: + # Convert base64 to bytes + logger.info(f"Decoding base64 content for file {file_data.name}") + file_bytes = base64.b64decode(file_data.bytes) + + # Determine MIME type based on binary content + mime_type = ( + file_data.mimeType if hasattr(file_data, "mimeType") else None + ) + + if not mime_type or mime_type == "application/octet-stream": + # Detection by byte signature + if file_bytes.startswith(b"\xff\xd8\xff"): # JPEG signature + mime_type = "image/jpeg" + elif file_bytes.startswith(b"\x89PNG\r\n\x1a\n"): # PNG signature + mime_type = "image/png" + elif file_bytes.startswith(b"GIF87a") or file_bytes.startswith( + b"GIF89a" + ): # GIF + mime_type = "image/gif" + elif file_bytes.startswith(b"%PDF"): # PDF + mime_type = "application/pdf" + else: + # Fallback to avoid generic type in images + if file_data.name.lower().endswith((".jpg", ".jpeg")): + mime_type = "image/jpeg" + elif file_data.name.lower().endswith(".png"): + mime_type = "image/png" + elif file_data.name.lower().endswith(".gif"): + mime_type = "image/gif" + elif file_data.name.lower().endswith(".pdf"): + mime_type = "application/pdf" + else: + mime_type = "application/octet-stream" + + logger.info( + f"Decoded file size: {len(file_bytes)} bytes, MIME type: {mime_type}" + ) + + # Split session_id to get app_name and user_id + parts = session_id.split("_") + if len(parts) != 2: + user_id = session_id + app_name = "a2a" + else: + user_id, app_name = parts + + # Create artifact Part + logger.info(f"Creating artifact Part for file {file_data.name}") + artifact = Part(inline_data=Blob(mime_type=mime_type, data=file_bytes)) + + # Save to artifact service + logger.info( + f"Saving artifact {file_data.name} to {app_name}/{user_id}/{session_id}" + ) + version = artifacts_service.save_artifact( + app_name=app_name, + user_id=user_id, + session_id=session_id, + filename=file_data.name, + artifact=artifact, + ) + + logger.info( + f"Successfully saved file {file_data.name} (version {version}) for session {session_id}" + ) + + # Import the FileData model from the chat schema + from src.schemas.chat import FileData + + # Create a FileData object instead of a dictionary + # This is compatible with what agent_runner.py expects + return FileData( + filename=file_data.name, + content_type=mime_type, + data=file_data.bytes, # Keep the original base64 format + ) + + except Exception as e: + logger.error(f"Error processing file data: {str(e)}") + # Log more details about the error to help with debugging + import traceback + + logger.error(f"Full traceback: {traceback.format_exc()}") + raise + + elif file_data.uri: + # Handling URIs would require additional implementation + # For now, log that we received a URI but can't process it + logger.warning(f"File URI references not yet implemented: {file_data.uri}") + # Future enhancement: fetch the file from the URI and save it + return None + + return None async def _run_agent(self, agent: Agent, query: str, session_id: str) -> dict: """Executes the agent to process the user query.""" try: + files = getattr(self, "_last_processed_files", None) + + if files: + logger.info(f"Passing {len(files)} files to run_agent") + for file_info in files: + logger.info( + f"File being sent: {file_info.filename} ({file_info.content_type})" + ) + else: + logger.info("No files to pass to run_agent") + # We call the same function used in the chat API return await run_agent( agent_id=str(agent.id), @@ -536,6 +816,7 @@ class A2ATaskManager: artifacts_service=artifacts_service, memory_service=memory_service, db=self.db, + files=files, ) except Exception as e: logger.error(f"Error running agent: {e}") diff --git a/src/services/agent_runner.py b/src/services/agent_runner.py index 1d330754..4e205875 100644 --- a/src/services/agent_runner.py +++ b/src/services/agent_runner.py @@ -28,7 +28,7 @@ """ from google.adk.runners import Runner -from google.genai.types import Content, Part +from google.genai.types import Content, Part, Blob from google.adk.sessions import DatabaseSessionService from google.adk.memory import InMemoryMemoryService from google.adk.artifacts.in_memory_artifact_service import InMemoryArtifactService @@ -42,6 +42,7 @@ import asyncio import json from src.utils.otel import get_tracer from opentelemetry import trace +import base64 logger = setup_logger(__name__) @@ -56,6 +57,7 @@ async def run_agent( db: Session, session_id: Optional[str] = None, timeout: float = 60.0, + files: Optional[list] = None, ): tracer = get_tracer() with tracer.start_as_current_span( @@ -65,6 +67,7 @@ async def run_agent( "external_id": external_id, "session_id": session_id or f"{external_id}_{agent_id}", "message": message, + "has_files": files is not None and len(files) > 0, }, ): exit_stack = None @@ -74,6 +77,9 @@ async def run_agent( ) logger.info(f"Received message: {message}") + if files and len(files) > 0: + logger.info(f"Received {len(files)} files with message") + get_root_agent = get_agent(db, agent_id) logger.info( f"Root agent found: {get_root_agent.name} (type: {get_root_agent.type})" @@ -113,7 +119,63 @@ async def run_agent( session_id=adk_session_id, ) - content = Content(role="user", parts=[Part(text=message)]) + file_parts = [] + if files and len(files) > 0: + for file_data in files: + try: + file_bytes = base64.b64decode(file_data.data) + + logger.info(f"DEBUG - Processing file: {file_data.filename}") + logger.info(f"DEBUG - File size: {len(file_bytes)} bytes") + logger.info(f"DEBUG - MIME type: '{file_data.content_type}'") + logger.info(f"DEBUG - First 20 bytes: {file_bytes[:20]}") + + try: + file_part = Part( + inline_data=Blob( + mime_type=file_data.content_type, data=file_bytes + ) + ) + logger.info(f"DEBUG - Part created successfully") + except Exception as part_error: + logger.error( + f"DEBUG - Error creating Part: {str(part_error)}" + ) + logger.error( + f"DEBUG - Error type: {type(part_error).__name__}" + ) + import traceback + + logger.error( + f"DEBUG - Stack trace: {traceback.format_exc()}" + ) + raise + + # Save the file in the ArtifactService + version = artifacts_service.save_artifact( + app_name=agent_id, + user_id=external_id, + session_id=adk_session_id, + filename=file_data.filename, + artifact=file_part, + ) + logger.info( + f"Saved file {file_data.filename} as version {version}" + ) + + # Add the Part to the list of parts for the message content + file_parts.append(file_part) + except Exception as e: + logger.error( + f"Error processing file {file_data.filename}: {str(e)}" + ) + + # Create the content with the text message and the files + parts = [Part(text=message)] + if file_parts: + parts.extend(file_parts) + + content = Content(role="user", parts=parts) logger.info("Starting agent execution") final_response_text = "No final response captured." @@ -256,6 +318,7 @@ async def run_agent_stream( memory_service: InMemoryMemoryService, db: Session, session_id: Optional[str] = None, + files: Optional[list] = None, ) -> AsyncGenerator[str, None]: tracer = get_tracer() span = tracer.start_span( @@ -265,6 +328,7 @@ async def run_agent_stream( "external_id": external_id, "session_id": session_id or f"{external_id}_{agent_id}", "message": message, + "has_files": files is not None and len(files) > 0, }, ) try: @@ -275,6 +339,9 @@ async def run_agent_stream( ) logger.info(f"Received message: {message}") + if files and len(files) > 0: + logger.info(f"Received {len(files)} files with message") + get_root_agent = get_agent(db, agent_id) logger.info( f"Root agent found: {get_root_agent.name} (type: {get_root_agent.type})" @@ -314,7 +381,72 @@ async def run_agent_stream( session_id=adk_session_id, ) - content = Content(role="user", parts=[Part(text=message)]) + # Process the received files + file_parts = [] + if files and len(files) > 0: + for file_data in files: + try: + # Decode the base64 file + file_bytes = base64.b64decode(file_data.data) + + # Detailed debug + logger.info( + f"DEBUG - Processing file: {file_data.filename}" + ) + logger.info(f"DEBUG - File size: {len(file_bytes)} bytes") + logger.info( + f"DEBUG - MIME type: '{file_data.content_type}'" + ) + logger.info(f"DEBUG - First 20 bytes: {file_bytes[:20]}") + + # Create a Part for the file using the default constructor + try: + file_part = Part( + inline_data=Blob( + mime_type=file_data.content_type, + data=file_bytes, + ) + ) + logger.info(f"DEBUG - Part created successfully") + except Exception as part_error: + logger.error( + f"DEBUG - Error creating Part: {str(part_error)}" + ) + logger.error( + f"DEBUG - Error type: {type(part_error).__name__}" + ) + import traceback + + logger.error( + f"DEBUG - Stack trace: {traceback.format_exc()}" + ) + raise + + # Save the file in the ArtifactService + version = artifacts_service.save_artifact( + app_name=agent_id, + user_id=external_id, + session_id=adk_session_id, + filename=file_data.filename, + artifact=file_part, + ) + logger.info( + f"Saved file {file_data.filename} as version {version}" + ) + + # Add the Part to the list of parts for the message content + file_parts.append(file_part) + except Exception as e: + logger.error( + f"Error processing file {file_data.filename}: {str(e)}" + ) + + # Create the content with the text message and the files + parts = [Part(text=message)] + if file_parts: + parts.extend(file_parts) + + content = Content(role="user", parts=parts) logger.info("Starting agent streaming execution") try: diff --git a/src/utils/a2a_utils.py b/src/utils/a2a_utils.py index e58b968c..62e9bbd5 100644 --- a/src/utils/a2a_utils.py +++ b/src/utils/a2a_utils.py @@ -27,10 +27,16 @@ └──────────────────────────────────────────────────────────────────────────────┘ """ +import base64 +import uuid +from typing import Dict, List, Any, Optional +from google.genai.types import Part, Blob + from src.schemas.a2a_types import ( ContentTypeNotSupportedError, JSONRPCResponse, UnsupportedOperationError, + Message, ) @@ -55,3 +61,130 @@ def new_incompatible_types_error(request_id): def new_not_implemented_error(request_id): return JSONRPCResponse(id=request_id, error=UnsupportedOperationError()) + + +def extract_files_from_message(message: Message) -> List[Dict[str, Any]]: + """ + Extract file parts from an A2A message. + + Args: + message: An A2A Message object + + Returns: + List of file parts extracted from the message + """ + if not message or not message.parts: + return [] + + files = [] + for part in message.parts: + if hasattr(part, "type") and part.type == "file" and hasattr(part, "file"): + files.append(part) + + return files + + +def a2a_part_to_adk_part(a2a_part: Dict[str, Any]) -> Optional[Part]: + """ + Convert an A2A protocol part to an ADK Part object. + + Args: + a2a_part: An A2A part dictionary + + Returns: + Converted ADK Part object or None if conversion not possible + """ + part_type = a2a_part.get("type") + if part_type == "file" and "file" in a2a_part: + file_data = a2a_part["file"] + if "bytes" in file_data: + try: + # Convert base64 to bytes + file_bytes = base64.b64decode(file_data["bytes"]) + mime_type = file_data.get("mimeType", "application/octet-stream") + + # Create ADK Part + return Part(inline_data=Blob(mime_type=mime_type, data=file_bytes)) + except Exception: + return None + elif part_type == "text" and "text" in a2a_part: + # For text parts, we could create a text blob if needed + return None + + return None + + +def adk_part_to_a2a_part( + adk_part: Part, filename: Optional[str] = None +) -> Optional[Dict[str, Any]]: + """ + Convert an ADK Part object to an A2A protocol part. + + Args: + adk_part: An ADK Part object + filename: Optional filename to use + + Returns: + Converted A2A Part dictionary or None if conversion not possible + """ + if hasattr(adk_part, "inline_data") and adk_part.inline_data: + if adk_part.inline_data.data and adk_part.inline_data.mime_type: + # Convert binary data to base64 + file_bytes = adk_part.inline_data.data + mime_type = adk_part.inline_data.mime_type + + # Generate filename if not provided + if not filename: + ext = get_extension_from_mime(mime_type) + filename = f"file_{uuid.uuid4().hex}{ext}" + + # Convert to A2A FilePart dict + return { + "type": "file", + "file": { + "name": filename, + "mimeType": mime_type, + "bytes": ( + base64.b64encode(file_bytes).decode("utf-8") + if isinstance(file_bytes, bytes) + else str(file_bytes) + ), + }, + } + elif hasattr(adk_part, "text") and adk_part.text: + # Convert text part + return {"type": "text", "text": adk_part.text} + + return None + + +def get_extension_from_mime(mime_type: str) -> str: + """ + Get a file extension from MIME type. + + Args: + mime_type: MIME type string + + Returns: + Appropriate file extension with leading dot + """ + if not mime_type: + return "" + + mime_map = { + "image/jpeg": ".jpg", + "image/png": ".png", + "image/gif": ".gif", + "application/pdf": ".pdf", + "text/plain": ".txt", + "text/html": ".html", + "text/csv": ".csv", + "application/json": ".json", + "application/xml": ".xml", + "application/msword": ".doc", + "application/vnd.openxmlformats-officedocument.wordprocessingml.document": ".docx", + "application/vnd.ms-excel": ".xls", + "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet": ".xlsx", + } + + return mime_map.get(mime_type, "")