950 lines
34 KiB
Python
950 lines
34 KiB
Python
"""
|
|
A2A Task Manager Service.
|
|
|
|
This service implements task management for the A2A protocol, handling task lifecycle
|
|
including execution, streaming, push notifications, status queries, and cancellation.
|
|
"""
|
|
|
|
import asyncio
|
|
import logging
|
|
from datetime import datetime
|
|
from typing import Any, Dict, Union, AsyncIterable
|
|
import uuid
|
|
|
|
from src.schemas.a2a.exceptions import (
|
|
TaskNotFoundError,
|
|
TaskNotCancelableError,
|
|
PushNotificationNotSupportedError,
|
|
InternalError,
|
|
ContentTypeNotSupportedError,
|
|
)
|
|
|
|
from src.schemas.a2a.types import (
|
|
JSONRPCResponse,
|
|
GetTaskRequest,
|
|
SendTaskRequest,
|
|
CancelTaskRequest,
|
|
SetTaskPushNotificationRequest,
|
|
GetTaskPushNotificationRequest,
|
|
GetTaskResponse,
|
|
CancelTaskResponse,
|
|
SendTaskResponse,
|
|
SetTaskPushNotificationResponse,
|
|
GetTaskPushNotificationResponse,
|
|
TaskSendParams,
|
|
TaskStatus,
|
|
TaskState,
|
|
TaskResubscriptionRequest,
|
|
SendTaskStreamingRequest,
|
|
SendTaskStreamingResponse,
|
|
Artifact,
|
|
PushNotificationConfig,
|
|
TaskStatusUpdateEvent,
|
|
TaskArtifactUpdateEvent,
|
|
JSONRPCError,
|
|
TaskPushNotificationConfig,
|
|
Message,
|
|
TextPart,
|
|
Task,
|
|
)
|
|
from src.services.redis_cache_service import RedisCacheService
|
|
from src.utils.a2a_utils import (
|
|
are_modalities_compatible,
|
|
new_incompatible_types_error,
|
|
)
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class A2ATaskManager:
|
|
"""
|
|
A2A Task Manager implementation.
|
|
|
|
This class manages the lifecycle of A2A tasks, including:
|
|
- Task submission and execution
|
|
- Task status queries
|
|
- Task cancellation
|
|
- Push notification configuration
|
|
- SSE streaming for real-time updates
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
redis_cache: RedisCacheService,
|
|
agent_runner=None,
|
|
streaming_service=None,
|
|
push_notification_service=None,
|
|
db=None,
|
|
):
|
|
"""
|
|
Initialize the A2A Task Manager.
|
|
|
|
Args:
|
|
redis_cache: Redis cache service for task storage
|
|
agent_runner: Agent runner service for task execution
|
|
streaming_service: Streaming service for SSE
|
|
push_notification_service: Service for push notifications
|
|
db: Database session
|
|
"""
|
|
self.redis_cache = redis_cache
|
|
self.agent_runner = agent_runner
|
|
self.streaming_service = streaming_service
|
|
self.push_notification_service = push_notification_service
|
|
self.db = db
|
|
self.lock = asyncio.Lock()
|
|
self.subscriber_lock = asyncio.Lock()
|
|
self.task_sse_subscribers = {}
|
|
|
|
async def on_get_task(self, request: GetTaskRequest) -> GetTaskResponse:
|
|
"""
|
|
Handle request to get task information.
|
|
|
|
Args:
|
|
request: A2A Get Task request
|
|
|
|
Returns:
|
|
Response with task details
|
|
"""
|
|
try:
|
|
task_id = request.params.id
|
|
history_length = request.params.historyLength
|
|
|
|
# Get task data from cache
|
|
task_data = await self.redis_cache.get(f"task:{task_id}")
|
|
|
|
if not task_data:
|
|
logger.warning(f"Task not found: {task_id}")
|
|
return GetTaskResponse(id=request.id, error=TaskNotFoundError())
|
|
|
|
# Create a Task instance from cache data
|
|
task = Task.model_validate(task_data)
|
|
|
|
# If historyLength parameter is present, handle the history
|
|
if history_length is not None and task.history:
|
|
if history_length == 0:
|
|
task.history = []
|
|
elif len(task.history) > history_length:
|
|
task.history = task.history[-history_length:]
|
|
|
|
return GetTaskResponse(id=request.id, result=task)
|
|
except Exception as e:
|
|
logger.error(f"Error processing on_get_task: {str(e)}")
|
|
return GetTaskResponse(id=request.id, error=InternalError(message=str(e)))
|
|
|
|
async def on_cancel_task(self, request: CancelTaskRequest) -> CancelTaskResponse:
|
|
"""
|
|
Handle request to cancel a running task.
|
|
|
|
Args:
|
|
request: The JSON-RPC request to cancel a task
|
|
|
|
Returns:
|
|
Response with updated task data or error
|
|
"""
|
|
logger.info(f"Cancelling task {request.params.id}")
|
|
task_id_params = request.params
|
|
|
|
try:
|
|
task_data = await self.redis_cache.get(f"task:{task_id_params.id}")
|
|
if not task_data:
|
|
logger.warning(f"Task {task_id_params.id} not found for cancellation")
|
|
return CancelTaskResponse(id=request.id, error=TaskNotFoundError())
|
|
|
|
# Check if task can be cancelled
|
|
current_state = task_data.get("status", {}).get("state")
|
|
if current_state not in [TaskState.SUBMITTED, TaskState.WORKING]:
|
|
logger.warning(
|
|
f"Task {task_id_params.id} in state {current_state} cannot be cancelled"
|
|
)
|
|
return CancelTaskResponse(id=request.id, error=TaskNotCancelableError())
|
|
|
|
# Update task status to cancelled
|
|
task_data["status"] = {
|
|
"state": TaskState.CANCELED,
|
|
"timestamp": datetime.now().isoformat(),
|
|
"message": {
|
|
"role": "agent",
|
|
"parts": [{"type": "text", "text": "Task cancelled by user"}],
|
|
},
|
|
}
|
|
|
|
# Save updated task data
|
|
await self.redis_cache.set(f"task:{task_id_params.id}", task_data)
|
|
|
|
# Send push notification if configured
|
|
await self._send_push_notification_for_task(
|
|
task_id_params.id, "canceled", system_message="Task cancelled by user"
|
|
)
|
|
|
|
# Publish event to SSE subscribers
|
|
await self._publish_task_update(
|
|
task_id_params.id,
|
|
TaskStatusUpdateEvent(
|
|
id=task_id_params.id,
|
|
status=TaskStatus(
|
|
state=TaskState.CANCELED,
|
|
timestamp=datetime.now(),
|
|
message=Message(
|
|
role="agent",
|
|
parts=[TextPart(text="Task cancelled by user")],
|
|
),
|
|
),
|
|
final=True,
|
|
),
|
|
)
|
|
|
|
return CancelTaskResponse(id=request.id, result=task_data)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error cancelling task: {str(e)}", exc_info=True)
|
|
return CancelTaskResponse(
|
|
id=request.id,
|
|
error=InternalError(message=f"Error cancelling task: {str(e)}"),
|
|
)
|
|
|
|
async def on_send_task(self, request: SendTaskRequest) -> SendTaskResponse:
|
|
"""
|
|
Handle request to send a new task.
|
|
|
|
Args:
|
|
request: Send Task request
|
|
|
|
Returns:
|
|
Response with the created task details
|
|
"""
|
|
try:
|
|
params = request.params
|
|
task_id = params.id
|
|
logger.info(f"Receiving task {task_id}")
|
|
|
|
# Check if a task with this ID already exists
|
|
existing_task = await self.redis_cache.get(f"task:{task_id}")
|
|
if existing_task:
|
|
# If the task already exists and is in progress, return the current task
|
|
if existing_task.get("status", {}).get("state") in [
|
|
TaskState.WORKING,
|
|
TaskState.COMPLETED,
|
|
]:
|
|
return SendTaskResponse(
|
|
id=request.id, result=Task.model_validate(existing_task)
|
|
)
|
|
|
|
# If the task exists but failed or was canceled, we can reprocess it
|
|
logger.info(f"Reprocessing existing task {task_id}")
|
|
|
|
# Check modality compatibility
|
|
server_output_modes = []
|
|
if self.agent_runner:
|
|
# Try to get supported modes from the agent
|
|
try:
|
|
server_output_modes = await self.agent_runner.get_supported_modes()
|
|
except Exception as e:
|
|
logger.warning(f"Error getting supported modes: {str(e)}")
|
|
server_output_modes = ["text"] # Fallback to text
|
|
|
|
if not are_modalities_compatible(
|
|
server_output_modes, params.acceptedOutputModes
|
|
):
|
|
logger.warning(
|
|
f"Incompatible modes: server={server_output_modes}, client={params.acceptedOutputModes}"
|
|
)
|
|
return SendTaskResponse(
|
|
id=request.id, error=ContentTypeNotSupportedError()
|
|
)
|
|
|
|
# Create task data
|
|
task_data = await self._create_task_data(params)
|
|
|
|
# Store task in cache
|
|
await self.redis_cache.set(f"task:{task_id}", task_data)
|
|
|
|
# Configure push notifications, if provided
|
|
if params.pushNotification:
|
|
await self.redis_cache.set(
|
|
f"task_notification:{task_id}", params.pushNotification.model_dump()
|
|
)
|
|
|
|
# Execute task SYNCHRONOUSLY instead of in background
|
|
# This is the key change for A2A compatibility
|
|
task_data = await self._execute_task(task_data, params)
|
|
|
|
# Convert to Task object and return
|
|
task = Task.model_validate(task_data)
|
|
return SendTaskResponse(id=request.id, result=task)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error processing on_send_task: {str(e)}")
|
|
return SendTaskResponse(id=request.id, error=InternalError(message=str(e)))
|
|
|
|
async def on_send_task_subscribe(
|
|
self, request: SendTaskStreamingRequest
|
|
) -> Union[AsyncIterable[SendTaskStreamingResponse], JSONRPCResponse]:
|
|
"""
|
|
Handle request to send a task and subscribe to streaming updates.
|
|
|
|
Args:
|
|
request: The JSON-RPC request to send a task with streaming
|
|
|
|
Returns:
|
|
Stream of events or error response
|
|
"""
|
|
logger.info(f"Sending task with streaming {request.params.id}")
|
|
task_send_params = request.params
|
|
|
|
try:
|
|
# Check output mode compatibility
|
|
if not are_modalities_compatible(
|
|
["text", "application/json"], # Default supported modes
|
|
task_send_params.acceptedOutputModes,
|
|
):
|
|
return new_incompatible_types_error(request.id)
|
|
|
|
# Create initial task data
|
|
task_data = await self._create_task_data(task_send_params)
|
|
|
|
# Setup SSE consumer
|
|
sse_queue = await self._setup_sse_consumer(task_send_params.id)
|
|
|
|
# Execute task asynchronously (fire and forget)
|
|
asyncio.create_task(self._execute_task(task_data, task_send_params))
|
|
|
|
# Return generator to dequeue events for SSE
|
|
return self._dequeue_events_for_sse(
|
|
request.id, task_send_params.id, sse_queue
|
|
)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error setting up streaming task: {str(e)}", exc_info=True)
|
|
return SendTaskStreamingResponse(
|
|
id=request.id,
|
|
error=InternalError(
|
|
message=f"Error setting up streaming task: {str(e)}"
|
|
),
|
|
)
|
|
|
|
async def on_set_task_push_notification(
|
|
self, request: SetTaskPushNotificationRequest
|
|
) -> SetTaskPushNotificationResponse:
|
|
"""
|
|
Configure push notifications for a task.
|
|
|
|
Args:
|
|
request: The JSON-RPC request to set push notification
|
|
|
|
Returns:
|
|
Response with configuration or error
|
|
"""
|
|
logger.info(f"Setting push notification for task {request.params.id}")
|
|
task_notification_params = request.params
|
|
|
|
try:
|
|
if not self.push_notification_service:
|
|
logger.warning("Push notifications not supported")
|
|
return SetTaskPushNotificationResponse(
|
|
id=request.id, error=PushNotificationNotSupportedError()
|
|
)
|
|
|
|
# Check if task exists
|
|
task_data = await self.redis_cache.get(
|
|
f"task:{task_notification_params.id}"
|
|
)
|
|
if not task_data:
|
|
logger.warning(
|
|
f"Task {task_notification_params.id} not found for setting push notification"
|
|
)
|
|
return SetTaskPushNotificationResponse(
|
|
id=request.id, error=TaskNotFoundError()
|
|
)
|
|
|
|
# Save push notification config
|
|
config = {
|
|
"url": task_notification_params.pushNotificationConfig.url,
|
|
"headers": {}, # Add auth headers if needed
|
|
}
|
|
|
|
await self.redis_cache.set(
|
|
f"task:{task_notification_params.id}:push", config
|
|
)
|
|
|
|
return SetTaskPushNotificationResponse(
|
|
id=request.id, result=task_notification_params
|
|
)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error setting push notification: {str(e)}", exc_info=True)
|
|
return SetTaskPushNotificationResponse(
|
|
id=request.id,
|
|
error=InternalError(
|
|
message=f"Error setting push notification: {str(e)}"
|
|
),
|
|
)
|
|
|
|
async def on_get_task_push_notification(
|
|
self, request: GetTaskPushNotificationRequest
|
|
) -> GetTaskPushNotificationResponse:
|
|
"""
|
|
Get push notification configuration for a task.
|
|
|
|
Args:
|
|
request: The JSON-RPC request to get push notification config
|
|
|
|
Returns:
|
|
Response with configuration or error
|
|
"""
|
|
logger.info(f"Getting push notification for task {request.params.id}")
|
|
task_params = request.params
|
|
|
|
try:
|
|
if not self.push_notification_service:
|
|
logger.warning("Push notifications not supported")
|
|
return GetTaskPushNotificationResponse(
|
|
id=request.id, error=PushNotificationNotSupportedError()
|
|
)
|
|
|
|
# Check if task exists
|
|
task_data = await self.redis_cache.get(f"task:{task_params.id}")
|
|
if not task_data:
|
|
logger.warning(
|
|
f"Task {task_params.id} not found for getting push notification"
|
|
)
|
|
return GetTaskPushNotificationResponse(
|
|
id=request.id, error=TaskNotFoundError()
|
|
)
|
|
|
|
# Get push notification config
|
|
config = await self.redis_cache.get(f"task:{task_params.id}:push")
|
|
if not config:
|
|
logger.warning(f"No push notification config for task {task_params.id}")
|
|
return GetTaskPushNotificationResponse(
|
|
id=request.id,
|
|
error=InternalError(
|
|
message="No push notification configuration found"
|
|
),
|
|
)
|
|
|
|
result = TaskPushNotificationConfig(
|
|
id=task_params.id,
|
|
pushNotificationConfig=PushNotificationConfig(
|
|
url=config.get("url"), token=None, authentication=None
|
|
),
|
|
)
|
|
|
|
return GetTaskPushNotificationResponse(id=request.id, result=result)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error getting push notification: {str(e)}", exc_info=True)
|
|
return GetTaskPushNotificationResponse(
|
|
id=request.id,
|
|
error=InternalError(
|
|
message=f"Error getting push notification: {str(e)}"
|
|
),
|
|
)
|
|
|
|
async def on_resubscribe_to_task(
|
|
self, request: TaskResubscriptionRequest
|
|
) -> Union[AsyncIterable[SendTaskStreamingResponse], JSONRPCResponse]:
|
|
"""
|
|
Resubscribe to a task's streaming events.
|
|
|
|
Args:
|
|
request: The JSON-RPC request to resubscribe
|
|
|
|
Returns:
|
|
Stream of events or error response
|
|
"""
|
|
logger.info(f"Resubscribing to task {request.params.id}")
|
|
task_params = request.params
|
|
|
|
try:
|
|
# Check if task exists
|
|
task_data = await self.redis_cache.get(f"task:{task_params.id}")
|
|
if not task_data:
|
|
logger.warning(f"Task {task_params.id} not found for resubscription")
|
|
return JSONRPCResponse(id=request.id, error=TaskNotFoundError())
|
|
|
|
# Setup SSE consumer with resubscribe flag
|
|
try:
|
|
sse_queue = await self._setup_sse_consumer(
|
|
task_params.id, is_resubscribe=True
|
|
)
|
|
except ValueError:
|
|
logger.warning(
|
|
f"Task {task_params.id} not available for resubscription"
|
|
)
|
|
return JSONRPCResponse(
|
|
id=request.id,
|
|
error=InternalError(
|
|
message="Task not available for resubscription"
|
|
),
|
|
)
|
|
|
|
# Send initial status update to the new subscriber
|
|
status = task_data.get("status", {})
|
|
final = status.get("state") in [
|
|
TaskState.COMPLETED,
|
|
TaskState.FAILED,
|
|
TaskState.CANCELED,
|
|
]
|
|
|
|
# Convert to TaskStatus object
|
|
task_status = TaskStatus(
|
|
state=status.get("state", TaskState.UNKNOWN),
|
|
timestamp=datetime.fromisoformat(
|
|
status.get("timestamp", datetime.now().isoformat())
|
|
),
|
|
message=status.get("message"),
|
|
)
|
|
|
|
# Publish to the specific queue
|
|
await sse_queue.put(
|
|
TaskStatusUpdateEvent(
|
|
id=task_params.id, status=task_status, final=final
|
|
)
|
|
)
|
|
|
|
# Return generator to dequeue events for SSE
|
|
return self._dequeue_events_for_sse(request.id, task_params.id, sse_queue)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error resubscribing to task: {str(e)}", exc_info=True)
|
|
return JSONRPCResponse(
|
|
id=request.id,
|
|
error=InternalError(message=f"Error resubscribing to task: {str(e)}"),
|
|
)
|
|
|
|
async def _create_task_data(self, params: TaskSendParams) -> Dict[str, Any]:
|
|
"""
|
|
Create initial task data structure.
|
|
|
|
Args:
|
|
params: Task send parameters
|
|
|
|
Returns:
|
|
Task data dictionary
|
|
"""
|
|
# Create task with initial status
|
|
task_data = {
|
|
"id": params.id,
|
|
"sessionId": params.sessionId
|
|
or str(uuid.uuid4()), # Preservar sessionId quando fornecido
|
|
"status": {
|
|
"state": TaskState.SUBMITTED,
|
|
"timestamp": datetime.now().isoformat(),
|
|
"message": None,
|
|
"error": None,
|
|
},
|
|
"artifacts": [],
|
|
"history": [params.message.model_dump()], # Apenas mensagem do usuário
|
|
"metadata": params.metadata or {},
|
|
}
|
|
|
|
# Save task data
|
|
await self.redis_cache.set(f"task:{params.id}", task_data)
|
|
|
|
return task_data
|
|
|
|
async def _execute_task(
|
|
self, task: Dict[str, Any], params: TaskSendParams
|
|
) -> Dict[str, Any]:
|
|
"""
|
|
Execute a task using the agent adapter.
|
|
|
|
This function is responsible for executing the task by the agent,
|
|
updating its status as progress is made.
|
|
|
|
Args:
|
|
task: Task data to be executed
|
|
params: Send task parameters
|
|
|
|
Returns:
|
|
Updated task data with completed status and response
|
|
"""
|
|
task_id = task["id"]
|
|
agent_id = params.agentId
|
|
message_text = ""
|
|
|
|
# Extract the text from the message
|
|
if params.message and params.message.parts:
|
|
for part in params.message.parts:
|
|
if part.type == "text":
|
|
message_text += part.text
|
|
|
|
if not message_text:
|
|
await self._update_task_status_without_history(
|
|
task_id, TaskState.FAILED, "Message does not contain text", final=True
|
|
)
|
|
# Return the updated task data
|
|
return await self.redis_cache.get(f"task:{task_id}")
|
|
|
|
# Check if it is an ongoing execution
|
|
task_status = task.get("status", {})
|
|
if task_status.get("state") in [TaskState.WORKING, TaskState.COMPLETED]:
|
|
logger.info(f"Task {task_id} is already in execution or completed")
|
|
# Return the current task data
|
|
return await self.redis_cache.get(f"task:{task_id}")
|
|
|
|
try:
|
|
# Update to "working" state - NÃO adicionar ao histórico
|
|
await self._update_task_status_without_history(
|
|
task_id, TaskState.WORKING, "Processing request"
|
|
)
|
|
|
|
# Execute the agent
|
|
if self.agent_runner:
|
|
response = await self.agent_runner.run_agent(
|
|
agent_id=agent_id,
|
|
message=message_text,
|
|
session_id=params.sessionId, # Usar o sessionId da requisição
|
|
task_id=task_id,
|
|
)
|
|
|
|
# Process the agent's response
|
|
if response and isinstance(response, dict):
|
|
# Extract text from the response
|
|
response_text = response.get("content", "")
|
|
if not response_text and "message" in response:
|
|
message = response.get("message", {})
|
|
parts = message.get("parts", [])
|
|
for part in parts:
|
|
if part.get("type") == "text":
|
|
response_text += part.get("text", "")
|
|
|
|
# Build the final agent message
|
|
if response_text:
|
|
# Atualizar o histórico com a mensagem do usuário
|
|
await self._update_task_history(task_id, params.message)
|
|
|
|
# Create an artifact for the response in Google A2A format
|
|
artifact = {
|
|
"parts": [{"type": "text", "text": response_text}],
|
|
"index": 0,
|
|
}
|
|
|
|
# Add the artifact to the task
|
|
await self._add_task_artifact(task_id, artifact)
|
|
|
|
# Update the task status to completed (sem adicionar ao histórico)
|
|
await self._update_task_status_without_history(
|
|
task_id, TaskState.COMPLETED, response_text, final=True
|
|
)
|
|
else:
|
|
await self._update_task_status_without_history(
|
|
task_id,
|
|
TaskState.FAILED,
|
|
"The agent did not return a valid response",
|
|
final=True,
|
|
)
|
|
else:
|
|
await self._update_task_status_without_history(
|
|
task_id,
|
|
TaskState.FAILED,
|
|
"Invalid agent response",
|
|
final=True,
|
|
)
|
|
else:
|
|
await self._update_task_status_without_history(
|
|
task_id,
|
|
TaskState.FAILED,
|
|
"Agent adapter not configured",
|
|
final=True,
|
|
)
|
|
except Exception as e:
|
|
logger.error(f"Error executing task {task_id}: {str(e)}")
|
|
await self._update_task_status_without_history(
|
|
task_id, TaskState.FAILED, f"Error processing: {str(e)}", final=True
|
|
)
|
|
|
|
# Return the updated task data
|
|
return await self.redis_cache.get(f"task:{task_id}")
|
|
|
|
async def _update_task_history(self, task_id: str, message) -> None:
|
|
"""
|
|
Atualiza o histórico da tarefa incluindo apenas mensagens do usuário.
|
|
|
|
Args:
|
|
task_id: ID da tarefa
|
|
message: Mensagem do usuário para adicionar ao histórico
|
|
"""
|
|
if not message:
|
|
return
|
|
|
|
# Obter dados da tarefa atual
|
|
task_data = await self.redis_cache.get(f"task:{task_id}")
|
|
if not task_data:
|
|
logger.warning(f"Task {task_id} not found for history update")
|
|
return
|
|
|
|
# Garantir que há um campo de histórico
|
|
if "history" not in task_data:
|
|
task_data["history"] = []
|
|
|
|
# Verificar se a mensagem já existe no histórico para evitar duplicação
|
|
user_message = (
|
|
message.model_dump() if hasattr(message, "model_dump") else message
|
|
)
|
|
message_exists = False
|
|
|
|
for msg in task_data["history"]:
|
|
if self._compare_messages(msg, user_message):
|
|
message_exists = True
|
|
break
|
|
|
|
# Adicionar mensagem se não existir
|
|
if not message_exists:
|
|
task_data["history"].append(user_message)
|
|
logger.info(f"Added new user message to history for task {task_id}")
|
|
|
|
# Salvar tarefa atualizada
|
|
await self.redis_cache.set(f"task:{task_id}", task_data)
|
|
|
|
# Método auxiliar para comparar mensagens e evitar duplicação no histórico
|
|
def _compare_messages(self, msg1: Dict, msg2: Dict) -> bool:
|
|
"""Compara duas mensagens para verificar se são essencialmente iguais."""
|
|
if msg1.get("role") != msg2.get("role"):
|
|
return False
|
|
|
|
parts1 = msg1.get("parts", [])
|
|
parts2 = msg2.get("parts", [])
|
|
|
|
if len(parts1) != len(parts2):
|
|
return False
|
|
|
|
for i in range(len(parts1)):
|
|
if parts1[i].get("type") != parts2[i].get("type"):
|
|
return False
|
|
if parts1[i].get("text") != parts2[i].get("text"):
|
|
return False
|
|
|
|
return True
|
|
|
|
async def _update_task_status_without_history(
|
|
self, task_id: str, state: TaskState, message_text: str, final: bool = False
|
|
) -> None:
|
|
"""
|
|
Update the status of a task without changing the history.
|
|
|
|
Args:
|
|
task_id: ID of the task to be updated
|
|
state: New task state
|
|
message_text: Text of the message associated with the status
|
|
final: Indicates if this is the final status of the task
|
|
"""
|
|
try:
|
|
# Get current task data
|
|
task_data = await self.redis_cache.get(f"task:{task_id}")
|
|
if not task_data:
|
|
logger.warning(f"Unable to update status: task {task_id} not found")
|
|
return
|
|
|
|
# Create status object with the message
|
|
agent_message = Message(
|
|
role="agent",
|
|
parts=[TextPart(text=message_text)],
|
|
)
|
|
|
|
status = TaskStatus(
|
|
state=state, message=agent_message, timestamp=datetime.now()
|
|
)
|
|
|
|
# Update the status in the task
|
|
task_data["status"] = status.model_dump(exclude_none=True)
|
|
|
|
# Store the updated task
|
|
await self.redis_cache.set(f"task:{task_id}", task_data)
|
|
|
|
# Create status update event
|
|
status_event = TaskStatusUpdateEvent(id=task_id, status=status, final=final)
|
|
|
|
# Publish status update
|
|
await self._publish_task_update(task_id, status_event)
|
|
|
|
# Send push notification, if configured
|
|
if final or state in [
|
|
TaskState.FAILED,
|
|
TaskState.COMPLETED,
|
|
TaskState.CANCELED,
|
|
]:
|
|
await self._send_push_notification_for_task(
|
|
task_id=task_id, state=state, message_text=message_text
|
|
)
|
|
except Exception as e:
|
|
logger.error(f"Error updating task status {task_id}: {str(e)}")
|
|
|
|
async def _add_task_artifact(self, task_id: str, artifact) -> None:
|
|
"""
|
|
Add an artifact to a task and publish the update.
|
|
|
|
Args:
|
|
task_id: Task ID
|
|
artifact: Artifact to add (dict no formato do Google)
|
|
"""
|
|
logger.info(f"Adding artifact to task {task_id}")
|
|
|
|
# Update task data
|
|
task_data = await self.redis_cache.get(f"task:{task_id}")
|
|
if task_data:
|
|
if "artifacts" not in task_data:
|
|
task_data["artifacts"] = []
|
|
|
|
# Adicionar o artefato sem substituir os existentes
|
|
task_data["artifacts"].append(artifact)
|
|
await self.redis_cache.set(f"task:{task_id}", task_data)
|
|
|
|
# Criar um artefato do tipo Artifact para o evento
|
|
artifact_obj = Artifact(
|
|
parts=[
|
|
TextPart(text=part.get("text", ""))
|
|
for part in artifact.get("parts", [])
|
|
],
|
|
index=artifact.get("index", 0),
|
|
)
|
|
|
|
# Create artifact update event
|
|
event = TaskArtifactUpdateEvent(id=task_id, artifact=artifact_obj)
|
|
|
|
# Publish event
|
|
await self._publish_task_update(task_id, event)
|
|
|
|
async def _publish_task_update(
|
|
self, task_id: str, event: Union[TaskStatusUpdateEvent, TaskArtifactUpdateEvent]
|
|
) -> None:
|
|
"""
|
|
Publish a task update event to all subscribers.
|
|
|
|
Args:
|
|
task_id: Task ID
|
|
event: Event to publish
|
|
"""
|
|
async with self.subscriber_lock:
|
|
if task_id not in self.task_sse_subscribers:
|
|
return
|
|
|
|
subscribers = self.task_sse_subscribers[task_id]
|
|
for subscriber in subscribers:
|
|
try:
|
|
await subscriber.put(event)
|
|
except Exception as e:
|
|
logger.error(f"Error publishing event to subscriber: {str(e)}")
|
|
|
|
async def _send_push_notification_for_task(
|
|
self,
|
|
task_id: str,
|
|
state: str,
|
|
message_text: str = None,
|
|
system_message: str = None,
|
|
) -> None:
|
|
"""
|
|
Send push notification for a task if configured.
|
|
|
|
Args:
|
|
task_id: Task ID
|
|
state: Task state
|
|
message_text: Optional message text
|
|
system_message: Optional system message
|
|
"""
|
|
if not self.push_notification_service:
|
|
return
|
|
|
|
try:
|
|
# Get push notification config
|
|
config = await self.redis_cache.get(f"task:{task_id}:push")
|
|
if not config:
|
|
return
|
|
|
|
# Create message if provided
|
|
message = None
|
|
if message_text:
|
|
message = {
|
|
"role": "agent",
|
|
"parts": [{"type": "text", "text": message_text}],
|
|
}
|
|
elif system_message:
|
|
# We use 'agent' instead of 'system' since Message only accepts 'user' or 'agent'
|
|
message = {
|
|
"role": "agent",
|
|
"parts": [{"type": "text", "text": system_message}],
|
|
}
|
|
|
|
# Send notification
|
|
await self.push_notification_service.send_notification(
|
|
url=config["url"],
|
|
task_id=task_id,
|
|
state=state,
|
|
message=message,
|
|
headers=config.get("headers", {}),
|
|
)
|
|
|
|
except Exception as e:
|
|
logger.error(
|
|
f"Error sending push notification for task {task_id}: {str(e)}"
|
|
)
|
|
|
|
async def _setup_sse_consumer(
|
|
self, task_id: str, is_resubscribe: bool = False
|
|
) -> asyncio.Queue:
|
|
"""
|
|
Set up an SSE consumer queue for a task.
|
|
|
|
Args:
|
|
task_id: Task ID
|
|
is_resubscribe: Whether this is a resubscription
|
|
|
|
Returns:
|
|
Queue for events
|
|
|
|
Raises:
|
|
ValueError: If resubscribing to non-existent task
|
|
"""
|
|
async with self.subscriber_lock:
|
|
if task_id not in self.task_sse_subscribers:
|
|
if is_resubscribe:
|
|
raise ValueError("Task not found for resubscription")
|
|
self.task_sse_subscribers[task_id] = []
|
|
|
|
queue = asyncio.Queue()
|
|
self.task_sse_subscribers[task_id].append(queue)
|
|
return queue
|
|
|
|
async def _dequeue_events_for_sse(
|
|
self, request_id: str, task_id: str, event_queue: asyncio.Queue
|
|
) -> AsyncIterable[SendTaskStreamingResponse]:
|
|
"""
|
|
Dequeue and yield events for SSE streaming.
|
|
|
|
Args:
|
|
request_id: Request ID
|
|
task_id: Task ID
|
|
event_queue: Queue for events
|
|
|
|
Yields:
|
|
SSE events wrapped in SendTaskStreamingResponse
|
|
"""
|
|
try:
|
|
while True:
|
|
event = await event_queue.get()
|
|
|
|
if isinstance(event, JSONRPCError):
|
|
yield SendTaskStreamingResponse(id=request_id, error=event)
|
|
break
|
|
|
|
yield SendTaskStreamingResponse(id=request_id, result=event)
|
|
|
|
# Check if this is the final event
|
|
is_final = False
|
|
if hasattr(event, "final") and event.final:
|
|
is_final = True
|
|
|
|
if is_final:
|
|
break
|
|
finally:
|
|
# Clean up the subscription when done
|
|
async with self.subscriber_lock:
|
|
if task_id in self.task_sse_subscribers:
|
|
try:
|
|
self.task_sse_subscribers[task_id].remove(event_queue)
|
|
# Remove the task from the dict if no more subscribers
|
|
if not self.task_sse_subscribers[task_id]:
|
|
del self.task_sse_subscribers[task_id]
|
|
except ValueError:
|
|
pass # Queue might have been removed already
|