feat(api): add agent_card_url handling for a2a type agents
This commit is contained in:
parent
96df2db27d
commit
c14d23333c
@ -0,0 +1,33 @@
|
||||
"""add_a2a_fields_in_agents_table
|
||||
|
||||
Revision ID: 07ac76cc090a
|
||||
Revises: 6cd898ec9f7c
|
||||
Create Date: 2025-04-30 17:32:29.582234
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "07ac76cc090a"
|
||||
down_revision: Union[str, None] = "6cd898ec9f7c"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
"""Upgrade schema."""
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
pass
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
"""Downgrade schema."""
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
pass
|
||||
# ### end Alembic commands ###
|
@ -0,0 +1,32 @@
|
||||
"""add_a2a_fields_in_agents_table
|
||||
|
||||
Revision ID: 545d3083200b
|
||||
Revises: 07ac76cc090a
|
||||
Create Date: 2025-04-30 17:35:31.573159
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '545d3083200b'
|
||||
down_revision: Union[str, None] = '07ac76cc090a'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
"""Upgrade schema."""
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.add_column('agents', sa.Column('agent_card_url', sa.String(), nullable=True))
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
"""Downgrade schema."""
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_column('agents', 'agent_card_url')
|
||||
# ### end Alembic commands ###
|
@ -15,6 +15,7 @@ from src.services import (
|
||||
agent_service,
|
||||
mcp_server_service,
|
||||
)
|
||||
from src.models.models import Agent as AgentModel
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -74,7 +75,12 @@ async def create_agent(
|
||||
# Verify if the user has access to the agent's client
|
||||
await verify_user_client(payload, db, agent.client_id)
|
||||
|
||||
return agent_service.create_agent(db, agent)
|
||||
db_agent = await agent_service.create_agent(db, agent)
|
||||
|
||||
if not db_agent.agent_card_url:
|
||||
db_agent.agent_card_url = db_agent.agent_card_url_property
|
||||
|
||||
return db_agent
|
||||
|
||||
|
||||
@router.get("/", response_model=List[Agent])
|
||||
@ -88,7 +94,13 @@ async def read_agents(
|
||||
# Verify if the user has access to this client's data
|
||||
await verify_user_client(payload, db, x_client_id)
|
||||
|
||||
return agent_service.get_agents_by_client(db, x_client_id, skip, limit)
|
||||
agents = agent_service.get_agents_by_client(db, x_client_id, skip, limit)
|
||||
|
||||
for agent in agents:
|
||||
if not agent.agent_card_url:
|
||||
agent.agent_card_url = agent.agent_card_url_property
|
||||
|
||||
return agents
|
||||
|
||||
|
||||
@router.get("/{agent_id}", response_model=Agent)
|
||||
@ -107,6 +119,9 @@ async def read_agent(
|
||||
# Verify if the user has access to the agent's client
|
||||
await verify_user_client(payload, db, x_client_id)
|
||||
|
||||
if not db_agent.agent_card_url:
|
||||
db_agent.agent_card_url = db_agent.agent_card_url_property
|
||||
|
||||
return db_agent
|
||||
|
||||
|
||||
@ -132,7 +147,12 @@ async def update_agent(
|
||||
new_client_id = uuid.UUID(agent_data["client_id"])
|
||||
await verify_user_client(payload, db, new_client_id)
|
||||
|
||||
return await agent_service.update_agent(db, agent_id, agent_data)
|
||||
updated_agent = await agent_service.update_agent(db, agent_id, agent_data)
|
||||
|
||||
if not updated_agent.agent_card_url:
|
||||
updated_agent.agent_card_url = updated_agent.agent_card_url_property
|
||||
|
||||
return updated_agent
|
||||
|
||||
|
||||
@router.delete("/{agent_id}", status_code=status.HTTP_204_NO_CONTENT)
|
||||
|
@ -69,7 +69,7 @@ async def verify_user_client(
|
||||
required_client_id: Client ID to be verified
|
||||
|
||||
Returns:
|
||||
bool: True se verificado com sucesso
|
||||
bool: True if verified successfully
|
||||
|
||||
Raises:
|
||||
HTTPException: If the user does not have permission
|
||||
@ -78,7 +78,7 @@ async def verify_user_client(
|
||||
if payload.get("is_admin", False):
|
||||
return True
|
||||
|
||||
# Para não-admins, verificar se o client_id corresponde
|
||||
# For non-admins, verify if the client_id corresponds
|
||||
user_client_id = payload.get("client_id")
|
||||
if not user_client_id:
|
||||
logger.warning(
|
||||
@ -153,8 +153,8 @@ def get_current_user_client_id(
|
||||
|
||||
async def get_jwt_token_ws(token: str) -> Optional[dict]:
|
||||
"""
|
||||
Verifica e decodifica o token JWT para WebSocket.
|
||||
Retorna o payload se o token for válido, None caso contrário.
|
||||
Verifies and decodes the JWT token for WebSocket.
|
||||
Returns the payload if the token is valid, None otherwise.
|
||||
"""
|
||||
try:
|
||||
payload = jwt.decode(
|
||||
|
@ -74,22 +74,25 @@ class Agent(Base):
|
||||
model = Column(String, nullable=True, default="")
|
||||
api_key = Column(String, nullable=True, default="")
|
||||
instruction = Column(Text)
|
||||
agent_card_url = Column(String, nullable=True)
|
||||
config = Column(JSON, default={})
|
||||
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||
updated_at = Column(DateTime(timezone=True), onupdate=func.now())
|
||||
|
||||
__table_args__ = (
|
||||
CheckConstraint(
|
||||
"type IN ('llm', 'sequential', 'parallel', 'loop')", name="check_agent_type"
|
||||
"type IN ('llm', 'sequential', 'parallel', 'loop', 'a2a')",
|
||||
name="check_agent_type",
|
||||
),
|
||||
)
|
||||
|
||||
@property
|
||||
def agent_card_url(self) -> str:
|
||||
"""URL virtual para o agent card que não é rastrada no banco de dados"""
|
||||
return (
|
||||
f"{os.getenv('API_URL', '')}/api/v1/agents/{self.id}/.well-known/agent.json"
|
||||
)
|
||||
def agent_card_url_property(self) -> str:
|
||||
"""Virtual URL for the agent card"""
|
||||
if self.agent_card_url:
|
||||
return self.agent_card_url
|
||||
|
||||
return f"{os.getenv('API_URL', '')}/api/v1/a2a/{self.id}/.well-known/agent.json"
|
||||
|
||||
def to_dict(self):
|
||||
"""Converts the object to a dictionary, converting UUIDs to strings"""
|
||||
@ -112,8 +115,7 @@ class Agent(Base):
|
||||
]
|
||||
else:
|
||||
result[key] = value
|
||||
# Adiciona a propriedade virtual ao dicionário
|
||||
result["agent_card_url"] = self.agent_card_url
|
||||
result["agent_card_url"] = self.agent_card_url_property
|
||||
return result
|
||||
|
||||
def _convert_dict(self, d):
|
||||
|
@ -52,9 +52,13 @@ class Contact(ContactBase):
|
||||
|
||||
|
||||
class AgentBase(BaseModel):
|
||||
name: str = Field(..., description="Agent name (no spaces or special characters)")
|
||||
name: Optional[str] = Field(
|
||||
None, description="Agent name (no spaces or special characters)"
|
||||
)
|
||||
description: Optional[str] = Field(None, description="Agent description")
|
||||
type: str = Field(..., description="Agent type (llm, sequential, parallel, loop)")
|
||||
type: str = Field(
|
||||
..., description="Agent type (llm, sequential, parallel, loop, a2a)"
|
||||
)
|
||||
model: Optional[str] = Field(
|
||||
None, description="Agent model (required only for llm type)"
|
||||
)
|
||||
@ -62,24 +66,42 @@ class AgentBase(BaseModel):
|
||||
None, description="Agent API Key (required only for llm type)"
|
||||
)
|
||||
instruction: Optional[str] = None
|
||||
config: Union[LLMConfig, Dict[str, Any]] = Field(
|
||||
..., description="Agent configuration based on type"
|
||||
agent_card_url: Optional[str] = Field(
|
||||
None, description="Agent card URL (required for a2a type)"
|
||||
)
|
||||
config: Optional[Union[LLMConfig, Dict[str, Any]]] = Field(
|
||||
None, description="Agent configuration based on type"
|
||||
)
|
||||
|
||||
@validator("name")
|
||||
def validate_name(cls, v):
|
||||
def validate_name(cls, v, values):
|
||||
if values.get("type") == "a2a":
|
||||
return v
|
||||
|
||||
if not v:
|
||||
raise ValueError("Name is required for non-a2a agent types")
|
||||
|
||||
if not re.match(r"^[a-zA-Z0-9_-]+$", v):
|
||||
raise ValueError("Agent name cannot contain spaces or special characters")
|
||||
return v
|
||||
|
||||
@validator("type")
|
||||
def validate_type(cls, v):
|
||||
if v not in ["llm", "sequential", "parallel", "loop"]:
|
||||
if v not in ["llm", "sequential", "parallel", "loop", "a2a"]:
|
||||
raise ValueError(
|
||||
"Invalid agent type. Must be: llm, sequential, parallel or loop"
|
||||
"Invalid agent type. Must be: llm, sequential, parallel, loop or a2a"
|
||||
)
|
||||
return v
|
||||
|
||||
@validator("agent_card_url")
|
||||
def validate_agent_card_url(cls, v, values):
|
||||
if "type" in values and values["type"] == "a2a":
|
||||
if not v:
|
||||
raise ValueError("agent_card_url is required for a2a type agents")
|
||||
if not v.endswith("/.well-known/agent.json"):
|
||||
raise ValueError("agent_card_url must end with /.well-known/agent.json")
|
||||
return v
|
||||
|
||||
@validator("model")
|
||||
def validate_model(cls, v, values):
|
||||
if "type" in values and values["type"] == "llm" and not v:
|
||||
@ -94,9 +116,17 @@ class AgentBase(BaseModel):
|
||||
|
||||
@validator("config")
|
||||
def validate_config(cls, v, values):
|
||||
if "type" in values and values["type"] == "a2a":
|
||||
return v or {}
|
||||
|
||||
if "type" not in values:
|
||||
return v
|
||||
|
||||
if not v and values.get("type") != "a2a":
|
||||
raise ValueError(
|
||||
f"Configuration is required for {values.get('type')} agent type"
|
||||
)
|
||||
|
||||
if values["type"] == "llm":
|
||||
if isinstance(v, dict):
|
||||
try:
|
||||
@ -134,6 +164,18 @@ class Agent(AgentBase):
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
@validator("agent_card_url", pre=True)
|
||||
def set_agent_card_url(cls, v, values):
|
||||
if v:
|
||||
return v
|
||||
|
||||
if "id" in values:
|
||||
from os import getenv
|
||||
|
||||
return f"{getenv('API_URL', '')}/api/v1/a2a/{values['id']}/.well-known/agent.json"
|
||||
|
||||
return v
|
||||
|
||||
|
||||
class ToolConfig(BaseModel):
|
||||
id: str
|
||||
|
@ -85,12 +85,6 @@ class AgentRunnerAdapter:
|
||||
Returns:
|
||||
Dictionary with the agent's response
|
||||
"""
|
||||
logger.info(
|
||||
f"[AGENT-RUNNER] run_agent iniciado - agent_id={agent_id}, task_id={task_id}, session_id={session_id}"
|
||||
)
|
||||
logger.info(
|
||||
f"[AGENT-RUNNER] run_agent - message: '{message[:50]}...' (truncado)"
|
||||
)
|
||||
|
||||
try:
|
||||
# Use the existing agent runner function
|
||||
@ -100,18 +94,6 @@ class AgentRunnerAdapter:
|
||||
# Use the provided db or fallback to self.db
|
||||
db_session = db if db is not None else self.db
|
||||
|
||||
if db_session is None:
|
||||
logger.error(
|
||||
f"[AGENT-RUNNER] No database session available. db={db}, self.db={self.db}"
|
||||
)
|
||||
else:
|
||||
logger.info(
|
||||
f"[AGENT-RUNNER] Using database session: {type(db_session).__name__}"
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"[AGENT-RUNNER] Chamando agent_runner_func com agent_id={agent_id}, contact_id={task_id}"
|
||||
)
|
||||
response_text = await self.agent_runner_func(
|
||||
agent_id=agent_id,
|
||||
contact_id=task_id,
|
||||
@ -123,13 +105,6 @@ class AgentRunnerAdapter:
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"[AGENT-RUNNER] run_agent concluído com sucesso para agent_id={agent_id}, task_id={task_id}"
|
||||
)
|
||||
logger.info(
|
||||
f"[AGENT-RUNNER] resposta: '{str(response_text)[:50]}...' (truncado)"
|
||||
)
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"content": response_text,
|
||||
@ -216,7 +191,6 @@ class StreamingServiceAdapter:
|
||||
status_event = TaskStatusUpdateEvent(
|
||||
id=task_id, status=working_status, final=False
|
||||
)
|
||||
# IMPORTANTE: Converter para string JSON para SSE
|
||||
yield json.dumps(status_event.model_dump())
|
||||
|
||||
content_buffer = ""
|
||||
@ -229,9 +203,8 @@ class StreamingServiceAdapter:
|
||||
# To streaming, we use task_id as contact_id
|
||||
contact_id = task_id
|
||||
|
||||
# Adicionar tratamento de heartbeat para manter conexão ativa
|
||||
last_event_time = datetime.now()
|
||||
heartbeat_interval = 20 # segundos
|
||||
heartbeat_interval = 20
|
||||
|
||||
async for event in self.streaming_service.send_task_streaming(
|
||||
agent_id=agent_id,
|
||||
@ -241,7 +214,6 @@ class StreamingServiceAdapter:
|
||||
session_id=session_id,
|
||||
db=db,
|
||||
):
|
||||
# Atualizar timestamp do último evento
|
||||
last_event_time = datetime.now()
|
||||
|
||||
# Process the streaming event format
|
||||
@ -268,7 +240,6 @@ class StreamingServiceAdapter:
|
||||
artifact_event = TaskArtifactUpdateEvent(
|
||||
id=task_id, artifact=artifact
|
||||
)
|
||||
# IMPORTANTE: Converter para string JSON para SSE
|
||||
yield json.dumps(artifact_event.model_dump())
|
||||
|
||||
# Check if final event
|
||||
@ -299,7 +270,7 @@ class StreamingServiceAdapter:
|
||||
final_artifact_event = TaskArtifactUpdateEvent(
|
||||
id=task_id, artifact=final_artifact
|
||||
)
|
||||
# IMPORTANTE: Converter para string JSON para SSE
|
||||
|
||||
yield json.dumps(final_artifact_event.model_dump())
|
||||
|
||||
# Send the completed status
|
||||
@ -308,7 +279,7 @@ class StreamingServiceAdapter:
|
||||
status=completed_status,
|
||||
final=True,
|
||||
)
|
||||
# IMPORTANTE: Converter para string JSON para SSE
|
||||
|
||||
yield json.dumps(final_status_event.model_dump())
|
||||
|
||||
final_sent = True
|
||||
@ -333,7 +304,7 @@ class StreamingServiceAdapter:
|
||||
artifact_event = TaskArtifactUpdateEvent(
|
||||
id=task_id, artifact=artifact
|
||||
)
|
||||
# IMPORTANTE: Converter para string JSON para SSE
|
||||
|
||||
yield json.dumps(artifact_event.model_dump())
|
||||
elif isinstance(event_data, dict):
|
||||
# Try to extract text from the dictionary
|
||||
@ -351,14 +322,14 @@ class StreamingServiceAdapter:
|
||||
artifact_event = TaskArtifactUpdateEvent(
|
||||
id=task_id, artifact=artifact
|
||||
)
|
||||
# IMPORTANTE: Converter para string JSON para SSE
|
||||
|
||||
yield json.dumps(artifact_event.model_dump())
|
||||
|
||||
# Enviar heartbeat/keep-alive para manter a conexão SSE aberta
|
||||
# Send heartbeat/keep-alive to keep the SSE connection open
|
||||
now = datetime.now()
|
||||
if (now - last_event_time).total_seconds() > heartbeat_interval:
|
||||
logger.info(f"Sending heartbeat for task {task_id}")
|
||||
# Enviando evento de keep-alive como um evento de status de "working"
|
||||
# Sending keep-alive event as a "working" status event
|
||||
working_heartbeat = TaskStatus(
|
||||
state="working",
|
||||
timestamp=now,
|
||||
@ -369,7 +340,6 @@ class StreamingServiceAdapter:
|
||||
heartbeat_event = TaskStatusUpdateEvent(
|
||||
id=task_id, status=working_heartbeat, final=False
|
||||
)
|
||||
# IMPORTANTE: Converter para string JSON para SSE
|
||||
yield json.dumps(heartbeat_event.model_dump())
|
||||
last_event_time = now
|
||||
|
||||
@ -392,7 +362,6 @@ class StreamingServiceAdapter:
|
||||
final_event = TaskStatusUpdateEvent(
|
||||
id=task_id, status=completed_status, final=True
|
||||
)
|
||||
# IMPORTANTE: Converter para string JSON para SSE
|
||||
yield json.dumps(final_event.model_dump())
|
||||
|
||||
except Exception as e:
|
||||
@ -416,11 +385,10 @@ class StreamingServiceAdapter:
|
||||
error_event = TaskStatusUpdateEvent(
|
||||
id=task_id, status=failed_status, final=True
|
||||
)
|
||||
# IMPORTANTE: Converter para string JSON para SSE
|
||||
yield json.dumps(error_event.model_dump())
|
||||
|
||||
finally:
|
||||
# Garantir que enviamos um evento final para fechar a conexão corretamente
|
||||
# Ensure we send a final event to properly close the connection
|
||||
if not final_sent and not has_error:
|
||||
logger.info(f"Stream finalizing for task {task_id} via finally block")
|
||||
try:
|
||||
@ -442,7 +410,6 @@ class StreamingServiceAdapter:
|
||||
final_event = TaskStatusUpdateEvent(
|
||||
id=task_id, status=completed_status, final=True
|
||||
)
|
||||
# IMPORTANTE: Converter para string JSON para SSE
|
||||
yield json.dumps(final_event.model_dump())
|
||||
except Exception as final_error:
|
||||
logger.error(
|
||||
@ -479,7 +446,6 @@ def create_agent_card_from_agent(agent, db) -> AgentCard:
|
||||
|
||||
# We create a new thread to execute the asynchronous function
|
||||
import concurrent.futures
|
||||
import functools
|
||||
|
||||
def run_async(coro):
|
||||
loop = asyncio.new_event_loop()
|
||||
|
@ -530,13 +530,9 @@ class A2AServer:
|
||||
body = await request.json()
|
||||
logger.info(f"Received JSON data: {json.dumps(body)}")
|
||||
method = body.get("method", "unknown")
|
||||
logger.info(f"[SERVER] Processando método: {method}")
|
||||
|
||||
# Validate the request using the A2A validator
|
||||
json_rpc_request = A2ARequest.validate_python(body)
|
||||
logger.info(
|
||||
f"[SERVER] Request validado como: {type(json_rpc_request).__name__}"
|
||||
)
|
||||
|
||||
original_db = self.task_manager.db
|
||||
try:
|
||||
@ -546,55 +542,34 @@ class A2AServer:
|
||||
|
||||
# Process the request
|
||||
if isinstance(json_rpc_request, SendTaskRequest):
|
||||
logger.info(
|
||||
f"[SERVER] Processando SendTaskRequest para task_id={json_rpc_request.params.id}"
|
||||
)
|
||||
json_rpc_request.params.agentId = agent_id
|
||||
result = await self.task_manager.on_send_task(json_rpc_request)
|
||||
elif isinstance(json_rpc_request, SendTaskStreamingRequest):
|
||||
logger.info(
|
||||
f"[SERVER] Processando SendTaskStreamingRequest para task_id={json_rpc_request.params.id}"
|
||||
)
|
||||
json_rpc_request.params.agentId = agent_id
|
||||
result = await self.task_manager.on_send_task_subscribe(
|
||||
json_rpc_request
|
||||
)
|
||||
elif isinstance(json_rpc_request, GetTaskRequest):
|
||||
logger.info(
|
||||
f"[SERVER] Processando GetTaskRequest para task_id={json_rpc_request.params.id}"
|
||||
)
|
||||
result = await self.task_manager.on_get_task(json_rpc_request)
|
||||
elif isinstance(json_rpc_request, CancelTaskRequest):
|
||||
logger.info(
|
||||
f"[SERVER] Processando CancelTaskRequest para task_id={json_rpc_request.params.id}"
|
||||
)
|
||||
result = await self.task_manager.on_cancel_task(
|
||||
json_rpc_request
|
||||
)
|
||||
elif isinstance(json_rpc_request, SetTaskPushNotificationRequest):
|
||||
logger.info(
|
||||
f"[SERVER] Processando SetTaskPushNotificationRequest para task_id={json_rpc_request.params.id}"
|
||||
)
|
||||
result = await self.task_manager.on_set_task_push_notification(
|
||||
json_rpc_request
|
||||
)
|
||||
elif isinstance(json_rpc_request, GetTaskPushNotificationRequest):
|
||||
logger.info(
|
||||
f"[SERVER] Processando GetTaskPushNotificationRequest para task_id={json_rpc_request.params.id}"
|
||||
)
|
||||
result = await self.task_manager.on_get_task_push_notification(
|
||||
json_rpc_request
|
||||
)
|
||||
elif isinstance(json_rpc_request, TaskResubscriptionRequest):
|
||||
logger.info(
|
||||
f"[SERVER] Processando TaskResubscriptionRequest para task_id={json_rpc_request.params.id}"
|
||||
)
|
||||
result = await self.task_manager.on_resubscribe_to_task(
|
||||
json_rpc_request
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
f"[SERVER] Tipo de request não suportado: {type(json_rpc_request)}"
|
||||
f"[SERVER] Request type not supported: {type(json_rpc_request)}"
|
||||
)
|
||||
return JSONResponse(
|
||||
status_code=400,
|
||||
|
@ -104,29 +104,29 @@ class A2ATaskManager:
|
||||
|
||||
async def on_get_task(self, request: GetTaskRequest) -> GetTaskResponse:
|
||||
"""
|
||||
Manipula requisição para obter informações sobre uma tarefa.
|
||||
Handle request to get task information.
|
||||
|
||||
Args:
|
||||
request: Requisição Get Task do A2A
|
||||
request: A2A Get Task request
|
||||
|
||||
Returns:
|
||||
Resposta com os detalhes da tarefa
|
||||
Response with task details
|
||||
"""
|
||||
try:
|
||||
task_id = request.params.id
|
||||
history_length = request.params.historyLength
|
||||
|
||||
# Busca dados da tarefa do cache
|
||||
# Get task data from cache
|
||||
task_data = await self.redis_cache.get(f"task:{task_id}")
|
||||
|
||||
if not task_data:
|
||||
logger.warning(f"Tarefa não encontrada: {task_id}")
|
||||
logger.warning(f"Task not found: {task_id}")
|
||||
return GetTaskResponse(id=request.id, error=TaskNotFoundError())
|
||||
|
||||
# Cria uma instância Task a partir dos dados do cache
|
||||
# Create a Task instance from cache data
|
||||
task = Task.model_validate(task_data)
|
||||
|
||||
# Se o parâmetro historyLength estiver presente, manipula o histórico
|
||||
# If historyLength parameter is present, handle the history
|
||||
if history_length is not None and task.history:
|
||||
if history_length == 0:
|
||||
task.history = []
|
||||
@ -135,7 +135,7 @@ class A2ATaskManager:
|
||||
|
||||
return GetTaskResponse(id=request.id, result=task)
|
||||
except Exception as e:
|
||||
logger.error(f"Erro ao processar on_get_task: {str(e)}")
|
||||
logger.error(f"Error processing on_get_task: {str(e)}")
|
||||
return GetTaskResponse(id=request.id, error=InternalError(message=str(e)))
|
||||
|
||||
async def on_cancel_task(self, request: CancelTaskRequest) -> CancelTaskResponse:
|
||||
@ -211,78 +211,75 @@ class A2ATaskManager:
|
||||
|
||||
async def on_send_task(self, request: SendTaskRequest) -> SendTaskResponse:
|
||||
"""
|
||||
Manipula requisição para enviar uma nova tarefa.
|
||||
Handle request to send a new task.
|
||||
|
||||
Args:
|
||||
request: Requisição de envio de tarefa
|
||||
request: Send Task request
|
||||
|
||||
Returns:
|
||||
Resposta com os detalhes da tarefa criada
|
||||
Response with the created task details
|
||||
"""
|
||||
try:
|
||||
params = request.params
|
||||
task_id = params.id
|
||||
logger.info(f"Recebendo tarefa {task_id}")
|
||||
logger.info(f"Receiving task {task_id}")
|
||||
|
||||
# Verifica se já existe uma tarefa com esse ID
|
||||
# Check if a task with this ID already exists
|
||||
existing_task = await self.redis_cache.get(f"task:{task_id}")
|
||||
if existing_task:
|
||||
# Se a tarefa já existe e está em progresso, retorna a tarefa atual
|
||||
# If the task already exists and is in progress, return the current task
|
||||
if existing_task.get("status", {}).get("state") in [
|
||||
TaskState.WORKING,
|
||||
TaskState.COMPLETED,
|
||||
]:
|
||||
logger.info(
|
||||
f"Tarefa {task_id} já existe e está em progresso/concluída"
|
||||
)
|
||||
return SendTaskResponse(
|
||||
id=request.id, result=Task.model_validate(existing_task)
|
||||
)
|
||||
|
||||
# Se a tarefa existe mas falhou ou foi cancelada, podemos reprocessá-la
|
||||
logger.info(f"Reprocessando tarefa existente {task_id}")
|
||||
# If the task exists but failed or was canceled, we can reprocess it
|
||||
logger.info(f"Reprocessing existing task {task_id}")
|
||||
|
||||
# Verifica compatibilidade de modalidades
|
||||
# Check modality compatibility
|
||||
server_output_modes = []
|
||||
if self.agent_runner:
|
||||
# Tenta obter modos suportados do agente
|
||||
# Try to get supported modes from the agent
|
||||
try:
|
||||
server_output_modes = await self.agent_runner.get_supported_modes()
|
||||
except Exception as e:
|
||||
logger.warning(f"Erro ao obter modos suportados: {str(e)}")
|
||||
server_output_modes = ["text"] # Fallback para texto
|
||||
logger.warning(f"Error getting supported modes: {str(e)}")
|
||||
server_output_modes = ["text"] # Fallback to text
|
||||
|
||||
if not are_modalities_compatible(
|
||||
server_output_modes, params.acceptedOutputModes
|
||||
):
|
||||
logger.warning(
|
||||
f"Modos incompatíveis: servidor={server_output_modes}, cliente={params.acceptedOutputModes}"
|
||||
f"Incompatible modes: server={server_output_modes}, client={params.acceptedOutputModes}"
|
||||
)
|
||||
return SendTaskResponse(
|
||||
id=request.id, error=ContentTypeNotSupportedError()
|
||||
)
|
||||
|
||||
# Cria dados da tarefa
|
||||
# Create task data
|
||||
task_data = await self._create_task_data(params)
|
||||
|
||||
# Armazena a tarefa no cache
|
||||
# Store task in cache
|
||||
await self.redis_cache.set(f"task:{task_id}", task_data)
|
||||
|
||||
# Configura notificações push, se fornecidas
|
||||
# Configure push notifications, if provided
|
||||
if params.pushNotification:
|
||||
await self.redis_cache.set(
|
||||
f"task_notification:{task_id}", params.pushNotification.model_dump()
|
||||
)
|
||||
|
||||
# Inicia a execução da tarefa em background
|
||||
# Start task execution in background
|
||||
asyncio.create_task(self._execute_task(task_data, params))
|
||||
|
||||
# Converte para objeto Task e retorna
|
||||
# Convert to Task object and return
|
||||
task = Task.model_validate(task_data)
|
||||
return SendTaskResponse(id=request.id, result=task)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Erro ao processar on_send_task: {str(e)}")
|
||||
logger.error(f"Error processing on_send_task: {str(e)}")
|
||||
return SendTaskResponse(id=request.id, error=InternalError(message=str(e)))
|
||||
|
||||
async def on_send_task_subscribe(
|
||||
@ -553,20 +550,20 @@ class A2ATaskManager:
|
||||
|
||||
async def _execute_task(self, task: Dict[str, Any], params: TaskSendParams) -> None:
|
||||
"""
|
||||
Executa uma tarefa usando o adaptador do agente.
|
||||
Execute a task using the agent adapter.
|
||||
|
||||
Esta função é responsável pela execução real da tarefa pelo agente,
|
||||
atualizando seu status conforme o progresso.
|
||||
This function is responsible for executing the task by the agent,
|
||||
updating its status as progress is made.
|
||||
|
||||
Args:
|
||||
task: Dados da tarefa a ser executada
|
||||
params: Parâmetros de envio da tarefa
|
||||
task: Task data to be executed
|
||||
params: Send task parameters
|
||||
"""
|
||||
task_id = task["id"]
|
||||
agent_id = params.agentId
|
||||
message_text = ""
|
||||
|
||||
# Extrai o texto da mensagem
|
||||
# Extract the text from the message
|
||||
if params.message and params.message.parts:
|
||||
for part in params.message.parts:
|
||||
if part.type == "text":
|
||||
@ -574,23 +571,23 @@ class A2ATaskManager:
|
||||
|
||||
if not message_text:
|
||||
await self._update_task_status(
|
||||
task_id, TaskState.FAILED, "Mensagem não contém texto", final=True
|
||||
task_id, TaskState.FAILED, "Message does not contain text", final=True
|
||||
)
|
||||
return
|
||||
|
||||
# Verificamos se é uma execução em andamento
|
||||
# Check if it is an ongoing execution
|
||||
task_status = task.get("status", {})
|
||||
if task_status.get("state") in [TaskState.WORKING, TaskState.COMPLETED]:
|
||||
logger.info(f"Tarefa {task_id} já está em execução ou concluída")
|
||||
logger.info(f"Task {task_id} is already in execution or completed")
|
||||
return
|
||||
|
||||
try:
|
||||
# Atualiza para estado "working"
|
||||
# Update to "working" state
|
||||
await self._update_task_status(
|
||||
task_id, TaskState.WORKING, "Processando solicitação"
|
||||
task_id, TaskState.WORKING, "Processing request"
|
||||
)
|
||||
|
||||
# Executa o agente
|
||||
# Execute the agent
|
||||
if self.agent_runner:
|
||||
response = await self.agent_runner.run_agent(
|
||||
agent_id=agent_id,
|
||||
@ -599,9 +596,9 @@ class A2ATaskManager:
|
||||
task_id=task_id,
|
||||
)
|
||||
|
||||
# Processa a resposta do agente
|
||||
# Process the agent's response
|
||||
if response and isinstance(response, dict):
|
||||
# Extrai texto da resposta
|
||||
# Extract text from the response
|
||||
response_text = response.get("content", "")
|
||||
if not response_text and "message" in response:
|
||||
message = response.get("message", {})
|
||||
@ -610,9 +607,9 @@ class A2ATaskManager:
|
||||
if part.get("type") == "text":
|
||||
response_text += part.get("text", "")
|
||||
|
||||
# Constrói a mensagem final do agente
|
||||
# Build the final agent message
|
||||
if response_text:
|
||||
# Cria um artefato para a resposta
|
||||
# Create an artifact for the response
|
||||
artifact = Artifact(
|
||||
name="response",
|
||||
parts=[TextPart(text=response_text)],
|
||||
@ -620,10 +617,10 @@ class A2ATaskManager:
|
||||
lastChunk=True,
|
||||
)
|
||||
|
||||
# Adiciona o artefato à tarefa
|
||||
# Add the artifact to the task
|
||||
await self._add_task_artifact(task_id, artifact)
|
||||
|
||||
# Atualiza o status da tarefa para completado
|
||||
# Update the task status to completed
|
||||
await self._update_task_status(
|
||||
task_id, TaskState.COMPLETED, response_text, final=True
|
||||
)
|
||||
@ -631,51 +628,49 @@ class A2ATaskManager:
|
||||
await self._update_task_status(
|
||||
task_id,
|
||||
TaskState.FAILED,
|
||||
"O agente não retornou uma resposta válida",
|
||||
"The agent did not return a valid response",
|
||||
final=True,
|
||||
)
|
||||
else:
|
||||
await self._update_task_status(
|
||||
task_id,
|
||||
TaskState.FAILED,
|
||||
"Resposta inválida do agente",
|
||||
"Invalid agent response",
|
||||
final=True,
|
||||
)
|
||||
else:
|
||||
await self._update_task_status(
|
||||
task_id,
|
||||
TaskState.FAILED,
|
||||
"Adaptador do agente não configurado",
|
||||
"Agent adapter not configured",
|
||||
final=True,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Erro na execução da tarefa {task_id}: {str(e)}")
|
||||
logger.error(f"Error executing task {task_id}: {str(e)}")
|
||||
await self._update_task_status(
|
||||
task_id, TaskState.FAILED, f"Erro ao processar: {str(e)}", final=True
|
||||
task_id, TaskState.FAILED, f"Error processing: {str(e)}", final=True
|
||||
)
|
||||
|
||||
async def _update_task_status(
|
||||
self, task_id: str, state: TaskState, message_text: str, final: bool = False
|
||||
) -> None:
|
||||
"""
|
||||
Atualiza o status de uma tarefa.
|
||||
Update the status of a task.
|
||||
|
||||
Args:
|
||||
task_id: ID da tarefa a ser atualizada
|
||||
state: Novo estado da tarefa
|
||||
message_text: Texto da mensagem associada ao status
|
||||
final: Indica se este é o status final da tarefa
|
||||
task_id: ID of the task to be updated
|
||||
state: New task state
|
||||
message_text: Text of the message associated with the status
|
||||
final: Indicates if this is the final status of the task
|
||||
"""
|
||||
try:
|
||||
# Busca dados atuais da tarefa
|
||||
# Get current task data
|
||||
task_data = await self.redis_cache.get(f"task:{task_id}")
|
||||
if not task_data:
|
||||
logger.warning(
|
||||
f"Não foi possível atualizar status: tarefa {task_id} não encontrada"
|
||||
)
|
||||
logger.warning(f"Unable to update status: task {task_id} not found")
|
||||
return
|
||||
|
||||
# Cria objeto de status com a mensagem
|
||||
# Create status object with the message
|
||||
agent_message = Message(
|
||||
role="agent",
|
||||
parts=[TextPart(text=message_text)],
|
||||
@ -686,26 +681,26 @@ class A2ATaskManager:
|
||||
state=state, message=agent_message, timestamp=datetime.now()
|
||||
)
|
||||
|
||||
# Atualiza o status na tarefa
|
||||
# Update the status in the task
|
||||
task_data["status"] = status.model_dump(exclude_none=True)
|
||||
|
||||
# Atualiza o histórico, se existir
|
||||
# Update the history, if it exists
|
||||
if "history" not in task_data:
|
||||
task_data["history"] = []
|
||||
|
||||
# Adiciona a mensagem ao histórico
|
||||
# Add the message to the history
|
||||
task_data["history"].append(agent_message.model_dump(exclude_none=True))
|
||||
|
||||
# Armazena a tarefa atualizada
|
||||
# Store the updated task
|
||||
await self.redis_cache.set(f"task:{task_id}", task_data)
|
||||
|
||||
# Cria evento de atualização de status
|
||||
# Create status update event
|
||||
status_event = TaskStatusUpdateEvent(id=task_id, status=status, final=final)
|
||||
|
||||
# Publica atualização
|
||||
# Publish status update
|
||||
await self._publish_task_update(task_id, status_event)
|
||||
|
||||
# Envia notificação push, se configurada
|
||||
# Send push notification, if configured
|
||||
if final or state in [
|
||||
TaskState.FAILED,
|
||||
TaskState.COMPLETED,
|
||||
@ -715,7 +710,7 @@ class A2ATaskManager:
|
||||
task_id=task_id, state=state, message_text=message_text
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Erro ao atualizar status da tarefa {task_id}: {str(e)}")
|
||||
logger.error(f"Error updating task status {task_id}: {str(e)}")
|
||||
|
||||
async def _add_task_artifact(self, task_id: str, artifact: Artifact) -> None:
|
||||
"""
|
||||
|
@ -7,6 +7,7 @@ from typing import List, Optional, Dict, Any, Union
|
||||
from src.services.mcp_server_service import get_mcp_server
|
||||
import uuid
|
||||
import logging
|
||||
import httpx
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -67,11 +68,50 @@ def get_agents_by_client(
|
||||
)
|
||||
|
||||
|
||||
def create_agent(db: Session, agent: AgentCreate) -> Agent:
|
||||
async def create_agent(db: Session, agent: AgentCreate) -> Agent:
|
||||
"""Create a new agent"""
|
||||
try:
|
||||
# Additional sub-agent validation
|
||||
if agent.type != "llm":
|
||||
# Special handling for a2a type agents
|
||||
if agent.type == "a2a":
|
||||
if not agent.agent_card_url:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="agent_card_url is required for a2a type agents",
|
||||
)
|
||||
|
||||
try:
|
||||
# Fetch agent card information
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(agent.agent_card_url)
|
||||
if response.status_code != 200:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"Failed to fetch agent card: HTTP {response.status_code}",
|
||||
)
|
||||
agent_card = response.json()
|
||||
|
||||
# Update agent with information from agent card
|
||||
agent.name = agent_card.get("name", "Unknown Agent")
|
||||
agent.description = agent_card.get("description", "")
|
||||
|
||||
if agent.config is None:
|
||||
agent.config = {}
|
||||
|
||||
# Store the whole agent card in config
|
||||
if isinstance(agent.config, dict):
|
||||
agent.config["agent_card"] = agent_card
|
||||
else:
|
||||
agent.config = {"agent_card": agent_card}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching agent card: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"Failed to process agent card: {str(e)}",
|
||||
)
|
||||
|
||||
# Additional sub-agent validation (for non-llm and non-a2a types)
|
||||
elif agent.type != "llm":
|
||||
if not isinstance(agent.config, dict):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
@ -170,6 +210,82 @@ async def update_agent(
|
||||
if not agent:
|
||||
raise HTTPException(status_code=404, detail="Agent not found")
|
||||
|
||||
if "type" in agent_data and agent_data["type"] == "a2a":
|
||||
if "agent_card_url" not in agent_data or not agent_data["agent_card_url"]:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="agent_card_url is required for a2a type agents",
|
||||
)
|
||||
|
||||
if not agent_data["agent_card_url"].endswith("/.well-known/agent.json"):
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="agent_card_url must end with /.well-known/agent.json",
|
||||
)
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(agent_data["agent_card_url"])
|
||||
if response.status_code != 200:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Failed to fetch agent card: HTTP {response.status_code}",
|
||||
)
|
||||
agent_card = response.json()
|
||||
|
||||
agent_data["name"] = agent_card.get("name", "Unknown Agent")
|
||||
agent_data["description"] = agent_card.get("description", "")
|
||||
|
||||
if "config" not in agent_data or agent_data["config"] is None:
|
||||
agent_data["config"] = agent.config if agent.config else {}
|
||||
|
||||
agent_data["config"]["agent_card"] = agent_card
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching agent card: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Failed to process agent card: {str(e)}",
|
||||
)
|
||||
|
||||
elif "agent_card_url" in agent_data and agent.type == "a2a":
|
||||
if not agent_data["agent_card_url"]:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="agent_card_url cannot be empty for a2a type agents",
|
||||
)
|
||||
|
||||
if not agent_data["agent_card_url"].endswith("/.well-known/agent.json"):
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="agent_card_url must end with /.well-known/agent.json",
|
||||
)
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(agent_data["agent_card_url"])
|
||||
if response.status_code != 200:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Failed to fetch agent card: HTTP {response.status_code}",
|
||||
)
|
||||
agent_card = response.json()
|
||||
|
||||
agent_data["name"] = agent_card.get("name", "Unknown Agent")
|
||||
agent_data["description"] = agent_card.get("description", "")
|
||||
|
||||
if "config" not in agent_data or agent_data["config"] is None:
|
||||
agent_data["config"] = agent.config if agent.config else {}
|
||||
|
||||
agent_data["config"]["agent_card"] = agent_card
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching agent card: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Failed to process agent card: {str(e)}",
|
||||
)
|
||||
|
||||
# Convert UUIDs to strings before saving
|
||||
if "config" in agent_data:
|
||||
config = agent_data["config"]
|
||||
|
Loading…
Reference in New Issue
Block a user