This commit is contained in:
Anderson Lemes 2025-06-05 23:23:05 +00:00 committed by GitHub
commit 61a7c082c3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 830 additions and 319 deletions

View File

@ -158,7 +158,7 @@ async def get_agent_messages(
processed_events = []
for event in events:
event_dict = event.dict()
event_dict = event.model_dump()
def process_dict(d):
if isinstance(d, dict):

View File

@ -27,7 +27,7 @@
"""
from pydantic import BaseModel, Field, validator, UUID4, ConfigDict
from pydantic import BaseModel, Field, field_validator, model_validator, UUID4, ConfigDict
from typing import Optional, Dict, Any, List
from datetime import datetime
from uuid import UUID
@ -40,7 +40,8 @@ class ClientBase(BaseModel):
name: str
email: Optional[str] = None
@validator("email")
@field_validator("email")
@classmethod
def validate_email(cls, v):
if v is None:
return v
@ -58,8 +59,7 @@ class Client(ClientBase):
id: UUID
created_at: datetime
class Config:
from_attributes = True
model_config = ConfigDict(from_attributes=True)
class ApiKeyBase(BaseModel):
@ -101,7 +101,7 @@ class AgentBase(BaseModel):
description="Agent type (llm, sequential, parallel, loop, a2a, workflow, task)",
)
model: Optional[str] = Field(
None, description="Agent model (required only for llm type)"
None, description="LLM model identifier (required for LLM agents only)"
)
api_key_id: Optional[UUID4] = Field(
None, description="Reference to a stored API Key ID"
@ -115,8 +115,13 @@ class AgentBase(BaseModel):
)
config: Any = Field(None, description="Agent configuration based on type")
@validator("name")
def validate_name(cls, v, values):
@field_validator("name")
@classmethod
def validate_name(cls, v, info):
# Get values from validation context
values = info.data if hasattr(info, 'data') else {}
# A2A agents can have optional names
if values.get("type") == "a2a":
return v
@ -127,107 +132,246 @@ class AgentBase(BaseModel):
raise ValueError("Agent name cannot contain spaces or special characters")
return v
@validator("type")
@field_validator("type")
@classmethod
def validate_type(cls, v):
if v not in [
valid_types = [
"llm",
"sequential",
"parallel",
"loop",
"a2a",
"workflow",
"task",
]:
"sequential",
"parallel",
"loop",
"a2a",
"workflow",
"task"
]
if v not in valid_types:
raise ValueError(
"Invalid agent type. Must be: llm, sequential, parallel, loop, a2a, workflow or task"
f"Invalid agent type '{v}'. Must be one of: {', '.join(valid_types)}"
)
return v
@validator("agent_card_url")
def validate_agent_card_url(cls, v, values):
if "type" in values and values["type"] == "a2a":
@field_validator("agent_card_url")
@classmethod
def validate_agent_card_url(cls, v, info):
values = info.data if hasattr(info, 'data') else {}
if values.get("type") == "a2a":
if not v:
raise ValueError("agent_card_url is required for a2a type agents")
if not v.endswith("/.well-known/agent.json"):
raise ValueError("agent_card_url must end with /.well-known/agent.json")
return v
@validator("model")
def validate_model(cls, v, values):
if "type" in values and values["type"] == "llm" and not v:
raise ValueError("Model is required for llm type agents")
@field_validator("model")
@classmethod
def validate_model(cls, v, info):
values = info.data if hasattr(info, 'data') else {}
agent_type = values.get("type")
if agent_type == "llm":
# Para agentes LLM, o modelo é obrigatório e não pode ser vazio
if not v or (isinstance(v, str) and v.strip() == ""):
raise ValueError(
"LLM agents require a valid model configuration. "
"Please specify a model identifier (e.g., 'gpt-4', 'claude-3-sonnet', 'gemini-pro')"
)
# Verificar se o modelo tem um formato válido
if isinstance(v, str) and len(v.strip()) < 3:
raise ValueError("Model identifier must be at least 3 characters long")
elif agent_type in ["workflow", "task", "sequential", "parallel", "loop"]:
# Para estes tipos, não devem ter modelo
if v and (isinstance(v, str) and v.strip()):
# Avisar mas permitir (será removido durante a criação)
import logging
logger = logging.getLogger(__name__)
logger.warning(f"{agent_type} agents don't need model configuration. Model will be ignored.")
return v
@validator("api_key_id")
def validate_api_key_id(cls, v, values):
@field_validator("api_key_id")
@classmethod
def validate_api_key_id(cls, v, info):
values = info.data if hasattr(info, 'data') else {}
agent_type = values.get("type")
# API key é obrigatório para agentes LLM (a menos que esteja na config)
if agent_type == "llm" and not v:
# Verificar se tem API key na config
config = values.get("config", {})
if not config or not config.get("api_key"):
# Não falhar aqui, deixar a validação para o momento da criação
pass
return v
@validator("config")
def validate_config(cls, v, values):
if "type" in values and values["type"] == "a2a":
@field_validator("config")
@classmethod
def validate_config(cls, v, info):
values = info.data if hasattr(info, 'data') else {}
agent_type = values.get("type")
if not agent_type:
return v
# A2A agents têm config opcional
if agent_type == "a2a":
return v or {}
if "type" not in values:
# Workflow agents têm config específico para workflow
if agent_type == "workflow":
if v and isinstance(v, dict):
if not v.get("workflow"):
raise ValueError("Workflow agents must have 'workflow' configuration")
return v
# For workflow agents, we do not perform any validation
if "type" in values and values["type"] == "workflow":
return v
if not v and values.get("type") != "a2a":
# Config é obrigatório para outros tipos (exceto a2a)
if not v and agent_type not in ["a2a"]:
raise ValueError(
f"Configuration is required for {values.get('type')} agent type"
f"Configuration is required for {agent_type} agent type"
)
if values["type"] == "llm":
if isinstance(v, dict):
try:
# Convert the dictionary to LLMConfig
v = LLMConfig(**v)
except Exception as e:
raise ValueError(f"Invalid LLM configuration for agent: {str(e)}")
elif not isinstance(v, LLMConfig):
raise ValueError("Invalid LLM configuration for agent")
elif values["type"] in ["sequential", "parallel", "loop"]:
if not isinstance(v, dict):
raise ValueError(f'Invalid configuration for agent {values["type"]}')
if "sub_agents" not in v:
raise ValueError(f'Agent {values["type"]} must have sub_agents')
if not isinstance(v["sub_agents"], list):
raise ValueError("sub_agents must be a list")
if not v["sub_agents"]:
raise ValueError(
f'Agent {values["type"]} must have at least one sub-agent'
)
elif values["type"] == "task":
if not isinstance(v, dict):
raise ValueError(f'Invalid configuration for agent {values["type"]}')
if "tasks" not in v:
raise ValueError(f'Agent {values["type"]} must have tasks')
if not isinstance(v["tasks"], list):
raise ValueError("tasks must be a list")
if not v["tasks"]:
raise ValueError(f'Agent {values["type"]} must have at least one task')
for task in v["tasks"]:
if not isinstance(task, dict):
raise ValueError("Each task must be a dictionary")
required_fields = ["agent_id", "description", "expected_output"]
for field in required_fields:
if field not in task:
raise ValueError(f"Task missing required field: {field}")
if "sub_agents" in v and v["sub_agents"] is not None:
if not isinstance(v["sub_agents"], list):
raise ValueError("sub_agents must be a list")
return v
# Validação específica por tipo
if agent_type == "llm":
return cls._validate_llm_config(v)
elif agent_type in ["sequential", "parallel", "loop"]:
return cls._validate_composite_config(v, agent_type)
elif agent_type == "task":
return cls._validate_task_config(v)
return v
@classmethod
def _validate_llm_config(cls, v):
"""Valida configuração para agentes LLM"""
if isinstance(v, dict):
try:
# Convert the dictionary to LLMConfig
v = LLMConfig(**v)
except Exception as e:
raise ValueError(f"Invalid LLM configuration: {str(e)}")
elif not isinstance(v, LLMConfig):
raise ValueError("Invalid LLM configuration format")
return v
@classmethod
def _validate_composite_config(cls, v, agent_type):
"""Valida configuração para agentes compostos (sequential, parallel, loop)"""
if not isinstance(v, dict):
raise ValueError(f'Configuration for {agent_type} agent must be a dictionary')
if "sub_agents" not in v:
raise ValueError(f'{agent_type} agents must have sub_agents configuration')
if not isinstance(v["sub_agents"], list):
raise ValueError("sub_agents must be a list")
if not v["sub_agents"]:
raise ValueError(
f'{agent_type} agents must have at least one sub-agent'
)
# Validação específica para LoopAgent
if agent_type == "loop":
max_iterations = v.get("max_iterations", 5)
if not isinstance(max_iterations, int) or max_iterations <= 0:
raise ValueError("max_iterations must be a positive integer")
return v
@classmethod
def _validate_task_config(cls, v):
"""Valida configuração para agentes de task"""
if not isinstance(v, dict):
raise ValueError('Configuration for task agent must be a dictionary')
if "tasks" not in v:
raise ValueError('Task agents must have tasks configuration')
if not isinstance(v["tasks"], list):
raise ValueError("tasks must be a list")
if not v["tasks"]:
raise ValueError('Task agents must have at least one task')
# Validar cada task individualmente
for i, task in enumerate(v["tasks"]):
if not isinstance(task, dict):
raise ValueError(f"Task {i+1} must be a dictionary")
required_fields = ["agent_id", "description", "expected_output"]
for field in required_fields:
if field not in task:
raise ValueError(f"Task {i+1} missing required field: {field}")
# Verificar se os campos não estão vazios
if not task[field] or (isinstance(task[field], str) and not task[field].strip()):
raise ValueError(f"Task {i+1} field '{field}' cannot be empty")
# Validar sub_agents se presente
if "sub_agents" in v and v["sub_agents"] is not None:
if not isinstance(v["sub_agents"], list):
raise ValueError("sub_agents must be a list")
return v
@model_validator(mode='after')
def validate_agent_consistency(self):
"""Validação cruzada entre campos do agente"""
# Verificar consistência entre tipo e configurações
if self.type == "llm":
# LLM agents devem ter modelo
if not self.model or (isinstance(self.model, str) and self.model.strip() == ""):
raise ValueError("LLM agents must have a valid model")
# LLM agents devem ter API key (na config ou api_key_id)
has_api_key = bool(self.api_key_id)
if not has_api_key and self.config:
config_dict = self.config if isinstance(self.config, dict) else self.config.__dict__
has_api_key = bool(config_dict.get("api_key"))
if not has_api_key:
raise ValueError("LLM agents must have an API key configured")
elif self.type in ["workflow", "task", "sequential", "parallel", "loop"]:
# Orchestrator agents não devem ter modelo
if self.model and isinstance(self.model, str) and self.model.strip():
import logging
logger = logging.getLogger(__name__)
logger.warning(f"{self.type} agents don't need model configuration. Clearing model.")
self.model = None
elif self.type == "a2a":
# A2A agents devem ter agent_card_url
if not self.agent_card_url:
raise ValueError("A2A agents must have agent_card_url")
return self
class AgentCreate(AgentBase):
client_id: UUID
@model_validator(mode='after')
def validate_creation_requirements(self):
"""Validações específicas para criação de agentes"""
# Chamar validação da classe pai
super().validate_agent_consistency()
# Validações específicas para criação
if self.type == "llm":
# Para criação, ser mais rigoroso com modelo
if not self.model or len(self.model.strip()) < 3:
raise ValueError(
"LLM agents require a valid model identifier (minimum 3 characters). "
"Examples: 'gpt-4', 'claude-3-sonnet', 'gemini-pro'"
)
return self
class Agent(AgentBase):
id: UUID
@ -237,17 +381,17 @@ class Agent(AgentBase):
agent_card_url: Optional[str] = None
folder_id: Optional[UUID4] = None
class Config:
from_attributes = True
model_config = ConfigDict(from_attributes=True)
@validator("agent_card_url", pre=True)
def set_agent_card_url(cls, v, values):
@field_validator("agent_card_url", mode='before')
@classmethod
def set_agent_card_url(cls, v, info):
if v:
return v
values = info.data if hasattr(info, 'data') else {}
if "id" in values:
from os import getenv
return f"{getenv('API_URL', '')}/api/v1/a2a/{values['id']}/.well-known/agent.json"
return v
@ -262,6 +406,7 @@ class ToolConfig(BaseModel):
inputModes: List[str] = Field(default_factory=list)
outputModes: List[str] = Field(default_factory=list)
# Last edited by Arley Peter on 2025-05-17
class MCPServerBase(BaseModel):
name: str
@ -272,6 +417,29 @@ class MCPServerBase(BaseModel):
tools: Optional[List[ToolConfig]] = Field(default_factory=list)
type: str = Field(default="official")
@field_validator("name")
@classmethod
def validate_name(cls, v):
if not v or not v.strip():
raise ValueError("MCP Server name cannot be empty")
return v.strip()
@field_validator("config_type")
@classmethod
def validate_config_type(cls, v):
valid_types = ["studio", "sse"]
if v not in valid_types:
raise ValueError(f"config_type must be one of: {valid_types}")
return v
@field_validator("type")
@classmethod
def validate_type(cls, v):
valid_types = ["official", "community"]
if v not in valid_types:
raise ValueError(f"type must be one of: {valid_types}")
return v
class MCPServerCreate(MCPServerBase):
pass
@ -282,8 +450,7 @@ class MCPServer(MCPServerBase):
created_at: datetime
updated_at: Optional[datetime] = None
class Config:
from_attributes = True
model_config = ConfigDict(from_attributes=True)
class ToolBase(BaseModel):
@ -292,6 +459,13 @@ class ToolBase(BaseModel):
config_json: Dict[str, Any] = Field(default_factory=dict)
environments: Dict[str, Any] = Field(default_factory=dict)
@field_validator("name")
@classmethod
def validate_name(cls, v):
if not v or not v.strip():
raise ValueError("Tool name cannot be empty")
return v.strip()
class ToolCreate(ToolBase):
pass
@ -302,14 +476,20 @@ class Tool(ToolBase):
created_at: datetime
updated_at: Optional[datetime] = None
class Config:
from_attributes = True
model_config = ConfigDict(from_attributes=True)
class AgentFolderBase(BaseModel):
name: str
description: Optional[str] = None
@field_validator("name")
@classmethod
def validate_name(cls, v):
if not v or not v.strip():
raise ValueError("Folder name cannot be empty")
return v.strip()
class AgentFolderCreate(AgentFolderBase):
client_id: UUID4
@ -326,3 +506,86 @@ class AgentFolder(AgentFolderBase):
updated_at: Optional[datetime] = None
model_config = ConfigDict(from_attributes=True)
class AgentTypeInfo(BaseModel):
"""Informações sobre tipos de agente válidos"""
type: str
requires_model: bool
requires_config: bool
description: str
@classmethod
def get_valid_types(cls) -> Dict[str, 'AgentTypeInfo']:
"""Retorna informações sobre todos os tipos válidos de agente"""
return {
"llm": cls(
type="llm",
requires_model=True,
requires_config=True,
description="Large Language Model agent - requires model and API key"
),
"workflow": cls(
type="workflow",
requires_model=False,
requires_config=True,
description="Workflow orchestrator agent - uses LangGraph for complex flows"
),
"task": cls(
type="task",
requires_model=False,
requires_config=True,
description="Task management agent - coordinates multiple tasks"
),
"sequential": cls(
type="sequential",
requires_model=False,
requires_config=True,
description="Sequential execution agent - runs sub-agents in order"
),
"parallel": cls(
type="parallel",
requires_model=False,
requires_config=True,
description="Parallel execution agent - runs sub-agents concurrently"
),
"loop": cls(
type="loop",
requires_model=False,
requires_config=True,
description="Loop execution agent - repeats sub-agents with conditions"
),
"a2a": cls(
type="a2a",
requires_model=False,
requires_config=False,
description="Agent-to-Agent communication - external agent integration"
)
}
class ModelValidationResult(BaseModel):
"""Resultado da validação de modelo"""
is_valid: bool
error_message: Optional[str] = None
warnings: List[str] = Field(default_factory=list)
@classmethod
def success(cls, warnings: List[str] = None) -> 'ModelValidationResult':
return cls(is_valid=True, warnings=warnings or [])
@classmethod
def failure(cls, error_message: str) -> 'ModelValidationResult':
return cls(is_valid=False, error_message=error_message)
class AgentValidationSummary(BaseModel):
"""Resumo de validação de agente"""
agent_id: Optional[UUID] = None
agent_name: Optional[str] = None
agent_type: str
is_valid: bool
model_validation: ModelValidationResult
config_validation: ModelValidationResult
general_errors: List[str] = Field(default_factory=list)
warnings: List[str] = Field(default_factory=list)

View File

@ -67,15 +67,39 @@ class AgentBuilder:
if agent_tools_ids and isinstance(agent_tools_ids, list):
for agent_tool_id in agent_tools_ids:
sub_agent = get_agent(self.db, agent_tool_id)
llm_agent, _ = await self.build_llm_agent(sub_agent)
if llm_agent:
agent_tools.append(AgentTool(agent=llm_agent))
if sub_agent:
# Verificar se o sub_agent é do tipo LLM antes de criar LlmAgent
if sub_agent.type == "llm":
llm_agent, _ = await self.build_llm_agent(sub_agent)
if llm_agent:
agent_tools.append(AgentTool(agent=llm_agent))
else:
logger.warning(f"Agent tool {agent_tool_id} is not of type 'llm', skipping")
else:
logger.warning(f"Agent tool {agent_tool_id} not found")
return agent_tools
def _validate_llm_agent_model(self, agent) -> None:
"""Validate that LLM agent has a proper model configuration."""
if not hasattr(agent, 'model') or not agent.model:
logger.error(f"LLM agent {agent.name} does not have a model configured")
raise ValueError(f"LLM agent {agent.name} requires a model configuration")
if isinstance(agent.model, str) and agent.model.strip() == "":
logger.error(f"LLM agent {agent.name} has an empty model string")
raise ValueError(f"LLM agent {agent.name} has an empty model configuration")
logger.info(f"Model validation passed for agent {agent.name}: {agent.model}")
async def _create_llm_agent(
self, agent, enabled_tools: List[str] = []
) -> Tuple[LlmAgent, Optional[AsyncExitStack]]:
"""Create an LLM agent from the agent data."""
self._validate_llm_agent_model(agent)
logger.info(f"Creating LLM agent: {agent.name} with model: {agent.model}")
# Get custom tools from the configuration
custom_tools = []
custom_tools = self.custom_tool_builder.build_tools(agent.config)
@ -110,7 +134,7 @@ class AgentBuilder:
current_day_of_week=current_day_of_week,
current_date_iso=current_date_iso,
current_time=current_time,
)
) if agent.instruction else ""
# add role on beginning of the prompt
if agent.role:
@ -170,21 +194,27 @@ class AgentBuilder:
f"Agent {agent.name} does not have a configured API key"
)
return (
LlmAgent(
if not agent.model or (isinstance(agent.model, str) and agent.model.strip() == ""):
raise ValueError(f"Cannot create LiteLlm with empty model for agent {agent.name}")
try:
llm_agent = LlmAgent(
name=agent.name,
model=LiteLlm(model=agent.model, api_key=api_key),
instruction=formatted_prompt,
description=agent.description,
tools=all_tools,
),
mcp_exit_stack,
)
)
logger.info(f"LLM agent created successfully: {agent.name}")
return llm_agent, mcp_exit_stack
except Exception as e:
logger.error(f"Error creating LLM agent {agent.name}: {str(e)}")
raise ValueError(f"Error creating LLM agent {agent.name}: {str(e)}")
async def _get_sub_agents(
self, sub_agent_ids: List[str]
) -> List[Tuple[LlmAgent, Optional[AsyncExitStack]]]:
"""Get and create LLM sub-agents."""
) -> List[Tuple[BaseAgent, Optional[AsyncExitStack]]]:
"""Get and create sub-agents with proper type validation."""
sub_agents = []
for sub_agent_id in sub_agent_ids:
sub_agent_id_str = str(sub_agent_id)
@ -197,39 +227,50 @@ class AgentBuilder:
logger.info(f"Sub-agent found: {agent.name} (type: {agent.type})")
if agent.type == "llm":
sub_agent, exit_stack = await self._create_llm_agent(agent)
elif agent.type == "a2a":
sub_agent, exit_stack = await self.build_a2a_agent(agent)
elif agent.type == "workflow":
sub_agent, exit_stack = await self.build_workflow_agent(agent)
elif agent.type == "task":
sub_agent, exit_stack = await self.build_task_agent(agent)
elif agent.type == "sequential":
sub_agent, exit_stack = await self.build_composite_agent(agent)
elif agent.type == "parallel":
sub_agent, exit_stack = await self.build_composite_agent(agent)
elif agent.type == "loop":
sub_agent, exit_stack = await self.build_composite_agent(agent)
else:
raise ValueError(f"Invalid agent type: {agent.type}")
try:
if agent.type == "llm":
# Verificar se tem modelo antes de criar
if not agent.model or (isinstance(agent.model, str) and agent.model.strip() == ""):
logger.error(f"LLM sub-agent {agent.name} does not have a model configured")
raise ValueError(f"LLM sub-agent {agent.name} requires a model configuration")
sub_agent, exit_stack = await self._create_llm_agent(agent)
elif agent.type == "a2a":
sub_agent, exit_stack = await self.build_a2a_agent(agent)
elif agent.type == "workflow":
# Workflow agents não precisam de modelo
sub_agent, exit_stack = await self.build_workflow_agent(agent)
elif agent.type == "task":
sub_agent, exit_stack = await self.build_task_agent(agent)
elif agent.type == "sequential":
sub_agent, exit_stack = await self.build_composite_agent(agent)
elif agent.type == "parallel":
sub_agent, exit_stack = await self.build_composite_agent(agent)
elif agent.type == "loop":
sub_agent, exit_stack = await self.build_composite_agent(agent)
else:
raise ValueError(f"Invalid agent type: {agent.type}")
sub_agents.append(sub_agent)
logger.info(f"Sub-agent added: {agent.name}")
sub_agents.append((sub_agent, exit_stack))
logger.info(f"Sub-agent added: {agent.name}")
except Exception as e:
logger.error(f"Error creating sub-agent {agent.name}: {str(e)}")
raise ValueError(f"Error creating sub-agent {agent.name}: {str(e)}")
logger.info(f"Sub-agents created: {len(sub_agents)}")
logger.info(f"Sub-agents: {str(sub_agents)}")
return sub_agents
async def build_llm_agent(
self, root_agent, enabled_tools: List[str] = []
) -> Tuple[LlmAgent, Optional[AsyncExitStack]]:
"""Build an LLM agent with its sub-agents."""
logger.info("Creating LLM agent")
logger.info(f"Creating LLM agent: {root_agent.name}")
if root_agent.type != "llm":
raise ValueError(f"Expected LLM agent, got {root_agent.type}")
sub_agents = []
if root_agent.config.get("sub_agents"):
if root_agent.config and root_agent.config.get("sub_agents"):
sub_agents_with_stacks = await self._get_sub_agents(
root_agent.config.get("sub_agents")
)
@ -241,20 +282,21 @@ class AgentBuilder:
if sub_agents:
root_llm_agent.sub_agents = sub_agents
logger.info(f"LLM agent built successfully: {root_agent.name}")
return root_llm_agent, exit_stack
async def build_a2a_agent(
self, root_agent
) -> Tuple[BaseAgent, Optional[AsyncExitStack]]:
) -> Tuple[A2ACustomAgent, Optional[AsyncExitStack]]:
"""Build an A2A agent with its sub-agents."""
logger.info(f"Creating A2A agent from {root_agent.agent_card_url}")
logger.info(f"Creating A2A agent from {root_agent.name}")
if not root_agent.agent_card_url:
raise ValueError("agent_card_url is required for a2a agents")
try:
sub_agents = []
if root_agent.config.get("sub_agents"):
if root_agent.config and root_agent.config.get("sub_agents"):
sub_agents_with_stacks = await self._get_sub_agents(
root_agent.config.get("sub_agents")
)
@ -288,6 +330,9 @@ class AgentBuilder:
"""Build a workflow agent with its sub-agents."""
logger.info(f"Creating Workflow agent from {root_agent.name}")
if root_agent.type != "workflow":
raise ValueError(f"Expected workflow agent, got {root_agent.type}")
agent_config = root_agent.config or {}
if not agent_config.get("workflow"):
@ -295,7 +340,7 @@ class AgentBuilder:
try:
sub_agents = []
if root_agent.config.get("sub_agents"):
if root_agent.config and root_agent.config.get("sub_agents"):
sub_agents_with_stacks = await self._get_sub_agents(
root_agent.config.get("sub_agents")
)
@ -304,15 +349,20 @@ class AgentBuilder:
config = root_agent.config or {}
timeout = config.get("timeout", 300)
workflow_agent = WorkflowAgent(
name=root_agent.name,
flow_json=agent_config.get("workflow"),
timeout=timeout,
description=root_agent.description
or f"Workflow Agent for {root_agent.name}",
sub_agents=sub_agents,
db=self.db,
)
kwargs = {
"name": root_agent.name,
"flow_json": agent_config.get("workflow"),
"timeout": timeout,
"description": root_agent.description or f"Workflow Agent for {root_agent.name}",
"sub_agents": sub_agents,
"db": self.db,
}
# Se o root_agent tiver modelo, não passá-lo para o WorkflowAgent
if hasattr(root_agent, 'model') and root_agent.model:
logger.warning(f"Workflow agent {root_agent.name} has model '{root_agent.model}' configured, but workflow agents should not have models. Ignoring model.")
workflow_agent = WorkflowAgent(**kwargs)
logger.info(f"Workflow agent created successfully: {root_agent.name}")
@ -328,6 +378,9 @@ class AgentBuilder:
"""Build a task agent with its sub-agents."""
logger.info(f"Creating Task agent: {root_agent.name}")
if root_agent.type != "task":
raise ValueError(f"Expected task agent, got {root_agent.type}")
agent_config = root_agent.config or {}
if not agent_config.get("tasks"):
@ -336,7 +389,7 @@ class AgentBuilder:
try:
# Get sub-agents if there are any
sub_agents = []
if root_agent.config.get("sub_agents"):
if root_agent.config and root_agent.config.get("sub_agents"):
sub_agents_with_stacks = await self._get_sub_agents(
root_agent.config.get("sub_agents")
)
@ -380,7 +433,11 @@ class AgentBuilder:
f"Processing sub-agents for agent {root_agent.type} (ID: {root_agent.id}, Name: {root_agent.name})"
)
if not root_agent.config.get("sub_agents"):
valid_composite_types = ["sequential", "parallel", "loop"]
if root_agent.type not in valid_composite_types:
raise ValueError(f"Expected composite agent type ({valid_composite_types}), got {root_agent.type}")
if not root_agent.config or not root_agent.config.get("sub_agents"):
logger.error(
f"Sub_agents configuration not found or empty for agent {root_agent.name}"
)
@ -401,39 +458,51 @@ class AgentBuilder:
sub_agents = [agent for agent, _ in sub_agents_with_stacks]
logger.info(f"Extracted sub-agents: {[agent.name for agent in sub_agents]}")
if root_agent.type == "sequential":
logger.info(f"Creating SequentialAgent with {len(sub_agents)} sub-agents")
return (
SequentialAgent(
name=root_agent.name,
sub_agents=sub_agents,
description=root_agent.config.get("description", ""),
),
None,
)
elif root_agent.type == "parallel":
logger.info(f"Creating ParallelAgent with {len(sub_agents)} sub-agents")
return (
ParallelAgent(
name=root_agent.name,
sub_agents=sub_agents,
description=root_agent.config.get("description", ""),
),
None,
)
elif root_agent.type == "loop":
logger.info(f"Creating LoopAgent with {len(sub_agents)} sub-agents")
return (
LoopAgent(
name=root_agent.name,
sub_agents=sub_agents,
description=root_agent.config.get("description", ""),
max_iterations=root_agent.config.get("max_iterations", 5),
),
None,
)
else:
raise ValueError(f"Invalid agent type: {root_agent.type}")
if not sub_agents:
raise ValueError(f"No valid sub-agents found for {root_agent.type} agent {root_agent.name}")
try:
if root_agent.type == "sequential":
logger.info(f"Creating SequentialAgent with {len(sub_agents)} sub-agents")
return (
SequentialAgent(
name=root_agent.name,
sub_agents=sub_agents,
description=root_agent.description or root_agent.config.get("description", ""),
),
None,
)
elif root_agent.type == "parallel":
logger.info(f"Creating ParallelAgent with {len(sub_agents)} sub-agents")
return (
ParallelAgent(
name=root_agent.name,
sub_agents=sub_agents,
description=root_agent.description or root_agent.config.get("description", ""),
),
None,
)
elif root_agent.type == "loop":
logger.info(f"Creating LoopAgent with {len(sub_agents)} sub-agents")
max_iterations = root_agent.config.get("max_iterations", 5)
if max_iterations <= 0:
logger.warning(f"Invalid max_iterations ({max_iterations}) for LoopAgent, using default 5")
max_iterations = 5
return (
LoopAgent(
name=root_agent.name,
sub_agents=sub_agents,
description=root_agent.description or root_agent.config.get("description", ""),
max_iterations=max_iterations,
),
None,
)
else:
raise ValueError(f"Invalid composite agent type: {root_agent.type}")
except Exception as e:
logger.error(f"Error creating {root_agent.type} agent {root_agent.name}: {str(e)}")
raise ValueError(f"Error creating {root_agent.type} agent {root_agent.name}: {str(e)}")
async def build_agent(self, root_agent, enabled_tools: List[str] = []) -> Tuple[
LlmAgent
@ -446,13 +515,29 @@ class AgentBuilder:
Optional[AsyncExitStack],
]:
"""Build the appropriate agent based on the type of the root agent."""
if root_agent.type == "llm":
return await self.build_llm_agent(root_agent, enabled_tools)
elif root_agent.type == "a2a":
return await self.build_a2a_agent(root_agent)
elif root_agent.type == "workflow":
return await self.build_workflow_agent(root_agent)
elif root_agent.type == "task":
return await self.build_task_agent(root_agent)
else:
return await self.build_composite_agent(root_agent)
if not root_agent:
raise ValueError("root_agent cannot be None")
if not hasattr(root_agent, 'type') or not root_agent.type:
raise ValueError("root_agent must have a valid type")
logger.info(f"Building agent: {root_agent.name} (type: {root_agent.type})")
try:
if root_agent.type == "llm":
return await self.build_llm_agent(root_agent, enabled_tools)
elif root_agent.type == "a2a":
return await self.build_a2a_agent(root_agent)
elif root_agent.type == "workflow":
return await self.build_workflow_agent(root_agent)
elif root_agent.type == "task":
return await self.build_task_agent(root_agent)
elif root_agent.type in ["sequential", "parallel", "loop"]:
return await self.build_composite_agent(root_agent)
else:
raise ValueError(f"Unknown agent type: {root_agent.type}")
except Exception as e:
logger.error(f"Error building agent {root_agent.name}: {str(e)}")
raise

View File

@ -458,7 +458,7 @@ async def run_agent_stream(
async for event in events_async:
try:
event_dict = event.dict()
event_dict = event.model_dump()
event_dict = convert_sets(event_dict)
if "content" in event_dict and event_dict["content"]:

View File

@ -30,6 +30,7 @@
"""
from datetime import datetime
from google.adk.agents import BaseAgent
from google.adk.agents.invocation_context import InvocationContext
@ -40,11 +41,14 @@ from typing import AsyncGenerator, Dict, Any, List, TypedDict
import uuid
from src.services.agent_service import get_agent
from src.utils.logger import setup_logger
from sqlalchemy.orm import Session
from langgraph.graph import StateGraph, END
logger = setup_logger(__name__)
class State(TypedDict):
content: List[Event]
@ -63,6 +67,9 @@ class WorkflowAgent(BaseAgent):
This agent allows defining and executing complex workflows between multiple agents
using LangGraph for orchestration.
IMPORTANT: Workflow agents are orchestrators and should NOT have a model configured.
They delegate to sub-agents that have their own models.
"""
# Field declarations for Pydantic
@ -89,6 +96,21 @@ class WorkflowAgent(BaseAgent):
sub_agents: List of sub-agents to be executed after the workflow agent
db: Session
"""
# Workflow agents não devem ter modelos
if 'model' in kwargs:
logger.warning(f"Removing model from workflow agent {name}. Workflow agents should not have models.")
del kwargs['model']
if not flow_json:
raise ValueError(f"Workflow agent {name} requires flow_json configuration")
if not isinstance(flow_json, dict):
raise ValueError(f"Workflow agent {name} flow_json must be a dictionary")
if not flow_json.get('nodes'):
raise ValueError(f"Workflow agent {name} flow_json must contain nodes")
# Initialize base class
super().__init__(
name=name,
@ -98,9 +120,13 @@ class WorkflowAgent(BaseAgent):
db=db,
**kwargs,
)
if hasattr(self, 'model'):
logger.warning(f"Workflow agent {name} had a model attribute. Removing it.")
delattr(self, 'model')
print(
f"Workflow agent initialized with {len(flow_json.get('nodes', []))} nodes"
logger.info(
f"Workflow agent '{name}' initialized with {len(flow_json.get('nodes', []))} nodes"
)
async def _create_node_functions(self, ctx: InvocationContext):
@ -112,11 +138,12 @@ class WorkflowAgent(BaseAgent):
node_id: str,
node_data: Dict[str, Any],
) -> AsyncGenerator[State, None]:
print("\n🏁 INITIAL NODE")
logger.info(f"🏁 INITIAL NODE: {node_id}")
content = state.get("content", [])
if not content:
logger.warning("No content found in initial state")
content = [
Event(
author=f"workflow-node:{node_id}",
@ -128,9 +155,11 @@ class WorkflowAgent(BaseAgent):
"status": "error",
"node_outputs": {},
"cycle_count": 0,
"conversation_history": ctx.session.events,
"conversation_history": ctx.session.events if ctx.session else [],
"session_id": state.get("session_id", ""),
}
return
session_id = state.get("session_id", "")
# Store specific results for this node
@ -149,7 +178,7 @@ class WorkflowAgent(BaseAgent):
"node_outputs": node_outputs,
"cycle_count": 0,
"session_id": session_id,
"conversation_history": ctx.session.events,
"conversation_history": ctx.session.events if ctx.session else [],
}
# Generic function for agent nodes
@ -163,7 +192,7 @@ class WorkflowAgent(BaseAgent):
# Increment cycle counter
cycle_count = state.get("cycle_count", 0) + 1
print(f"\n👤 AGENT: {agent_name} (Cycle {cycle_count})")
logger.info(f"👤 AGENT: {agent_name} (Cycle {cycle_count})")
content = state.get("content", [])
session_id = state.get("session_id", "")
@ -171,14 +200,13 @@ class WorkflowAgent(BaseAgent):
# Get conversation history
conversation_history = state.get("conversation_history", [])
agent = get_agent(self.db, agent_id)
if not agent:
if not agent_id:
logger.error(f"Agent node {node_id} does not have a valid agent_id")
yield {
"content": [
Event(
author=f"workflow-node:{node_id}",
content=Content(parts=[Part(text="Agent not found")]),
content=Content(parts=[Part(text="Agent ID not configured")]),
)
],
"session_id": session_id,
@ -189,44 +217,84 @@ class WorkflowAgent(BaseAgent):
}
return
# Import moved to inside the function to avoid circular import
from src.services.adk.agent_builder import AgentBuilder
agent = get_agent(self.db, agent_id)
agent_builder = AgentBuilder(self.db)
root_agent, exit_stack = await agent_builder.build_agent(agent)
if not agent:
logger.error(f"Agent not found for ID: {agent_id}")
yield {
"content": [
Event(
author=f"workflow-node:{node_id}",
content=Content(parts=[Part(text=f"Agent not found: {agent_id}")]),
)
],
"session_id": session_id,
"status": "error",
"node_outputs": {},
"cycle_count": cycle_count,
"conversation_history": conversation_history,
}
return
new_content = []
async for event in root_agent.run_async(ctx):
conversation_history.append(event)
modified_event = Event(
author=f"workflow-node:{node_id}", content=event.content
)
new_content.append(modified_event)
try:
# Import moved to inside the function to avoid circular import
from src.services.adk.agent_builder import AgentBuilder
agent_builder = AgentBuilder(self.db)
root_agent, exit_stack = await agent_builder.build_agent(agent)
print(f"New content: {new_content}")
new_content = []
async for event in root_agent.run_async(ctx):
conversation_history.append(event)
modified_event = Event(
author=f"workflow-node:{node_id}", content=event.content
)
new_content.append(modified_event)
node_outputs = state.get("node_outputs", {})
node_outputs[node_id] = {
"processed_by": agent_name,
"agent_content": new_content,
"cycle": cycle_count,
}
logger.debug(f"Agent {agent_name} generated {len(new_content)} events")
content = content + new_content
node_outputs = state.get("node_outputs", {})
node_outputs[node_id] = {
"processed_by": agent_name,
"agent_id": agent_id,
"agent_content": new_content,
"cycle": cycle_count,
"processed_at": datetime.now().isoformat(),
}
yield {
"content": content,
"status": "processed_by_agent",
"node_outputs": node_outputs,
"cycle_count": cycle_count,
"conversation_history": conversation_history,
"session_id": session_id,
}
content = content + new_content
if exit_stack:
await exit_stack.aclose()
yield {
"content": content,
"status": "processed_by_agent",
"node_outputs": node_outputs,
"cycle_count": cycle_count,
"conversation_history": conversation_history,
"session_id": session_id,
}
if exit_stack:
try:
await exit_stack.aclose()
except Exception as e:
logger.warning(f"Error closing exit stack for agent {agent_name}: {e}")
except Exception as e:
logger.error(f"Error executing agent {agent_name}: {str(e)}")
yield {
"content": [
Event(
author=f"workflow-node:{node_id}",
content=Content(parts=[Part(text=f"Error executing agent: {str(e)}")]),
)
],
"session_id": session_id,
"status": "agent_error",
"node_outputs": state.get("node_outputs", {}),
"cycle_count": cycle_count,
"conversation_history": conversation_history,
}
# Function for condition nodes
async def condition_node_function(
@ -236,7 +304,7 @@ class WorkflowAgent(BaseAgent):
conditions = node_data.get("conditions", [])
cycle_count = state.get("cycle_count", 0)
print(f"\n🔄 CONDITION: {label} (Cycle {cycle_count})")
logger.info(f"🔄 CONDITION: {label} (Cycle {cycle_count})")
content = state.get("content", [])
conversation_history = state.get("conversation_history", [])
@ -245,16 +313,17 @@ class WorkflowAgent(BaseAgent):
if content and len(content) > 0:
for event in reversed(content):
if (
event.author != "agent"
or not hasattr(event.content, "parts")
or not event.content.parts
hasattr(event, 'author') and
event.author != "user" and
hasattr(event, 'content') and
hasattr(event.content, "parts") and
event.content.parts
):
latest_event = event
break
if latest_event:
print(
f"Evaluating condition only for the most recent event: '{latest_event}'"
)
logger.debug(f"Evaluating condition for latest event from: {latest_event.author}")
# Use only the most recent event for condition evaluation
evaluation_state = state.copy()
@ -273,25 +342,24 @@ class WorkflowAgent(BaseAgent):
operator = condition_data.get("operator")
expected_value = condition_data.get("value")
print(
f" Checking if {field} {operator} '{expected_value}' (current value: '{evaluation_state.get(field, '')}')"
logger.debug(
f"Checking condition: {field} {operator} '{expected_value}'"
)
if self._evaluate_condition(condition, evaluation_state):
conditions_met.append(condition_id)
condition_details.append(
f"{field} {operator} '{expected_value}'"
)
print(f" ✅ Condition {condition_id} met!")
logger.info(f"✅ Condition {condition_id} met!")
else:
condition_details.append(
f"{field} {operator} '{expected_value}'"
)
# Check if the cycle reached the limit (extra security)
if cycle_count >= 10:
print(
f"⚠️ ATTENTION: Cycle limit reached ({cycle_count}). Forcing termination."
)
max_cycles = 10 # Poderia vir da configuração
if cycle_count >= max_cycles:
logger.warning(f"Cycle limit reached ({cycle_count}). Forcing termination.")
condition_content = [
Event(
@ -314,10 +382,10 @@ class WorkflowAgent(BaseAgent):
node_outputs = state.get("node_outputs", {})
node_outputs[node_id] = {
"condition_evaluated": label,
"content_evaluated": content,
"conditions_met": conditions_met,
"condition_details": condition_details,
"cycle": cycle_count,
"evaluated_at": datetime.now().isoformat(),
}
# Prepare a more descriptive message about the conditions
@ -334,7 +402,8 @@ class WorkflowAgent(BaseAgent):
)
]
),
) ]
)
]
content = content + condition_content
yield {
@ -353,7 +422,7 @@ class WorkflowAgent(BaseAgent):
message_type = message_data.get("type", "text")
message_content = message_data.get("content", "")
print(f"\n💬 MESSAGE-NODE: {message_content}")
logger.info(f"💬 MESSAGE-NODE: {message_content}")
content = state.get("content", [])
session_id = state.get("session_id", "")
@ -371,6 +440,8 @@ class WorkflowAgent(BaseAgent):
node_outputs[node_id] = {
"message_type": message_type,
"message_content": message_content,
"label": label,
"processed_at": datetime.now().isoformat(),
}
yield {
@ -378,7 +449,8 @@ class WorkflowAgent(BaseAgent):
"status": "message_added",
"node_outputs": node_outputs,
"cycle_count": state.get("cycle_count", 0),
"conversation_history": conversation_history, "session_id": session_id,
"conversation_history": conversation_history,
"session_id": session_id,
}
async def delay_node_function(
@ -389,6 +461,10 @@ class WorkflowAgent(BaseAgent):
delay_unit = delay_data.get("unit", "seconds")
delay_description = delay_data.get("description", "")
if delay_value <= 0:
logger.warning(f"Invalid delay value: {delay_value}. Using 1 second.")
delay_value = 1
# Convert to seconds based on unit
delay_seconds = delay_value
if delay_unit == "minutes":
@ -397,7 +473,7 @@ class WorkflowAgent(BaseAgent):
delay_seconds = delay_value * 3600
label = node_data.get("label", "delay_node")
print(f"\n⏱️ DELAY-NODE: {delay_value} {delay_unit} - {delay_description}")
logger.info(f"⏱️ DELAY-NODE: {delay_value} {delay_unit} ({delay_seconds}s) - {delay_description}")
content = state.get("content", [])
session_id = state.get("session_id", "")
@ -409,13 +485,17 @@ class WorkflowAgent(BaseAgent):
"delay_value": delay_value,
"delay_unit": delay_unit,
"delay_seconds": delay_seconds,
"delay_description": delay_description,
"delay_start_time": datetime.now().isoformat(),
}
# Actually perform the delay
import asyncio
await asyncio.sleep(delay_seconds)
try:
await asyncio.sleep(delay_seconds)
except asyncio.CancelledError:
logger.warning(f"Delay in node {node_id} was cancelled")
# Continue execution even if delay was cancelled
# Update node outputs with completion information
node_outputs[node_id]["delay_end_time"] = datetime.now().isoformat()
@ -424,7 +504,8 @@ class WorkflowAgent(BaseAgent):
yield {
"content": content,
"status": "delay_completed",
"node_outputs": node_outputs, "cycle_count": state.get("cycle_count", 0),
"node_outputs": node_outputs,
"cycle_count": state.get("cycle_count", 0),
"conversation_history": conversation_history,
"session_id": session_id,
}
@ -452,7 +533,7 @@ class WorkflowAgent(BaseAgent):
result = self._process_condition(operator, actual_value, expected_value)
print(f" Check '{operator}': {result}")
logger.debug(f"Condition check '{operator}': {result}")
return result
return False
@ -488,7 +569,7 @@ class WorkflowAgent(BaseAgent):
if extracted_texts:
joined_text = " ".join(extracted_texts)
print(f" Extracted text from events: '{joined_text[:100]}...'")
logger.debug(f"Extracted text from events: '{joined_text[:100]}...'")
return joined_text
return ""
@ -524,6 +605,7 @@ class WorkflowAgent(BaseAgent):
elif operator in ["matches", "not_matches"]:
return self._check_regex(operator, actual_str, expected_str)
logger.warning(f"Unknown operator: {operator}")
return False
def _check_definition(self, operator, actual_value):
@ -563,8 +645,8 @@ class WorkflowAgent(BaseAgent):
else: # less_than_or_equal
return actual_num <= expected_num
except (ValueError, TypeError):
print(
f" Error converting values for numeric comparison: '{actual_str[:100]}...' and '{expected_str}'"
logger.warning(
f"Error converting values for numeric comparison: '{actual_str[:100]}...' and '{expected_str}'"
)
return False
@ -579,7 +661,7 @@ class WorkflowAgent(BaseAgent):
else: # not_matches
return not bool(pattern.search(actual_str))
except re.error:
print(f" Error in regular expression: '{expected_str}'")
logger.warning(f"Error in regular expression: '{expected_str}'")
return (
operator == "not_matches"
) # Return True for not_matches, False for matches
@ -589,8 +671,8 @@ class WorkflowAgent(BaseAgent):
expected_lower = expected_str.lower()
actual_lower = actual_str.lower()
print(
f" Comparison '{operator}' without case distinction: '{expected_lower}' in '{actual_lower[:100]}...'"
logger.debug(
f"Comparison '{operator}' case insensitive: '{expected_lower}' in '{actual_lower[:100]}...'"
)
if operator == "contains":
@ -627,14 +709,13 @@ class WorkflowAgent(BaseAgent):
# Routing function for each specific node
def create_router_for_node(node_id: str):
def router(state: State) -> str:
print(f"Routing from node: {node_id}")
logger.debug(f"Routing from node: {node_id}")
# Check if the cycle limit has been reached
cycle_count = state.get("cycle_count", 0)
if cycle_count >= 10:
print(
f"⚠️ Cycle limit ({cycle_count}) reached. Finalizing the flow."
)
max_cycles = 10 # Configurável
if cycle_count >= max_cycles:
logger.warning(f"Cycle limit ({cycle_count}) reached. Finalizing the flow.")
return END
# If it's a condition node, evaluate the conditions
@ -648,32 +729,29 @@ class WorkflowAgent(BaseAgent):
if conditions_met:
any_condition_met = True
condition_id = conditions_met[0]
print(
f"Using stored condition evaluation result: Condition {condition_id} met."
)
logger.debug(f"Using stored condition result: Condition {condition_id} met.")
if (
node_id in edges_map
and condition_id in edges_map[node_id]
):
return edges_map[node_id][condition_id]
else:
print(
"Using stored condition evaluation result: No conditions met."
)
logger.debug("Using stored condition result: No conditions met.")
else:
# Evaluate conditions
for condition in conditions:
condition_id = condition.get("id")
# Get latest event for evaluation, ignoring condition node informational events
content = state.get("content", [])
# Filter out events generated by condition nodes or informational messages
# Filter out events generated by condition nodes or that contain evaluation results
filtered_content = []
for event in content:
# Ignore events from condition nodes or that contain evaluation results
if not hasattr(event, "author") or not (
event.author.startswith("Condition")
or "Condition evaluated:" in str(event)
event.author.startswith("workflow-node:") and
"Condition evaluated:" in str(event)
):
filtered_content.append(event)
@ -687,9 +765,7 @@ class WorkflowAgent(BaseAgent):
if is_condition_met:
any_condition_met = True
print(
f"Condition {condition_id} met. Moving to the next node."
)
logger.debug(f"Condition {condition_id} met. Moving to next node.")
# Find the connection that uses this condition_id as a handle
if (
@ -698,9 +774,7 @@ class WorkflowAgent(BaseAgent):
):
return edges_map[node_id][condition_id]
else:
print(
f"Condition {condition_id} not met. Continuing evaluation or using default path."
)
logger.debug(f"Condition {condition_id} not met.")
# If no condition is met, use the bottom-handle if available
if not any_condition_met:
@ -708,14 +782,10 @@ class WorkflowAgent(BaseAgent):
node_id in edges_map
and "bottom-handle" in edges_map[node_id]
):
print(
"No condition met. Using default path (bottom-handle)."
)
logger.debug("No condition met. Using default path (bottom-handle).")
return edges_map[node_id]["bottom-handle"]
else:
print(
"No condition met and no default path. Closing the flow."
)
logger.debug("No condition met and no default path. Closing the flow.")
return END
# For regular nodes, simply follow the first available connection
@ -731,7 +801,7 @@ class WorkflowAgent(BaseAgent):
return edges_map[node_id][first_handle]
# If there is no output connection, close the flow
print(f"No output connection from node {node_id}. Closing the flow.")
logger.debug(f"No output connection from node {node_id}. Closing the flow.")
return END
return router
@ -745,6 +815,9 @@ class WorkflowAgent(BaseAgent):
# Extract nodes from the flow
nodes = flow_data.get("nodes", [])
if not nodes:
raise ValueError("Flow data must contain at least one node")
# Initialize StateGraph
graph_builder = StateGraph(State)
@ -754,34 +827,60 @@ class WorkflowAgent(BaseAgent):
# Dictionary to store specific functions for each node
node_specific_functions = {}
valid_node_types = set(node_functions.keys())
# Add nodes to the graph
for node in nodes:
node_id = node.get("id")
node_type = node.get("type")
node_data = node.get("data", {})
if node_type in node_functions:
# Create a specific function for this node
def create_node_function(node_type, node_id, node_data):
async def node_function(state):
# Consume the asynchronous generator and return the last result
result = None
if not node_id:
logger.warning(f"Skipping node without ID: {node}")
continue
if node_type not in valid_node_types:
logger.warning(f"Unknown node type '{node_type}' for node {node_id}. Skipping.")
continue
# Create a specific function for this node
def create_node_function(node_type, node_id, node_data):
async def node_function(state):
# Consume the asynchronous generator and return the last result
result = None
try:
async for item in node_functions[node_type](
state, node_id, node_data
):
result = item
return result
except Exception as e:
logger.error(f"Error in node {node_id} ({node_type}): {str(e)}")
# Return error state
return {
"content": [
Event(
author=f"workflow-node:{node_id}",
content=Content(parts=[Part(text=f"Node error: {str(e)}")]),
)
],
"status": "node_error",
"node_outputs": state.get("node_outputs", {}),
"cycle_count": state.get("cycle_count", 0),
"conversation_history": state.get("conversation_history", []),
"session_id": state.get("session_id", ""),
}
return node_function
return node_function
# Add specific function to the dictionary
node_specific_functions[node_id] = create_node_function(
node_type, node_id, node_data
)
# Add specific function to the dictionary
node_specific_functions[node_id] = create_node_function(
node_type, node_id, node_data
)
# Add node to the graph
print(f"Adding node {node_id} of type {node_type}")
graph_builder.add_node(node_id, node_specific_functions[node_id])
# Add node to the graph
logger.debug(f"Adding node {node_id} of type {node_type}")
graph_builder.add_node(node_id, node_specific_functions[node_id])
# Create function to generate specific routers
create_router = self._create_flow_router(flow_data)
@ -808,8 +907,8 @@ class WorkflowAgent(BaseAgent):
node_router = create_router(node_id)
# Add conditional connections
print(f"Adding conditional connections for node {node_id}")
print(f"Possible destinations: {edge_destinations}")
logger.debug(f"Adding conditional connections for node {node_id}")
logger.debug(f"Possible destinations: {list(edge_destinations.keys())}")
graph_builder.add_conditional_edges(
node_id, node_router, edge_destinations
@ -825,35 +924,56 @@ class WorkflowAgent(BaseAgent):
# If there is no start-node, use the first node found
if not entry_point and nodes:
entry_point = nodes[0].get("id")
logger.warning(f"No start-node found, using first node as entry point: {entry_point}")
# Define the entry point
if entry_point:
print(f"Defining entry point: {entry_point}")
logger.info(f"Setting entry point: {entry_point}")
graph_builder.set_entry_point(entry_point)
else:
raise ValueError("No valid entry point found for workflow")
# Compile the graph
return graph_builder.compile()
try:
compiled_graph = graph_builder.compile()
logger.info("Workflow graph compiled successfully")
return compiled_graph
except Exception as e:
logger.error(f"Error compiling workflow graph: {str(e)}")
raise ValueError(f"Error compiling workflow graph: {str(e)}")
async def _run_async_impl(
self, ctx: InvocationContext
) -> AsyncGenerator[Event, None]:
"""Implementation of the workflow agent executing the defined workflow and returning results."""
if hasattr(self, 'model') and self.model:
logger.error(f"Workflow agent {self.name} should not have a model configured")
raise ValueError(f"Workflow agent {self.name} is an orchestrator and should not have a model. Models should be configured on sub-agents.")
try:
logger.info(f"Starting workflow execution for agent: {self.name}")
logger.debug(f"Context session ID: {ctx.session.id if ctx.session else 'No session'}")
user_message = await self._extract_user_message(ctx)
session_id = self._get_session_id(ctx)
if not self.flow_json:
raise ValueError("Workflow agent has no flow_json configured")
graph = await self._create_graph(ctx, self.flow_json)
initial_state = await self._prepare_initial_state(
ctx, user_message, session_id
)
print("\n🚀 Starting workflow execution:")
print(f"Initial content: {user_message[:100]}...")
logger.info(f"🚀 Starting workflow execution with initial message: {user_message[:100]}...")
# Iterar sobre o AsyncGenerator em vez de usar await
async for event in self._execute_workflow(ctx, graph, initial_state):
yield event
except Exception as e:
logger.error(f"Error in workflow execution: {str(e)}", exc_info=True)
yield await self._handle_workflow_error(e)
async def _extract_user_message(self, ctx: InvocationContext) -> str:
@ -861,24 +981,36 @@ class WorkflowAgent(BaseAgent):
# Try to find message in session events
if ctx.session and hasattr(ctx.session, "events") and ctx.session.events:
for event in reversed(ctx.session.events):
if event.author == "user" and event.content and event.content.parts:
print("Message found in session events")
if (
hasattr(event, 'author') and
event.author == "user" and
hasattr(event, 'content') and
event.content and
hasattr(event.content, 'parts') and
event.content.parts
):
logger.debug("User message found in session events")
return event.content.parts[0].text
# Try to find message in session state
if ctx.session and ctx.session.state:
if ctx.session and hasattr(ctx.session, 'state') and ctx.session.state:
if "user_message" in ctx.session.state:
return ctx.session.state["user_message"]
elif "message" in ctx.session.state:
return ctx.session.state["message"]
return ""
logger.warning("No user message found in context")
return "No user message provided"
def _get_session_id(self, ctx: InvocationContext) -> str:
"""Gets or generates a session ID."""
if ctx.session and hasattr(ctx.session, "id"):
if ctx.session and hasattr(ctx.session, "id") and ctx.session.id:
return str(ctx.session.id)
return str(uuid.uuid4())
# Generate a new session ID
new_session_id = str(uuid.uuid4())
logger.debug(f"Generated new session ID: {new_session_id}")
return new_session_id
async def _prepare_initial_state(
self, ctx: InvocationContext, user_message: str, session_id: str
@ -889,9 +1021,13 @@ class WorkflowAgent(BaseAgent):
content=Content(parts=[Part(text=user_message)]),
)
conversation_history = ctx.session.events or [user_event]
conversation_history = []
if ctx.session and hasattr(ctx.session, 'events') and ctx.session.events:
conversation_history = ctx.session.events.copy()
else:
conversation_history = [user_event]
return State(
initial_state = State(
content=[user_event],
status="started",
session_id=session_id,
@ -899,34 +1035,61 @@ class WorkflowAgent(BaseAgent):
node_outputs={},
conversation_history=conversation_history,
)
logger.debug(f"Initial state prepared with {len(conversation_history)} history events")
return initial_state
async def _execute_workflow(
self, ctx: InvocationContext, graph: StateGraph, initial_state: State
) -> AsyncGenerator[Event, None]:
"""Executes the workflow graph and yields events."""
sent_events = 0
total_iterations = 0
max_iterations = 100
async for state in graph.astream(initial_state, {"recursion_limit": 100}):
for node_state in state.values():
content = node_state.get("content", [])
for event in content[sent_events:]:
if event.author != "user":
try:
async for state in graph.astream(initial_state, {"recursion_limit": max_iterations}):
total_iterations += 1
if total_iterations > max_iterations:
logger.warning(f"Maximum iterations ({max_iterations}) reached, stopping workflow")
break
for node_state in state.values():
content = node_state.get("content", [])
# Yield new events that haven't been sent yet
for event in content[sent_events:]:
if hasattr(event, 'author') and event.author != "user":
yield event
sent_events = len(content)
logger.info(f"Workflow completed after {total_iterations} iterations")
except Exception as e:
logger.error(f"Error during workflow execution: {str(e)}")
yield await self._handle_workflow_error(e)
if self.sub_agents:
logger.info(f"Executing {len(self.sub_agents)} sub-agents")
for sub_agent in self.sub_agents:
try:
async for event in sub_agent.run_async(ctx):
yield event
sent_events = len(content)
# Execute sub-agents if any
for sub_agent in self.sub_agents:
async for event in sub_agent.run_async(ctx):
yield event
except Exception as e:
logger.error(f"Error executing sub-agent {sub_agent.name}: {str(e)}")
yield await self._handle_workflow_error(e)
async def _handle_workflow_error(self, error: Exception) -> Event:
"""Creates an error event for workflow execution errors."""
error_msg = f"Error executing the workflow agent: {str(error)}"
print(error_msg)
error_msg = f"Error executing workflow agent '{self.name}': {str(error)}"
logger.error(error_msg)
return Event(
author=f"workflow-error:{self.name}",
content=Content(
role="agent",
parts=[Part(text=error_msg)],
),
)
)