feat(a2a): add file support and multimodal content processing for A2A protocol
This commit is contained in:
parent
958eeec4a6
commit
6bf0ea52e0
@ -11,6 +11,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
|||||||
|
|
||||||
- Add Task Agent for structured single-task execution
|
- Add Task Agent for structured single-task execution
|
||||||
- Improve context management in agent execution
|
- Improve context management in agent execution
|
||||||
|
- Add file support for A2A protocol (Agent-to-Agent) endpoints
|
||||||
|
- Implement multimodal content processing in A2A messages
|
||||||
|
|
||||||
## [0.0.9] - 2025-05-13
|
## [0.0.9] - 2025-05-13
|
||||||
|
|
||||||
|
@ -31,6 +31,7 @@
|
|||||||
Routes for the A2A (Agent-to-Agent) protocol.
|
Routes for the A2A (Agent-to-Agent) protocol.
|
||||||
|
|
||||||
This module implements the standard A2A routes according to the specification.
|
This module implements the standard A2A routes according to the specification.
|
||||||
|
Supports both text messages and file uploads through the message parts mechanism.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import uuid
|
import uuid
|
||||||
@ -92,7 +93,39 @@ async def process_a2a_request(
|
|||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
a2a_service: A2AService = Depends(get_a2a_service),
|
a2a_service: A2AService = Depends(get_a2a_service),
|
||||||
):
|
):
|
||||||
"""Processes an A2A request."""
|
"""
|
||||||
|
Processes an A2A request.
|
||||||
|
|
||||||
|
Supports both text messages and file uploads. For file uploads,
|
||||||
|
include file parts in the message following the A2A protocol format:
|
||||||
|
|
||||||
|
{
|
||||||
|
"jsonrpc": "2.0",
|
||||||
|
"id": "request-id",
|
||||||
|
"method": "tasks/send",
|
||||||
|
"params": {
|
||||||
|
"id": "task-id",
|
||||||
|
"sessionId": "session-id",
|
||||||
|
"message": {
|
||||||
|
"role": "user",
|
||||||
|
"parts": [
|
||||||
|
{
|
||||||
|
"type": "text",
|
||||||
|
"text": "Analyze this image"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "file",
|
||||||
|
"file": {
|
||||||
|
"name": "example.jpg",
|
||||||
|
"mimeType": "image/jpeg",
|
||||||
|
"bytes": "base64-encoded-content"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
"""
|
||||||
# Verify the API key
|
# Verify the API key
|
||||||
if not verify_api_key(db, x_api_key):
|
if not verify_api_key(db, x_api_key):
|
||||||
raise HTTPException(status_code=401, detail="Invalid API key")
|
raise HTTPException(status_code=401, detail="Invalid API key")
|
||||||
@ -100,10 +133,60 @@ async def process_a2a_request(
|
|||||||
# Process the request
|
# Process the request
|
||||||
try:
|
try:
|
||||||
request_body = await request.json()
|
request_body = await request.json()
|
||||||
|
|
||||||
|
debug_request_body = {}
|
||||||
|
if "method" in request_body:
|
||||||
|
debug_request_body["method"] = request_body["method"]
|
||||||
|
if "id" in request_body:
|
||||||
|
debug_request_body["id"] = request_body["id"]
|
||||||
|
|
||||||
|
logger.info(f"A2A request received: {debug_request_body}")
|
||||||
|
|
||||||
|
# Log if request contains file parts for debugging
|
||||||
|
if isinstance(request_body, dict) and "params" in request_body:
|
||||||
|
params = request_body.get("params", {})
|
||||||
|
message = params.get("message", {})
|
||||||
|
parts = message.get("parts", [])
|
||||||
|
|
||||||
|
logger.info(f"A2A message contains {len(parts)} parts")
|
||||||
|
for i, part in enumerate(parts):
|
||||||
|
if not isinstance(part, dict):
|
||||||
|
logger.warning(f"Part {i+1} is not a dictionary: {type(part)}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
part_type = part.get("type")
|
||||||
|
logger.info(f"Part {i+1} type: {part_type}")
|
||||||
|
|
||||||
|
if part_type == "file":
|
||||||
|
file_info = part.get("file", {})
|
||||||
|
logger.info(
|
||||||
|
f"File part found: {file_info.get('name')} ({file_info.get('mimeType')})"
|
||||||
|
)
|
||||||
|
if "bytes" in file_info:
|
||||||
|
bytes_data = file_info.get("bytes", "")
|
||||||
|
bytes_size = len(bytes_data) * 0.75
|
||||||
|
logger.info(f"File size: ~{bytes_size/1024:.2f} KB")
|
||||||
|
if bytes_data:
|
||||||
|
sample = (
|
||||||
|
bytes_data[:10] + "..."
|
||||||
|
if len(bytes_data) > 10
|
||||||
|
else bytes_data
|
||||||
|
)
|
||||||
|
logger.info(f"Sample of base64 data: {sample}")
|
||||||
|
elif part_type == "text":
|
||||||
|
text_content = part.get("text", "")
|
||||||
|
preview = (
|
||||||
|
text_content[:30] + "..."
|
||||||
|
if len(text_content) > 30
|
||||||
|
else text_content
|
||||||
|
)
|
||||||
|
logger.info(f"Text part found: '{preview}'")
|
||||||
|
|
||||||
result = await a2a_service.process_request(agent_id, request_body)
|
result = await a2a_service.process_request(agent_id, request_body)
|
||||||
|
|
||||||
# If the response is a streaming response, return as EventSourceResponse
|
# If the response is a streaming response, return as EventSourceResponse
|
||||||
if hasattr(result, "__aiter__"):
|
if hasattr(result, "__aiter__"):
|
||||||
|
logger.info("Returning streaming response")
|
||||||
|
|
||||||
async def event_generator():
|
async def event_generator():
|
||||||
async for item in result:
|
async for item in result:
|
||||||
@ -115,11 +198,15 @@ async def process_a2a_request(
|
|||||||
return EventSourceResponse(event_generator())
|
return EventSourceResponse(event_generator())
|
||||||
|
|
||||||
# Otherwise, return as JSONResponse
|
# Otherwise, return as JSONResponse
|
||||||
|
logger.info("Returning standard JSON response")
|
||||||
if hasattr(result, "model_dump"):
|
if hasattr(result, "model_dump"):
|
||||||
return JSONResponse(result.model_dump(exclude_none=True))
|
return JSONResponse(result.model_dump(exclude_none=True))
|
||||||
return JSONResponse(result)
|
return JSONResponse(result)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error processing A2A request: {e}")
|
logger.error(f"Error processing A2A request: {e}")
|
||||||
|
import traceback
|
||||||
|
|
||||||
|
logger.error(f"Full traceback: {traceback.format_exc()}")
|
||||||
return JSONResponse(
|
return JSONResponse(
|
||||||
status_code=500,
|
status_code=500,
|
||||||
content={
|
content={
|
||||||
|
@ -28,6 +28,7 @@
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import uuid
|
import uuid
|
||||||
|
import base64
|
||||||
from fastapi import (
|
from fastapi import (
|
||||||
APIRouter,
|
APIRouter,
|
||||||
Depends,
|
Depends,
|
||||||
@ -47,7 +48,7 @@ from src.core.jwt_middleware import (
|
|||||||
from src.services import (
|
from src.services import (
|
||||||
agent_service,
|
agent_service,
|
||||||
)
|
)
|
||||||
from src.schemas.chat import ChatRequest, ChatResponse, ErrorResponse
|
from src.schemas.chat import ChatRequest, ChatResponse, ErrorResponse, FileData
|
||||||
from src.services.agent_runner import run_agent, run_agent_stream
|
from src.services.agent_runner import run_agent, run_agent_stream
|
||||||
from src.core.exceptions import AgentNotFoundError
|
from src.core.exceptions import AgentNotFoundError
|
||||||
from src.services.service_providers import (
|
from src.services.service_providers import (
|
||||||
@ -59,7 +60,7 @@ from src.services.service_providers import (
|
|||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
import logging
|
import logging
|
||||||
import json
|
import json
|
||||||
from typing import Optional, Dict
|
from typing import Optional, Dict, List, Any
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -195,6 +196,29 @@ async def websocket_chat(
|
|||||||
if not message:
|
if not message:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
files = None
|
||||||
|
if data.get("files") and isinstance(data.get("files"), list):
|
||||||
|
try:
|
||||||
|
files = []
|
||||||
|
for file_data in data.get("files"):
|
||||||
|
if (
|
||||||
|
isinstance(file_data, dict)
|
||||||
|
and file_data.get("filename")
|
||||||
|
and file_data.get("content_type")
|
||||||
|
and file_data.get("data")
|
||||||
|
):
|
||||||
|
files.append(
|
||||||
|
FileData(
|
||||||
|
filename=file_data.get("filename"),
|
||||||
|
content_type=file_data.get("content_type"),
|
||||||
|
data=file_data.get("data"),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
logger.info(f"Processed {len(files)} files via WebSocket")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error processing files: {str(e)}")
|
||||||
|
files = None
|
||||||
|
|
||||||
async for chunk in run_agent_stream(
|
async for chunk in run_agent_stream(
|
||||||
agent_id=agent_id,
|
agent_id=agent_id,
|
||||||
external_id=external_id,
|
external_id=external_id,
|
||||||
@ -203,6 +227,7 @@ async def websocket_chat(
|
|||||||
artifacts_service=artifacts_service,
|
artifacts_service=artifacts_service,
|
||||||
memory_service=memory_service,
|
memory_service=memory_service,
|
||||||
db=db,
|
db=db,
|
||||||
|
files=files,
|
||||||
):
|
):
|
||||||
await websocket.send_json(
|
await websocket.send_json(
|
||||||
{"message": json.loads(chunk), "turn_complete": False}
|
{"message": json.loads(chunk), "turn_complete": False}
|
||||||
@ -259,6 +284,7 @@ async def chat(
|
|||||||
artifacts_service,
|
artifacts_service,
|
||||||
memory_service,
|
memory_service,
|
||||||
db,
|
db,
|
||||||
|
files=request.files,
|
||||||
)
|
)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
|
@ -30,8 +30,9 @@
|
|||||||
from fastapi import APIRouter, Depends, HTTPException, status
|
from fastapi import APIRouter, Depends, HTTPException, status
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
from src.config.database import get_db
|
from src.config.database import get_db
|
||||||
from typing import List
|
from typing import List, Optional, Dict, Any
|
||||||
import uuid
|
import uuid
|
||||||
|
import base64
|
||||||
from src.core.jwt_middleware import (
|
from src.core.jwt_middleware import (
|
||||||
get_jwt_token,
|
get_jwt_token,
|
||||||
verify_user_client,
|
verify_user_client,
|
||||||
@ -48,7 +49,7 @@ from src.services.session_service import (
|
|||||||
get_sessions_by_agent,
|
get_sessions_by_agent,
|
||||||
get_sessions_by_client,
|
get_sessions_by_client,
|
||||||
)
|
)
|
||||||
from src.services.service_providers import session_service
|
from src.services.service_providers import session_service, artifacts_service
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@ -118,13 +119,18 @@ async def get_session(
|
|||||||
|
|
||||||
@router.get(
|
@router.get(
|
||||||
"/{session_id}/messages",
|
"/{session_id}/messages",
|
||||||
response_model=List[Event],
|
|
||||||
)
|
)
|
||||||
async def get_agent_messages(
|
async def get_agent_messages(
|
||||||
session_id: str,
|
session_id: str,
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
payload: dict = Depends(get_jwt_token),
|
payload: dict = Depends(get_jwt_token),
|
||||||
):
|
):
|
||||||
|
"""
|
||||||
|
Gets messages from a session with embedded artifacts.
|
||||||
|
|
||||||
|
This function loads all messages from a session and processes any references
|
||||||
|
to artifacts, loading them and converting them to base64 for direct use in the frontend.
|
||||||
|
"""
|
||||||
# Get the session
|
# Get the session
|
||||||
session = get_session_by_id(session_service, session_id)
|
session = get_session_by_id(session_service, session_id)
|
||||||
if not session:
|
if not session:
|
||||||
@ -139,7 +145,160 @@ async def get_agent_messages(
|
|||||||
if agent:
|
if agent:
|
||||||
await verify_user_client(payload, db, agent.client_id)
|
await verify_user_client(payload, db, agent.client_id)
|
||||||
|
|
||||||
return get_session_events(session_service, session_id)
|
# Parse session ID para obter app_name e user_id
|
||||||
|
parts = session_id.split("_")
|
||||||
|
if len(parts) != 2:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid session ID format"
|
||||||
|
)
|
||||||
|
|
||||||
|
user_id, app_name = parts[0], parts[1]
|
||||||
|
|
||||||
|
events = get_session_events(session_service, session_id)
|
||||||
|
|
||||||
|
processed_events = []
|
||||||
|
for event in events:
|
||||||
|
event_dict = event.dict()
|
||||||
|
|
||||||
|
def process_dict(d):
|
||||||
|
if isinstance(d, dict):
|
||||||
|
for key, value in list(d.items()):
|
||||||
|
if isinstance(value, bytes):
|
||||||
|
try:
|
||||||
|
d[key] = base64.b64encode(value).decode("utf-8")
|
||||||
|
logger.debug(f"Converted bytes field to base64: {key}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error encoding bytes to base64: {str(e)}")
|
||||||
|
d[key] = None
|
||||||
|
elif isinstance(value, dict):
|
||||||
|
process_dict(value)
|
||||||
|
elif isinstance(value, list):
|
||||||
|
for item in value:
|
||||||
|
if isinstance(item, (dict, list)):
|
||||||
|
process_dict(item)
|
||||||
|
elif isinstance(d, list):
|
||||||
|
for i, item in enumerate(d):
|
||||||
|
if isinstance(item, bytes):
|
||||||
|
try:
|
||||||
|
d[i] = base64.b64encode(item).decode("utf-8")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"Error encoding bytes to base64 in list: {str(e)}"
|
||||||
|
)
|
||||||
|
d[i] = None
|
||||||
|
elif isinstance(item, (dict, list)):
|
||||||
|
process_dict(item)
|
||||||
|
return d
|
||||||
|
|
||||||
|
# Process all event dictionary
|
||||||
|
event_dict = process_dict(event_dict)
|
||||||
|
|
||||||
|
# Process the content parts specifically
|
||||||
|
if event_dict.get("content") and event_dict["content"].get("parts"):
|
||||||
|
for part in event_dict["content"]["parts"]:
|
||||||
|
# Process inlineData if present
|
||||||
|
if part and part.get("inlineData") and part["inlineData"].get("data"):
|
||||||
|
# Check if it's already a string or if it's bytes
|
||||||
|
if isinstance(part["inlineData"]["data"], bytes):
|
||||||
|
# Convert bytes to base64 string
|
||||||
|
part["inlineData"]["data"] = base64.b64encode(
|
||||||
|
part["inlineData"]["data"]
|
||||||
|
).decode("utf-8")
|
||||||
|
logger.debug(
|
||||||
|
f"Converted binary data to base64 in message {event_dict.get('id')}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Process fileData if present (reference to an artifact)
|
||||||
|
if part and part.get("fileData") and part["fileData"].get("fileId"):
|
||||||
|
try:
|
||||||
|
# Extract the file name from the fileId
|
||||||
|
file_id = part["fileData"]["fileId"]
|
||||||
|
|
||||||
|
# Load the artifact from the artifacts service
|
||||||
|
artifact = artifacts_service.load_artifact(
|
||||||
|
app_name=app_name,
|
||||||
|
user_id=user_id,
|
||||||
|
session_id=session_id,
|
||||||
|
filename=file_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
if artifact and hasattr(artifact, "inline_data"):
|
||||||
|
# Extract the data and MIME type
|
||||||
|
file_bytes = artifact.inline_data.data
|
||||||
|
mime_type = artifact.inline_data.mime_type
|
||||||
|
|
||||||
|
# Add inlineData with the artifact data
|
||||||
|
if not part.get("inlineData"):
|
||||||
|
part["inlineData"] = {}
|
||||||
|
|
||||||
|
# Ensure we're sending a base64 string, not bytes
|
||||||
|
if isinstance(file_bytes, bytes):
|
||||||
|
try:
|
||||||
|
part["inlineData"]["data"] = base64.b64encode(
|
||||||
|
file_bytes
|
||||||
|
).decode("utf-8")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"Error encoding artifact to base64: {str(e)}"
|
||||||
|
)
|
||||||
|
part["inlineData"]["data"] = None
|
||||||
|
else:
|
||||||
|
part["inlineData"]["data"] = str(file_bytes)
|
||||||
|
|
||||||
|
part["inlineData"]["mimeType"] = mime_type
|
||||||
|
|
||||||
|
logger.debug(
|
||||||
|
f"Loaded artifact {file_id} for message {event_dict.get('id')}"
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error loading artifact: {str(e)}")
|
||||||
|
# Don't interrupt the flow if an artifact fails
|
||||||
|
|
||||||
|
# Check artifact_delta in actions
|
||||||
|
if event_dict.get("actions") and event_dict["actions"].get("artifact_delta"):
|
||||||
|
artifact_deltas = event_dict["actions"]["artifact_delta"]
|
||||||
|
for filename, version in artifact_deltas.items():
|
||||||
|
try:
|
||||||
|
# Load the artifact
|
||||||
|
artifact = artifacts_service.load_artifact(
|
||||||
|
app_name=app_name,
|
||||||
|
user_id=user_id,
|
||||||
|
session_id=session_id,
|
||||||
|
filename=filename,
|
||||||
|
version=version,
|
||||||
|
)
|
||||||
|
|
||||||
|
if artifact and hasattr(artifact, "inline_data"):
|
||||||
|
# If the event doesn't have an artifacts section, create it
|
||||||
|
if "artifacts" not in event_dict:
|
||||||
|
event_dict["artifacts"] = {}
|
||||||
|
|
||||||
|
# Add the artifact to the event's artifacts list
|
||||||
|
file_bytes = artifact.inline_data.data
|
||||||
|
mime_type = artifact.inline_data.mime_type
|
||||||
|
|
||||||
|
# Ensure the bytes are converted to base64
|
||||||
|
event_dict["artifacts"][filename] = {
|
||||||
|
"data": (
|
||||||
|
base64.b64encode(file_bytes).decode("utf-8")
|
||||||
|
if isinstance(file_bytes, bytes)
|
||||||
|
else str(file_bytes)
|
||||||
|
),
|
||||||
|
"mimeType": mime_type,
|
||||||
|
"version": version,
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.debug(
|
||||||
|
f"Added artifact {filename} (v{version}) to message {event_dict.get('id')}"
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"Error processing artifact_delta {filename}: {str(e)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
processed_events.append(event_dict)
|
||||||
|
|
||||||
|
return processed_events
|
||||||
|
|
||||||
|
|
||||||
@router.delete(
|
@router.delete(
|
||||||
|
@ -27,36 +27,42 @@
|
|||||||
└──────────────────────────────────────────────────────────────────────────────┘
|
└──────────────────────────────────────────────────────────────────────────────┘
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field, validator
|
||||||
from typing import Dict, Any, Optional
|
from typing import Dict, List, Optional, Any
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
|
||||||
|
class FileData(BaseModel):
|
||||||
|
"""Model to represent file data sent in a chat request."""
|
||||||
|
|
||||||
|
filename: str = Field(..., description="File name")
|
||||||
|
content_type: str = Field(..., description="File content type")
|
||||||
|
data: str = Field(..., description="File content encoded in base64")
|
||||||
|
|
||||||
|
|
||||||
class ChatRequest(BaseModel):
|
class ChatRequest(BaseModel):
|
||||||
"""Schema for chat requests"""
|
"""Model to represent a chat request."""
|
||||||
|
|
||||||
agent_id: str = Field(
|
agent_id: str = Field(..., description="Agent ID to process the message")
|
||||||
..., description="ID of the agent that will process the message"
|
external_id: str = Field(..., description="External ID for user identification")
|
||||||
|
message: str = Field(..., description="User message to the agent")
|
||||||
|
files: Optional[List[FileData]] = Field(
|
||||||
|
None, description="List of files attached to the message"
|
||||||
)
|
)
|
||||||
external_id: str = Field(
|
|
||||||
..., description="ID of the external_id that will process the message"
|
|
||||||
)
|
|
||||||
message: str = Field(..., description="User message")
|
|
||||||
|
|
||||||
|
|
||||||
class ChatResponse(BaseModel):
|
class ChatResponse(BaseModel):
|
||||||
"""Schema for chat responses"""
|
"""Model to represent a chat response."""
|
||||||
|
|
||||||
response: str = Field(..., description="Agent response")
|
response: str = Field(..., description="Response generated by the agent")
|
||||||
status: str = Field(..., description="Operation status")
|
message_history: List[Dict[str, Any]] = Field(
|
||||||
error: Optional[str] = Field(None, description="Error message, if there is one")
|
default_factory=list, description="Message history"
|
||||||
timestamp: str = Field(..., description="Timestamp of the response")
|
)
|
||||||
|
status: str = Field(..., description="Response status (success/error)")
|
||||||
|
timestamp: str = Field(..., description="Response timestamp")
|
||||||
|
|
||||||
|
|
||||||
class ErrorResponse(BaseModel):
|
class ErrorResponse(BaseModel):
|
||||||
"""Schema for error responses"""
|
"""Model to represent an error response."""
|
||||||
|
|
||||||
error: str = Field(..., description="Error message")
|
detail: str = Field(..., description="Error details")
|
||||||
status_code: int = Field(..., description="HTTP status code of the error")
|
|
||||||
details: Optional[Dict[str, Any]] = Field(
|
|
||||||
None, description="Additional error details"
|
|
||||||
)
|
|
||||||
|
@ -33,8 +33,11 @@ from collections.abc import AsyncIterable
|
|||||||
from typing import Dict, Optional
|
from typing import Dict, Optional
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
import json
|
import json
|
||||||
|
import base64
|
||||||
|
import uuid as uuid_pkg
|
||||||
|
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
from google.genai.types import Part, Blob
|
||||||
|
|
||||||
from src.config.settings import settings
|
from src.config.settings import settings
|
||||||
from src.services.agent_service import (
|
from src.services.agent_service import (
|
||||||
@ -76,6 +79,7 @@ from src.schemas.a2a_types import (
|
|||||||
AgentAuthentication,
|
AgentAuthentication,
|
||||||
AgentProvider,
|
AgentProvider,
|
||||||
)
|
)
|
||||||
|
from src.schemas.chat import FileData
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -281,12 +285,29 @@ class A2ATaskManager:
|
|||||||
all_messages.append(agent_message)
|
all_messages.append(agent_message)
|
||||||
|
|
||||||
task_state = self._determine_task_state(result)
|
task_state = self._determine_task_state(result)
|
||||||
artifact = Artifact(parts=agent_message.parts, index=0)
|
|
||||||
|
# Create artifacts for any file content
|
||||||
|
artifacts = []
|
||||||
|
# First, add the main response as an artifact
|
||||||
|
artifacts.append(Artifact(parts=agent_message.parts, index=0))
|
||||||
|
|
||||||
|
# Also add any files from the message history
|
||||||
|
for idx, msg in enumerate(all_messages, 1):
|
||||||
|
for part in msg.parts:
|
||||||
|
if hasattr(part, "type") and part.type == "file":
|
||||||
|
artifacts.append(
|
||||||
|
Artifact(
|
||||||
|
parts=[part],
|
||||||
|
index=idx,
|
||||||
|
name=part.file.name,
|
||||||
|
description=f"File from message {idx}",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
task = await self.update_store(
|
task = await self.update_store(
|
||||||
task_params.id,
|
task_params.id,
|
||||||
TaskStatus(state=task_state, message=agent_message),
|
TaskStatus(state=task_state, message=agent_message),
|
||||||
[artifact],
|
artifacts,
|
||||||
)
|
)
|
||||||
|
|
||||||
await self._update_task_history(
|
await self._update_task_history(
|
||||||
@ -400,6 +421,32 @@ class A2ATaskManager:
|
|||||||
|
|
||||||
final_message = None
|
final_message = None
|
||||||
|
|
||||||
|
# Check for files in the user message and include them as artifacts
|
||||||
|
user_files = []
|
||||||
|
for part in request.params.message.parts:
|
||||||
|
if (
|
||||||
|
hasattr(part, "type")
|
||||||
|
and part.type == "file"
|
||||||
|
and hasattr(part, "file")
|
||||||
|
):
|
||||||
|
user_files.append(
|
||||||
|
Artifact(
|
||||||
|
parts=[part],
|
||||||
|
index=0,
|
||||||
|
name=part.file.name if part.file.name else "file",
|
||||||
|
description="File from user",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Send artifacts for any user files
|
||||||
|
for artifact in user_files:
|
||||||
|
yield SendTaskStreamingResponse(
|
||||||
|
id=request.id,
|
||||||
|
result=TaskArtifactUpdateEvent(
|
||||||
|
id=request.params.id, artifact=artifact
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
async for chunk in run_agent_stream(
|
async for chunk in run_agent_stream(
|
||||||
agent_id=str(agent.id),
|
agent_id=str(agent.id),
|
||||||
external_id=external_id,
|
external_id=external_id,
|
||||||
@ -418,7 +465,48 @@ class A2ATaskManager:
|
|||||||
parts = content.get("parts", [])
|
parts = content.get("parts", [])
|
||||||
|
|
||||||
if parts:
|
if parts:
|
||||||
update_message = Message(role=role, parts=parts)
|
# Modify to handle file parts as well
|
||||||
|
agent_parts = []
|
||||||
|
for part in parts:
|
||||||
|
# Handle different part types
|
||||||
|
if part.get("type") == "text":
|
||||||
|
agent_parts.append(part)
|
||||||
|
full_response += part.get("text", "")
|
||||||
|
elif part.get("inlineData") and part["inlineData"].get(
|
||||||
|
"data"
|
||||||
|
):
|
||||||
|
# Convert inline data to file part
|
||||||
|
mime_type = part["inlineData"].get(
|
||||||
|
"mimeType", "application/octet-stream"
|
||||||
|
)
|
||||||
|
file_name = f"file_{uuid_pkg.uuid4().hex}{self._get_extension_from_mime(mime_type)}"
|
||||||
|
file_part = {
|
||||||
|
"type": "file",
|
||||||
|
"file": {
|
||||||
|
"name": file_name,
|
||||||
|
"mimeType": mime_type,
|
||||||
|
"bytes": part["inlineData"]["data"],
|
||||||
|
},
|
||||||
|
}
|
||||||
|
agent_parts.append(file_part)
|
||||||
|
|
||||||
|
# Also send as artifact
|
||||||
|
yield SendTaskStreamingResponse(
|
||||||
|
id=request.id,
|
||||||
|
result=TaskArtifactUpdateEvent(
|
||||||
|
id=request.params.id,
|
||||||
|
artifact=Artifact(
|
||||||
|
parts=[file_part],
|
||||||
|
index=0,
|
||||||
|
name=file_name,
|
||||||
|
description=f"Generated {mime_type} file",
|
||||||
|
),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
if agent_parts:
|
||||||
|
update_message = Message(role=role, parts=agent_parts)
|
||||||
|
final_message = update_message
|
||||||
|
|
||||||
yield SendTaskStreamingResponse(
|
yield SendTaskStreamingResponse(
|
||||||
id=request.id,
|
id=request.id,
|
||||||
@ -431,11 +519,6 @@ class A2ATaskManager:
|
|||||||
final=False,
|
final=False,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
for part in parts:
|
|
||||||
if part.get("type") == "text":
|
|
||||||
full_response += part.get("text", "")
|
|
||||||
final_message = update_message
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error processing chunk: {e}, chunk: {chunk}")
|
logger.error(f"Error processing chunk: {e}, chunk: {chunk}")
|
||||||
continue
|
continue
|
||||||
@ -485,6 +568,29 @@ class A2ATaskManager:
|
|||||||
error=InternalError(message=f"Error streaming task process: {str(e)}"),
|
error=InternalError(message=f"Error streaming task process: {str(e)}"),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _get_extension_from_mime(self, mime_type: str) -> str:
|
||||||
|
"""Get a file extension from MIME type."""
|
||||||
|
if not mime_type:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
mime_map = {
|
||||||
|
"image/jpeg": ".jpg",
|
||||||
|
"image/png": ".png",
|
||||||
|
"image/gif": ".gif",
|
||||||
|
"application/pdf": ".pdf",
|
||||||
|
"text/plain": ".txt",
|
||||||
|
"text/html": ".html",
|
||||||
|
"text/csv": ".csv",
|
||||||
|
"application/json": ".json",
|
||||||
|
"application/xml": ".xml",
|
||||||
|
"application/msword": ".doc",
|
||||||
|
"application/vnd.openxmlformats-officedocument.wordprocessingml.document": ".docx",
|
||||||
|
"application/vnd.ms-excel": ".xls",
|
||||||
|
"application/vnd.openxmlformats-officedocument.spreadsheetml.sheet": ".xlsx",
|
||||||
|
}
|
||||||
|
|
||||||
|
return mime_map.get(mime_type, "")
|
||||||
|
|
||||||
async def update_store(
|
async def update_store(
|
||||||
self,
|
self,
|
||||||
task_id: str,
|
task_id: str,
|
||||||
@ -514,19 +620,193 @@ class A2ATaskManager:
|
|||||||
return task
|
return task
|
||||||
|
|
||||||
def _extract_user_query(self, task_params: TaskSendParams) -> str:
|
def _extract_user_query(self, task_params: TaskSendParams) -> str:
|
||||||
"""Extracts the user query from the task parameters."""
|
"""Extracts the user query from the task parameters and processes any files."""
|
||||||
if not task_params.message or not task_params.message.parts:
|
if not task_params.message or not task_params.message.parts:
|
||||||
raise ValueError("Message or parts are missing in task parameters")
|
raise ValueError("Message or parts are missing in task parameters")
|
||||||
|
|
||||||
part = task_params.message.parts[0]
|
# Process file parts first
|
||||||
if part.type != "text":
|
text_parts = []
|
||||||
raise ValueError("Only text parts are supported")
|
has_files = False
|
||||||
|
file_parts = []
|
||||||
|
|
||||||
return part.text
|
logger.info(
|
||||||
|
f"Extracting query from message with {len(task_params.message.parts)} parts"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Extract text parts and file parts separately
|
||||||
|
for idx, part in enumerate(task_params.message.parts):
|
||||||
|
logger.info(
|
||||||
|
f"Processing part {idx+1}, type: {getattr(part, 'type', 'unknown')}"
|
||||||
|
)
|
||||||
|
if hasattr(part, "type"):
|
||||||
|
if part.type == "text":
|
||||||
|
logger.info(f"Found text part: '{part.text[:50]}...' (truncated)")
|
||||||
|
text_parts.append(part.text)
|
||||||
|
elif part.type == "file":
|
||||||
|
logger.info(
|
||||||
|
f"Found file part: {getattr(getattr(part, 'file', None), 'name', 'unnamed')}"
|
||||||
|
)
|
||||||
|
has_files = True
|
||||||
|
try:
|
||||||
|
processed_file = self._process_file_part(
|
||||||
|
part, task_params.sessionId
|
||||||
|
)
|
||||||
|
if processed_file:
|
||||||
|
file_parts.append(processed_file)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error processing file part: {e}")
|
||||||
|
# Continue with other parts even if a file fails
|
||||||
|
else:
|
||||||
|
logger.warning(f"Unknown part type: {part.type}")
|
||||||
|
else:
|
||||||
|
logger.warning(f"Part has no type attribute: {part}")
|
||||||
|
|
||||||
|
# Store the file parts in self for later use
|
||||||
|
self._last_processed_files = file_parts if file_parts else None
|
||||||
|
|
||||||
|
# If we have at least one text part, use that as the query
|
||||||
|
if text_parts:
|
||||||
|
final_query = " ".join(text_parts)
|
||||||
|
logger.info(
|
||||||
|
f"Final query from text parts: '{final_query[:50]}...' (truncated)"
|
||||||
|
)
|
||||||
|
return final_query
|
||||||
|
# If we only have file parts, create a generic query asking for analysis
|
||||||
|
elif has_files:
|
||||||
|
logger.info("No text parts, using generic query for file analysis")
|
||||||
|
return "Analyze the attached files"
|
||||||
|
else:
|
||||||
|
logger.error("No supported content parts found in the message")
|
||||||
|
raise ValueError("No supported content parts found in the message")
|
||||||
|
|
||||||
|
def _process_file_part(self, part, session_id: str):
|
||||||
|
"""Processes a file part and saves it to the artifact service.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: Processed file information to pass to agent_runner
|
||||||
|
"""
|
||||||
|
if not hasattr(part, "file") or not part.file:
|
||||||
|
logger.warning("File part missing file data")
|
||||||
|
return None
|
||||||
|
|
||||||
|
file_data = part.file
|
||||||
|
|
||||||
|
if not file_data.name:
|
||||||
|
file_data.name = f"file_{uuid_pkg.uuid4().hex}"
|
||||||
|
|
||||||
|
logger.info(f"Processing file {file_data.name} for session {session_id}")
|
||||||
|
|
||||||
|
if file_data.bytes:
|
||||||
|
# Process file data provided as base64 string
|
||||||
|
try:
|
||||||
|
# Convert base64 to bytes
|
||||||
|
logger.info(f"Decoding base64 content for file {file_data.name}")
|
||||||
|
file_bytes = base64.b64decode(file_data.bytes)
|
||||||
|
|
||||||
|
# Determine MIME type based on binary content
|
||||||
|
mime_type = (
|
||||||
|
file_data.mimeType if hasattr(file_data, "mimeType") else None
|
||||||
|
)
|
||||||
|
|
||||||
|
if not mime_type or mime_type == "application/octet-stream":
|
||||||
|
# Detection by byte signature
|
||||||
|
if file_bytes.startswith(b"\xff\xd8\xff"): # JPEG signature
|
||||||
|
mime_type = "image/jpeg"
|
||||||
|
elif file_bytes.startswith(b"\x89PNG\r\n\x1a\n"): # PNG signature
|
||||||
|
mime_type = "image/png"
|
||||||
|
elif file_bytes.startswith(b"GIF87a") or file_bytes.startswith(
|
||||||
|
b"GIF89a"
|
||||||
|
): # GIF
|
||||||
|
mime_type = "image/gif"
|
||||||
|
elif file_bytes.startswith(b"%PDF"): # PDF
|
||||||
|
mime_type = "application/pdf"
|
||||||
|
else:
|
||||||
|
# Fallback to avoid generic type in images
|
||||||
|
if file_data.name.lower().endswith((".jpg", ".jpeg")):
|
||||||
|
mime_type = "image/jpeg"
|
||||||
|
elif file_data.name.lower().endswith(".png"):
|
||||||
|
mime_type = "image/png"
|
||||||
|
elif file_data.name.lower().endswith(".gif"):
|
||||||
|
mime_type = "image/gif"
|
||||||
|
elif file_data.name.lower().endswith(".pdf"):
|
||||||
|
mime_type = "application/pdf"
|
||||||
|
else:
|
||||||
|
mime_type = "application/octet-stream"
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Decoded file size: {len(file_bytes)} bytes, MIME type: {mime_type}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Split session_id to get app_name and user_id
|
||||||
|
parts = session_id.split("_")
|
||||||
|
if len(parts) != 2:
|
||||||
|
user_id = session_id
|
||||||
|
app_name = "a2a"
|
||||||
|
else:
|
||||||
|
user_id, app_name = parts
|
||||||
|
|
||||||
|
# Create artifact Part
|
||||||
|
logger.info(f"Creating artifact Part for file {file_data.name}")
|
||||||
|
artifact = Part(inline_data=Blob(mime_type=mime_type, data=file_bytes))
|
||||||
|
|
||||||
|
# Save to artifact service
|
||||||
|
logger.info(
|
||||||
|
f"Saving artifact {file_data.name} to {app_name}/{user_id}/{session_id}"
|
||||||
|
)
|
||||||
|
version = artifacts_service.save_artifact(
|
||||||
|
app_name=app_name,
|
||||||
|
user_id=user_id,
|
||||||
|
session_id=session_id,
|
||||||
|
filename=file_data.name,
|
||||||
|
artifact=artifact,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Successfully saved file {file_data.name} (version {version}) for session {session_id}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Import the FileData model from the chat schema
|
||||||
|
from src.schemas.chat import FileData
|
||||||
|
|
||||||
|
# Create a FileData object instead of a dictionary
|
||||||
|
# This is compatible with what agent_runner.py expects
|
||||||
|
return FileData(
|
||||||
|
filename=file_data.name,
|
||||||
|
content_type=mime_type,
|
||||||
|
data=file_data.bytes, # Keep the original base64 format
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error processing file data: {str(e)}")
|
||||||
|
# Log more details about the error to help with debugging
|
||||||
|
import traceback
|
||||||
|
|
||||||
|
logger.error(f"Full traceback: {traceback.format_exc()}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
elif file_data.uri:
|
||||||
|
# Handling URIs would require additional implementation
|
||||||
|
# For now, log that we received a URI but can't process it
|
||||||
|
logger.warning(f"File URI references not yet implemented: {file_data.uri}")
|
||||||
|
# Future enhancement: fetch the file from the URI and save it
|
||||||
|
return None
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
async def _run_agent(self, agent: Agent, query: str, session_id: str) -> dict:
|
async def _run_agent(self, agent: Agent, query: str, session_id: str) -> dict:
|
||||||
"""Executes the agent to process the user query."""
|
"""Executes the agent to process the user query."""
|
||||||
try:
|
try:
|
||||||
|
files = getattr(self, "_last_processed_files", None)
|
||||||
|
|
||||||
|
if files:
|
||||||
|
logger.info(f"Passing {len(files)} files to run_agent")
|
||||||
|
for file_info in files:
|
||||||
|
logger.info(
|
||||||
|
f"File being sent: {file_info.filename} ({file_info.content_type})"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.info("No files to pass to run_agent")
|
||||||
|
|
||||||
# We call the same function used in the chat API
|
# We call the same function used in the chat API
|
||||||
return await run_agent(
|
return await run_agent(
|
||||||
agent_id=str(agent.id),
|
agent_id=str(agent.id),
|
||||||
@ -536,6 +816,7 @@ class A2ATaskManager:
|
|||||||
artifacts_service=artifacts_service,
|
artifacts_service=artifacts_service,
|
||||||
memory_service=memory_service,
|
memory_service=memory_service,
|
||||||
db=self.db,
|
db=self.db,
|
||||||
|
files=files,
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error running agent: {e}")
|
logger.error(f"Error running agent: {e}")
|
||||||
|
@ -28,7 +28,7 @@
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
from google.adk.runners import Runner
|
from google.adk.runners import Runner
|
||||||
from google.genai.types import Content, Part
|
from google.genai.types import Content, Part, Blob
|
||||||
from google.adk.sessions import DatabaseSessionService
|
from google.adk.sessions import DatabaseSessionService
|
||||||
from google.adk.memory import InMemoryMemoryService
|
from google.adk.memory import InMemoryMemoryService
|
||||||
from google.adk.artifacts.in_memory_artifact_service import InMemoryArtifactService
|
from google.adk.artifacts.in_memory_artifact_service import InMemoryArtifactService
|
||||||
@ -42,6 +42,7 @@ import asyncio
|
|||||||
import json
|
import json
|
||||||
from src.utils.otel import get_tracer
|
from src.utils.otel import get_tracer
|
||||||
from opentelemetry import trace
|
from opentelemetry import trace
|
||||||
|
import base64
|
||||||
|
|
||||||
logger = setup_logger(__name__)
|
logger = setup_logger(__name__)
|
||||||
|
|
||||||
@ -56,6 +57,7 @@ async def run_agent(
|
|||||||
db: Session,
|
db: Session,
|
||||||
session_id: Optional[str] = None,
|
session_id: Optional[str] = None,
|
||||||
timeout: float = 60.0,
|
timeout: float = 60.0,
|
||||||
|
files: Optional[list] = None,
|
||||||
):
|
):
|
||||||
tracer = get_tracer()
|
tracer = get_tracer()
|
||||||
with tracer.start_as_current_span(
|
with tracer.start_as_current_span(
|
||||||
@ -65,6 +67,7 @@ async def run_agent(
|
|||||||
"external_id": external_id,
|
"external_id": external_id,
|
||||||
"session_id": session_id or f"{external_id}_{agent_id}",
|
"session_id": session_id or f"{external_id}_{agent_id}",
|
||||||
"message": message,
|
"message": message,
|
||||||
|
"has_files": files is not None and len(files) > 0,
|
||||||
},
|
},
|
||||||
):
|
):
|
||||||
exit_stack = None
|
exit_stack = None
|
||||||
@ -74,6 +77,9 @@ async def run_agent(
|
|||||||
)
|
)
|
||||||
logger.info(f"Received message: {message}")
|
logger.info(f"Received message: {message}")
|
||||||
|
|
||||||
|
if files and len(files) > 0:
|
||||||
|
logger.info(f"Received {len(files)} files with message")
|
||||||
|
|
||||||
get_root_agent = get_agent(db, agent_id)
|
get_root_agent = get_agent(db, agent_id)
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Root agent found: {get_root_agent.name} (type: {get_root_agent.type})"
|
f"Root agent found: {get_root_agent.name} (type: {get_root_agent.type})"
|
||||||
@ -113,7 +119,63 @@ async def run_agent(
|
|||||||
session_id=adk_session_id,
|
session_id=adk_session_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
content = Content(role="user", parts=[Part(text=message)])
|
file_parts = []
|
||||||
|
if files and len(files) > 0:
|
||||||
|
for file_data in files:
|
||||||
|
try:
|
||||||
|
file_bytes = base64.b64decode(file_data.data)
|
||||||
|
|
||||||
|
logger.info(f"DEBUG - Processing file: {file_data.filename}")
|
||||||
|
logger.info(f"DEBUG - File size: {len(file_bytes)} bytes")
|
||||||
|
logger.info(f"DEBUG - MIME type: '{file_data.content_type}'")
|
||||||
|
logger.info(f"DEBUG - First 20 bytes: {file_bytes[:20]}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
file_part = Part(
|
||||||
|
inline_data=Blob(
|
||||||
|
mime_type=file_data.content_type, data=file_bytes
|
||||||
|
)
|
||||||
|
)
|
||||||
|
logger.info(f"DEBUG - Part created successfully")
|
||||||
|
except Exception as part_error:
|
||||||
|
logger.error(
|
||||||
|
f"DEBUG - Error creating Part: {str(part_error)}"
|
||||||
|
)
|
||||||
|
logger.error(
|
||||||
|
f"DEBUG - Error type: {type(part_error).__name__}"
|
||||||
|
)
|
||||||
|
import traceback
|
||||||
|
|
||||||
|
logger.error(
|
||||||
|
f"DEBUG - Stack trace: {traceback.format_exc()}"
|
||||||
|
)
|
||||||
|
raise
|
||||||
|
|
||||||
|
# Save the file in the ArtifactService
|
||||||
|
version = artifacts_service.save_artifact(
|
||||||
|
app_name=agent_id,
|
||||||
|
user_id=external_id,
|
||||||
|
session_id=adk_session_id,
|
||||||
|
filename=file_data.filename,
|
||||||
|
artifact=file_part,
|
||||||
|
)
|
||||||
|
logger.info(
|
||||||
|
f"Saved file {file_data.filename} as version {version}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add the Part to the list of parts for the message content
|
||||||
|
file_parts.append(file_part)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"Error processing file {file_data.filename}: {str(e)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create the content with the text message and the files
|
||||||
|
parts = [Part(text=message)]
|
||||||
|
if file_parts:
|
||||||
|
parts.extend(file_parts)
|
||||||
|
|
||||||
|
content = Content(role="user", parts=parts)
|
||||||
logger.info("Starting agent execution")
|
logger.info("Starting agent execution")
|
||||||
|
|
||||||
final_response_text = "No final response captured."
|
final_response_text = "No final response captured."
|
||||||
@ -256,6 +318,7 @@ async def run_agent_stream(
|
|||||||
memory_service: InMemoryMemoryService,
|
memory_service: InMemoryMemoryService,
|
||||||
db: Session,
|
db: Session,
|
||||||
session_id: Optional[str] = None,
|
session_id: Optional[str] = None,
|
||||||
|
files: Optional[list] = None,
|
||||||
) -> AsyncGenerator[str, None]:
|
) -> AsyncGenerator[str, None]:
|
||||||
tracer = get_tracer()
|
tracer = get_tracer()
|
||||||
span = tracer.start_span(
|
span = tracer.start_span(
|
||||||
@ -265,6 +328,7 @@ async def run_agent_stream(
|
|||||||
"external_id": external_id,
|
"external_id": external_id,
|
||||||
"session_id": session_id or f"{external_id}_{agent_id}",
|
"session_id": session_id or f"{external_id}_{agent_id}",
|
||||||
"message": message,
|
"message": message,
|
||||||
|
"has_files": files is not None and len(files) > 0,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
@ -275,6 +339,9 @@ async def run_agent_stream(
|
|||||||
)
|
)
|
||||||
logger.info(f"Received message: {message}")
|
logger.info(f"Received message: {message}")
|
||||||
|
|
||||||
|
if files and len(files) > 0:
|
||||||
|
logger.info(f"Received {len(files)} files with message")
|
||||||
|
|
||||||
get_root_agent = get_agent(db, agent_id)
|
get_root_agent = get_agent(db, agent_id)
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Root agent found: {get_root_agent.name} (type: {get_root_agent.type})"
|
f"Root agent found: {get_root_agent.name} (type: {get_root_agent.type})"
|
||||||
@ -314,7 +381,72 @@ async def run_agent_stream(
|
|||||||
session_id=adk_session_id,
|
session_id=adk_session_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
content = Content(role="user", parts=[Part(text=message)])
|
# Process the received files
|
||||||
|
file_parts = []
|
||||||
|
if files and len(files) > 0:
|
||||||
|
for file_data in files:
|
||||||
|
try:
|
||||||
|
# Decode the base64 file
|
||||||
|
file_bytes = base64.b64decode(file_data.data)
|
||||||
|
|
||||||
|
# Detailed debug
|
||||||
|
logger.info(
|
||||||
|
f"DEBUG - Processing file: {file_data.filename}"
|
||||||
|
)
|
||||||
|
logger.info(f"DEBUG - File size: {len(file_bytes)} bytes")
|
||||||
|
logger.info(
|
||||||
|
f"DEBUG - MIME type: '{file_data.content_type}'"
|
||||||
|
)
|
||||||
|
logger.info(f"DEBUG - First 20 bytes: {file_bytes[:20]}")
|
||||||
|
|
||||||
|
# Create a Part for the file using the default constructor
|
||||||
|
try:
|
||||||
|
file_part = Part(
|
||||||
|
inline_data=Blob(
|
||||||
|
mime_type=file_data.content_type,
|
||||||
|
data=file_bytes,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
logger.info(f"DEBUG - Part created successfully")
|
||||||
|
except Exception as part_error:
|
||||||
|
logger.error(
|
||||||
|
f"DEBUG - Error creating Part: {str(part_error)}"
|
||||||
|
)
|
||||||
|
logger.error(
|
||||||
|
f"DEBUG - Error type: {type(part_error).__name__}"
|
||||||
|
)
|
||||||
|
import traceback
|
||||||
|
|
||||||
|
logger.error(
|
||||||
|
f"DEBUG - Stack trace: {traceback.format_exc()}"
|
||||||
|
)
|
||||||
|
raise
|
||||||
|
|
||||||
|
# Save the file in the ArtifactService
|
||||||
|
version = artifacts_service.save_artifact(
|
||||||
|
app_name=agent_id,
|
||||||
|
user_id=external_id,
|
||||||
|
session_id=adk_session_id,
|
||||||
|
filename=file_data.filename,
|
||||||
|
artifact=file_part,
|
||||||
|
)
|
||||||
|
logger.info(
|
||||||
|
f"Saved file {file_data.filename} as version {version}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add the Part to the list of parts for the message content
|
||||||
|
file_parts.append(file_part)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"Error processing file {file_data.filename}: {str(e)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create the content with the text message and the files
|
||||||
|
parts = [Part(text=message)]
|
||||||
|
if file_parts:
|
||||||
|
parts.extend(file_parts)
|
||||||
|
|
||||||
|
content = Content(role="user", parts=parts)
|
||||||
logger.info("Starting agent streaming execution")
|
logger.info("Starting agent streaming execution")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
@ -27,10 +27,16 @@
|
|||||||
└──────────────────────────────────────────────────────────────────────────────┘
|
└──────────────────────────────────────────────────────────────────────────────┘
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import base64
|
||||||
|
import uuid
|
||||||
|
from typing import Dict, List, Any, Optional
|
||||||
|
from google.genai.types import Part, Blob
|
||||||
|
|
||||||
from src.schemas.a2a_types import (
|
from src.schemas.a2a_types import (
|
||||||
ContentTypeNotSupportedError,
|
ContentTypeNotSupportedError,
|
||||||
JSONRPCResponse,
|
JSONRPCResponse,
|
||||||
UnsupportedOperationError,
|
UnsupportedOperationError,
|
||||||
|
Message,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -55,3 +61,130 @@ def new_incompatible_types_error(request_id):
|
|||||||
|
|
||||||
def new_not_implemented_error(request_id):
|
def new_not_implemented_error(request_id):
|
||||||
return JSONRPCResponse(id=request_id, error=UnsupportedOperationError())
|
return JSONRPCResponse(id=request_id, error=UnsupportedOperationError())
|
||||||
|
|
||||||
|
|
||||||
|
def extract_files_from_message(message: Message) -> List[Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
Extract file parts from an A2A message.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
message: An A2A Message object
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of file parts extracted from the message
|
||||||
|
"""
|
||||||
|
if not message or not message.parts:
|
||||||
|
return []
|
||||||
|
|
||||||
|
files = []
|
||||||
|
for part in message.parts:
|
||||||
|
if hasattr(part, "type") and part.type == "file" and hasattr(part, "file"):
|
||||||
|
files.append(part)
|
||||||
|
|
||||||
|
return files
|
||||||
|
|
||||||
|
|
||||||
|
def a2a_part_to_adk_part(a2a_part: Dict[str, Any]) -> Optional[Part]:
|
||||||
|
"""
|
||||||
|
Convert an A2A protocol part to an ADK Part object.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
a2a_part: An A2A part dictionary
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Converted ADK Part object or None if conversion not possible
|
||||||
|
"""
|
||||||
|
part_type = a2a_part.get("type")
|
||||||
|
if part_type == "file" and "file" in a2a_part:
|
||||||
|
file_data = a2a_part["file"]
|
||||||
|
if "bytes" in file_data:
|
||||||
|
try:
|
||||||
|
# Convert base64 to bytes
|
||||||
|
file_bytes = base64.b64decode(file_data["bytes"])
|
||||||
|
mime_type = file_data.get("mimeType", "application/octet-stream")
|
||||||
|
|
||||||
|
# Create ADK Part
|
||||||
|
return Part(inline_data=Blob(mime_type=mime_type, data=file_bytes))
|
||||||
|
except Exception:
|
||||||
|
return None
|
||||||
|
elif part_type == "text" and "text" in a2a_part:
|
||||||
|
# For text parts, we could create a text blob if needed
|
||||||
|
return None
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def adk_part_to_a2a_part(
|
||||||
|
adk_part: Part, filename: Optional[str] = None
|
||||||
|
) -> Optional[Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
Convert an ADK Part object to an A2A protocol part.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
adk_part: An ADK Part object
|
||||||
|
filename: Optional filename to use
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Converted A2A Part dictionary or None if conversion not possible
|
||||||
|
"""
|
||||||
|
if hasattr(adk_part, "inline_data") and adk_part.inline_data:
|
||||||
|
if adk_part.inline_data.data and adk_part.inline_data.mime_type:
|
||||||
|
# Convert binary data to base64
|
||||||
|
file_bytes = adk_part.inline_data.data
|
||||||
|
mime_type = adk_part.inline_data.mime_type
|
||||||
|
|
||||||
|
# Generate filename if not provided
|
||||||
|
if not filename:
|
||||||
|
ext = get_extension_from_mime(mime_type)
|
||||||
|
filename = f"file_{uuid.uuid4().hex}{ext}"
|
||||||
|
|
||||||
|
# Convert to A2A FilePart dict
|
||||||
|
return {
|
||||||
|
"type": "file",
|
||||||
|
"file": {
|
||||||
|
"name": filename,
|
||||||
|
"mimeType": mime_type,
|
||||||
|
"bytes": (
|
||||||
|
base64.b64encode(file_bytes).decode("utf-8")
|
||||||
|
if isinstance(file_bytes, bytes)
|
||||||
|
else str(file_bytes)
|
||||||
|
),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
elif hasattr(adk_part, "text") and adk_part.text:
|
||||||
|
# Convert text part
|
||||||
|
return {"type": "text", "text": adk_part.text}
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def get_extension_from_mime(mime_type: str) -> str:
|
||||||
|
"""
|
||||||
|
Get a file extension from MIME type.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
mime_type: MIME type string
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Appropriate file extension with leading dot
|
||||||
|
"""
|
||||||
|
if not mime_type:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
mime_map = {
|
||||||
|
"image/jpeg": ".jpg",
|
||||||
|
"image/png": ".png",
|
||||||
|
"image/gif": ".gif",
|
||||||
|
"application/pdf": ".pdf",
|
||||||
|
"text/plain": ".txt",
|
||||||
|
"text/html": ".html",
|
||||||
|
"text/csv": ".csv",
|
||||||
|
"application/json": ".json",
|
||||||
|
"application/xml": ".xml",
|
||||||
|
"application/msword": ".doc",
|
||||||
|
"application/vnd.openxmlformats-officedocument.wordprocessingml.document": ".docx",
|
||||||
|
"application/vnd.ms-excel": ".xls",
|
||||||
|
"application/vnd.openxmlformats-officedocument.spreadsheetml.sheet": ".xlsx",
|
||||||
|
}
|
||||||
|
|
||||||
|
return mime_map.get(mime_type, "")
|
||||||
|
Loading…
Reference in New Issue
Block a user