fix: comprehensive websocket and LLM model validation
- Fix original websocket format error - Add robust model validation for all agent types - Prevent empty model strings in LiteLLM calls - Update Pydantic V2 compatibility (dict() -> model_dump()) - Improve error handling in workflow agents - Add comprehensive logging and validation
This commit is contained in:
parent
473cf63252
commit
d4618fa345
@ -158,7 +158,7 @@ async def get_agent_messages(
|
|||||||
|
|
||||||
processed_events = []
|
processed_events = []
|
||||||
for event in events:
|
for event in events:
|
||||||
event_dict = event.dict()
|
event_dict = event.model_dump()
|
||||||
|
|
||||||
def process_dict(d):
|
def process_dict(d):
|
||||||
if isinstance(d, dict):
|
if isinstance(d, dict):
|
||||||
|
@ -27,7 +27,7 @@
|
|||||||
└──────────────────────────────────────────────────────────────────────────────┘
|
└──────────────────────────────────────────────────────────────────────────────┘
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from pydantic import BaseModel, Field, validator, UUID4, ConfigDict
|
from pydantic import BaseModel, Field, field_validator, model_validator, UUID4, ConfigDict
|
||||||
from typing import Optional, Dict, Any, List
|
from typing import Optional, Dict, Any, List
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
@ -40,7 +40,8 @@ class ClientBase(BaseModel):
|
|||||||
name: str
|
name: str
|
||||||
email: Optional[str] = None
|
email: Optional[str] = None
|
||||||
|
|
||||||
@validator("email")
|
@field_validator("email")
|
||||||
|
@classmethod
|
||||||
def validate_email(cls, v):
|
def validate_email(cls, v):
|
||||||
if v is None:
|
if v is None:
|
||||||
return v
|
return v
|
||||||
@ -58,8 +59,7 @@ class Client(ClientBase):
|
|||||||
id: UUID
|
id: UUID
|
||||||
created_at: datetime
|
created_at: datetime
|
||||||
|
|
||||||
class Config:
|
model_config = ConfigDict(from_attributes=True)
|
||||||
from_attributes = True
|
|
||||||
|
|
||||||
|
|
||||||
class ApiKeyBase(BaseModel):
|
class ApiKeyBase(BaseModel):
|
||||||
@ -101,7 +101,7 @@ class AgentBase(BaseModel):
|
|||||||
description="Agent type (llm, sequential, parallel, loop, a2a, workflow, task)",
|
description="Agent type (llm, sequential, parallel, loop, a2a, workflow, task)",
|
||||||
)
|
)
|
||||||
model: Optional[str] = Field(
|
model: Optional[str] = Field(
|
||||||
None, description="Agent model (required only for llm type)"
|
None, description="LLM model identifier (required for LLM agents only)"
|
||||||
)
|
)
|
||||||
api_key_id: Optional[UUID4] = Field(
|
api_key_id: Optional[UUID4] = Field(
|
||||||
None, description="Reference to a stored API Key ID"
|
None, description="Reference to a stored API Key ID"
|
||||||
@ -115,8 +115,13 @@ class AgentBase(BaseModel):
|
|||||||
)
|
)
|
||||||
config: Any = Field(None, description="Agent configuration based on type")
|
config: Any = Field(None, description="Agent configuration based on type")
|
||||||
|
|
||||||
@validator("name")
|
@field_validator("name")
|
||||||
def validate_name(cls, v, values):
|
@classmethod
|
||||||
|
def validate_name(cls, v, info):
|
||||||
|
# Get values from validation context
|
||||||
|
values = info.data if hasattr(info, 'data') else {}
|
||||||
|
|
||||||
|
# A2A agents can have optional names
|
||||||
if values.get("type") == "a2a":
|
if values.get("type") == "a2a":
|
||||||
return v
|
return v
|
||||||
|
|
||||||
@ -127,107 +132,246 @@ class AgentBase(BaseModel):
|
|||||||
raise ValueError("Agent name cannot contain spaces or special characters")
|
raise ValueError("Agent name cannot contain spaces or special characters")
|
||||||
return v
|
return v
|
||||||
|
|
||||||
@validator("type")
|
@field_validator("type")
|
||||||
|
@classmethod
|
||||||
def validate_type(cls, v):
|
def validate_type(cls, v):
|
||||||
if v not in [
|
valid_types = [
|
||||||
"llm",
|
"llm",
|
||||||
"sequential",
|
"sequential",
|
||||||
"parallel",
|
"parallel",
|
||||||
"loop",
|
"loop",
|
||||||
"a2a",
|
"a2a",
|
||||||
"workflow",
|
"workflow",
|
||||||
"task",
|
"task"
|
||||||
]:
|
]
|
||||||
|
if v not in valid_types:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Invalid agent type. Must be: llm, sequential, parallel, loop, a2a, workflow or task"
|
f"Invalid agent type '{v}'. Must be one of: {', '.join(valid_types)}"
|
||||||
)
|
)
|
||||||
return v
|
return v
|
||||||
|
|
||||||
@validator("agent_card_url")
|
@field_validator("agent_card_url")
|
||||||
def validate_agent_card_url(cls, v, values):
|
@classmethod
|
||||||
if "type" in values and values["type"] == "a2a":
|
def validate_agent_card_url(cls, v, info):
|
||||||
|
values = info.data if hasattr(info, 'data') else {}
|
||||||
|
|
||||||
|
if values.get("type") == "a2a":
|
||||||
if not v:
|
if not v:
|
||||||
raise ValueError("agent_card_url is required for a2a type agents")
|
raise ValueError("agent_card_url is required for a2a type agents")
|
||||||
if not v.endswith("/.well-known/agent.json"):
|
if not v.endswith("/.well-known/agent.json"):
|
||||||
raise ValueError("agent_card_url must end with /.well-known/agent.json")
|
raise ValueError("agent_card_url must end with /.well-known/agent.json")
|
||||||
return v
|
return v
|
||||||
|
|
||||||
@validator("model")
|
@field_validator("model")
|
||||||
def validate_model(cls, v, values):
|
@classmethod
|
||||||
if "type" in values and values["type"] == "llm" and not v:
|
def validate_model(cls, v, info):
|
||||||
raise ValueError("Model is required for llm type agents")
|
values = info.data if hasattr(info, 'data') else {}
|
||||||
|
agent_type = values.get("type")
|
||||||
|
|
||||||
|
if agent_type == "llm":
|
||||||
|
# Para agentes LLM, o modelo é obrigatório e não pode ser vazio
|
||||||
|
if not v or (isinstance(v, str) and v.strip() == ""):
|
||||||
|
raise ValueError(
|
||||||
|
"LLM agents require a valid model configuration. "
|
||||||
|
"Please specify a model identifier (e.g., 'gpt-4', 'claude-3-sonnet', 'gemini-pro')"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verificar se o modelo tem um formato válido
|
||||||
|
if isinstance(v, str) and len(v.strip()) < 3:
|
||||||
|
raise ValueError("Model identifier must be at least 3 characters long")
|
||||||
|
|
||||||
|
elif agent_type in ["workflow", "task", "sequential", "parallel", "loop"]:
|
||||||
|
# Para estes tipos, não devem ter modelo
|
||||||
|
if v and (isinstance(v, str) and v.strip()):
|
||||||
|
# Avisar mas permitir (será removido durante a criação)
|
||||||
|
import logging
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
logger.warning(f"{agent_type} agents don't need model configuration. Model will be ignored.")
|
||||||
|
|
||||||
return v
|
return v
|
||||||
|
|
||||||
@validator("api_key_id")
|
@field_validator("api_key_id")
|
||||||
def validate_api_key_id(cls, v, values):
|
@classmethod
|
||||||
|
def validate_api_key_id(cls, v, info):
|
||||||
|
values = info.data if hasattr(info, 'data') else {}
|
||||||
|
agent_type = values.get("type")
|
||||||
|
|
||||||
|
# API key é obrigatório para agentes LLM (a menos que esteja na config)
|
||||||
|
if agent_type == "llm" and not v:
|
||||||
|
# Verificar se tem API key na config
|
||||||
|
config = values.get("config", {})
|
||||||
|
if not config or not config.get("api_key"):
|
||||||
|
# Não falhar aqui, deixar a validação para o momento da criação
|
||||||
|
pass
|
||||||
|
|
||||||
return v
|
return v
|
||||||
|
|
||||||
@validator("config")
|
@field_validator("config")
|
||||||
def validate_config(cls, v, values):
|
@classmethod
|
||||||
if "type" in values and values["type"] == "a2a":
|
def validate_config(cls, v, info):
|
||||||
|
values = info.data if hasattr(info, 'data') else {}
|
||||||
|
agent_type = values.get("type")
|
||||||
|
|
||||||
|
if not agent_type:
|
||||||
|
return v
|
||||||
|
|
||||||
|
# A2A agents têm config opcional
|
||||||
|
if agent_type == "a2a":
|
||||||
return v or {}
|
return v or {}
|
||||||
|
|
||||||
if "type" not in values:
|
# Workflow agents têm config específico para workflow
|
||||||
|
if agent_type == "workflow":
|
||||||
|
if v and isinstance(v, dict):
|
||||||
|
if not v.get("workflow"):
|
||||||
|
raise ValueError("Workflow agents must have 'workflow' configuration")
|
||||||
return v
|
return v
|
||||||
|
|
||||||
# For workflow agents, we do not perform any validation
|
# Config é obrigatório para outros tipos (exceto a2a)
|
||||||
if "type" in values and values["type"] == "workflow":
|
if not v and agent_type not in ["a2a"]:
|
||||||
return v
|
|
||||||
|
|
||||||
if not v and values.get("type") != "a2a":
|
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Configuration is required for {values.get('type')} agent type"
|
f"Configuration is required for {agent_type} agent type"
|
||||||
)
|
)
|
||||||
|
|
||||||
if values["type"] == "llm":
|
# Validação específica por tipo
|
||||||
if isinstance(v, dict):
|
if agent_type == "llm":
|
||||||
try:
|
return cls._validate_llm_config(v)
|
||||||
# Convert the dictionary to LLMConfig
|
elif agent_type in ["sequential", "parallel", "loop"]:
|
||||||
v = LLMConfig(**v)
|
return cls._validate_composite_config(v, agent_type)
|
||||||
except Exception as e:
|
elif agent_type == "task":
|
||||||
raise ValueError(f"Invalid LLM configuration for agent: {str(e)}")
|
return cls._validate_task_config(v)
|
||||||
elif not isinstance(v, LLMConfig):
|
|
||||||
raise ValueError("Invalid LLM configuration for agent")
|
|
||||||
elif values["type"] in ["sequential", "parallel", "loop"]:
|
|
||||||
if not isinstance(v, dict):
|
|
||||||
raise ValueError(f'Invalid configuration for agent {values["type"]}')
|
|
||||||
if "sub_agents" not in v:
|
|
||||||
raise ValueError(f'Agent {values["type"]} must have sub_agents')
|
|
||||||
if not isinstance(v["sub_agents"], list):
|
|
||||||
raise ValueError("sub_agents must be a list")
|
|
||||||
if not v["sub_agents"]:
|
|
||||||
raise ValueError(
|
|
||||||
f'Agent {values["type"]} must have at least one sub-agent'
|
|
||||||
)
|
|
||||||
elif values["type"] == "task":
|
|
||||||
if not isinstance(v, dict):
|
|
||||||
raise ValueError(f'Invalid configuration for agent {values["type"]}')
|
|
||||||
if "tasks" not in v:
|
|
||||||
raise ValueError(f'Agent {values["type"]} must have tasks')
|
|
||||||
if not isinstance(v["tasks"], list):
|
|
||||||
raise ValueError("tasks must be a list")
|
|
||||||
if not v["tasks"]:
|
|
||||||
raise ValueError(f'Agent {values["type"]} must have at least one task')
|
|
||||||
for task in v["tasks"]:
|
|
||||||
if not isinstance(task, dict):
|
|
||||||
raise ValueError("Each task must be a dictionary")
|
|
||||||
required_fields = ["agent_id", "description", "expected_output"]
|
|
||||||
for field in required_fields:
|
|
||||||
if field not in task:
|
|
||||||
raise ValueError(f"Task missing required field: {field}")
|
|
||||||
|
|
||||||
if "sub_agents" in v and v["sub_agents"] is not None:
|
|
||||||
if not isinstance(v["sub_agents"], list):
|
|
||||||
raise ValueError("sub_agents must be a list")
|
|
||||||
|
|
||||||
return v
|
|
||||||
|
|
||||||
return v
|
return v
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _validate_llm_config(cls, v):
|
||||||
|
"""Valida configuração para agentes LLM"""
|
||||||
|
if isinstance(v, dict):
|
||||||
|
try:
|
||||||
|
# Convert the dictionary to LLMConfig
|
||||||
|
v = LLMConfig(**v)
|
||||||
|
except Exception as e:
|
||||||
|
raise ValueError(f"Invalid LLM configuration: {str(e)}")
|
||||||
|
elif not isinstance(v, LLMConfig):
|
||||||
|
raise ValueError("Invalid LLM configuration format")
|
||||||
|
return v
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _validate_composite_config(cls, v, agent_type):
|
||||||
|
"""Valida configuração para agentes compostos (sequential, parallel, loop)"""
|
||||||
|
if not isinstance(v, dict):
|
||||||
|
raise ValueError(f'Configuration for {agent_type} agent must be a dictionary')
|
||||||
|
|
||||||
|
if "sub_agents" not in v:
|
||||||
|
raise ValueError(f'{agent_type} agents must have sub_agents configuration')
|
||||||
|
|
||||||
|
if not isinstance(v["sub_agents"], list):
|
||||||
|
raise ValueError("sub_agents must be a list")
|
||||||
|
|
||||||
|
if not v["sub_agents"]:
|
||||||
|
raise ValueError(
|
||||||
|
f'{agent_type} agents must have at least one sub-agent'
|
||||||
|
)
|
||||||
|
|
||||||
|
# Validação específica para LoopAgent
|
||||||
|
if agent_type == "loop":
|
||||||
|
max_iterations = v.get("max_iterations", 5)
|
||||||
|
if not isinstance(max_iterations, int) or max_iterations <= 0:
|
||||||
|
raise ValueError("max_iterations must be a positive integer")
|
||||||
|
|
||||||
|
return v
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _validate_task_config(cls, v):
|
||||||
|
"""Valida configuração para agentes de task"""
|
||||||
|
if not isinstance(v, dict):
|
||||||
|
raise ValueError('Configuration for task agent must be a dictionary')
|
||||||
|
|
||||||
|
if "tasks" not in v:
|
||||||
|
raise ValueError('Task agents must have tasks configuration')
|
||||||
|
|
||||||
|
if not isinstance(v["tasks"], list):
|
||||||
|
raise ValueError("tasks must be a list")
|
||||||
|
|
||||||
|
if not v["tasks"]:
|
||||||
|
raise ValueError('Task agents must have at least one task')
|
||||||
|
|
||||||
|
# Validar cada task individualmente
|
||||||
|
for i, task in enumerate(v["tasks"]):
|
||||||
|
if not isinstance(task, dict):
|
||||||
|
raise ValueError(f"Task {i+1} must be a dictionary")
|
||||||
|
|
||||||
|
required_fields = ["agent_id", "description", "expected_output"]
|
||||||
|
for field in required_fields:
|
||||||
|
if field not in task:
|
||||||
|
raise ValueError(f"Task {i+1} missing required field: {field}")
|
||||||
|
|
||||||
|
# Verificar se os campos não estão vazios
|
||||||
|
if not task[field] or (isinstance(task[field], str) and not task[field].strip()):
|
||||||
|
raise ValueError(f"Task {i+1} field '{field}' cannot be empty")
|
||||||
|
|
||||||
|
# Validar sub_agents se presente
|
||||||
|
if "sub_agents" in v and v["sub_agents"] is not None:
|
||||||
|
if not isinstance(v["sub_agents"], list):
|
||||||
|
raise ValueError("sub_agents must be a list")
|
||||||
|
|
||||||
|
return v
|
||||||
|
|
||||||
|
@model_validator(mode='after')
|
||||||
|
def validate_agent_consistency(self):
|
||||||
|
"""Validação cruzada entre campos do agente"""
|
||||||
|
|
||||||
|
# Verificar consistência entre tipo e configurações
|
||||||
|
if self.type == "llm":
|
||||||
|
# LLM agents devem ter modelo
|
||||||
|
if not self.model or (isinstance(self.model, str) and self.model.strip() == ""):
|
||||||
|
raise ValueError("LLM agents must have a valid model")
|
||||||
|
|
||||||
|
# LLM agents devem ter API key (na config ou api_key_id)
|
||||||
|
has_api_key = bool(self.api_key_id)
|
||||||
|
if not has_api_key and self.config:
|
||||||
|
config_dict = self.config if isinstance(self.config, dict) else self.config.__dict__
|
||||||
|
has_api_key = bool(config_dict.get("api_key"))
|
||||||
|
|
||||||
|
if not has_api_key:
|
||||||
|
raise ValueError("LLM agents must have an API key configured")
|
||||||
|
|
||||||
|
elif self.type in ["workflow", "task", "sequential", "parallel", "loop"]:
|
||||||
|
# Orchestrator agents não devem ter modelo
|
||||||
|
if self.model and isinstance(self.model, str) and self.model.strip():
|
||||||
|
import logging
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
logger.warning(f"{self.type} agents don't need model configuration. Clearing model.")
|
||||||
|
self.model = None
|
||||||
|
|
||||||
|
elif self.type == "a2a":
|
||||||
|
# A2A agents devem ter agent_card_url
|
||||||
|
if not self.agent_card_url:
|
||||||
|
raise ValueError("A2A agents must have agent_card_url")
|
||||||
|
|
||||||
|
return self
|
||||||
|
|
||||||
|
|
||||||
class AgentCreate(AgentBase):
|
class AgentCreate(AgentBase):
|
||||||
client_id: UUID
|
client_id: UUID
|
||||||
|
|
||||||
|
@model_validator(mode='after')
|
||||||
|
def validate_creation_requirements(self):
|
||||||
|
"""Validações específicas para criação de agentes"""
|
||||||
|
|
||||||
|
# Chamar validação da classe pai
|
||||||
|
super().validate_agent_consistency()
|
||||||
|
|
||||||
|
# Validações específicas para criação
|
||||||
|
if self.type == "llm":
|
||||||
|
# Para criação, ser mais rigoroso com modelo
|
||||||
|
if not self.model or len(self.model.strip()) < 3:
|
||||||
|
raise ValueError(
|
||||||
|
"LLM agents require a valid model identifier (minimum 3 characters). "
|
||||||
|
"Examples: 'gpt-4', 'claude-3-sonnet', 'gemini-pro'"
|
||||||
|
)
|
||||||
|
|
||||||
|
return self
|
||||||
|
|
||||||
|
|
||||||
class Agent(AgentBase):
|
class Agent(AgentBase):
|
||||||
id: UUID
|
id: UUID
|
||||||
@ -237,17 +381,17 @@ class Agent(AgentBase):
|
|||||||
agent_card_url: Optional[str] = None
|
agent_card_url: Optional[str] = None
|
||||||
folder_id: Optional[UUID4] = None
|
folder_id: Optional[UUID4] = None
|
||||||
|
|
||||||
class Config:
|
model_config = ConfigDict(from_attributes=True)
|
||||||
from_attributes = True
|
|
||||||
|
|
||||||
@validator("agent_card_url", pre=True)
|
@field_validator("agent_card_url", mode='before')
|
||||||
def set_agent_card_url(cls, v, values):
|
@classmethod
|
||||||
|
def set_agent_card_url(cls, v, info):
|
||||||
if v:
|
if v:
|
||||||
return v
|
return v
|
||||||
|
|
||||||
|
values = info.data if hasattr(info, 'data') else {}
|
||||||
if "id" in values:
|
if "id" in values:
|
||||||
from os import getenv
|
from os import getenv
|
||||||
|
|
||||||
return f"{getenv('API_URL', '')}/api/v1/a2a/{values['id']}/.well-known/agent.json"
|
return f"{getenv('API_URL', '')}/api/v1/a2a/{values['id']}/.well-known/agent.json"
|
||||||
|
|
||||||
return v
|
return v
|
||||||
@ -262,6 +406,7 @@ class ToolConfig(BaseModel):
|
|||||||
inputModes: List[str] = Field(default_factory=list)
|
inputModes: List[str] = Field(default_factory=list)
|
||||||
outputModes: List[str] = Field(default_factory=list)
|
outputModes: List[str] = Field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
# Last edited by Arley Peter on 2025-05-17
|
# Last edited by Arley Peter on 2025-05-17
|
||||||
class MCPServerBase(BaseModel):
|
class MCPServerBase(BaseModel):
|
||||||
name: str
|
name: str
|
||||||
@ -272,6 +417,29 @@ class MCPServerBase(BaseModel):
|
|||||||
tools: Optional[List[ToolConfig]] = Field(default_factory=list)
|
tools: Optional[List[ToolConfig]] = Field(default_factory=list)
|
||||||
type: str = Field(default="official")
|
type: str = Field(default="official")
|
||||||
|
|
||||||
|
@field_validator("name")
|
||||||
|
@classmethod
|
||||||
|
def validate_name(cls, v):
|
||||||
|
if not v or not v.strip():
|
||||||
|
raise ValueError("MCP Server name cannot be empty")
|
||||||
|
return v.strip()
|
||||||
|
|
||||||
|
@field_validator("config_type")
|
||||||
|
@classmethod
|
||||||
|
def validate_config_type(cls, v):
|
||||||
|
valid_types = ["studio", "custom"]
|
||||||
|
if v not in valid_types:
|
||||||
|
raise ValueError(f"config_type must be one of: {valid_types}")
|
||||||
|
return v
|
||||||
|
|
||||||
|
@field_validator("type")
|
||||||
|
@classmethod
|
||||||
|
def validate_type(cls, v):
|
||||||
|
valid_types = ["official", "custom"]
|
||||||
|
if v not in valid_types:
|
||||||
|
raise ValueError(f"type must be one of: {valid_types}")
|
||||||
|
return v
|
||||||
|
|
||||||
|
|
||||||
class MCPServerCreate(MCPServerBase):
|
class MCPServerCreate(MCPServerBase):
|
||||||
pass
|
pass
|
||||||
@ -282,8 +450,7 @@ class MCPServer(MCPServerBase):
|
|||||||
created_at: datetime
|
created_at: datetime
|
||||||
updated_at: Optional[datetime] = None
|
updated_at: Optional[datetime] = None
|
||||||
|
|
||||||
class Config:
|
model_config = ConfigDict(from_attributes=True)
|
||||||
from_attributes = True
|
|
||||||
|
|
||||||
|
|
||||||
class ToolBase(BaseModel):
|
class ToolBase(BaseModel):
|
||||||
@ -292,6 +459,13 @@ class ToolBase(BaseModel):
|
|||||||
config_json: Dict[str, Any] = Field(default_factory=dict)
|
config_json: Dict[str, Any] = Field(default_factory=dict)
|
||||||
environments: Dict[str, Any] = Field(default_factory=dict)
|
environments: Dict[str, Any] = Field(default_factory=dict)
|
||||||
|
|
||||||
|
@field_validator("name")
|
||||||
|
@classmethod
|
||||||
|
def validate_name(cls, v):
|
||||||
|
if not v or not v.strip():
|
||||||
|
raise ValueError("Tool name cannot be empty")
|
||||||
|
return v.strip()
|
||||||
|
|
||||||
|
|
||||||
class ToolCreate(ToolBase):
|
class ToolCreate(ToolBase):
|
||||||
pass
|
pass
|
||||||
@ -302,14 +476,20 @@ class Tool(ToolBase):
|
|||||||
created_at: datetime
|
created_at: datetime
|
||||||
updated_at: Optional[datetime] = None
|
updated_at: Optional[datetime] = None
|
||||||
|
|
||||||
class Config:
|
model_config = ConfigDict(from_attributes=True)
|
||||||
from_attributes = True
|
|
||||||
|
|
||||||
|
|
||||||
class AgentFolderBase(BaseModel):
|
class AgentFolderBase(BaseModel):
|
||||||
name: str
|
name: str
|
||||||
description: Optional[str] = None
|
description: Optional[str] = None
|
||||||
|
|
||||||
|
@field_validator("name")
|
||||||
|
@classmethod
|
||||||
|
def validate_name(cls, v):
|
||||||
|
if not v or not v.strip():
|
||||||
|
raise ValueError("Folder name cannot be empty")
|
||||||
|
return v.strip()
|
||||||
|
|
||||||
|
|
||||||
class AgentFolderCreate(AgentFolderBase):
|
class AgentFolderCreate(AgentFolderBase):
|
||||||
client_id: UUID4
|
client_id: UUID4
|
||||||
@ -326,3 +506,86 @@ class AgentFolder(AgentFolderBase):
|
|||||||
updated_at: Optional[datetime] = None
|
updated_at: Optional[datetime] = None
|
||||||
|
|
||||||
model_config = ConfigDict(from_attributes=True)
|
model_config = ConfigDict(from_attributes=True)
|
||||||
|
|
||||||
|
|
||||||
|
class AgentTypeInfo(BaseModel):
|
||||||
|
"""Informações sobre tipos de agente válidos"""
|
||||||
|
type: str
|
||||||
|
requires_model: bool
|
||||||
|
requires_config: bool
|
||||||
|
description: str
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_valid_types(cls) -> Dict[str, 'AgentTypeInfo']:
|
||||||
|
"""Retorna informações sobre todos os tipos válidos de agente"""
|
||||||
|
return {
|
||||||
|
"llm": cls(
|
||||||
|
type="llm",
|
||||||
|
requires_model=True,
|
||||||
|
requires_config=True,
|
||||||
|
description="Large Language Model agent - requires model and API key"
|
||||||
|
),
|
||||||
|
"workflow": cls(
|
||||||
|
type="workflow",
|
||||||
|
requires_model=False,
|
||||||
|
requires_config=True,
|
||||||
|
description="Workflow orchestrator agent - uses LangGraph for complex flows"
|
||||||
|
),
|
||||||
|
"task": cls(
|
||||||
|
type="task",
|
||||||
|
requires_model=False,
|
||||||
|
requires_config=True,
|
||||||
|
description="Task management agent - coordinates multiple tasks"
|
||||||
|
),
|
||||||
|
"sequential": cls(
|
||||||
|
type="sequential",
|
||||||
|
requires_model=False,
|
||||||
|
requires_config=True,
|
||||||
|
description="Sequential execution agent - runs sub-agents in order"
|
||||||
|
),
|
||||||
|
"parallel": cls(
|
||||||
|
type="parallel",
|
||||||
|
requires_model=False,
|
||||||
|
requires_config=True,
|
||||||
|
description="Parallel execution agent - runs sub-agents concurrently"
|
||||||
|
),
|
||||||
|
"loop": cls(
|
||||||
|
type="loop",
|
||||||
|
requires_model=False,
|
||||||
|
requires_config=True,
|
||||||
|
description="Loop execution agent - repeats sub-agents with conditions"
|
||||||
|
),
|
||||||
|
"a2a": cls(
|
||||||
|
type="a2a",
|
||||||
|
requires_model=False,
|
||||||
|
requires_config=False,
|
||||||
|
description="Agent-to-Agent communication - external agent integration"
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class ModelValidationResult(BaseModel):
|
||||||
|
"""Resultado da validação de modelo"""
|
||||||
|
is_valid: bool
|
||||||
|
error_message: Optional[str] = None
|
||||||
|
warnings: List[str] = Field(default_factory=list)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def success(cls, warnings: List[str] = None) -> 'ModelValidationResult':
|
||||||
|
return cls(is_valid=True, warnings=warnings or [])
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def failure(cls, error_message: str) -> 'ModelValidationResult':
|
||||||
|
return cls(is_valid=False, error_message=error_message)
|
||||||
|
|
||||||
|
|
||||||
|
class AgentValidationSummary(BaseModel):
|
||||||
|
"""Resumo de validação de agente"""
|
||||||
|
agent_id: Optional[UUID] = None
|
||||||
|
agent_name: Optional[str] = None
|
||||||
|
agent_type: str
|
||||||
|
is_valid: bool
|
||||||
|
model_validation: ModelValidationResult
|
||||||
|
config_validation: ModelValidationResult
|
||||||
|
general_errors: List[str] = Field(default_factory=list)
|
||||||
|
warnings: List[str] = Field(default_factory=list)
|
@ -67,15 +67,39 @@ class AgentBuilder:
|
|||||||
if agent_tools_ids and isinstance(agent_tools_ids, list):
|
if agent_tools_ids and isinstance(agent_tools_ids, list):
|
||||||
for agent_tool_id in agent_tools_ids:
|
for agent_tool_id in agent_tools_ids:
|
||||||
sub_agent = get_agent(self.db, agent_tool_id)
|
sub_agent = get_agent(self.db, agent_tool_id)
|
||||||
llm_agent, _ = await self.build_llm_agent(sub_agent)
|
if sub_agent:
|
||||||
if llm_agent:
|
# Verificar se o sub_agent é do tipo LLM antes de criar LlmAgent
|
||||||
agent_tools.append(AgentTool(agent=llm_agent))
|
if sub_agent.type == "llm":
|
||||||
|
llm_agent, _ = await self.build_llm_agent(sub_agent)
|
||||||
|
if llm_agent:
|
||||||
|
agent_tools.append(AgentTool(agent=llm_agent))
|
||||||
|
else:
|
||||||
|
logger.warning(f"Agent tool {agent_tool_id} is not of type 'llm', skipping")
|
||||||
|
else:
|
||||||
|
logger.warning(f"Agent tool {agent_tool_id} not found")
|
||||||
return agent_tools
|
return agent_tools
|
||||||
|
|
||||||
|
def _validate_llm_agent_model(self, agent) -> None:
|
||||||
|
"""Validate that LLM agent has a proper model configuration."""
|
||||||
|
if not hasattr(agent, 'model') or not agent.model:
|
||||||
|
logger.error(f"LLM agent {agent.name} does not have a model configured")
|
||||||
|
raise ValueError(f"LLM agent {agent.name} requires a model configuration")
|
||||||
|
|
||||||
|
if isinstance(agent.model, str) and agent.model.strip() == "":
|
||||||
|
logger.error(f"LLM agent {agent.name} has an empty model string")
|
||||||
|
raise ValueError(f"LLM agent {agent.name} has an empty model configuration")
|
||||||
|
|
||||||
|
logger.info(f"Model validation passed for agent {agent.name}: {agent.model}")
|
||||||
|
|
||||||
async def _create_llm_agent(
|
async def _create_llm_agent(
|
||||||
self, agent, enabled_tools: List[str] = []
|
self, agent, enabled_tools: List[str] = []
|
||||||
) -> Tuple[LlmAgent, Optional[AsyncExitStack]]:
|
) -> Tuple[LlmAgent, Optional[AsyncExitStack]]:
|
||||||
"""Create an LLM agent from the agent data."""
|
"""Create an LLM agent from the agent data."""
|
||||||
|
|
||||||
|
self._validate_llm_agent_model(agent)
|
||||||
|
|
||||||
|
logger.info(f"Creating LLM agent: {agent.name} with model: {agent.model}")
|
||||||
|
|
||||||
# Get custom tools from the configuration
|
# Get custom tools from the configuration
|
||||||
custom_tools = []
|
custom_tools = []
|
||||||
custom_tools = self.custom_tool_builder.build_tools(agent.config)
|
custom_tools = self.custom_tool_builder.build_tools(agent.config)
|
||||||
@ -110,7 +134,7 @@ class AgentBuilder:
|
|||||||
current_day_of_week=current_day_of_week,
|
current_day_of_week=current_day_of_week,
|
||||||
current_date_iso=current_date_iso,
|
current_date_iso=current_date_iso,
|
||||||
current_time=current_time,
|
current_time=current_time,
|
||||||
)
|
) if agent.instruction else ""
|
||||||
|
|
||||||
# add role on beginning of the prompt
|
# add role on beginning of the prompt
|
||||||
if agent.role:
|
if agent.role:
|
||||||
@ -170,21 +194,27 @@ class AgentBuilder:
|
|||||||
f"Agent {agent.name} does not have a configured API key"
|
f"Agent {agent.name} does not have a configured API key"
|
||||||
)
|
)
|
||||||
|
|
||||||
return (
|
if not agent.model or (isinstance(agent.model, str) and agent.model.strip() == ""):
|
||||||
LlmAgent(
|
raise ValueError(f"Cannot create LiteLlm with empty model for agent {agent.name}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
llm_agent = LlmAgent(
|
||||||
name=agent.name,
|
name=agent.name,
|
||||||
model=LiteLlm(model=agent.model, api_key=api_key),
|
model=LiteLlm(model=agent.model, api_key=api_key),
|
||||||
instruction=formatted_prompt,
|
instruction=formatted_prompt,
|
||||||
description=agent.description,
|
description=agent.description,
|
||||||
tools=all_tools,
|
tools=all_tools,
|
||||||
),
|
)
|
||||||
mcp_exit_stack,
|
logger.info(f"LLM agent created successfully: {agent.name}")
|
||||||
)
|
return llm_agent, mcp_exit_stack
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error creating LLM agent {agent.name}: {str(e)}")
|
||||||
|
raise ValueError(f"Error creating LLM agent {agent.name}: {str(e)}")
|
||||||
|
|
||||||
async def _get_sub_agents(
|
async def _get_sub_agents(
|
||||||
self, sub_agent_ids: List[str]
|
self, sub_agent_ids: List[str]
|
||||||
) -> List[Tuple[LlmAgent, Optional[AsyncExitStack]]]:
|
) -> List[Tuple[BaseAgent, Optional[AsyncExitStack]]]:
|
||||||
"""Get and create LLM sub-agents."""
|
"""Get and create sub-agents with proper type validation."""
|
||||||
sub_agents = []
|
sub_agents = []
|
||||||
for sub_agent_id in sub_agent_ids:
|
for sub_agent_id in sub_agent_ids:
|
||||||
sub_agent_id_str = str(sub_agent_id)
|
sub_agent_id_str = str(sub_agent_id)
|
||||||
@ -197,39 +227,50 @@ class AgentBuilder:
|
|||||||
|
|
||||||
logger.info(f"Sub-agent found: {agent.name} (type: {agent.type})")
|
logger.info(f"Sub-agent found: {agent.name} (type: {agent.type})")
|
||||||
|
|
||||||
if agent.type == "llm":
|
try:
|
||||||
sub_agent, exit_stack = await self._create_llm_agent(agent)
|
if agent.type == "llm":
|
||||||
elif agent.type == "a2a":
|
# Verificar se tem modelo antes de criar
|
||||||
sub_agent, exit_stack = await self.build_a2a_agent(agent)
|
if not agent.model or (isinstance(agent.model, str) and agent.model.strip() == ""):
|
||||||
elif agent.type == "workflow":
|
logger.error(f"LLM sub-agent {agent.name} does not have a model configured")
|
||||||
sub_agent, exit_stack = await self.build_workflow_agent(agent)
|
raise ValueError(f"LLM sub-agent {agent.name} requires a model configuration")
|
||||||
elif agent.type == "task":
|
sub_agent, exit_stack = await self._create_llm_agent(agent)
|
||||||
sub_agent, exit_stack = await self.build_task_agent(agent)
|
elif agent.type == "a2a":
|
||||||
elif agent.type == "sequential":
|
sub_agent, exit_stack = await self.build_a2a_agent(agent)
|
||||||
sub_agent, exit_stack = await self.build_composite_agent(agent)
|
elif agent.type == "workflow":
|
||||||
elif agent.type == "parallel":
|
# Workflow agents não precisam de modelo
|
||||||
sub_agent, exit_stack = await self.build_composite_agent(agent)
|
sub_agent, exit_stack = await self.build_workflow_agent(agent)
|
||||||
elif agent.type == "loop":
|
elif agent.type == "task":
|
||||||
sub_agent, exit_stack = await self.build_composite_agent(agent)
|
sub_agent, exit_stack = await self.build_task_agent(agent)
|
||||||
else:
|
elif agent.type == "sequential":
|
||||||
raise ValueError(f"Invalid agent type: {agent.type}")
|
sub_agent, exit_stack = await self.build_composite_agent(agent)
|
||||||
|
elif agent.type == "parallel":
|
||||||
|
sub_agent, exit_stack = await self.build_composite_agent(agent)
|
||||||
|
elif agent.type == "loop":
|
||||||
|
sub_agent, exit_stack = await self.build_composite_agent(agent)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Invalid agent type: {agent.type}")
|
||||||
|
|
||||||
sub_agents.append((sub_agent, exit_stack))
|
sub_agents.append((sub_agent, exit_stack))
|
||||||
logger.info(f"Sub-agent added: {agent.name}")
|
logger.info(f"Sub-agent added: {agent.name}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error creating sub-agent {agent.name}: {str(e)}")
|
||||||
|
raise ValueError(f"Error creating sub-agent {agent.name}: {str(e)}")
|
||||||
|
|
||||||
logger.info(f"Sub-agents created: {len(sub_agents)}")
|
logger.info(f"Sub-agents created: {len(sub_agents)}")
|
||||||
logger.info(f"Sub-agents: {str(sub_agents)}")
|
|
||||||
|
|
||||||
return sub_agents
|
return sub_agents
|
||||||
|
|
||||||
async def build_llm_agent(
|
async def build_llm_agent(
|
||||||
self, root_agent, enabled_tools: List[str] = []
|
self, root_agent, enabled_tools: List[str] = []
|
||||||
) -> Tuple[LlmAgent, Optional[AsyncExitStack]]:
|
) -> Tuple[LlmAgent, Optional[AsyncExitStack]]:
|
||||||
"""Build an LLM agent with its sub-agents."""
|
"""Build an LLM agent with its sub-agents."""
|
||||||
logger.info("Creating LLM agent")
|
logger.info(f"Creating LLM agent: {root_agent.name}")
|
||||||
|
|
||||||
|
if root_agent.type != "llm":
|
||||||
|
raise ValueError(f"Expected LLM agent, got {root_agent.type}")
|
||||||
|
|
||||||
sub_agents = []
|
sub_agents = []
|
||||||
if root_agent.config.get("sub_agents"):
|
if root_agent.config and root_agent.config.get("sub_agents"):
|
||||||
sub_agents_with_stacks = await self._get_sub_agents(
|
sub_agents_with_stacks = await self._get_sub_agents(
|
||||||
root_agent.config.get("sub_agents")
|
root_agent.config.get("sub_agents")
|
||||||
)
|
)
|
||||||
@ -241,20 +282,21 @@ class AgentBuilder:
|
|||||||
if sub_agents:
|
if sub_agents:
|
||||||
root_llm_agent.sub_agents = sub_agents
|
root_llm_agent.sub_agents = sub_agents
|
||||||
|
|
||||||
|
logger.info(f"LLM agent built successfully: {root_agent.name}")
|
||||||
return root_llm_agent, exit_stack
|
return root_llm_agent, exit_stack
|
||||||
|
|
||||||
async def build_a2a_agent(
|
async def build_a2a_agent(
|
||||||
self, root_agent
|
self, root_agent
|
||||||
) -> Tuple[BaseAgent, Optional[AsyncExitStack]]:
|
) -> Tuple[A2ACustomAgent, Optional[AsyncExitStack]]:
|
||||||
"""Build an A2A agent with its sub-agents."""
|
"""Build an A2A agent with its sub-agents."""
|
||||||
logger.info(f"Creating A2A agent from {root_agent.agent_card_url}")
|
logger.info(f"Creating A2A agent from {root_agent.name}")
|
||||||
|
|
||||||
if not root_agent.agent_card_url:
|
if not root_agent.agent_card_url:
|
||||||
raise ValueError("agent_card_url is required for a2a agents")
|
raise ValueError("agent_card_url is required for a2a agents")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
sub_agents = []
|
sub_agents = []
|
||||||
if root_agent.config.get("sub_agents"):
|
if root_agent.config and root_agent.config.get("sub_agents"):
|
||||||
sub_agents_with_stacks = await self._get_sub_agents(
|
sub_agents_with_stacks = await self._get_sub_agents(
|
||||||
root_agent.config.get("sub_agents")
|
root_agent.config.get("sub_agents")
|
||||||
)
|
)
|
||||||
@ -288,6 +330,9 @@ class AgentBuilder:
|
|||||||
"""Build a workflow agent with its sub-agents."""
|
"""Build a workflow agent with its sub-agents."""
|
||||||
logger.info(f"Creating Workflow agent from {root_agent.name}")
|
logger.info(f"Creating Workflow agent from {root_agent.name}")
|
||||||
|
|
||||||
|
if root_agent.type != "workflow":
|
||||||
|
raise ValueError(f"Expected workflow agent, got {root_agent.type}")
|
||||||
|
|
||||||
agent_config = root_agent.config or {}
|
agent_config = root_agent.config or {}
|
||||||
|
|
||||||
if not agent_config.get("workflow"):
|
if not agent_config.get("workflow"):
|
||||||
@ -295,7 +340,7 @@ class AgentBuilder:
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
sub_agents = []
|
sub_agents = []
|
||||||
if root_agent.config.get("sub_agents"):
|
if root_agent.config and root_agent.config.get("sub_agents"):
|
||||||
sub_agents_with_stacks = await self._get_sub_agents(
|
sub_agents_with_stacks = await self._get_sub_agents(
|
||||||
root_agent.config.get("sub_agents")
|
root_agent.config.get("sub_agents")
|
||||||
)
|
)
|
||||||
@ -304,15 +349,20 @@ class AgentBuilder:
|
|||||||
config = root_agent.config or {}
|
config = root_agent.config or {}
|
||||||
timeout = config.get("timeout", 300)
|
timeout = config.get("timeout", 300)
|
||||||
|
|
||||||
workflow_agent = WorkflowAgent(
|
kwargs = {
|
||||||
name=root_agent.name,
|
"name": root_agent.name,
|
||||||
flow_json=agent_config.get("workflow"),
|
"flow_json": agent_config.get("workflow"),
|
||||||
timeout=timeout,
|
"timeout": timeout,
|
||||||
description=root_agent.description
|
"description": root_agent.description or f"Workflow Agent for {root_agent.name}",
|
||||||
or f"Workflow Agent for {root_agent.name}",
|
"sub_agents": sub_agents,
|
||||||
sub_agents=sub_agents,
|
"db": self.db,
|
||||||
db=self.db,
|
}
|
||||||
)
|
|
||||||
|
# Se o root_agent tiver modelo, não passá-lo para o WorkflowAgent
|
||||||
|
if hasattr(root_agent, 'model') and root_agent.model:
|
||||||
|
logger.warning(f"Workflow agent {root_agent.name} has model '{root_agent.model}' configured, but workflow agents should not have models. Ignoring model.")
|
||||||
|
|
||||||
|
workflow_agent = WorkflowAgent(**kwargs)
|
||||||
|
|
||||||
logger.info(f"Workflow agent created successfully: {root_agent.name}")
|
logger.info(f"Workflow agent created successfully: {root_agent.name}")
|
||||||
|
|
||||||
@ -328,6 +378,9 @@ class AgentBuilder:
|
|||||||
"""Build a task agent with its sub-agents."""
|
"""Build a task agent with its sub-agents."""
|
||||||
logger.info(f"Creating Task agent: {root_agent.name}")
|
logger.info(f"Creating Task agent: {root_agent.name}")
|
||||||
|
|
||||||
|
if root_agent.type != "task":
|
||||||
|
raise ValueError(f"Expected task agent, got {root_agent.type}")
|
||||||
|
|
||||||
agent_config = root_agent.config or {}
|
agent_config = root_agent.config or {}
|
||||||
|
|
||||||
if not agent_config.get("tasks"):
|
if not agent_config.get("tasks"):
|
||||||
@ -336,7 +389,7 @@ class AgentBuilder:
|
|||||||
try:
|
try:
|
||||||
# Get sub-agents if there are any
|
# Get sub-agents if there are any
|
||||||
sub_agents = []
|
sub_agents = []
|
||||||
if root_agent.config.get("sub_agents"):
|
if root_agent.config and root_agent.config.get("sub_agents"):
|
||||||
sub_agents_with_stacks = await self._get_sub_agents(
|
sub_agents_with_stacks = await self._get_sub_agents(
|
||||||
root_agent.config.get("sub_agents")
|
root_agent.config.get("sub_agents")
|
||||||
)
|
)
|
||||||
@ -380,7 +433,11 @@ class AgentBuilder:
|
|||||||
f"Processing sub-agents for agent {root_agent.type} (ID: {root_agent.id}, Name: {root_agent.name})"
|
f"Processing sub-agents for agent {root_agent.type} (ID: {root_agent.id}, Name: {root_agent.name})"
|
||||||
)
|
)
|
||||||
|
|
||||||
if not root_agent.config.get("sub_agents"):
|
valid_composite_types = ["sequential", "parallel", "loop"]
|
||||||
|
if root_agent.type not in valid_composite_types:
|
||||||
|
raise ValueError(f"Expected composite agent type ({valid_composite_types}), got {root_agent.type}")
|
||||||
|
|
||||||
|
if not root_agent.config or not root_agent.config.get("sub_agents"):
|
||||||
logger.error(
|
logger.error(
|
||||||
f"Sub_agents configuration not found or empty for agent {root_agent.name}"
|
f"Sub_agents configuration not found or empty for agent {root_agent.name}"
|
||||||
)
|
)
|
||||||
@ -401,39 +458,51 @@ class AgentBuilder:
|
|||||||
sub_agents = [agent for agent, _ in sub_agents_with_stacks]
|
sub_agents = [agent for agent, _ in sub_agents_with_stacks]
|
||||||
logger.info(f"Extracted sub-agents: {[agent.name for agent in sub_agents]}")
|
logger.info(f"Extracted sub-agents: {[agent.name for agent in sub_agents]}")
|
||||||
|
|
||||||
if root_agent.type == "sequential":
|
if not sub_agents:
|
||||||
logger.info(f"Creating SequentialAgent with {len(sub_agents)} sub-agents")
|
raise ValueError(f"No valid sub-agents found for {root_agent.type} agent {root_agent.name}")
|
||||||
return (
|
|
||||||
SequentialAgent(
|
try:
|
||||||
name=root_agent.name,
|
if root_agent.type == "sequential":
|
||||||
sub_agents=sub_agents,
|
logger.info(f"Creating SequentialAgent with {len(sub_agents)} sub-agents")
|
||||||
description=root_agent.config.get("description", ""),
|
return (
|
||||||
),
|
SequentialAgent(
|
||||||
None,
|
name=root_agent.name,
|
||||||
)
|
sub_agents=sub_agents,
|
||||||
elif root_agent.type == "parallel":
|
description=root_agent.description or root_agent.config.get("description", ""),
|
||||||
logger.info(f"Creating ParallelAgent with {len(sub_agents)} sub-agents")
|
),
|
||||||
return (
|
None,
|
||||||
ParallelAgent(
|
)
|
||||||
name=root_agent.name,
|
elif root_agent.type == "parallel":
|
||||||
sub_agents=sub_agents,
|
logger.info(f"Creating ParallelAgent with {len(sub_agents)} sub-agents")
|
||||||
description=root_agent.config.get("description", ""),
|
return (
|
||||||
),
|
ParallelAgent(
|
||||||
None,
|
name=root_agent.name,
|
||||||
)
|
sub_agents=sub_agents,
|
||||||
elif root_agent.type == "loop":
|
description=root_agent.description or root_agent.config.get("description", ""),
|
||||||
logger.info(f"Creating LoopAgent with {len(sub_agents)} sub-agents")
|
),
|
||||||
return (
|
None,
|
||||||
LoopAgent(
|
)
|
||||||
name=root_agent.name,
|
elif root_agent.type == "loop":
|
||||||
sub_agents=sub_agents,
|
logger.info(f"Creating LoopAgent with {len(sub_agents)} sub-agents")
|
||||||
description=root_agent.config.get("description", ""),
|
max_iterations = root_agent.config.get("max_iterations", 5)
|
||||||
max_iterations=root_agent.config.get("max_iterations", 5),
|
if max_iterations <= 0:
|
||||||
),
|
logger.warning(f"Invalid max_iterations ({max_iterations}) for LoopAgent, using default 5")
|
||||||
None,
|
max_iterations = 5
|
||||||
)
|
return (
|
||||||
else:
|
LoopAgent(
|
||||||
raise ValueError(f"Invalid agent type: {root_agent.type}")
|
name=root_agent.name,
|
||||||
|
sub_agents=sub_agents,
|
||||||
|
description=root_agent.description or root_agent.config.get("description", ""),
|
||||||
|
max_iterations=max_iterations,
|
||||||
|
),
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Invalid composite agent type: {root_agent.type}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error creating {root_agent.type} agent {root_agent.name}: {str(e)}")
|
||||||
|
raise ValueError(f"Error creating {root_agent.type} agent {root_agent.name}: {str(e)}")
|
||||||
|
|
||||||
async def build_agent(self, root_agent, enabled_tools: List[str] = []) -> Tuple[
|
async def build_agent(self, root_agent, enabled_tools: List[str] = []) -> Tuple[
|
||||||
LlmAgent
|
LlmAgent
|
||||||
@ -446,13 +515,29 @@ class AgentBuilder:
|
|||||||
Optional[AsyncExitStack],
|
Optional[AsyncExitStack],
|
||||||
]:
|
]:
|
||||||
"""Build the appropriate agent based on the type of the root agent."""
|
"""Build the appropriate agent based on the type of the root agent."""
|
||||||
if root_agent.type == "llm":
|
|
||||||
return await self.build_llm_agent(root_agent, enabled_tools)
|
if not root_agent:
|
||||||
elif root_agent.type == "a2a":
|
raise ValueError("root_agent cannot be None")
|
||||||
return await self.build_a2a_agent(root_agent)
|
|
||||||
elif root_agent.type == "workflow":
|
if not hasattr(root_agent, 'type') or not root_agent.type:
|
||||||
return await self.build_workflow_agent(root_agent)
|
raise ValueError("root_agent must have a valid type")
|
||||||
elif root_agent.type == "task":
|
|
||||||
return await self.build_task_agent(root_agent)
|
logger.info(f"Building agent: {root_agent.name} (type: {root_agent.type})")
|
||||||
else:
|
|
||||||
return await self.build_composite_agent(root_agent)
|
try:
|
||||||
|
if root_agent.type == "llm":
|
||||||
|
return await self.build_llm_agent(root_agent, enabled_tools)
|
||||||
|
elif root_agent.type == "a2a":
|
||||||
|
return await self.build_a2a_agent(root_agent)
|
||||||
|
elif root_agent.type == "workflow":
|
||||||
|
return await self.build_workflow_agent(root_agent)
|
||||||
|
elif root_agent.type == "task":
|
||||||
|
return await self.build_task_agent(root_agent)
|
||||||
|
elif root_agent.type in ["sequential", "parallel", "loop"]:
|
||||||
|
return await self.build_composite_agent(root_agent)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown agent type: {root_agent.type}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error building agent {root_agent.name}: {str(e)}")
|
||||||
|
raise
|
@ -458,7 +458,7 @@ async def run_agent_stream(
|
|||||||
|
|
||||||
async for event in events_async:
|
async for event in events_async:
|
||||||
try:
|
try:
|
||||||
event_dict = event.dict()
|
event_dict = event.model_dump()
|
||||||
event_dict = convert_sets(event_dict)
|
event_dict = convert_sets(event_dict)
|
||||||
|
|
||||||
if "content" in event_dict and event_dict["content"]:
|
if "content" in event_dict and event_dict["content"]:
|
||||||
|
@ -30,6 +30,7 @@
|
|||||||
└──────────────────────────────────────────────────────────────────────────────┘
|
└──────────────────────────────────────────────────────────────────────────────┘
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from google.adk.agents import BaseAgent
|
from google.adk.agents import BaseAgent
|
||||||
from google.adk.agents.invocation_context import InvocationContext
|
from google.adk.agents.invocation_context import InvocationContext
|
||||||
@ -40,11 +41,14 @@ from typing import AsyncGenerator, Dict, Any, List, TypedDict
|
|||||||
import uuid
|
import uuid
|
||||||
|
|
||||||
from src.services.agent_service import get_agent
|
from src.services.agent_service import get_agent
|
||||||
|
from src.utils.logger import setup_logger
|
||||||
|
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from langgraph.graph import StateGraph, END
|
from langgraph.graph import StateGraph, END
|
||||||
|
|
||||||
|
logger = setup_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class State(TypedDict):
|
class State(TypedDict):
|
||||||
content: List[Event]
|
content: List[Event]
|
||||||
@ -63,6 +67,9 @@ class WorkflowAgent(BaseAgent):
|
|||||||
|
|
||||||
This agent allows defining and executing complex workflows between multiple agents
|
This agent allows defining and executing complex workflows between multiple agents
|
||||||
using LangGraph for orchestration.
|
using LangGraph for orchestration.
|
||||||
|
|
||||||
|
IMPORTANT: Workflow agents are orchestrators and should NOT have a model configured.
|
||||||
|
They delegate to sub-agents that have their own models.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Field declarations for Pydantic
|
# Field declarations for Pydantic
|
||||||
@ -89,6 +96,21 @@ class WorkflowAgent(BaseAgent):
|
|||||||
sub_agents: List of sub-agents to be executed after the workflow agent
|
sub_agents: List of sub-agents to be executed after the workflow agent
|
||||||
db: Session
|
db: Session
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# Workflow agents não devem ter modelos
|
||||||
|
if 'model' in kwargs:
|
||||||
|
logger.warning(f"Removing model from workflow agent {name}. Workflow agents should not have models.")
|
||||||
|
del kwargs['model']
|
||||||
|
|
||||||
|
if not flow_json:
|
||||||
|
raise ValueError(f"Workflow agent {name} requires flow_json configuration")
|
||||||
|
|
||||||
|
if not isinstance(flow_json, dict):
|
||||||
|
raise ValueError(f"Workflow agent {name} flow_json must be a dictionary")
|
||||||
|
|
||||||
|
if not flow_json.get('nodes'):
|
||||||
|
raise ValueError(f"Workflow agent {name} flow_json must contain nodes")
|
||||||
|
|
||||||
# Initialize base class
|
# Initialize base class
|
||||||
super().__init__(
|
super().__init__(
|
||||||
name=name,
|
name=name,
|
||||||
@ -98,9 +120,13 @@ class WorkflowAgent(BaseAgent):
|
|||||||
db=db,
|
db=db,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if hasattr(self, 'model'):
|
||||||
|
logger.warning(f"Workflow agent {name} had a model attribute. Removing it.")
|
||||||
|
delattr(self, 'model')
|
||||||
|
|
||||||
print(
|
logger.info(
|
||||||
f"Workflow agent initialized with {len(flow_json.get('nodes', []))} nodes"
|
f"Workflow agent '{name}' initialized with {len(flow_json.get('nodes', []))} nodes"
|
||||||
)
|
)
|
||||||
|
|
||||||
async def _create_node_functions(self, ctx: InvocationContext):
|
async def _create_node_functions(self, ctx: InvocationContext):
|
||||||
@ -112,11 +138,12 @@ class WorkflowAgent(BaseAgent):
|
|||||||
node_id: str,
|
node_id: str,
|
||||||
node_data: Dict[str, Any],
|
node_data: Dict[str, Any],
|
||||||
) -> AsyncGenerator[State, None]:
|
) -> AsyncGenerator[State, None]:
|
||||||
print("\n🏁 INITIAL NODE")
|
logger.info(f"🏁 INITIAL NODE: {node_id}")
|
||||||
|
|
||||||
content = state.get("content", [])
|
content = state.get("content", [])
|
||||||
|
|
||||||
if not content:
|
if not content:
|
||||||
|
logger.warning("No content found in initial state")
|
||||||
content = [
|
content = [
|
||||||
Event(
|
Event(
|
||||||
author=f"workflow-node:{node_id}",
|
author=f"workflow-node:{node_id}",
|
||||||
@ -128,9 +155,11 @@ class WorkflowAgent(BaseAgent):
|
|||||||
"status": "error",
|
"status": "error",
|
||||||
"node_outputs": {},
|
"node_outputs": {},
|
||||||
"cycle_count": 0,
|
"cycle_count": 0,
|
||||||
"conversation_history": ctx.session.events,
|
"conversation_history": ctx.session.events if ctx.session else [],
|
||||||
|
"session_id": state.get("session_id", ""),
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
|
|
||||||
session_id = state.get("session_id", "")
|
session_id = state.get("session_id", "")
|
||||||
|
|
||||||
# Store specific results for this node
|
# Store specific results for this node
|
||||||
@ -149,7 +178,7 @@ class WorkflowAgent(BaseAgent):
|
|||||||
"node_outputs": node_outputs,
|
"node_outputs": node_outputs,
|
||||||
"cycle_count": 0,
|
"cycle_count": 0,
|
||||||
"session_id": session_id,
|
"session_id": session_id,
|
||||||
"conversation_history": ctx.session.events,
|
"conversation_history": ctx.session.events if ctx.session else [],
|
||||||
}
|
}
|
||||||
|
|
||||||
# Generic function for agent nodes
|
# Generic function for agent nodes
|
||||||
@ -163,7 +192,7 @@ class WorkflowAgent(BaseAgent):
|
|||||||
|
|
||||||
# Increment cycle counter
|
# Increment cycle counter
|
||||||
cycle_count = state.get("cycle_count", 0) + 1
|
cycle_count = state.get("cycle_count", 0) + 1
|
||||||
print(f"\n👤 AGENT: {agent_name} (Cycle {cycle_count})")
|
logger.info(f"👤 AGENT: {agent_name} (Cycle {cycle_count})")
|
||||||
|
|
||||||
content = state.get("content", [])
|
content = state.get("content", [])
|
||||||
session_id = state.get("session_id", "")
|
session_id = state.get("session_id", "")
|
||||||
@ -171,14 +200,13 @@ class WorkflowAgent(BaseAgent):
|
|||||||
# Get conversation history
|
# Get conversation history
|
||||||
conversation_history = state.get("conversation_history", [])
|
conversation_history = state.get("conversation_history", [])
|
||||||
|
|
||||||
agent = get_agent(self.db, agent_id)
|
if not agent_id:
|
||||||
|
logger.error(f"Agent node {node_id} does not have a valid agent_id")
|
||||||
if not agent:
|
|
||||||
yield {
|
yield {
|
||||||
"content": [
|
"content": [
|
||||||
Event(
|
Event(
|
||||||
author=f"workflow-node:{node_id}",
|
author=f"workflow-node:{node_id}",
|
||||||
content=Content(parts=[Part(text="Agent not found")]),
|
content=Content(parts=[Part(text="Agent ID not configured")]),
|
||||||
)
|
)
|
||||||
],
|
],
|
||||||
"session_id": session_id,
|
"session_id": session_id,
|
||||||
@ -189,44 +217,84 @@ class WorkflowAgent(BaseAgent):
|
|||||||
}
|
}
|
||||||
return
|
return
|
||||||
|
|
||||||
# Import moved to inside the function to avoid circular import
|
agent = get_agent(self.db, agent_id)
|
||||||
from src.services.adk.agent_builder import AgentBuilder
|
|
||||||
|
|
||||||
agent_builder = AgentBuilder(self.db)
|
if not agent:
|
||||||
root_agent, exit_stack = await agent_builder.build_agent(agent)
|
logger.error(f"Agent not found for ID: {agent_id}")
|
||||||
|
yield {
|
||||||
|
"content": [
|
||||||
|
Event(
|
||||||
|
author=f"workflow-node:{node_id}",
|
||||||
|
content=Content(parts=[Part(text=f"Agent not found: {agent_id}")]),
|
||||||
|
)
|
||||||
|
],
|
||||||
|
"session_id": session_id,
|
||||||
|
"status": "error",
|
||||||
|
"node_outputs": {},
|
||||||
|
"cycle_count": cycle_count,
|
||||||
|
"conversation_history": conversation_history,
|
||||||
|
}
|
||||||
|
return
|
||||||
|
|
||||||
new_content = []
|
try:
|
||||||
async for event in root_agent.run_async(ctx):
|
# Import moved to inside the function to avoid circular import
|
||||||
conversation_history.append(event)
|
from src.services.adk.agent_builder import AgentBuilder
|
||||||
|
|
||||||
modified_event = Event(
|
|
||||||
author=f"workflow-node:{node_id}", content=event.content
|
|
||||||
)
|
|
||||||
new_content.append(modified_event)
|
|
||||||
|
|
||||||
|
agent_builder = AgentBuilder(self.db)
|
||||||
|
root_agent, exit_stack = await agent_builder.build_agent(agent)
|
||||||
|
|
||||||
print(f"New content: {new_content}")
|
new_content = []
|
||||||
|
async for event in root_agent.run_async(ctx):
|
||||||
|
conversation_history.append(event)
|
||||||
|
|
||||||
|
modified_event = Event(
|
||||||
|
author=f"workflow-node:{node_id}", content=event.content
|
||||||
|
)
|
||||||
|
new_content.append(modified_event)
|
||||||
|
|
||||||
node_outputs = state.get("node_outputs", {})
|
logger.debug(f"Agent {agent_name} generated {len(new_content)} events")
|
||||||
node_outputs[node_id] = {
|
|
||||||
"processed_by": agent_name,
|
|
||||||
"agent_content": new_content,
|
|
||||||
"cycle": cycle_count,
|
|
||||||
}
|
|
||||||
|
|
||||||
content = content + new_content
|
node_outputs = state.get("node_outputs", {})
|
||||||
|
node_outputs[node_id] = {
|
||||||
|
"processed_by": agent_name,
|
||||||
|
"agent_id": agent_id,
|
||||||
|
"agent_content": new_content,
|
||||||
|
"cycle": cycle_count,
|
||||||
|
"processed_at": datetime.now().isoformat(),
|
||||||
|
}
|
||||||
|
|
||||||
yield {
|
content = content + new_content
|
||||||
"content": content,
|
|
||||||
"status": "processed_by_agent",
|
|
||||||
"node_outputs": node_outputs,
|
|
||||||
"cycle_count": cycle_count,
|
|
||||||
"conversation_history": conversation_history,
|
|
||||||
"session_id": session_id,
|
|
||||||
}
|
|
||||||
|
|
||||||
if exit_stack:
|
yield {
|
||||||
await exit_stack.aclose()
|
"content": content,
|
||||||
|
"status": "processed_by_agent",
|
||||||
|
"node_outputs": node_outputs,
|
||||||
|
"cycle_count": cycle_count,
|
||||||
|
"conversation_history": conversation_history,
|
||||||
|
"session_id": session_id,
|
||||||
|
}
|
||||||
|
|
||||||
|
if exit_stack:
|
||||||
|
try:
|
||||||
|
await exit_stack.aclose()
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Error closing exit stack for agent {agent_name}: {e}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error executing agent {agent_name}: {str(e)}")
|
||||||
|
yield {
|
||||||
|
"content": [
|
||||||
|
Event(
|
||||||
|
author=f"workflow-node:{node_id}",
|
||||||
|
content=Content(parts=[Part(text=f"Error executing agent: {str(e)}")]),
|
||||||
|
)
|
||||||
|
],
|
||||||
|
"session_id": session_id,
|
||||||
|
"status": "agent_error",
|
||||||
|
"node_outputs": state.get("node_outputs", {}),
|
||||||
|
"cycle_count": cycle_count,
|
||||||
|
"conversation_history": conversation_history,
|
||||||
|
}
|
||||||
|
|
||||||
# Function for condition nodes
|
# Function for condition nodes
|
||||||
async def condition_node_function(
|
async def condition_node_function(
|
||||||
@ -236,7 +304,7 @@ class WorkflowAgent(BaseAgent):
|
|||||||
conditions = node_data.get("conditions", [])
|
conditions = node_data.get("conditions", [])
|
||||||
cycle_count = state.get("cycle_count", 0)
|
cycle_count = state.get("cycle_count", 0)
|
||||||
|
|
||||||
print(f"\n🔄 CONDITION: {label} (Cycle {cycle_count})")
|
logger.info(f"🔄 CONDITION: {label} (Cycle {cycle_count})")
|
||||||
|
|
||||||
content = state.get("content", [])
|
content = state.get("content", [])
|
||||||
conversation_history = state.get("conversation_history", [])
|
conversation_history = state.get("conversation_history", [])
|
||||||
@ -245,16 +313,17 @@ class WorkflowAgent(BaseAgent):
|
|||||||
if content and len(content) > 0:
|
if content and len(content) > 0:
|
||||||
for event in reversed(content):
|
for event in reversed(content):
|
||||||
if (
|
if (
|
||||||
event.author != "agent"
|
hasattr(event, 'author') and
|
||||||
or not hasattr(event.content, "parts")
|
event.author != "user" and
|
||||||
or not event.content.parts
|
hasattr(event, 'content') and
|
||||||
|
hasattr(event.content, "parts") and
|
||||||
|
event.content.parts
|
||||||
):
|
):
|
||||||
latest_event = event
|
latest_event = event
|
||||||
break
|
break
|
||||||
|
|
||||||
if latest_event:
|
if latest_event:
|
||||||
print(
|
logger.debug(f"Evaluating condition for latest event from: {latest_event.author}")
|
||||||
f"Evaluating condition only for the most recent event: '{latest_event}'"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Use only the most recent event for condition evaluation
|
# Use only the most recent event for condition evaluation
|
||||||
evaluation_state = state.copy()
|
evaluation_state = state.copy()
|
||||||
@ -273,25 +342,24 @@ class WorkflowAgent(BaseAgent):
|
|||||||
operator = condition_data.get("operator")
|
operator = condition_data.get("operator")
|
||||||
expected_value = condition_data.get("value")
|
expected_value = condition_data.get("value")
|
||||||
|
|
||||||
print(
|
logger.debug(
|
||||||
f" Checking if {field} {operator} '{expected_value}' (current value: '{evaluation_state.get(field, '')}')"
|
f"Checking condition: {field} {operator} '{expected_value}'"
|
||||||
)
|
)
|
||||||
|
|
||||||
if self._evaluate_condition(condition, evaluation_state):
|
if self._evaluate_condition(condition, evaluation_state):
|
||||||
conditions_met.append(condition_id)
|
conditions_met.append(condition_id)
|
||||||
condition_details.append(
|
condition_details.append(
|
||||||
f"{field} {operator} '{expected_value}' ✅"
|
f"{field} {operator} '{expected_value}' ✅"
|
||||||
)
|
)
|
||||||
print(f" ✅ Condition {condition_id} met!")
|
logger.info(f"✅ Condition {condition_id} met!")
|
||||||
else:
|
else:
|
||||||
condition_details.append(
|
condition_details.append(
|
||||||
f"{field} {operator} '{expected_value}' ❌"
|
f"{field} {operator} '{expected_value}' ❌"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Check if the cycle reached the limit (extra security)
|
max_cycles = 10 # Poderia vir da configuração
|
||||||
if cycle_count >= 10:
|
if cycle_count >= max_cycles:
|
||||||
print(
|
logger.warning(f"Cycle limit reached ({cycle_count}). Forcing termination.")
|
||||||
f"⚠️ ATTENTION: Cycle limit reached ({cycle_count}). Forcing termination."
|
|
||||||
)
|
|
||||||
|
|
||||||
condition_content = [
|
condition_content = [
|
||||||
Event(
|
Event(
|
||||||
@ -314,10 +382,10 @@ class WorkflowAgent(BaseAgent):
|
|||||||
node_outputs = state.get("node_outputs", {})
|
node_outputs = state.get("node_outputs", {})
|
||||||
node_outputs[node_id] = {
|
node_outputs[node_id] = {
|
||||||
"condition_evaluated": label,
|
"condition_evaluated": label,
|
||||||
"content_evaluated": content,
|
|
||||||
"conditions_met": conditions_met,
|
"conditions_met": conditions_met,
|
||||||
"condition_details": condition_details,
|
"condition_details": condition_details,
|
||||||
"cycle": cycle_count,
|
"cycle": cycle_count,
|
||||||
|
"evaluated_at": datetime.now().isoformat(),
|
||||||
}
|
}
|
||||||
|
|
||||||
# Prepare a more descriptive message about the conditions
|
# Prepare a more descriptive message about the conditions
|
||||||
@ -334,7 +402,8 @@ class WorkflowAgent(BaseAgent):
|
|||||||
)
|
)
|
||||||
]
|
]
|
||||||
),
|
),
|
||||||
) ]
|
)
|
||||||
|
]
|
||||||
content = content + condition_content
|
content = content + condition_content
|
||||||
|
|
||||||
yield {
|
yield {
|
||||||
@ -353,7 +422,7 @@ class WorkflowAgent(BaseAgent):
|
|||||||
message_type = message_data.get("type", "text")
|
message_type = message_data.get("type", "text")
|
||||||
message_content = message_data.get("content", "")
|
message_content = message_data.get("content", "")
|
||||||
|
|
||||||
print(f"\n💬 MESSAGE-NODE: {message_content}")
|
logger.info(f"💬 MESSAGE-NODE: {message_content}")
|
||||||
|
|
||||||
content = state.get("content", [])
|
content = state.get("content", [])
|
||||||
session_id = state.get("session_id", "")
|
session_id = state.get("session_id", "")
|
||||||
@ -371,6 +440,8 @@ class WorkflowAgent(BaseAgent):
|
|||||||
node_outputs[node_id] = {
|
node_outputs[node_id] = {
|
||||||
"message_type": message_type,
|
"message_type": message_type,
|
||||||
"message_content": message_content,
|
"message_content": message_content,
|
||||||
|
"label": label,
|
||||||
|
"processed_at": datetime.now().isoformat(),
|
||||||
}
|
}
|
||||||
|
|
||||||
yield {
|
yield {
|
||||||
@ -378,7 +449,8 @@ class WorkflowAgent(BaseAgent):
|
|||||||
"status": "message_added",
|
"status": "message_added",
|
||||||
"node_outputs": node_outputs,
|
"node_outputs": node_outputs,
|
||||||
"cycle_count": state.get("cycle_count", 0),
|
"cycle_count": state.get("cycle_count", 0),
|
||||||
"conversation_history": conversation_history, "session_id": session_id,
|
"conversation_history": conversation_history,
|
||||||
|
"session_id": session_id,
|
||||||
}
|
}
|
||||||
|
|
||||||
async def delay_node_function(
|
async def delay_node_function(
|
||||||
@ -389,6 +461,10 @@ class WorkflowAgent(BaseAgent):
|
|||||||
delay_unit = delay_data.get("unit", "seconds")
|
delay_unit = delay_data.get("unit", "seconds")
|
||||||
delay_description = delay_data.get("description", "")
|
delay_description = delay_data.get("description", "")
|
||||||
|
|
||||||
|
if delay_value <= 0:
|
||||||
|
logger.warning(f"Invalid delay value: {delay_value}. Using 1 second.")
|
||||||
|
delay_value = 1
|
||||||
|
|
||||||
# Convert to seconds based on unit
|
# Convert to seconds based on unit
|
||||||
delay_seconds = delay_value
|
delay_seconds = delay_value
|
||||||
if delay_unit == "minutes":
|
if delay_unit == "minutes":
|
||||||
@ -397,7 +473,7 @@ class WorkflowAgent(BaseAgent):
|
|||||||
delay_seconds = delay_value * 3600
|
delay_seconds = delay_value * 3600
|
||||||
|
|
||||||
label = node_data.get("label", "delay_node")
|
label = node_data.get("label", "delay_node")
|
||||||
print(f"\n⏱️ DELAY-NODE: {delay_value} {delay_unit} - {delay_description}")
|
logger.info(f"⏱️ DELAY-NODE: {delay_value} {delay_unit} ({delay_seconds}s) - {delay_description}")
|
||||||
|
|
||||||
content = state.get("content", [])
|
content = state.get("content", [])
|
||||||
session_id = state.get("session_id", "")
|
session_id = state.get("session_id", "")
|
||||||
@ -409,13 +485,17 @@ class WorkflowAgent(BaseAgent):
|
|||||||
"delay_value": delay_value,
|
"delay_value": delay_value,
|
||||||
"delay_unit": delay_unit,
|
"delay_unit": delay_unit,
|
||||||
"delay_seconds": delay_seconds,
|
"delay_seconds": delay_seconds,
|
||||||
|
"delay_description": delay_description,
|
||||||
"delay_start_time": datetime.now().isoformat(),
|
"delay_start_time": datetime.now().isoformat(),
|
||||||
}
|
}
|
||||||
|
|
||||||
# Actually perform the delay
|
# Actually perform the delay
|
||||||
import asyncio
|
import asyncio
|
||||||
await asyncio.sleep(delay_seconds)
|
try:
|
||||||
|
await asyncio.sleep(delay_seconds)
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
logger.warning(f"Delay in node {node_id} was cancelled")
|
||||||
|
# Continue execution even if delay was cancelled
|
||||||
|
|
||||||
# Update node outputs with completion information
|
# Update node outputs with completion information
|
||||||
node_outputs[node_id]["delay_end_time"] = datetime.now().isoformat()
|
node_outputs[node_id]["delay_end_time"] = datetime.now().isoformat()
|
||||||
@ -424,7 +504,8 @@ class WorkflowAgent(BaseAgent):
|
|||||||
yield {
|
yield {
|
||||||
"content": content,
|
"content": content,
|
||||||
"status": "delay_completed",
|
"status": "delay_completed",
|
||||||
"node_outputs": node_outputs, "cycle_count": state.get("cycle_count", 0),
|
"node_outputs": node_outputs,
|
||||||
|
"cycle_count": state.get("cycle_count", 0),
|
||||||
"conversation_history": conversation_history,
|
"conversation_history": conversation_history,
|
||||||
"session_id": session_id,
|
"session_id": session_id,
|
||||||
}
|
}
|
||||||
@ -452,7 +533,7 @@ class WorkflowAgent(BaseAgent):
|
|||||||
|
|
||||||
result = self._process_condition(operator, actual_value, expected_value)
|
result = self._process_condition(operator, actual_value, expected_value)
|
||||||
|
|
||||||
print(f" Check '{operator}': {result}")
|
logger.debug(f"Condition check '{operator}': {result}")
|
||||||
return result
|
return result
|
||||||
|
|
||||||
return False
|
return False
|
||||||
@ -488,7 +569,7 @@ class WorkflowAgent(BaseAgent):
|
|||||||
|
|
||||||
if extracted_texts:
|
if extracted_texts:
|
||||||
joined_text = " ".join(extracted_texts)
|
joined_text = " ".join(extracted_texts)
|
||||||
print(f" Extracted text from events: '{joined_text[:100]}...'")
|
logger.debug(f"Extracted text from events: '{joined_text[:100]}...'")
|
||||||
return joined_text
|
return joined_text
|
||||||
|
|
||||||
return ""
|
return ""
|
||||||
@ -524,6 +605,7 @@ class WorkflowAgent(BaseAgent):
|
|||||||
elif operator in ["matches", "not_matches"]:
|
elif operator in ["matches", "not_matches"]:
|
||||||
return self._check_regex(operator, actual_str, expected_str)
|
return self._check_regex(operator, actual_str, expected_str)
|
||||||
|
|
||||||
|
logger.warning(f"Unknown operator: {operator}")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def _check_definition(self, operator, actual_value):
|
def _check_definition(self, operator, actual_value):
|
||||||
@ -563,8 +645,8 @@ class WorkflowAgent(BaseAgent):
|
|||||||
else: # less_than_or_equal
|
else: # less_than_or_equal
|
||||||
return actual_num <= expected_num
|
return actual_num <= expected_num
|
||||||
except (ValueError, TypeError):
|
except (ValueError, TypeError):
|
||||||
print(
|
logger.warning(
|
||||||
f" Error converting values for numeric comparison: '{actual_str[:100]}...' and '{expected_str}'"
|
f"Error converting values for numeric comparison: '{actual_str[:100]}...' and '{expected_str}'"
|
||||||
)
|
)
|
||||||
return False
|
return False
|
||||||
|
|
||||||
@ -579,7 +661,7 @@ class WorkflowAgent(BaseAgent):
|
|||||||
else: # not_matches
|
else: # not_matches
|
||||||
return not bool(pattern.search(actual_str))
|
return not bool(pattern.search(actual_str))
|
||||||
except re.error:
|
except re.error:
|
||||||
print(f" Error in regular expression: '{expected_str}'")
|
logger.warning(f"Error in regular expression: '{expected_str}'")
|
||||||
return (
|
return (
|
||||||
operator == "not_matches"
|
operator == "not_matches"
|
||||||
) # Return True for not_matches, False for matches
|
) # Return True for not_matches, False for matches
|
||||||
@ -589,8 +671,8 @@ class WorkflowAgent(BaseAgent):
|
|||||||
expected_lower = expected_str.lower()
|
expected_lower = expected_str.lower()
|
||||||
actual_lower = actual_str.lower()
|
actual_lower = actual_str.lower()
|
||||||
|
|
||||||
print(
|
logger.debug(
|
||||||
f" Comparison '{operator}' without case distinction: '{expected_lower}' in '{actual_lower[:100]}...'"
|
f"Comparison '{operator}' case insensitive: '{expected_lower}' in '{actual_lower[:100]}...'"
|
||||||
)
|
)
|
||||||
|
|
||||||
if operator == "contains":
|
if operator == "contains":
|
||||||
@ -627,14 +709,13 @@ class WorkflowAgent(BaseAgent):
|
|||||||
# Routing function for each specific node
|
# Routing function for each specific node
|
||||||
def create_router_for_node(node_id: str):
|
def create_router_for_node(node_id: str):
|
||||||
def router(state: State) -> str:
|
def router(state: State) -> str:
|
||||||
print(f"Routing from node: {node_id}")
|
logger.debug(f"Routing from node: {node_id}")
|
||||||
|
|
||||||
# Check if the cycle limit has been reached
|
# Check if the cycle limit has been reached
|
||||||
cycle_count = state.get("cycle_count", 0)
|
cycle_count = state.get("cycle_count", 0)
|
||||||
if cycle_count >= 10:
|
max_cycles = 10 # Configurável
|
||||||
print(
|
if cycle_count >= max_cycles:
|
||||||
f"⚠️ Cycle limit ({cycle_count}) reached. Finalizing the flow."
|
logger.warning(f"Cycle limit ({cycle_count}) reached. Finalizing the flow.")
|
||||||
)
|
|
||||||
return END
|
return END
|
||||||
|
|
||||||
# If it's a condition node, evaluate the conditions
|
# If it's a condition node, evaluate the conditions
|
||||||
@ -648,32 +729,29 @@ class WorkflowAgent(BaseAgent):
|
|||||||
if conditions_met:
|
if conditions_met:
|
||||||
any_condition_met = True
|
any_condition_met = True
|
||||||
condition_id = conditions_met[0]
|
condition_id = conditions_met[0]
|
||||||
print(
|
logger.debug(f"Using stored condition result: Condition {condition_id} met.")
|
||||||
f"Using stored condition evaluation result: Condition {condition_id} met."
|
|
||||||
)
|
|
||||||
if (
|
if (
|
||||||
node_id in edges_map
|
node_id in edges_map
|
||||||
and condition_id in edges_map[node_id]
|
and condition_id in edges_map[node_id]
|
||||||
):
|
):
|
||||||
return edges_map[node_id][condition_id]
|
return edges_map[node_id][condition_id]
|
||||||
else:
|
else:
|
||||||
print(
|
logger.debug("Using stored condition result: No conditions met.")
|
||||||
"Using stored condition evaluation result: No conditions met."
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
|
# Evaluate conditions
|
||||||
for condition in conditions:
|
for condition in conditions:
|
||||||
condition_id = condition.get("id")
|
condition_id = condition.get("id")
|
||||||
|
|
||||||
# Get latest event for evaluation, ignoring condition node informational events
|
# Get latest event for evaluation, ignoring condition node informational events
|
||||||
content = state.get("content", [])
|
content = state.get("content", [])
|
||||||
|
|
||||||
# Filter out events generated by condition nodes or informational messages
|
# Filter out events generated by condition nodes or that contain evaluation results
|
||||||
filtered_content = []
|
filtered_content = []
|
||||||
for event in content:
|
for event in content:
|
||||||
# Ignore events from condition nodes or that contain evaluation results
|
# Ignore events from condition nodes or that contain evaluation results
|
||||||
if not hasattr(event, "author") or not (
|
if not hasattr(event, "author") or not (
|
||||||
event.author.startswith("Condition")
|
event.author.startswith("workflow-node:") and
|
||||||
or "Condition evaluated:" in str(event)
|
"Condition evaluated:" in str(event)
|
||||||
):
|
):
|
||||||
filtered_content.append(event)
|
filtered_content.append(event)
|
||||||
|
|
||||||
@ -687,9 +765,7 @@ class WorkflowAgent(BaseAgent):
|
|||||||
|
|
||||||
if is_condition_met:
|
if is_condition_met:
|
||||||
any_condition_met = True
|
any_condition_met = True
|
||||||
print(
|
logger.debug(f"Condition {condition_id} met. Moving to next node.")
|
||||||
f"Condition {condition_id} met. Moving to the next node."
|
|
||||||
)
|
|
||||||
|
|
||||||
# Find the connection that uses this condition_id as a handle
|
# Find the connection that uses this condition_id as a handle
|
||||||
if (
|
if (
|
||||||
@ -698,9 +774,7 @@ class WorkflowAgent(BaseAgent):
|
|||||||
):
|
):
|
||||||
return edges_map[node_id][condition_id]
|
return edges_map[node_id][condition_id]
|
||||||
else:
|
else:
|
||||||
print(
|
logger.debug(f"Condition {condition_id} not met.")
|
||||||
f"Condition {condition_id} not met. Continuing evaluation or using default path."
|
|
||||||
)
|
|
||||||
|
|
||||||
# If no condition is met, use the bottom-handle if available
|
# If no condition is met, use the bottom-handle if available
|
||||||
if not any_condition_met:
|
if not any_condition_met:
|
||||||
@ -708,14 +782,10 @@ class WorkflowAgent(BaseAgent):
|
|||||||
node_id in edges_map
|
node_id in edges_map
|
||||||
and "bottom-handle" in edges_map[node_id]
|
and "bottom-handle" in edges_map[node_id]
|
||||||
):
|
):
|
||||||
print(
|
logger.debug("No condition met. Using default path (bottom-handle).")
|
||||||
"No condition met. Using default path (bottom-handle)."
|
|
||||||
)
|
|
||||||
return edges_map[node_id]["bottom-handle"]
|
return edges_map[node_id]["bottom-handle"]
|
||||||
else:
|
else:
|
||||||
print(
|
logger.debug("No condition met and no default path. Closing the flow.")
|
||||||
"No condition met and no default path. Closing the flow."
|
|
||||||
)
|
|
||||||
return END
|
return END
|
||||||
|
|
||||||
# For regular nodes, simply follow the first available connection
|
# For regular nodes, simply follow the first available connection
|
||||||
@ -731,7 +801,7 @@ class WorkflowAgent(BaseAgent):
|
|||||||
return edges_map[node_id][first_handle]
|
return edges_map[node_id][first_handle]
|
||||||
|
|
||||||
# If there is no output connection, close the flow
|
# If there is no output connection, close the flow
|
||||||
print(f"No output connection from node {node_id}. Closing the flow.")
|
logger.debug(f"No output connection from node {node_id}. Closing the flow.")
|
||||||
return END
|
return END
|
||||||
|
|
||||||
return router
|
return router
|
||||||
@ -745,6 +815,9 @@ class WorkflowAgent(BaseAgent):
|
|||||||
# Extract nodes from the flow
|
# Extract nodes from the flow
|
||||||
nodes = flow_data.get("nodes", [])
|
nodes = flow_data.get("nodes", [])
|
||||||
|
|
||||||
|
if not nodes:
|
||||||
|
raise ValueError("Flow data must contain at least one node")
|
||||||
|
|
||||||
# Initialize StateGraph
|
# Initialize StateGraph
|
||||||
graph_builder = StateGraph(State)
|
graph_builder = StateGraph(State)
|
||||||
|
|
||||||
@ -754,34 +827,60 @@ class WorkflowAgent(BaseAgent):
|
|||||||
# Dictionary to store specific functions for each node
|
# Dictionary to store specific functions for each node
|
||||||
node_specific_functions = {}
|
node_specific_functions = {}
|
||||||
|
|
||||||
|
valid_node_types = set(node_functions.keys())
|
||||||
|
|
||||||
# Add nodes to the graph
|
# Add nodes to the graph
|
||||||
for node in nodes:
|
for node in nodes:
|
||||||
node_id = node.get("id")
|
node_id = node.get("id")
|
||||||
node_type = node.get("type")
|
node_type = node.get("type")
|
||||||
node_data = node.get("data", {})
|
node_data = node.get("data", {})
|
||||||
|
|
||||||
if node_type in node_functions:
|
if not node_id:
|
||||||
# Create a specific function for this node
|
logger.warning(f"Skipping node without ID: {node}")
|
||||||
def create_node_function(node_type, node_id, node_data):
|
continue
|
||||||
async def node_function(state):
|
|
||||||
# Consume the asynchronous generator and return the last result
|
if node_type not in valid_node_types:
|
||||||
result = None
|
logger.warning(f"Unknown node type '{node_type}' for node {node_id}. Skipping.")
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Create a specific function for this node
|
||||||
|
def create_node_function(node_type, node_id, node_data):
|
||||||
|
async def node_function(state):
|
||||||
|
# Consume the asynchronous generator and return the last result
|
||||||
|
result = None
|
||||||
|
try:
|
||||||
async for item in node_functions[node_type](
|
async for item in node_functions[node_type](
|
||||||
state, node_id, node_data
|
state, node_id, node_data
|
||||||
):
|
):
|
||||||
result = item
|
result = item
|
||||||
return result
|
return result
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error in node {node_id} ({node_type}): {str(e)}")
|
||||||
|
# Return error state
|
||||||
|
return {
|
||||||
|
"content": [
|
||||||
|
Event(
|
||||||
|
author=f"workflow-node:{node_id}",
|
||||||
|
content=Content(parts=[Part(text=f"Node error: {str(e)}")]),
|
||||||
|
)
|
||||||
|
],
|
||||||
|
"status": "node_error",
|
||||||
|
"node_outputs": state.get("node_outputs", {}),
|
||||||
|
"cycle_count": state.get("cycle_count", 0),
|
||||||
|
"conversation_history": state.get("conversation_history", []),
|
||||||
|
"session_id": state.get("session_id", ""),
|
||||||
|
}
|
||||||
|
|
||||||
return node_function
|
return node_function
|
||||||
|
|
||||||
# Add specific function to the dictionary
|
# Add specific function to the dictionary
|
||||||
node_specific_functions[node_id] = create_node_function(
|
node_specific_functions[node_id] = create_node_function(
|
||||||
node_type, node_id, node_data
|
node_type, node_id, node_data
|
||||||
)
|
)
|
||||||
|
|
||||||
# Add node to the graph
|
# Add node to the graph
|
||||||
print(f"Adding node {node_id} of type {node_type}")
|
logger.debug(f"Adding node {node_id} of type {node_type}")
|
||||||
graph_builder.add_node(node_id, node_specific_functions[node_id])
|
graph_builder.add_node(node_id, node_specific_functions[node_id])
|
||||||
|
|
||||||
# Create function to generate specific routers
|
# Create function to generate specific routers
|
||||||
create_router = self._create_flow_router(flow_data)
|
create_router = self._create_flow_router(flow_data)
|
||||||
@ -808,8 +907,8 @@ class WorkflowAgent(BaseAgent):
|
|||||||
node_router = create_router(node_id)
|
node_router = create_router(node_id)
|
||||||
|
|
||||||
# Add conditional connections
|
# Add conditional connections
|
||||||
print(f"Adding conditional connections for node {node_id}")
|
logger.debug(f"Adding conditional connections for node {node_id}")
|
||||||
print(f"Possible destinations: {edge_destinations}")
|
logger.debug(f"Possible destinations: {list(edge_destinations.keys())}")
|
||||||
|
|
||||||
graph_builder.add_conditional_edges(
|
graph_builder.add_conditional_edges(
|
||||||
node_id, node_router, edge_destinations
|
node_id, node_router, edge_destinations
|
||||||
@ -825,35 +924,56 @@ class WorkflowAgent(BaseAgent):
|
|||||||
# If there is no start-node, use the first node found
|
# If there is no start-node, use the first node found
|
||||||
if not entry_point and nodes:
|
if not entry_point and nodes:
|
||||||
entry_point = nodes[0].get("id")
|
entry_point = nodes[0].get("id")
|
||||||
|
logger.warning(f"No start-node found, using first node as entry point: {entry_point}")
|
||||||
|
|
||||||
# Define the entry point
|
# Define the entry point
|
||||||
if entry_point:
|
if entry_point:
|
||||||
print(f"Defining entry point: {entry_point}")
|
logger.info(f"Setting entry point: {entry_point}")
|
||||||
graph_builder.set_entry_point(entry_point)
|
graph_builder.set_entry_point(entry_point)
|
||||||
|
else:
|
||||||
|
raise ValueError("No valid entry point found for workflow")
|
||||||
|
|
||||||
# Compile the graph
|
# Compile the graph
|
||||||
return graph_builder.compile()
|
try:
|
||||||
|
compiled_graph = graph_builder.compile()
|
||||||
|
logger.info("Workflow graph compiled successfully")
|
||||||
|
return compiled_graph
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error compiling workflow graph: {str(e)}")
|
||||||
|
raise ValueError(f"Error compiling workflow graph: {str(e)}")
|
||||||
|
|
||||||
async def _run_async_impl(
|
async def _run_async_impl(
|
||||||
self, ctx: InvocationContext
|
self, ctx: InvocationContext
|
||||||
) -> AsyncGenerator[Event, None]:
|
) -> AsyncGenerator[Event, None]:
|
||||||
"""Implementation of the workflow agent executing the defined workflow and returning results."""
|
"""Implementation of the workflow agent executing the defined workflow and returning results."""
|
||||||
|
|
||||||
|
if hasattr(self, 'model') and self.model:
|
||||||
|
logger.error(f"Workflow agent {self.name} should not have a model configured")
|
||||||
|
raise ValueError(f"Workflow agent {self.name} is an orchestrator and should not have a model. Models should be configured on sub-agents.")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
logger.info(f"Starting workflow execution for agent: {self.name}")
|
||||||
|
logger.debug(f"Context session ID: {ctx.session.id if ctx.session else 'No session'}")
|
||||||
|
|
||||||
user_message = await self._extract_user_message(ctx)
|
user_message = await self._extract_user_message(ctx)
|
||||||
session_id = self._get_session_id(ctx)
|
session_id = self._get_session_id(ctx)
|
||||||
|
|
||||||
|
if not self.flow_json:
|
||||||
|
raise ValueError("Workflow agent has no flow_json configured")
|
||||||
|
|
||||||
graph = await self._create_graph(ctx, self.flow_json)
|
graph = await self._create_graph(ctx, self.flow_json)
|
||||||
initial_state = await self._prepare_initial_state(
|
initial_state = await self._prepare_initial_state(
|
||||||
ctx, user_message, session_id
|
ctx, user_message, session_id
|
||||||
)
|
)
|
||||||
|
|
||||||
print("\n🚀 Starting workflow execution:")
|
logger.info(f"🚀 Starting workflow execution with initial message: {user_message[:100]}...")
|
||||||
print(f"Initial content: {user_message[:100]}...")
|
|
||||||
|
|
||||||
# Iterar sobre o AsyncGenerator em vez de usar await
|
# Iterar sobre o AsyncGenerator em vez de usar await
|
||||||
async for event in self._execute_workflow(ctx, graph, initial_state):
|
async for event in self._execute_workflow(ctx, graph, initial_state):
|
||||||
yield event
|
yield event
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
logger.error(f"Error in workflow execution: {str(e)}", exc_info=True)
|
||||||
yield await self._handle_workflow_error(e)
|
yield await self._handle_workflow_error(e)
|
||||||
|
|
||||||
async def _extract_user_message(self, ctx: InvocationContext) -> str:
|
async def _extract_user_message(self, ctx: InvocationContext) -> str:
|
||||||
@ -861,24 +981,36 @@ class WorkflowAgent(BaseAgent):
|
|||||||
# Try to find message in session events
|
# Try to find message in session events
|
||||||
if ctx.session and hasattr(ctx.session, "events") and ctx.session.events:
|
if ctx.session and hasattr(ctx.session, "events") and ctx.session.events:
|
||||||
for event in reversed(ctx.session.events):
|
for event in reversed(ctx.session.events):
|
||||||
if event.author == "user" and event.content and event.content.parts:
|
if (
|
||||||
print("Message found in session events")
|
hasattr(event, 'author') and
|
||||||
|
event.author == "user" and
|
||||||
|
hasattr(event, 'content') and
|
||||||
|
event.content and
|
||||||
|
hasattr(event.content, 'parts') and
|
||||||
|
event.content.parts
|
||||||
|
):
|
||||||
|
logger.debug("User message found in session events")
|
||||||
return event.content.parts[0].text
|
return event.content.parts[0].text
|
||||||
|
|
||||||
# Try to find message in session state
|
# Try to find message in session state
|
||||||
if ctx.session and ctx.session.state:
|
if ctx.session and hasattr(ctx.session, 'state') and ctx.session.state:
|
||||||
if "user_message" in ctx.session.state:
|
if "user_message" in ctx.session.state:
|
||||||
return ctx.session.state["user_message"]
|
return ctx.session.state["user_message"]
|
||||||
elif "message" in ctx.session.state:
|
elif "message" in ctx.session.state:
|
||||||
return ctx.session.state["message"]
|
return ctx.session.state["message"]
|
||||||
|
|
||||||
return ""
|
logger.warning("No user message found in context")
|
||||||
|
return "No user message provided"
|
||||||
|
|
||||||
def _get_session_id(self, ctx: InvocationContext) -> str:
|
def _get_session_id(self, ctx: InvocationContext) -> str:
|
||||||
"""Gets or generates a session ID."""
|
"""Gets or generates a session ID."""
|
||||||
if ctx.session and hasattr(ctx.session, "id"):
|
if ctx.session and hasattr(ctx.session, "id") and ctx.session.id:
|
||||||
return str(ctx.session.id)
|
return str(ctx.session.id)
|
||||||
return str(uuid.uuid4())
|
|
||||||
|
# Generate a new session ID
|
||||||
|
new_session_id = str(uuid.uuid4())
|
||||||
|
logger.debug(f"Generated new session ID: {new_session_id}")
|
||||||
|
return new_session_id
|
||||||
|
|
||||||
async def _prepare_initial_state(
|
async def _prepare_initial_state(
|
||||||
self, ctx: InvocationContext, user_message: str, session_id: str
|
self, ctx: InvocationContext, user_message: str, session_id: str
|
||||||
@ -889,9 +1021,13 @@ class WorkflowAgent(BaseAgent):
|
|||||||
content=Content(parts=[Part(text=user_message)]),
|
content=Content(parts=[Part(text=user_message)]),
|
||||||
)
|
)
|
||||||
|
|
||||||
conversation_history = ctx.session.events or [user_event]
|
conversation_history = []
|
||||||
|
if ctx.session and hasattr(ctx.session, 'events') and ctx.session.events:
|
||||||
|
conversation_history = ctx.session.events.copy()
|
||||||
|
else:
|
||||||
|
conversation_history = [user_event]
|
||||||
|
|
||||||
return State(
|
initial_state = State(
|
||||||
content=[user_event],
|
content=[user_event],
|
||||||
status="started",
|
status="started",
|
||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
@ -899,34 +1035,61 @@ class WorkflowAgent(BaseAgent):
|
|||||||
node_outputs={},
|
node_outputs={},
|
||||||
conversation_history=conversation_history,
|
conversation_history=conversation_history,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
logger.debug(f"Initial state prepared with {len(conversation_history)} history events")
|
||||||
|
return initial_state
|
||||||
|
|
||||||
async def _execute_workflow(
|
async def _execute_workflow(
|
||||||
self, ctx: InvocationContext, graph: StateGraph, initial_state: State
|
self, ctx: InvocationContext, graph: StateGraph, initial_state: State
|
||||||
) -> AsyncGenerator[Event, None]:
|
) -> AsyncGenerator[Event, None]:
|
||||||
"""Executes the workflow graph and yields events."""
|
"""Executes the workflow graph and yields events."""
|
||||||
sent_events = 0
|
sent_events = 0
|
||||||
|
total_iterations = 0
|
||||||
|
max_iterations = 100
|
||||||
|
|
||||||
async for state in graph.astream(initial_state, {"recursion_limit": 100}):
|
try:
|
||||||
for node_state in state.values():
|
async for state in graph.astream(initial_state, {"recursion_limit": max_iterations}):
|
||||||
content = node_state.get("content", [])
|
total_iterations += 1
|
||||||
for event in content[sent_events:]:
|
|
||||||
if event.author != "user":
|
if total_iterations > max_iterations:
|
||||||
|
logger.warning(f"Maximum iterations ({max_iterations}) reached, stopping workflow")
|
||||||
|
break
|
||||||
|
|
||||||
|
for node_state in state.values():
|
||||||
|
content = node_state.get("content", [])
|
||||||
|
|
||||||
|
# Yield new events that haven't been sent yet
|
||||||
|
for event in content[sent_events:]:
|
||||||
|
if hasattr(event, 'author') and event.author != "user":
|
||||||
|
yield event
|
||||||
|
|
||||||
|
sent_events = len(content)
|
||||||
|
|
||||||
|
logger.info(f"Workflow completed after {total_iterations} iterations")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error during workflow execution: {str(e)}")
|
||||||
|
yield await self._handle_workflow_error(e)
|
||||||
|
|
||||||
|
if self.sub_agents:
|
||||||
|
logger.info(f"Executing {len(self.sub_agents)} sub-agents")
|
||||||
|
for sub_agent in self.sub_agents:
|
||||||
|
try:
|
||||||
|
async for event in sub_agent.run_async(ctx):
|
||||||
yield event
|
yield event
|
||||||
sent_events = len(content)
|
except Exception as e:
|
||||||
|
logger.error(f"Error executing sub-agent {sub_agent.name}: {str(e)}")
|
||||||
# Execute sub-agents if any
|
yield await self._handle_workflow_error(e)
|
||||||
for sub_agent in self.sub_agents:
|
|
||||||
async for event in sub_agent.run_async(ctx):
|
|
||||||
yield event
|
|
||||||
|
|
||||||
async def _handle_workflow_error(self, error: Exception) -> Event:
|
async def _handle_workflow_error(self, error: Exception) -> Event:
|
||||||
"""Creates an error event for workflow execution errors."""
|
"""Creates an error event for workflow execution errors."""
|
||||||
error_msg = f"Error executing the workflow agent: {str(error)}"
|
error_msg = f"Error executing workflow agent '{self.name}': {str(error)}"
|
||||||
print(error_msg)
|
logger.error(error_msg)
|
||||||
|
|
||||||
return Event(
|
return Event(
|
||||||
author=f"workflow-error:{self.name}",
|
author=f"workflow-error:{self.name}",
|
||||||
content=Content(
|
content=Content(
|
||||||
role="agent",
|
role="agent",
|
||||||
parts=[Part(text=error_msg)],
|
parts=[Part(text=error_msg)],
|
||||||
),
|
),
|
||||||
)
|
)
|
Loading…
Reference in New Issue
Block a user