fix: comprehensive websocket and LLM model validation

- Fix original websocket format error
- Add robust model validation for all agent types
- Prevent empty model strings in LiteLLM calls
- Update Pydantic V2 compatibility (dict() -> model_dump())
- Improve error handling in workflow agents
- Add comprehensive logging and validation
This commit is contained in:
Anderson Lemes 2025-06-04 22:26:16 -03:00
parent 473cf63252
commit d4618fa345
5 changed files with 830 additions and 319 deletions

View File

@ -158,7 +158,7 @@ async def get_agent_messages(
processed_events = [] processed_events = []
for event in events: for event in events:
event_dict = event.dict() event_dict = event.model_dump()
def process_dict(d): def process_dict(d):
if isinstance(d, dict): if isinstance(d, dict):

View File

@ -27,7 +27,7 @@
""" """
from pydantic import BaseModel, Field, validator, UUID4, ConfigDict from pydantic import BaseModel, Field, field_validator, model_validator, UUID4, ConfigDict
from typing import Optional, Dict, Any, List from typing import Optional, Dict, Any, List
from datetime import datetime from datetime import datetime
from uuid import UUID from uuid import UUID
@ -40,7 +40,8 @@ class ClientBase(BaseModel):
name: str name: str
email: Optional[str] = None email: Optional[str] = None
@validator("email") @field_validator("email")
@classmethod
def validate_email(cls, v): def validate_email(cls, v):
if v is None: if v is None:
return v return v
@ -58,8 +59,7 @@ class Client(ClientBase):
id: UUID id: UUID
created_at: datetime created_at: datetime
class Config: model_config = ConfigDict(from_attributes=True)
from_attributes = True
class ApiKeyBase(BaseModel): class ApiKeyBase(BaseModel):
@ -101,7 +101,7 @@ class AgentBase(BaseModel):
description="Agent type (llm, sequential, parallel, loop, a2a, workflow, task)", description="Agent type (llm, sequential, parallel, loop, a2a, workflow, task)",
) )
model: Optional[str] = Field( model: Optional[str] = Field(
None, description="Agent model (required only for llm type)" None, description="LLM model identifier (required for LLM agents only)"
) )
api_key_id: Optional[UUID4] = Field( api_key_id: Optional[UUID4] = Field(
None, description="Reference to a stored API Key ID" None, description="Reference to a stored API Key ID"
@ -115,8 +115,13 @@ class AgentBase(BaseModel):
) )
config: Any = Field(None, description="Agent configuration based on type") config: Any = Field(None, description="Agent configuration based on type")
@validator("name") @field_validator("name")
def validate_name(cls, v, values): @classmethod
def validate_name(cls, v, info):
# Get values from validation context
values = info.data if hasattr(info, 'data') else {}
# A2A agents can have optional names
if values.get("type") == "a2a": if values.get("type") == "a2a":
return v return v
@ -127,107 +132,246 @@ class AgentBase(BaseModel):
raise ValueError("Agent name cannot contain spaces or special characters") raise ValueError("Agent name cannot contain spaces or special characters")
return v return v
@validator("type") @field_validator("type")
@classmethod
def validate_type(cls, v): def validate_type(cls, v):
if v not in [ valid_types = [
"llm", "llm",
"sequential", "sequential",
"parallel", "parallel",
"loop", "loop",
"a2a", "a2a",
"workflow", "workflow",
"task", "task"
]: ]
if v not in valid_types:
raise ValueError( raise ValueError(
"Invalid agent type. Must be: llm, sequential, parallel, loop, a2a, workflow or task" f"Invalid agent type '{v}'. Must be one of: {', '.join(valid_types)}"
) )
return v return v
@validator("agent_card_url") @field_validator("agent_card_url")
def validate_agent_card_url(cls, v, values): @classmethod
if "type" in values and values["type"] == "a2a": def validate_agent_card_url(cls, v, info):
values = info.data if hasattr(info, 'data') else {}
if values.get("type") == "a2a":
if not v: if not v:
raise ValueError("agent_card_url is required for a2a type agents") raise ValueError("agent_card_url is required for a2a type agents")
if not v.endswith("/.well-known/agent.json"): if not v.endswith("/.well-known/agent.json"):
raise ValueError("agent_card_url must end with /.well-known/agent.json") raise ValueError("agent_card_url must end with /.well-known/agent.json")
return v return v
@validator("model") @field_validator("model")
def validate_model(cls, v, values): @classmethod
if "type" in values and values["type"] == "llm" and not v: def validate_model(cls, v, info):
raise ValueError("Model is required for llm type agents") values = info.data if hasattr(info, 'data') else {}
agent_type = values.get("type")
if agent_type == "llm":
# Para agentes LLM, o modelo é obrigatório e não pode ser vazio
if not v or (isinstance(v, str) and v.strip() == ""):
raise ValueError(
"LLM agents require a valid model configuration. "
"Please specify a model identifier (e.g., 'gpt-4', 'claude-3-sonnet', 'gemini-pro')"
)
# Verificar se o modelo tem um formato válido
if isinstance(v, str) and len(v.strip()) < 3:
raise ValueError("Model identifier must be at least 3 characters long")
elif agent_type in ["workflow", "task", "sequential", "parallel", "loop"]:
# Para estes tipos, não devem ter modelo
if v and (isinstance(v, str) and v.strip()):
# Avisar mas permitir (será removido durante a criação)
import logging
logger = logging.getLogger(__name__)
logger.warning(f"{agent_type} agents don't need model configuration. Model will be ignored.")
return v return v
@validator("api_key_id") @field_validator("api_key_id")
def validate_api_key_id(cls, v, values): @classmethod
def validate_api_key_id(cls, v, info):
values = info.data if hasattr(info, 'data') else {}
agent_type = values.get("type")
# API key é obrigatório para agentes LLM (a menos que esteja na config)
if agent_type == "llm" and not v:
# Verificar se tem API key na config
config = values.get("config", {})
if not config or not config.get("api_key"):
# Não falhar aqui, deixar a validação para o momento da criação
pass
return v return v
@validator("config") @field_validator("config")
def validate_config(cls, v, values): @classmethod
if "type" in values and values["type"] == "a2a": def validate_config(cls, v, info):
values = info.data if hasattr(info, 'data') else {}
agent_type = values.get("type")
if not agent_type:
return v
# A2A agents têm config opcional
if agent_type == "a2a":
return v or {} return v or {}
if "type" not in values: # Workflow agents têm config específico para workflow
if agent_type == "workflow":
if v and isinstance(v, dict):
if not v.get("workflow"):
raise ValueError("Workflow agents must have 'workflow' configuration")
return v return v
# For workflow agents, we do not perform any validation # Config é obrigatório para outros tipos (exceto a2a)
if "type" in values and values["type"] == "workflow": if not v and agent_type not in ["a2a"]:
return v
if not v and values.get("type") != "a2a":
raise ValueError( raise ValueError(
f"Configuration is required for {values.get('type')} agent type" f"Configuration is required for {agent_type} agent type"
) )
if values["type"] == "llm": # Validação específica por tipo
if isinstance(v, dict): if agent_type == "llm":
try: return cls._validate_llm_config(v)
# Convert the dictionary to LLMConfig elif agent_type in ["sequential", "parallel", "loop"]:
v = LLMConfig(**v) return cls._validate_composite_config(v, agent_type)
except Exception as e: elif agent_type == "task":
raise ValueError(f"Invalid LLM configuration for agent: {str(e)}") return cls._validate_task_config(v)
elif not isinstance(v, LLMConfig):
raise ValueError("Invalid LLM configuration for agent")
elif values["type"] in ["sequential", "parallel", "loop"]:
if not isinstance(v, dict):
raise ValueError(f'Invalid configuration for agent {values["type"]}')
if "sub_agents" not in v:
raise ValueError(f'Agent {values["type"]} must have sub_agents')
if not isinstance(v["sub_agents"], list):
raise ValueError("sub_agents must be a list")
if not v["sub_agents"]:
raise ValueError(
f'Agent {values["type"]} must have at least one sub-agent'
)
elif values["type"] == "task":
if not isinstance(v, dict):
raise ValueError(f'Invalid configuration for agent {values["type"]}')
if "tasks" not in v:
raise ValueError(f'Agent {values["type"]} must have tasks')
if not isinstance(v["tasks"], list):
raise ValueError("tasks must be a list")
if not v["tasks"]:
raise ValueError(f'Agent {values["type"]} must have at least one task')
for task in v["tasks"]:
if not isinstance(task, dict):
raise ValueError("Each task must be a dictionary")
required_fields = ["agent_id", "description", "expected_output"]
for field in required_fields:
if field not in task:
raise ValueError(f"Task missing required field: {field}")
if "sub_agents" in v and v["sub_agents"] is not None:
if not isinstance(v["sub_agents"], list):
raise ValueError("sub_agents must be a list")
return v
return v return v
@classmethod
def _validate_llm_config(cls, v):
"""Valida configuração para agentes LLM"""
if isinstance(v, dict):
try:
# Convert the dictionary to LLMConfig
v = LLMConfig(**v)
except Exception as e:
raise ValueError(f"Invalid LLM configuration: {str(e)}")
elif not isinstance(v, LLMConfig):
raise ValueError("Invalid LLM configuration format")
return v
@classmethod
def _validate_composite_config(cls, v, agent_type):
"""Valida configuração para agentes compostos (sequential, parallel, loop)"""
if not isinstance(v, dict):
raise ValueError(f'Configuration for {agent_type} agent must be a dictionary')
if "sub_agents" not in v:
raise ValueError(f'{agent_type} agents must have sub_agents configuration')
if not isinstance(v["sub_agents"], list):
raise ValueError("sub_agents must be a list")
if not v["sub_agents"]:
raise ValueError(
f'{agent_type} agents must have at least one sub-agent'
)
# Validação específica para LoopAgent
if agent_type == "loop":
max_iterations = v.get("max_iterations", 5)
if not isinstance(max_iterations, int) or max_iterations <= 0:
raise ValueError("max_iterations must be a positive integer")
return v
@classmethod
def _validate_task_config(cls, v):
"""Valida configuração para agentes de task"""
if not isinstance(v, dict):
raise ValueError('Configuration for task agent must be a dictionary')
if "tasks" not in v:
raise ValueError('Task agents must have tasks configuration')
if not isinstance(v["tasks"], list):
raise ValueError("tasks must be a list")
if not v["tasks"]:
raise ValueError('Task agents must have at least one task')
# Validar cada task individualmente
for i, task in enumerate(v["tasks"]):
if not isinstance(task, dict):
raise ValueError(f"Task {i+1} must be a dictionary")
required_fields = ["agent_id", "description", "expected_output"]
for field in required_fields:
if field not in task:
raise ValueError(f"Task {i+1} missing required field: {field}")
# Verificar se os campos não estão vazios
if not task[field] or (isinstance(task[field], str) and not task[field].strip()):
raise ValueError(f"Task {i+1} field '{field}' cannot be empty")
# Validar sub_agents se presente
if "sub_agents" in v and v["sub_agents"] is not None:
if not isinstance(v["sub_agents"], list):
raise ValueError("sub_agents must be a list")
return v
@model_validator(mode='after')
def validate_agent_consistency(self):
"""Validação cruzada entre campos do agente"""
# Verificar consistência entre tipo e configurações
if self.type == "llm":
# LLM agents devem ter modelo
if not self.model or (isinstance(self.model, str) and self.model.strip() == ""):
raise ValueError("LLM agents must have a valid model")
# LLM agents devem ter API key (na config ou api_key_id)
has_api_key = bool(self.api_key_id)
if not has_api_key and self.config:
config_dict = self.config if isinstance(self.config, dict) else self.config.__dict__
has_api_key = bool(config_dict.get("api_key"))
if not has_api_key:
raise ValueError("LLM agents must have an API key configured")
elif self.type in ["workflow", "task", "sequential", "parallel", "loop"]:
# Orchestrator agents não devem ter modelo
if self.model and isinstance(self.model, str) and self.model.strip():
import logging
logger = logging.getLogger(__name__)
logger.warning(f"{self.type} agents don't need model configuration. Clearing model.")
self.model = None
elif self.type == "a2a":
# A2A agents devem ter agent_card_url
if not self.agent_card_url:
raise ValueError("A2A agents must have agent_card_url")
return self
class AgentCreate(AgentBase): class AgentCreate(AgentBase):
client_id: UUID client_id: UUID
@model_validator(mode='after')
def validate_creation_requirements(self):
"""Validações específicas para criação de agentes"""
# Chamar validação da classe pai
super().validate_agent_consistency()
# Validações específicas para criação
if self.type == "llm":
# Para criação, ser mais rigoroso com modelo
if not self.model or len(self.model.strip()) < 3:
raise ValueError(
"LLM agents require a valid model identifier (minimum 3 characters). "
"Examples: 'gpt-4', 'claude-3-sonnet', 'gemini-pro'"
)
return self
class Agent(AgentBase): class Agent(AgentBase):
id: UUID id: UUID
@ -237,17 +381,17 @@ class Agent(AgentBase):
agent_card_url: Optional[str] = None agent_card_url: Optional[str] = None
folder_id: Optional[UUID4] = None folder_id: Optional[UUID4] = None
class Config: model_config = ConfigDict(from_attributes=True)
from_attributes = True
@validator("agent_card_url", pre=True) @field_validator("agent_card_url", mode='before')
def set_agent_card_url(cls, v, values): @classmethod
def set_agent_card_url(cls, v, info):
if v: if v:
return v return v
values = info.data if hasattr(info, 'data') else {}
if "id" in values: if "id" in values:
from os import getenv from os import getenv
return f"{getenv('API_URL', '')}/api/v1/a2a/{values['id']}/.well-known/agent.json" return f"{getenv('API_URL', '')}/api/v1/a2a/{values['id']}/.well-known/agent.json"
return v return v
@ -262,6 +406,7 @@ class ToolConfig(BaseModel):
inputModes: List[str] = Field(default_factory=list) inputModes: List[str] = Field(default_factory=list)
outputModes: List[str] = Field(default_factory=list) outputModes: List[str] = Field(default_factory=list)
# Last edited by Arley Peter on 2025-05-17 # Last edited by Arley Peter on 2025-05-17
class MCPServerBase(BaseModel): class MCPServerBase(BaseModel):
name: str name: str
@ -272,6 +417,29 @@ class MCPServerBase(BaseModel):
tools: Optional[List[ToolConfig]] = Field(default_factory=list) tools: Optional[List[ToolConfig]] = Field(default_factory=list)
type: str = Field(default="official") type: str = Field(default="official")
@field_validator("name")
@classmethod
def validate_name(cls, v):
if not v or not v.strip():
raise ValueError("MCP Server name cannot be empty")
return v.strip()
@field_validator("config_type")
@classmethod
def validate_config_type(cls, v):
valid_types = ["studio", "custom"]
if v not in valid_types:
raise ValueError(f"config_type must be one of: {valid_types}")
return v
@field_validator("type")
@classmethod
def validate_type(cls, v):
valid_types = ["official", "custom"]
if v not in valid_types:
raise ValueError(f"type must be one of: {valid_types}")
return v
class MCPServerCreate(MCPServerBase): class MCPServerCreate(MCPServerBase):
pass pass
@ -282,8 +450,7 @@ class MCPServer(MCPServerBase):
created_at: datetime created_at: datetime
updated_at: Optional[datetime] = None updated_at: Optional[datetime] = None
class Config: model_config = ConfigDict(from_attributes=True)
from_attributes = True
class ToolBase(BaseModel): class ToolBase(BaseModel):
@ -292,6 +459,13 @@ class ToolBase(BaseModel):
config_json: Dict[str, Any] = Field(default_factory=dict) config_json: Dict[str, Any] = Field(default_factory=dict)
environments: Dict[str, Any] = Field(default_factory=dict) environments: Dict[str, Any] = Field(default_factory=dict)
@field_validator("name")
@classmethod
def validate_name(cls, v):
if not v or not v.strip():
raise ValueError("Tool name cannot be empty")
return v.strip()
class ToolCreate(ToolBase): class ToolCreate(ToolBase):
pass pass
@ -302,14 +476,20 @@ class Tool(ToolBase):
created_at: datetime created_at: datetime
updated_at: Optional[datetime] = None updated_at: Optional[datetime] = None
class Config: model_config = ConfigDict(from_attributes=True)
from_attributes = True
class AgentFolderBase(BaseModel): class AgentFolderBase(BaseModel):
name: str name: str
description: Optional[str] = None description: Optional[str] = None
@field_validator("name")
@classmethod
def validate_name(cls, v):
if not v or not v.strip():
raise ValueError("Folder name cannot be empty")
return v.strip()
class AgentFolderCreate(AgentFolderBase): class AgentFolderCreate(AgentFolderBase):
client_id: UUID4 client_id: UUID4
@ -326,3 +506,86 @@ class AgentFolder(AgentFolderBase):
updated_at: Optional[datetime] = None updated_at: Optional[datetime] = None
model_config = ConfigDict(from_attributes=True) model_config = ConfigDict(from_attributes=True)
class AgentTypeInfo(BaseModel):
"""Informações sobre tipos de agente válidos"""
type: str
requires_model: bool
requires_config: bool
description: str
@classmethod
def get_valid_types(cls) -> Dict[str, 'AgentTypeInfo']:
"""Retorna informações sobre todos os tipos válidos de agente"""
return {
"llm": cls(
type="llm",
requires_model=True,
requires_config=True,
description="Large Language Model agent - requires model and API key"
),
"workflow": cls(
type="workflow",
requires_model=False,
requires_config=True,
description="Workflow orchestrator agent - uses LangGraph for complex flows"
),
"task": cls(
type="task",
requires_model=False,
requires_config=True,
description="Task management agent - coordinates multiple tasks"
),
"sequential": cls(
type="sequential",
requires_model=False,
requires_config=True,
description="Sequential execution agent - runs sub-agents in order"
),
"parallel": cls(
type="parallel",
requires_model=False,
requires_config=True,
description="Parallel execution agent - runs sub-agents concurrently"
),
"loop": cls(
type="loop",
requires_model=False,
requires_config=True,
description="Loop execution agent - repeats sub-agents with conditions"
),
"a2a": cls(
type="a2a",
requires_model=False,
requires_config=False,
description="Agent-to-Agent communication - external agent integration"
)
}
class ModelValidationResult(BaseModel):
"""Resultado da validação de modelo"""
is_valid: bool
error_message: Optional[str] = None
warnings: List[str] = Field(default_factory=list)
@classmethod
def success(cls, warnings: List[str] = None) -> 'ModelValidationResult':
return cls(is_valid=True, warnings=warnings or [])
@classmethod
def failure(cls, error_message: str) -> 'ModelValidationResult':
return cls(is_valid=False, error_message=error_message)
class AgentValidationSummary(BaseModel):
"""Resumo de validação de agente"""
agent_id: Optional[UUID] = None
agent_name: Optional[str] = None
agent_type: str
is_valid: bool
model_validation: ModelValidationResult
config_validation: ModelValidationResult
general_errors: List[str] = Field(default_factory=list)
warnings: List[str] = Field(default_factory=list)

View File

@ -67,15 +67,39 @@ class AgentBuilder:
if agent_tools_ids and isinstance(agent_tools_ids, list): if agent_tools_ids and isinstance(agent_tools_ids, list):
for agent_tool_id in agent_tools_ids: for agent_tool_id in agent_tools_ids:
sub_agent = get_agent(self.db, agent_tool_id) sub_agent = get_agent(self.db, agent_tool_id)
llm_agent, _ = await self.build_llm_agent(sub_agent) if sub_agent:
if llm_agent: # Verificar se o sub_agent é do tipo LLM antes de criar LlmAgent
agent_tools.append(AgentTool(agent=llm_agent)) if sub_agent.type == "llm":
llm_agent, _ = await self.build_llm_agent(sub_agent)
if llm_agent:
agent_tools.append(AgentTool(agent=llm_agent))
else:
logger.warning(f"Agent tool {agent_tool_id} is not of type 'llm', skipping")
else:
logger.warning(f"Agent tool {agent_tool_id} not found")
return agent_tools return agent_tools
def _validate_llm_agent_model(self, agent) -> None:
"""Validate that LLM agent has a proper model configuration."""
if not hasattr(agent, 'model') or not agent.model:
logger.error(f"LLM agent {agent.name} does not have a model configured")
raise ValueError(f"LLM agent {agent.name} requires a model configuration")
if isinstance(agent.model, str) and agent.model.strip() == "":
logger.error(f"LLM agent {agent.name} has an empty model string")
raise ValueError(f"LLM agent {agent.name} has an empty model configuration")
logger.info(f"Model validation passed for agent {agent.name}: {agent.model}")
async def _create_llm_agent( async def _create_llm_agent(
self, agent, enabled_tools: List[str] = [] self, agent, enabled_tools: List[str] = []
) -> Tuple[LlmAgent, Optional[AsyncExitStack]]: ) -> Tuple[LlmAgent, Optional[AsyncExitStack]]:
"""Create an LLM agent from the agent data.""" """Create an LLM agent from the agent data."""
self._validate_llm_agent_model(agent)
logger.info(f"Creating LLM agent: {agent.name} with model: {agent.model}")
# Get custom tools from the configuration # Get custom tools from the configuration
custom_tools = [] custom_tools = []
custom_tools = self.custom_tool_builder.build_tools(agent.config) custom_tools = self.custom_tool_builder.build_tools(agent.config)
@ -110,7 +134,7 @@ class AgentBuilder:
current_day_of_week=current_day_of_week, current_day_of_week=current_day_of_week,
current_date_iso=current_date_iso, current_date_iso=current_date_iso,
current_time=current_time, current_time=current_time,
) ) if agent.instruction else ""
# add role on beginning of the prompt # add role on beginning of the prompt
if agent.role: if agent.role:
@ -170,21 +194,27 @@ class AgentBuilder:
f"Agent {agent.name} does not have a configured API key" f"Agent {agent.name} does not have a configured API key"
) )
return ( if not agent.model or (isinstance(agent.model, str) and agent.model.strip() == ""):
LlmAgent( raise ValueError(f"Cannot create LiteLlm with empty model for agent {agent.name}")
try:
llm_agent = LlmAgent(
name=agent.name, name=agent.name,
model=LiteLlm(model=agent.model, api_key=api_key), model=LiteLlm(model=agent.model, api_key=api_key),
instruction=formatted_prompt, instruction=formatted_prompt,
description=agent.description, description=agent.description,
tools=all_tools, tools=all_tools,
), )
mcp_exit_stack, logger.info(f"LLM agent created successfully: {agent.name}")
) return llm_agent, mcp_exit_stack
except Exception as e:
logger.error(f"Error creating LLM agent {agent.name}: {str(e)}")
raise ValueError(f"Error creating LLM agent {agent.name}: {str(e)}")
async def _get_sub_agents( async def _get_sub_agents(
self, sub_agent_ids: List[str] self, sub_agent_ids: List[str]
) -> List[Tuple[LlmAgent, Optional[AsyncExitStack]]]: ) -> List[Tuple[BaseAgent, Optional[AsyncExitStack]]]:
"""Get and create LLM sub-agents.""" """Get and create sub-agents with proper type validation."""
sub_agents = [] sub_agents = []
for sub_agent_id in sub_agent_ids: for sub_agent_id in sub_agent_ids:
sub_agent_id_str = str(sub_agent_id) sub_agent_id_str = str(sub_agent_id)
@ -197,39 +227,50 @@ class AgentBuilder:
logger.info(f"Sub-agent found: {agent.name} (type: {agent.type})") logger.info(f"Sub-agent found: {agent.name} (type: {agent.type})")
if agent.type == "llm": try:
sub_agent, exit_stack = await self._create_llm_agent(agent) if agent.type == "llm":
elif agent.type == "a2a": # Verificar se tem modelo antes de criar
sub_agent, exit_stack = await self.build_a2a_agent(agent) if not agent.model or (isinstance(agent.model, str) and agent.model.strip() == ""):
elif agent.type == "workflow": logger.error(f"LLM sub-agent {agent.name} does not have a model configured")
sub_agent, exit_stack = await self.build_workflow_agent(agent) raise ValueError(f"LLM sub-agent {agent.name} requires a model configuration")
elif agent.type == "task": sub_agent, exit_stack = await self._create_llm_agent(agent)
sub_agent, exit_stack = await self.build_task_agent(agent) elif agent.type == "a2a":
elif agent.type == "sequential": sub_agent, exit_stack = await self.build_a2a_agent(agent)
sub_agent, exit_stack = await self.build_composite_agent(agent) elif agent.type == "workflow":
elif agent.type == "parallel": # Workflow agents não precisam de modelo
sub_agent, exit_stack = await self.build_composite_agent(agent) sub_agent, exit_stack = await self.build_workflow_agent(agent)
elif agent.type == "loop": elif agent.type == "task":
sub_agent, exit_stack = await self.build_composite_agent(agent) sub_agent, exit_stack = await self.build_task_agent(agent)
else: elif agent.type == "sequential":
raise ValueError(f"Invalid agent type: {agent.type}") sub_agent, exit_stack = await self.build_composite_agent(agent)
elif agent.type == "parallel":
sub_agent, exit_stack = await self.build_composite_agent(agent)
elif agent.type == "loop":
sub_agent, exit_stack = await self.build_composite_agent(agent)
else:
raise ValueError(f"Invalid agent type: {agent.type}")
sub_agents.append((sub_agent, exit_stack)) sub_agents.append((sub_agent, exit_stack))
logger.info(f"Sub-agent added: {agent.name}") logger.info(f"Sub-agent added: {agent.name}")
except Exception as e:
logger.error(f"Error creating sub-agent {agent.name}: {str(e)}")
raise ValueError(f"Error creating sub-agent {agent.name}: {str(e)}")
logger.info(f"Sub-agents created: {len(sub_agents)}") logger.info(f"Sub-agents created: {len(sub_agents)}")
logger.info(f"Sub-agents: {str(sub_agents)}")
return sub_agents return sub_agents
async def build_llm_agent( async def build_llm_agent(
self, root_agent, enabled_tools: List[str] = [] self, root_agent, enabled_tools: List[str] = []
) -> Tuple[LlmAgent, Optional[AsyncExitStack]]: ) -> Tuple[LlmAgent, Optional[AsyncExitStack]]:
"""Build an LLM agent with its sub-agents.""" """Build an LLM agent with its sub-agents."""
logger.info("Creating LLM agent") logger.info(f"Creating LLM agent: {root_agent.name}")
if root_agent.type != "llm":
raise ValueError(f"Expected LLM agent, got {root_agent.type}")
sub_agents = [] sub_agents = []
if root_agent.config.get("sub_agents"): if root_agent.config and root_agent.config.get("sub_agents"):
sub_agents_with_stacks = await self._get_sub_agents( sub_agents_with_stacks = await self._get_sub_agents(
root_agent.config.get("sub_agents") root_agent.config.get("sub_agents")
) )
@ -241,20 +282,21 @@ class AgentBuilder:
if sub_agents: if sub_agents:
root_llm_agent.sub_agents = sub_agents root_llm_agent.sub_agents = sub_agents
logger.info(f"LLM agent built successfully: {root_agent.name}")
return root_llm_agent, exit_stack return root_llm_agent, exit_stack
async def build_a2a_agent( async def build_a2a_agent(
self, root_agent self, root_agent
) -> Tuple[BaseAgent, Optional[AsyncExitStack]]: ) -> Tuple[A2ACustomAgent, Optional[AsyncExitStack]]:
"""Build an A2A agent with its sub-agents.""" """Build an A2A agent with its sub-agents."""
logger.info(f"Creating A2A agent from {root_agent.agent_card_url}") logger.info(f"Creating A2A agent from {root_agent.name}")
if not root_agent.agent_card_url: if not root_agent.agent_card_url:
raise ValueError("agent_card_url is required for a2a agents") raise ValueError("agent_card_url is required for a2a agents")
try: try:
sub_agents = [] sub_agents = []
if root_agent.config.get("sub_agents"): if root_agent.config and root_agent.config.get("sub_agents"):
sub_agents_with_stacks = await self._get_sub_agents( sub_agents_with_stacks = await self._get_sub_agents(
root_agent.config.get("sub_agents") root_agent.config.get("sub_agents")
) )
@ -288,6 +330,9 @@ class AgentBuilder:
"""Build a workflow agent with its sub-agents.""" """Build a workflow agent with its sub-agents."""
logger.info(f"Creating Workflow agent from {root_agent.name}") logger.info(f"Creating Workflow agent from {root_agent.name}")
if root_agent.type != "workflow":
raise ValueError(f"Expected workflow agent, got {root_agent.type}")
agent_config = root_agent.config or {} agent_config = root_agent.config or {}
if not agent_config.get("workflow"): if not agent_config.get("workflow"):
@ -295,7 +340,7 @@ class AgentBuilder:
try: try:
sub_agents = [] sub_agents = []
if root_agent.config.get("sub_agents"): if root_agent.config and root_agent.config.get("sub_agents"):
sub_agents_with_stacks = await self._get_sub_agents( sub_agents_with_stacks = await self._get_sub_agents(
root_agent.config.get("sub_agents") root_agent.config.get("sub_agents")
) )
@ -304,15 +349,20 @@ class AgentBuilder:
config = root_agent.config or {} config = root_agent.config or {}
timeout = config.get("timeout", 300) timeout = config.get("timeout", 300)
workflow_agent = WorkflowAgent( kwargs = {
name=root_agent.name, "name": root_agent.name,
flow_json=agent_config.get("workflow"), "flow_json": agent_config.get("workflow"),
timeout=timeout, "timeout": timeout,
description=root_agent.description "description": root_agent.description or f"Workflow Agent for {root_agent.name}",
or f"Workflow Agent for {root_agent.name}", "sub_agents": sub_agents,
sub_agents=sub_agents, "db": self.db,
db=self.db, }
)
# Se o root_agent tiver modelo, não passá-lo para o WorkflowAgent
if hasattr(root_agent, 'model') and root_agent.model:
logger.warning(f"Workflow agent {root_agent.name} has model '{root_agent.model}' configured, but workflow agents should not have models. Ignoring model.")
workflow_agent = WorkflowAgent(**kwargs)
logger.info(f"Workflow agent created successfully: {root_agent.name}") logger.info(f"Workflow agent created successfully: {root_agent.name}")
@ -328,6 +378,9 @@ class AgentBuilder:
"""Build a task agent with its sub-agents.""" """Build a task agent with its sub-agents."""
logger.info(f"Creating Task agent: {root_agent.name}") logger.info(f"Creating Task agent: {root_agent.name}")
if root_agent.type != "task":
raise ValueError(f"Expected task agent, got {root_agent.type}")
agent_config = root_agent.config or {} agent_config = root_agent.config or {}
if not agent_config.get("tasks"): if not agent_config.get("tasks"):
@ -336,7 +389,7 @@ class AgentBuilder:
try: try:
# Get sub-agents if there are any # Get sub-agents if there are any
sub_agents = [] sub_agents = []
if root_agent.config.get("sub_agents"): if root_agent.config and root_agent.config.get("sub_agents"):
sub_agents_with_stacks = await self._get_sub_agents( sub_agents_with_stacks = await self._get_sub_agents(
root_agent.config.get("sub_agents") root_agent.config.get("sub_agents")
) )
@ -380,7 +433,11 @@ class AgentBuilder:
f"Processing sub-agents for agent {root_agent.type} (ID: {root_agent.id}, Name: {root_agent.name})" f"Processing sub-agents for agent {root_agent.type} (ID: {root_agent.id}, Name: {root_agent.name})"
) )
if not root_agent.config.get("sub_agents"): valid_composite_types = ["sequential", "parallel", "loop"]
if root_agent.type not in valid_composite_types:
raise ValueError(f"Expected composite agent type ({valid_composite_types}), got {root_agent.type}")
if not root_agent.config or not root_agent.config.get("sub_agents"):
logger.error( logger.error(
f"Sub_agents configuration not found or empty for agent {root_agent.name}" f"Sub_agents configuration not found or empty for agent {root_agent.name}"
) )
@ -401,39 +458,51 @@ class AgentBuilder:
sub_agents = [agent for agent, _ in sub_agents_with_stacks] sub_agents = [agent for agent, _ in sub_agents_with_stacks]
logger.info(f"Extracted sub-agents: {[agent.name for agent in sub_agents]}") logger.info(f"Extracted sub-agents: {[agent.name for agent in sub_agents]}")
if root_agent.type == "sequential": if not sub_agents:
logger.info(f"Creating SequentialAgent with {len(sub_agents)} sub-agents") raise ValueError(f"No valid sub-agents found for {root_agent.type} agent {root_agent.name}")
return (
SequentialAgent( try:
name=root_agent.name, if root_agent.type == "sequential":
sub_agents=sub_agents, logger.info(f"Creating SequentialAgent with {len(sub_agents)} sub-agents")
description=root_agent.config.get("description", ""), return (
), SequentialAgent(
None, name=root_agent.name,
) sub_agents=sub_agents,
elif root_agent.type == "parallel": description=root_agent.description or root_agent.config.get("description", ""),
logger.info(f"Creating ParallelAgent with {len(sub_agents)} sub-agents") ),
return ( None,
ParallelAgent( )
name=root_agent.name, elif root_agent.type == "parallel":
sub_agents=sub_agents, logger.info(f"Creating ParallelAgent with {len(sub_agents)} sub-agents")
description=root_agent.config.get("description", ""), return (
), ParallelAgent(
None, name=root_agent.name,
) sub_agents=sub_agents,
elif root_agent.type == "loop": description=root_agent.description or root_agent.config.get("description", ""),
logger.info(f"Creating LoopAgent with {len(sub_agents)} sub-agents") ),
return ( None,
LoopAgent( )
name=root_agent.name, elif root_agent.type == "loop":
sub_agents=sub_agents, logger.info(f"Creating LoopAgent with {len(sub_agents)} sub-agents")
description=root_agent.config.get("description", ""), max_iterations = root_agent.config.get("max_iterations", 5)
max_iterations=root_agent.config.get("max_iterations", 5), if max_iterations <= 0:
), logger.warning(f"Invalid max_iterations ({max_iterations}) for LoopAgent, using default 5")
None, max_iterations = 5
) return (
else: LoopAgent(
raise ValueError(f"Invalid agent type: {root_agent.type}") name=root_agent.name,
sub_agents=sub_agents,
description=root_agent.description or root_agent.config.get("description", ""),
max_iterations=max_iterations,
),
None,
)
else:
raise ValueError(f"Invalid composite agent type: {root_agent.type}")
except Exception as e:
logger.error(f"Error creating {root_agent.type} agent {root_agent.name}: {str(e)}")
raise ValueError(f"Error creating {root_agent.type} agent {root_agent.name}: {str(e)}")
async def build_agent(self, root_agent, enabled_tools: List[str] = []) -> Tuple[ async def build_agent(self, root_agent, enabled_tools: List[str] = []) -> Tuple[
LlmAgent LlmAgent
@ -446,13 +515,29 @@ class AgentBuilder:
Optional[AsyncExitStack], Optional[AsyncExitStack],
]: ]:
"""Build the appropriate agent based on the type of the root agent.""" """Build the appropriate agent based on the type of the root agent."""
if root_agent.type == "llm":
return await self.build_llm_agent(root_agent, enabled_tools) if not root_agent:
elif root_agent.type == "a2a": raise ValueError("root_agent cannot be None")
return await self.build_a2a_agent(root_agent)
elif root_agent.type == "workflow": if not hasattr(root_agent, 'type') or not root_agent.type:
return await self.build_workflow_agent(root_agent) raise ValueError("root_agent must have a valid type")
elif root_agent.type == "task":
return await self.build_task_agent(root_agent) logger.info(f"Building agent: {root_agent.name} (type: {root_agent.type})")
else:
return await self.build_composite_agent(root_agent) try:
if root_agent.type == "llm":
return await self.build_llm_agent(root_agent, enabled_tools)
elif root_agent.type == "a2a":
return await self.build_a2a_agent(root_agent)
elif root_agent.type == "workflow":
return await self.build_workflow_agent(root_agent)
elif root_agent.type == "task":
return await self.build_task_agent(root_agent)
elif root_agent.type in ["sequential", "parallel", "loop"]:
return await self.build_composite_agent(root_agent)
else:
raise ValueError(f"Unknown agent type: {root_agent.type}")
except Exception as e:
logger.error(f"Error building agent {root_agent.name}: {str(e)}")
raise

View File

@ -458,7 +458,7 @@ async def run_agent_stream(
async for event in events_async: async for event in events_async:
try: try:
event_dict = event.dict() event_dict = event.model_dump()
event_dict = convert_sets(event_dict) event_dict = convert_sets(event_dict)
if "content" in event_dict and event_dict["content"]: if "content" in event_dict and event_dict["content"]:

View File

@ -30,6 +30,7 @@
""" """
from datetime import datetime from datetime import datetime
from google.adk.agents import BaseAgent from google.adk.agents import BaseAgent
from google.adk.agents.invocation_context import InvocationContext from google.adk.agents.invocation_context import InvocationContext
@ -40,11 +41,14 @@ from typing import AsyncGenerator, Dict, Any, List, TypedDict
import uuid import uuid
from src.services.agent_service import get_agent from src.services.agent_service import get_agent
from src.utils.logger import setup_logger
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from langgraph.graph import StateGraph, END from langgraph.graph import StateGraph, END
logger = setup_logger(__name__)
class State(TypedDict): class State(TypedDict):
content: List[Event] content: List[Event]
@ -63,6 +67,9 @@ class WorkflowAgent(BaseAgent):
This agent allows defining and executing complex workflows between multiple agents This agent allows defining and executing complex workflows between multiple agents
using LangGraph for orchestration. using LangGraph for orchestration.
IMPORTANT: Workflow agents are orchestrators and should NOT have a model configured.
They delegate to sub-agents that have their own models.
""" """
# Field declarations for Pydantic # Field declarations for Pydantic
@ -89,6 +96,21 @@ class WorkflowAgent(BaseAgent):
sub_agents: List of sub-agents to be executed after the workflow agent sub_agents: List of sub-agents to be executed after the workflow agent
db: Session db: Session
""" """
# Workflow agents não devem ter modelos
if 'model' in kwargs:
logger.warning(f"Removing model from workflow agent {name}. Workflow agents should not have models.")
del kwargs['model']
if not flow_json:
raise ValueError(f"Workflow agent {name} requires flow_json configuration")
if not isinstance(flow_json, dict):
raise ValueError(f"Workflow agent {name} flow_json must be a dictionary")
if not flow_json.get('nodes'):
raise ValueError(f"Workflow agent {name} flow_json must contain nodes")
# Initialize base class # Initialize base class
super().__init__( super().__init__(
name=name, name=name,
@ -98,9 +120,13 @@ class WorkflowAgent(BaseAgent):
db=db, db=db,
**kwargs, **kwargs,
) )
if hasattr(self, 'model'):
logger.warning(f"Workflow agent {name} had a model attribute. Removing it.")
delattr(self, 'model')
print( logger.info(
f"Workflow agent initialized with {len(flow_json.get('nodes', []))} nodes" f"Workflow agent '{name}' initialized with {len(flow_json.get('nodes', []))} nodes"
) )
async def _create_node_functions(self, ctx: InvocationContext): async def _create_node_functions(self, ctx: InvocationContext):
@ -112,11 +138,12 @@ class WorkflowAgent(BaseAgent):
node_id: str, node_id: str,
node_data: Dict[str, Any], node_data: Dict[str, Any],
) -> AsyncGenerator[State, None]: ) -> AsyncGenerator[State, None]:
print("\n🏁 INITIAL NODE") logger.info(f"🏁 INITIAL NODE: {node_id}")
content = state.get("content", []) content = state.get("content", [])
if not content: if not content:
logger.warning("No content found in initial state")
content = [ content = [
Event( Event(
author=f"workflow-node:{node_id}", author=f"workflow-node:{node_id}",
@ -128,9 +155,11 @@ class WorkflowAgent(BaseAgent):
"status": "error", "status": "error",
"node_outputs": {}, "node_outputs": {},
"cycle_count": 0, "cycle_count": 0,
"conversation_history": ctx.session.events, "conversation_history": ctx.session.events if ctx.session else [],
"session_id": state.get("session_id", ""),
} }
return return
session_id = state.get("session_id", "") session_id = state.get("session_id", "")
# Store specific results for this node # Store specific results for this node
@ -149,7 +178,7 @@ class WorkflowAgent(BaseAgent):
"node_outputs": node_outputs, "node_outputs": node_outputs,
"cycle_count": 0, "cycle_count": 0,
"session_id": session_id, "session_id": session_id,
"conversation_history": ctx.session.events, "conversation_history": ctx.session.events if ctx.session else [],
} }
# Generic function for agent nodes # Generic function for agent nodes
@ -163,7 +192,7 @@ class WorkflowAgent(BaseAgent):
# Increment cycle counter # Increment cycle counter
cycle_count = state.get("cycle_count", 0) + 1 cycle_count = state.get("cycle_count", 0) + 1
print(f"\n👤 AGENT: {agent_name} (Cycle {cycle_count})") logger.info(f"👤 AGENT: {agent_name} (Cycle {cycle_count})")
content = state.get("content", []) content = state.get("content", [])
session_id = state.get("session_id", "") session_id = state.get("session_id", "")
@ -171,14 +200,13 @@ class WorkflowAgent(BaseAgent):
# Get conversation history # Get conversation history
conversation_history = state.get("conversation_history", []) conversation_history = state.get("conversation_history", [])
agent = get_agent(self.db, agent_id) if not agent_id:
logger.error(f"Agent node {node_id} does not have a valid agent_id")
if not agent:
yield { yield {
"content": [ "content": [
Event( Event(
author=f"workflow-node:{node_id}", author=f"workflow-node:{node_id}",
content=Content(parts=[Part(text="Agent not found")]), content=Content(parts=[Part(text="Agent ID not configured")]),
) )
], ],
"session_id": session_id, "session_id": session_id,
@ -189,44 +217,84 @@ class WorkflowAgent(BaseAgent):
} }
return return
# Import moved to inside the function to avoid circular import agent = get_agent(self.db, agent_id)
from src.services.adk.agent_builder import AgentBuilder
agent_builder = AgentBuilder(self.db) if not agent:
root_agent, exit_stack = await agent_builder.build_agent(agent) logger.error(f"Agent not found for ID: {agent_id}")
yield {
"content": [
Event(
author=f"workflow-node:{node_id}",
content=Content(parts=[Part(text=f"Agent not found: {agent_id}")]),
)
],
"session_id": session_id,
"status": "error",
"node_outputs": {},
"cycle_count": cycle_count,
"conversation_history": conversation_history,
}
return
new_content = [] try:
async for event in root_agent.run_async(ctx): # Import moved to inside the function to avoid circular import
conversation_history.append(event) from src.services.adk.agent_builder import AgentBuilder
modified_event = Event(
author=f"workflow-node:{node_id}", content=event.content
)
new_content.append(modified_event)
agent_builder = AgentBuilder(self.db)
root_agent, exit_stack = await agent_builder.build_agent(agent)
print(f"New content: {new_content}") new_content = []
async for event in root_agent.run_async(ctx):
conversation_history.append(event)
modified_event = Event(
author=f"workflow-node:{node_id}", content=event.content
)
new_content.append(modified_event)
node_outputs = state.get("node_outputs", {}) logger.debug(f"Agent {agent_name} generated {len(new_content)} events")
node_outputs[node_id] = {
"processed_by": agent_name,
"agent_content": new_content,
"cycle": cycle_count,
}
content = content + new_content node_outputs = state.get("node_outputs", {})
node_outputs[node_id] = {
"processed_by": agent_name,
"agent_id": agent_id,
"agent_content": new_content,
"cycle": cycle_count,
"processed_at": datetime.now().isoformat(),
}
yield { content = content + new_content
"content": content,
"status": "processed_by_agent",
"node_outputs": node_outputs,
"cycle_count": cycle_count,
"conversation_history": conversation_history,
"session_id": session_id,
}
if exit_stack: yield {
await exit_stack.aclose() "content": content,
"status": "processed_by_agent",
"node_outputs": node_outputs,
"cycle_count": cycle_count,
"conversation_history": conversation_history,
"session_id": session_id,
}
if exit_stack:
try:
await exit_stack.aclose()
except Exception as e:
logger.warning(f"Error closing exit stack for agent {agent_name}: {e}")
except Exception as e:
logger.error(f"Error executing agent {agent_name}: {str(e)}")
yield {
"content": [
Event(
author=f"workflow-node:{node_id}",
content=Content(parts=[Part(text=f"Error executing agent: {str(e)}")]),
)
],
"session_id": session_id,
"status": "agent_error",
"node_outputs": state.get("node_outputs", {}),
"cycle_count": cycle_count,
"conversation_history": conversation_history,
}
# Function for condition nodes # Function for condition nodes
async def condition_node_function( async def condition_node_function(
@ -236,7 +304,7 @@ class WorkflowAgent(BaseAgent):
conditions = node_data.get("conditions", []) conditions = node_data.get("conditions", [])
cycle_count = state.get("cycle_count", 0) cycle_count = state.get("cycle_count", 0)
print(f"\n🔄 CONDITION: {label} (Cycle {cycle_count})") logger.info(f"🔄 CONDITION: {label} (Cycle {cycle_count})")
content = state.get("content", []) content = state.get("content", [])
conversation_history = state.get("conversation_history", []) conversation_history = state.get("conversation_history", [])
@ -245,16 +313,17 @@ class WorkflowAgent(BaseAgent):
if content and len(content) > 0: if content and len(content) > 0:
for event in reversed(content): for event in reversed(content):
if ( if (
event.author != "agent" hasattr(event, 'author') and
or not hasattr(event.content, "parts") event.author != "user" and
or not event.content.parts hasattr(event, 'content') and
hasattr(event.content, "parts") and
event.content.parts
): ):
latest_event = event latest_event = event
break break
if latest_event: if latest_event:
print( logger.debug(f"Evaluating condition for latest event from: {latest_event.author}")
f"Evaluating condition only for the most recent event: '{latest_event}'"
)
# Use only the most recent event for condition evaluation # Use only the most recent event for condition evaluation
evaluation_state = state.copy() evaluation_state = state.copy()
@ -273,25 +342,24 @@ class WorkflowAgent(BaseAgent):
operator = condition_data.get("operator") operator = condition_data.get("operator")
expected_value = condition_data.get("value") expected_value = condition_data.get("value")
print( logger.debug(
f" Checking if {field} {operator} '{expected_value}' (current value: '{evaluation_state.get(field, '')}')" f"Checking condition: {field} {operator} '{expected_value}'"
) )
if self._evaluate_condition(condition, evaluation_state): if self._evaluate_condition(condition, evaluation_state):
conditions_met.append(condition_id) conditions_met.append(condition_id)
condition_details.append( condition_details.append(
f"{field} {operator} '{expected_value}'" f"{field} {operator} '{expected_value}'"
) )
print(f" ✅ Condition {condition_id} met!") logger.info(f"✅ Condition {condition_id} met!")
else: else:
condition_details.append( condition_details.append(
f"{field} {operator} '{expected_value}'" f"{field} {operator} '{expected_value}'"
) )
# Check if the cycle reached the limit (extra security) max_cycles = 10 # Poderia vir da configuração
if cycle_count >= 10: if cycle_count >= max_cycles:
print( logger.warning(f"Cycle limit reached ({cycle_count}). Forcing termination.")
f"⚠️ ATTENTION: Cycle limit reached ({cycle_count}). Forcing termination."
)
condition_content = [ condition_content = [
Event( Event(
@ -314,10 +382,10 @@ class WorkflowAgent(BaseAgent):
node_outputs = state.get("node_outputs", {}) node_outputs = state.get("node_outputs", {})
node_outputs[node_id] = { node_outputs[node_id] = {
"condition_evaluated": label, "condition_evaluated": label,
"content_evaluated": content,
"conditions_met": conditions_met, "conditions_met": conditions_met,
"condition_details": condition_details, "condition_details": condition_details,
"cycle": cycle_count, "cycle": cycle_count,
"evaluated_at": datetime.now().isoformat(),
} }
# Prepare a more descriptive message about the conditions # Prepare a more descriptive message about the conditions
@ -334,7 +402,8 @@ class WorkflowAgent(BaseAgent):
) )
] ]
), ),
) ] )
]
content = content + condition_content content = content + condition_content
yield { yield {
@ -353,7 +422,7 @@ class WorkflowAgent(BaseAgent):
message_type = message_data.get("type", "text") message_type = message_data.get("type", "text")
message_content = message_data.get("content", "") message_content = message_data.get("content", "")
print(f"\n💬 MESSAGE-NODE: {message_content}") logger.info(f"💬 MESSAGE-NODE: {message_content}")
content = state.get("content", []) content = state.get("content", [])
session_id = state.get("session_id", "") session_id = state.get("session_id", "")
@ -371,6 +440,8 @@ class WorkflowAgent(BaseAgent):
node_outputs[node_id] = { node_outputs[node_id] = {
"message_type": message_type, "message_type": message_type,
"message_content": message_content, "message_content": message_content,
"label": label,
"processed_at": datetime.now().isoformat(),
} }
yield { yield {
@ -378,7 +449,8 @@ class WorkflowAgent(BaseAgent):
"status": "message_added", "status": "message_added",
"node_outputs": node_outputs, "node_outputs": node_outputs,
"cycle_count": state.get("cycle_count", 0), "cycle_count": state.get("cycle_count", 0),
"conversation_history": conversation_history, "session_id": session_id, "conversation_history": conversation_history,
"session_id": session_id,
} }
async def delay_node_function( async def delay_node_function(
@ -389,6 +461,10 @@ class WorkflowAgent(BaseAgent):
delay_unit = delay_data.get("unit", "seconds") delay_unit = delay_data.get("unit", "seconds")
delay_description = delay_data.get("description", "") delay_description = delay_data.get("description", "")
if delay_value <= 0:
logger.warning(f"Invalid delay value: {delay_value}. Using 1 second.")
delay_value = 1
# Convert to seconds based on unit # Convert to seconds based on unit
delay_seconds = delay_value delay_seconds = delay_value
if delay_unit == "minutes": if delay_unit == "minutes":
@ -397,7 +473,7 @@ class WorkflowAgent(BaseAgent):
delay_seconds = delay_value * 3600 delay_seconds = delay_value * 3600
label = node_data.get("label", "delay_node") label = node_data.get("label", "delay_node")
print(f"\n⏱️ DELAY-NODE: {delay_value} {delay_unit} - {delay_description}") logger.info(f"⏱️ DELAY-NODE: {delay_value} {delay_unit} ({delay_seconds}s) - {delay_description}")
content = state.get("content", []) content = state.get("content", [])
session_id = state.get("session_id", "") session_id = state.get("session_id", "")
@ -409,13 +485,17 @@ class WorkflowAgent(BaseAgent):
"delay_value": delay_value, "delay_value": delay_value,
"delay_unit": delay_unit, "delay_unit": delay_unit,
"delay_seconds": delay_seconds, "delay_seconds": delay_seconds,
"delay_description": delay_description,
"delay_start_time": datetime.now().isoformat(), "delay_start_time": datetime.now().isoformat(),
} }
# Actually perform the delay # Actually perform the delay
import asyncio import asyncio
await asyncio.sleep(delay_seconds) try:
await asyncio.sleep(delay_seconds)
except asyncio.CancelledError:
logger.warning(f"Delay in node {node_id} was cancelled")
# Continue execution even if delay was cancelled
# Update node outputs with completion information # Update node outputs with completion information
node_outputs[node_id]["delay_end_time"] = datetime.now().isoformat() node_outputs[node_id]["delay_end_time"] = datetime.now().isoformat()
@ -424,7 +504,8 @@ class WorkflowAgent(BaseAgent):
yield { yield {
"content": content, "content": content,
"status": "delay_completed", "status": "delay_completed",
"node_outputs": node_outputs, "cycle_count": state.get("cycle_count", 0), "node_outputs": node_outputs,
"cycle_count": state.get("cycle_count", 0),
"conversation_history": conversation_history, "conversation_history": conversation_history,
"session_id": session_id, "session_id": session_id,
} }
@ -452,7 +533,7 @@ class WorkflowAgent(BaseAgent):
result = self._process_condition(operator, actual_value, expected_value) result = self._process_condition(operator, actual_value, expected_value)
print(f" Check '{operator}': {result}") logger.debug(f"Condition check '{operator}': {result}")
return result return result
return False return False
@ -488,7 +569,7 @@ class WorkflowAgent(BaseAgent):
if extracted_texts: if extracted_texts:
joined_text = " ".join(extracted_texts) joined_text = " ".join(extracted_texts)
print(f" Extracted text from events: '{joined_text[:100]}...'") logger.debug(f"Extracted text from events: '{joined_text[:100]}...'")
return joined_text return joined_text
return "" return ""
@ -524,6 +605,7 @@ class WorkflowAgent(BaseAgent):
elif operator in ["matches", "not_matches"]: elif operator in ["matches", "not_matches"]:
return self._check_regex(operator, actual_str, expected_str) return self._check_regex(operator, actual_str, expected_str)
logger.warning(f"Unknown operator: {operator}")
return False return False
def _check_definition(self, operator, actual_value): def _check_definition(self, operator, actual_value):
@ -563,8 +645,8 @@ class WorkflowAgent(BaseAgent):
else: # less_than_or_equal else: # less_than_or_equal
return actual_num <= expected_num return actual_num <= expected_num
except (ValueError, TypeError): except (ValueError, TypeError):
print( logger.warning(
f" Error converting values for numeric comparison: '{actual_str[:100]}...' and '{expected_str}'" f"Error converting values for numeric comparison: '{actual_str[:100]}...' and '{expected_str}'"
) )
return False return False
@ -579,7 +661,7 @@ class WorkflowAgent(BaseAgent):
else: # not_matches else: # not_matches
return not bool(pattern.search(actual_str)) return not bool(pattern.search(actual_str))
except re.error: except re.error:
print(f" Error in regular expression: '{expected_str}'") logger.warning(f"Error in regular expression: '{expected_str}'")
return ( return (
operator == "not_matches" operator == "not_matches"
) # Return True for not_matches, False for matches ) # Return True for not_matches, False for matches
@ -589,8 +671,8 @@ class WorkflowAgent(BaseAgent):
expected_lower = expected_str.lower() expected_lower = expected_str.lower()
actual_lower = actual_str.lower() actual_lower = actual_str.lower()
print( logger.debug(
f" Comparison '{operator}' without case distinction: '{expected_lower}' in '{actual_lower[:100]}...'" f"Comparison '{operator}' case insensitive: '{expected_lower}' in '{actual_lower[:100]}...'"
) )
if operator == "contains": if operator == "contains":
@ -627,14 +709,13 @@ class WorkflowAgent(BaseAgent):
# Routing function for each specific node # Routing function for each specific node
def create_router_for_node(node_id: str): def create_router_for_node(node_id: str):
def router(state: State) -> str: def router(state: State) -> str:
print(f"Routing from node: {node_id}") logger.debug(f"Routing from node: {node_id}")
# Check if the cycle limit has been reached # Check if the cycle limit has been reached
cycle_count = state.get("cycle_count", 0) cycle_count = state.get("cycle_count", 0)
if cycle_count >= 10: max_cycles = 10 # Configurável
print( if cycle_count >= max_cycles:
f"⚠️ Cycle limit ({cycle_count}) reached. Finalizing the flow." logger.warning(f"Cycle limit ({cycle_count}) reached. Finalizing the flow.")
)
return END return END
# If it's a condition node, evaluate the conditions # If it's a condition node, evaluate the conditions
@ -648,32 +729,29 @@ class WorkflowAgent(BaseAgent):
if conditions_met: if conditions_met:
any_condition_met = True any_condition_met = True
condition_id = conditions_met[0] condition_id = conditions_met[0]
print( logger.debug(f"Using stored condition result: Condition {condition_id} met.")
f"Using stored condition evaluation result: Condition {condition_id} met."
)
if ( if (
node_id in edges_map node_id in edges_map
and condition_id in edges_map[node_id] and condition_id in edges_map[node_id]
): ):
return edges_map[node_id][condition_id] return edges_map[node_id][condition_id]
else: else:
print( logger.debug("Using stored condition result: No conditions met.")
"Using stored condition evaluation result: No conditions met."
)
else: else:
# Evaluate conditions
for condition in conditions: for condition in conditions:
condition_id = condition.get("id") condition_id = condition.get("id")
# Get latest event for evaluation, ignoring condition node informational events # Get latest event for evaluation, ignoring condition node informational events
content = state.get("content", []) content = state.get("content", [])
# Filter out events generated by condition nodes or informational messages # Filter out events generated by condition nodes or that contain evaluation results
filtered_content = [] filtered_content = []
for event in content: for event in content:
# Ignore events from condition nodes or that contain evaluation results # Ignore events from condition nodes or that contain evaluation results
if not hasattr(event, "author") or not ( if not hasattr(event, "author") or not (
event.author.startswith("Condition") event.author.startswith("workflow-node:") and
or "Condition evaluated:" in str(event) "Condition evaluated:" in str(event)
): ):
filtered_content.append(event) filtered_content.append(event)
@ -687,9 +765,7 @@ class WorkflowAgent(BaseAgent):
if is_condition_met: if is_condition_met:
any_condition_met = True any_condition_met = True
print( logger.debug(f"Condition {condition_id} met. Moving to next node.")
f"Condition {condition_id} met. Moving to the next node."
)
# Find the connection that uses this condition_id as a handle # Find the connection that uses this condition_id as a handle
if ( if (
@ -698,9 +774,7 @@ class WorkflowAgent(BaseAgent):
): ):
return edges_map[node_id][condition_id] return edges_map[node_id][condition_id]
else: else:
print( logger.debug(f"Condition {condition_id} not met.")
f"Condition {condition_id} not met. Continuing evaluation or using default path."
)
# If no condition is met, use the bottom-handle if available # If no condition is met, use the bottom-handle if available
if not any_condition_met: if not any_condition_met:
@ -708,14 +782,10 @@ class WorkflowAgent(BaseAgent):
node_id in edges_map node_id in edges_map
and "bottom-handle" in edges_map[node_id] and "bottom-handle" in edges_map[node_id]
): ):
print( logger.debug("No condition met. Using default path (bottom-handle).")
"No condition met. Using default path (bottom-handle)."
)
return edges_map[node_id]["bottom-handle"] return edges_map[node_id]["bottom-handle"]
else: else:
print( logger.debug("No condition met and no default path. Closing the flow.")
"No condition met and no default path. Closing the flow."
)
return END return END
# For regular nodes, simply follow the first available connection # For regular nodes, simply follow the first available connection
@ -731,7 +801,7 @@ class WorkflowAgent(BaseAgent):
return edges_map[node_id][first_handle] return edges_map[node_id][first_handle]
# If there is no output connection, close the flow # If there is no output connection, close the flow
print(f"No output connection from node {node_id}. Closing the flow.") logger.debug(f"No output connection from node {node_id}. Closing the flow.")
return END return END
return router return router
@ -745,6 +815,9 @@ class WorkflowAgent(BaseAgent):
# Extract nodes from the flow # Extract nodes from the flow
nodes = flow_data.get("nodes", []) nodes = flow_data.get("nodes", [])
if not nodes:
raise ValueError("Flow data must contain at least one node")
# Initialize StateGraph # Initialize StateGraph
graph_builder = StateGraph(State) graph_builder = StateGraph(State)
@ -754,34 +827,60 @@ class WorkflowAgent(BaseAgent):
# Dictionary to store specific functions for each node # Dictionary to store specific functions for each node
node_specific_functions = {} node_specific_functions = {}
valid_node_types = set(node_functions.keys())
# Add nodes to the graph # Add nodes to the graph
for node in nodes: for node in nodes:
node_id = node.get("id") node_id = node.get("id")
node_type = node.get("type") node_type = node.get("type")
node_data = node.get("data", {}) node_data = node.get("data", {})
if node_type in node_functions: if not node_id:
# Create a specific function for this node logger.warning(f"Skipping node without ID: {node}")
def create_node_function(node_type, node_id, node_data): continue
async def node_function(state):
# Consume the asynchronous generator and return the last result if node_type not in valid_node_types:
result = None logger.warning(f"Unknown node type '{node_type}' for node {node_id}. Skipping.")
continue
# Create a specific function for this node
def create_node_function(node_type, node_id, node_data):
async def node_function(state):
# Consume the asynchronous generator and return the last result
result = None
try:
async for item in node_functions[node_type]( async for item in node_functions[node_type](
state, node_id, node_data state, node_id, node_data
): ):
result = item result = item
return result return result
except Exception as e:
logger.error(f"Error in node {node_id} ({node_type}): {str(e)}")
# Return error state
return {
"content": [
Event(
author=f"workflow-node:{node_id}",
content=Content(parts=[Part(text=f"Node error: {str(e)}")]),
)
],
"status": "node_error",
"node_outputs": state.get("node_outputs", {}),
"cycle_count": state.get("cycle_count", 0),
"conversation_history": state.get("conversation_history", []),
"session_id": state.get("session_id", ""),
}
return node_function return node_function
# Add specific function to the dictionary # Add specific function to the dictionary
node_specific_functions[node_id] = create_node_function( node_specific_functions[node_id] = create_node_function(
node_type, node_id, node_data node_type, node_id, node_data
) )
# Add node to the graph # Add node to the graph
print(f"Adding node {node_id} of type {node_type}") logger.debug(f"Adding node {node_id} of type {node_type}")
graph_builder.add_node(node_id, node_specific_functions[node_id]) graph_builder.add_node(node_id, node_specific_functions[node_id])
# Create function to generate specific routers # Create function to generate specific routers
create_router = self._create_flow_router(flow_data) create_router = self._create_flow_router(flow_data)
@ -808,8 +907,8 @@ class WorkflowAgent(BaseAgent):
node_router = create_router(node_id) node_router = create_router(node_id)
# Add conditional connections # Add conditional connections
print(f"Adding conditional connections for node {node_id}") logger.debug(f"Adding conditional connections for node {node_id}")
print(f"Possible destinations: {edge_destinations}") logger.debug(f"Possible destinations: {list(edge_destinations.keys())}")
graph_builder.add_conditional_edges( graph_builder.add_conditional_edges(
node_id, node_router, edge_destinations node_id, node_router, edge_destinations
@ -825,35 +924,56 @@ class WorkflowAgent(BaseAgent):
# If there is no start-node, use the first node found # If there is no start-node, use the first node found
if not entry_point and nodes: if not entry_point and nodes:
entry_point = nodes[0].get("id") entry_point = nodes[0].get("id")
logger.warning(f"No start-node found, using first node as entry point: {entry_point}")
# Define the entry point # Define the entry point
if entry_point: if entry_point:
print(f"Defining entry point: {entry_point}") logger.info(f"Setting entry point: {entry_point}")
graph_builder.set_entry_point(entry_point) graph_builder.set_entry_point(entry_point)
else:
raise ValueError("No valid entry point found for workflow")
# Compile the graph # Compile the graph
return graph_builder.compile() try:
compiled_graph = graph_builder.compile()
logger.info("Workflow graph compiled successfully")
return compiled_graph
except Exception as e:
logger.error(f"Error compiling workflow graph: {str(e)}")
raise ValueError(f"Error compiling workflow graph: {str(e)}")
async def _run_async_impl( async def _run_async_impl(
self, ctx: InvocationContext self, ctx: InvocationContext
) -> AsyncGenerator[Event, None]: ) -> AsyncGenerator[Event, None]:
"""Implementation of the workflow agent executing the defined workflow and returning results.""" """Implementation of the workflow agent executing the defined workflow and returning results."""
if hasattr(self, 'model') and self.model:
logger.error(f"Workflow agent {self.name} should not have a model configured")
raise ValueError(f"Workflow agent {self.name} is an orchestrator and should not have a model. Models should be configured on sub-agents.")
try: try:
logger.info(f"Starting workflow execution for agent: {self.name}")
logger.debug(f"Context session ID: {ctx.session.id if ctx.session else 'No session'}")
user_message = await self._extract_user_message(ctx) user_message = await self._extract_user_message(ctx)
session_id = self._get_session_id(ctx) session_id = self._get_session_id(ctx)
if not self.flow_json:
raise ValueError("Workflow agent has no flow_json configured")
graph = await self._create_graph(ctx, self.flow_json) graph = await self._create_graph(ctx, self.flow_json)
initial_state = await self._prepare_initial_state( initial_state = await self._prepare_initial_state(
ctx, user_message, session_id ctx, user_message, session_id
) )
print("\n🚀 Starting workflow execution:") logger.info(f"🚀 Starting workflow execution with initial message: {user_message[:100]}...")
print(f"Initial content: {user_message[:100]}...")
# Iterar sobre o AsyncGenerator em vez de usar await # Iterar sobre o AsyncGenerator em vez de usar await
async for event in self._execute_workflow(ctx, graph, initial_state): async for event in self._execute_workflow(ctx, graph, initial_state):
yield event yield event
except Exception as e: except Exception as e:
logger.error(f"Error in workflow execution: {str(e)}", exc_info=True)
yield await self._handle_workflow_error(e) yield await self._handle_workflow_error(e)
async def _extract_user_message(self, ctx: InvocationContext) -> str: async def _extract_user_message(self, ctx: InvocationContext) -> str:
@ -861,24 +981,36 @@ class WorkflowAgent(BaseAgent):
# Try to find message in session events # Try to find message in session events
if ctx.session and hasattr(ctx.session, "events") and ctx.session.events: if ctx.session and hasattr(ctx.session, "events") and ctx.session.events:
for event in reversed(ctx.session.events): for event in reversed(ctx.session.events):
if event.author == "user" and event.content and event.content.parts: if (
print("Message found in session events") hasattr(event, 'author') and
event.author == "user" and
hasattr(event, 'content') and
event.content and
hasattr(event.content, 'parts') and
event.content.parts
):
logger.debug("User message found in session events")
return event.content.parts[0].text return event.content.parts[0].text
# Try to find message in session state # Try to find message in session state
if ctx.session and ctx.session.state: if ctx.session and hasattr(ctx.session, 'state') and ctx.session.state:
if "user_message" in ctx.session.state: if "user_message" in ctx.session.state:
return ctx.session.state["user_message"] return ctx.session.state["user_message"]
elif "message" in ctx.session.state: elif "message" in ctx.session.state:
return ctx.session.state["message"] return ctx.session.state["message"]
return "" logger.warning("No user message found in context")
return "No user message provided"
def _get_session_id(self, ctx: InvocationContext) -> str: def _get_session_id(self, ctx: InvocationContext) -> str:
"""Gets or generates a session ID.""" """Gets or generates a session ID."""
if ctx.session and hasattr(ctx.session, "id"): if ctx.session and hasattr(ctx.session, "id") and ctx.session.id:
return str(ctx.session.id) return str(ctx.session.id)
return str(uuid.uuid4())
# Generate a new session ID
new_session_id = str(uuid.uuid4())
logger.debug(f"Generated new session ID: {new_session_id}")
return new_session_id
async def _prepare_initial_state( async def _prepare_initial_state(
self, ctx: InvocationContext, user_message: str, session_id: str self, ctx: InvocationContext, user_message: str, session_id: str
@ -889,9 +1021,13 @@ class WorkflowAgent(BaseAgent):
content=Content(parts=[Part(text=user_message)]), content=Content(parts=[Part(text=user_message)]),
) )
conversation_history = ctx.session.events or [user_event] conversation_history = []
if ctx.session and hasattr(ctx.session, 'events') and ctx.session.events:
conversation_history = ctx.session.events.copy()
else:
conversation_history = [user_event]
return State( initial_state = State(
content=[user_event], content=[user_event],
status="started", status="started",
session_id=session_id, session_id=session_id,
@ -899,34 +1035,61 @@ class WorkflowAgent(BaseAgent):
node_outputs={}, node_outputs={},
conversation_history=conversation_history, conversation_history=conversation_history,
) )
logger.debug(f"Initial state prepared with {len(conversation_history)} history events")
return initial_state
async def _execute_workflow( async def _execute_workflow(
self, ctx: InvocationContext, graph: StateGraph, initial_state: State self, ctx: InvocationContext, graph: StateGraph, initial_state: State
) -> AsyncGenerator[Event, None]: ) -> AsyncGenerator[Event, None]:
"""Executes the workflow graph and yields events.""" """Executes the workflow graph and yields events."""
sent_events = 0 sent_events = 0
total_iterations = 0
max_iterations = 100
async for state in graph.astream(initial_state, {"recursion_limit": 100}): try:
for node_state in state.values(): async for state in graph.astream(initial_state, {"recursion_limit": max_iterations}):
content = node_state.get("content", []) total_iterations += 1
for event in content[sent_events:]:
if event.author != "user": if total_iterations > max_iterations:
logger.warning(f"Maximum iterations ({max_iterations}) reached, stopping workflow")
break
for node_state in state.values():
content = node_state.get("content", [])
# Yield new events that haven't been sent yet
for event in content[sent_events:]:
if hasattr(event, 'author') and event.author != "user":
yield event
sent_events = len(content)
logger.info(f"Workflow completed after {total_iterations} iterations")
except Exception as e:
logger.error(f"Error during workflow execution: {str(e)}")
yield await self._handle_workflow_error(e)
if self.sub_agents:
logger.info(f"Executing {len(self.sub_agents)} sub-agents")
for sub_agent in self.sub_agents:
try:
async for event in sub_agent.run_async(ctx):
yield event yield event
sent_events = len(content) except Exception as e:
logger.error(f"Error executing sub-agent {sub_agent.name}: {str(e)}")
# Execute sub-agents if any yield await self._handle_workflow_error(e)
for sub_agent in self.sub_agents:
async for event in sub_agent.run_async(ctx):
yield event
async def _handle_workflow_error(self, error: Exception) -> Event: async def _handle_workflow_error(self, error: Exception) -> Event:
"""Creates an error event for workflow execution errors.""" """Creates an error event for workflow execution errors."""
error_msg = f"Error executing the workflow agent: {str(error)}" error_msg = f"Error executing workflow agent '{self.name}': {str(error)}"
print(error_msg) logger.error(error_msg)
return Event( return Event(
author=f"workflow-error:{self.name}", author=f"workflow-error:{self.name}",
content=Content( content=Content(
role="agent", role="agent",
parts=[Part(text=error_msg)], parts=[Part(text=error_msg)],
), ),
) )