Merge 1df83ea2dd
into 6f1d2745fd
This commit is contained in:
commit
61a7c082c3
@ -158,7 +158,7 @@ async def get_agent_messages(
|
||||
|
||||
processed_events = []
|
||||
for event in events:
|
||||
event_dict = event.dict()
|
||||
event_dict = event.model_dump()
|
||||
|
||||
def process_dict(d):
|
||||
if isinstance(d, dict):
|
||||
|
@ -27,7 +27,7 @@
|
||||
└──────────────────────────────────────────────────────────────────────────────┘
|
||||
"""
|
||||
|
||||
from pydantic import BaseModel, Field, validator, UUID4, ConfigDict
|
||||
from pydantic import BaseModel, Field, field_validator, model_validator, UUID4, ConfigDict
|
||||
from typing import Optional, Dict, Any, List
|
||||
from datetime import datetime
|
||||
from uuid import UUID
|
||||
@ -40,7 +40,8 @@ class ClientBase(BaseModel):
|
||||
name: str
|
||||
email: Optional[str] = None
|
||||
|
||||
@validator("email")
|
||||
@field_validator("email")
|
||||
@classmethod
|
||||
def validate_email(cls, v):
|
||||
if v is None:
|
||||
return v
|
||||
@ -58,8 +59,7 @@ class Client(ClientBase):
|
||||
id: UUID
|
||||
created_at: datetime
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
class ApiKeyBase(BaseModel):
|
||||
@ -101,7 +101,7 @@ class AgentBase(BaseModel):
|
||||
description="Agent type (llm, sequential, parallel, loop, a2a, workflow, task)",
|
||||
)
|
||||
model: Optional[str] = Field(
|
||||
None, description="Agent model (required only for llm type)"
|
||||
None, description="LLM model identifier (required for LLM agents only)"
|
||||
)
|
||||
api_key_id: Optional[UUID4] = Field(
|
||||
None, description="Reference to a stored API Key ID"
|
||||
@ -115,8 +115,13 @@ class AgentBase(BaseModel):
|
||||
)
|
||||
config: Any = Field(None, description="Agent configuration based on type")
|
||||
|
||||
@validator("name")
|
||||
def validate_name(cls, v, values):
|
||||
@field_validator("name")
|
||||
@classmethod
|
||||
def validate_name(cls, v, info):
|
||||
# Get values from validation context
|
||||
values = info.data if hasattr(info, 'data') else {}
|
||||
|
||||
# A2A agents can have optional names
|
||||
if values.get("type") == "a2a":
|
||||
return v
|
||||
|
||||
@ -127,107 +132,246 @@ class AgentBase(BaseModel):
|
||||
raise ValueError("Agent name cannot contain spaces or special characters")
|
||||
return v
|
||||
|
||||
@validator("type")
|
||||
@field_validator("type")
|
||||
@classmethod
|
||||
def validate_type(cls, v):
|
||||
if v not in [
|
||||
valid_types = [
|
||||
"llm",
|
||||
"sequential",
|
||||
"parallel",
|
||||
"loop",
|
||||
"a2a",
|
||||
"workflow",
|
||||
"task",
|
||||
]:
|
||||
"sequential",
|
||||
"parallel",
|
||||
"loop",
|
||||
"a2a",
|
||||
"workflow",
|
||||
"task"
|
||||
]
|
||||
if v not in valid_types:
|
||||
raise ValueError(
|
||||
"Invalid agent type. Must be: llm, sequential, parallel, loop, a2a, workflow or task"
|
||||
f"Invalid agent type '{v}'. Must be one of: {', '.join(valid_types)}"
|
||||
)
|
||||
return v
|
||||
|
||||
@validator("agent_card_url")
|
||||
def validate_agent_card_url(cls, v, values):
|
||||
if "type" in values and values["type"] == "a2a":
|
||||
@field_validator("agent_card_url")
|
||||
@classmethod
|
||||
def validate_agent_card_url(cls, v, info):
|
||||
values = info.data if hasattr(info, 'data') else {}
|
||||
|
||||
if values.get("type") == "a2a":
|
||||
if not v:
|
||||
raise ValueError("agent_card_url is required for a2a type agents")
|
||||
if not v.endswith("/.well-known/agent.json"):
|
||||
raise ValueError("agent_card_url must end with /.well-known/agent.json")
|
||||
return v
|
||||
|
||||
@validator("model")
|
||||
def validate_model(cls, v, values):
|
||||
if "type" in values and values["type"] == "llm" and not v:
|
||||
raise ValueError("Model is required for llm type agents")
|
||||
@field_validator("model")
|
||||
@classmethod
|
||||
def validate_model(cls, v, info):
|
||||
values = info.data if hasattr(info, 'data') else {}
|
||||
agent_type = values.get("type")
|
||||
|
||||
if agent_type == "llm":
|
||||
# Para agentes LLM, o modelo é obrigatório e não pode ser vazio
|
||||
if not v or (isinstance(v, str) and v.strip() == ""):
|
||||
raise ValueError(
|
||||
"LLM agents require a valid model configuration. "
|
||||
"Please specify a model identifier (e.g., 'gpt-4', 'claude-3-sonnet', 'gemini-pro')"
|
||||
)
|
||||
|
||||
# Verificar se o modelo tem um formato válido
|
||||
if isinstance(v, str) and len(v.strip()) < 3:
|
||||
raise ValueError("Model identifier must be at least 3 characters long")
|
||||
|
||||
elif agent_type in ["workflow", "task", "sequential", "parallel", "loop"]:
|
||||
# Para estes tipos, não devem ter modelo
|
||||
if v and (isinstance(v, str) and v.strip()):
|
||||
# Avisar mas permitir (será removido durante a criação)
|
||||
import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.warning(f"{agent_type} agents don't need model configuration. Model will be ignored.")
|
||||
|
||||
return v
|
||||
|
||||
@validator("api_key_id")
|
||||
def validate_api_key_id(cls, v, values):
|
||||
@field_validator("api_key_id")
|
||||
@classmethod
|
||||
def validate_api_key_id(cls, v, info):
|
||||
values = info.data if hasattr(info, 'data') else {}
|
||||
agent_type = values.get("type")
|
||||
|
||||
# API key é obrigatório para agentes LLM (a menos que esteja na config)
|
||||
if agent_type == "llm" and not v:
|
||||
# Verificar se tem API key na config
|
||||
config = values.get("config", {})
|
||||
if not config or not config.get("api_key"):
|
||||
# Não falhar aqui, deixar a validação para o momento da criação
|
||||
pass
|
||||
|
||||
return v
|
||||
|
||||
@validator("config")
|
||||
def validate_config(cls, v, values):
|
||||
if "type" in values and values["type"] == "a2a":
|
||||
@field_validator("config")
|
||||
@classmethod
|
||||
def validate_config(cls, v, info):
|
||||
values = info.data if hasattr(info, 'data') else {}
|
||||
agent_type = values.get("type")
|
||||
|
||||
if not agent_type:
|
||||
return v
|
||||
|
||||
# A2A agents têm config opcional
|
||||
if agent_type == "a2a":
|
||||
return v or {}
|
||||
|
||||
if "type" not in values:
|
||||
# Workflow agents têm config específico para workflow
|
||||
if agent_type == "workflow":
|
||||
if v and isinstance(v, dict):
|
||||
if not v.get("workflow"):
|
||||
raise ValueError("Workflow agents must have 'workflow' configuration")
|
||||
return v
|
||||
|
||||
# For workflow agents, we do not perform any validation
|
||||
if "type" in values and values["type"] == "workflow":
|
||||
return v
|
||||
|
||||
if not v and values.get("type") != "a2a":
|
||||
# Config é obrigatório para outros tipos (exceto a2a)
|
||||
if not v and agent_type not in ["a2a"]:
|
||||
raise ValueError(
|
||||
f"Configuration is required for {values.get('type')} agent type"
|
||||
f"Configuration is required for {agent_type} agent type"
|
||||
)
|
||||
|
||||
if values["type"] == "llm":
|
||||
if isinstance(v, dict):
|
||||
try:
|
||||
# Convert the dictionary to LLMConfig
|
||||
v = LLMConfig(**v)
|
||||
except Exception as e:
|
||||
raise ValueError(f"Invalid LLM configuration for agent: {str(e)}")
|
||||
elif not isinstance(v, LLMConfig):
|
||||
raise ValueError("Invalid LLM configuration for agent")
|
||||
elif values["type"] in ["sequential", "parallel", "loop"]:
|
||||
if not isinstance(v, dict):
|
||||
raise ValueError(f'Invalid configuration for agent {values["type"]}')
|
||||
if "sub_agents" not in v:
|
||||
raise ValueError(f'Agent {values["type"]} must have sub_agents')
|
||||
if not isinstance(v["sub_agents"], list):
|
||||
raise ValueError("sub_agents must be a list")
|
||||
if not v["sub_agents"]:
|
||||
raise ValueError(
|
||||
f'Agent {values["type"]} must have at least one sub-agent'
|
||||
)
|
||||
elif values["type"] == "task":
|
||||
if not isinstance(v, dict):
|
||||
raise ValueError(f'Invalid configuration for agent {values["type"]}')
|
||||
if "tasks" not in v:
|
||||
raise ValueError(f'Agent {values["type"]} must have tasks')
|
||||
if not isinstance(v["tasks"], list):
|
||||
raise ValueError("tasks must be a list")
|
||||
if not v["tasks"]:
|
||||
raise ValueError(f'Agent {values["type"]} must have at least one task')
|
||||
for task in v["tasks"]:
|
||||
if not isinstance(task, dict):
|
||||
raise ValueError("Each task must be a dictionary")
|
||||
required_fields = ["agent_id", "description", "expected_output"]
|
||||
for field in required_fields:
|
||||
if field not in task:
|
||||
raise ValueError(f"Task missing required field: {field}")
|
||||
|
||||
if "sub_agents" in v and v["sub_agents"] is not None:
|
||||
if not isinstance(v["sub_agents"], list):
|
||||
raise ValueError("sub_agents must be a list")
|
||||
|
||||
return v
|
||||
# Validação específica por tipo
|
||||
if agent_type == "llm":
|
||||
return cls._validate_llm_config(v)
|
||||
elif agent_type in ["sequential", "parallel", "loop"]:
|
||||
return cls._validate_composite_config(v, agent_type)
|
||||
elif agent_type == "task":
|
||||
return cls._validate_task_config(v)
|
||||
|
||||
return v
|
||||
|
||||
@classmethod
|
||||
def _validate_llm_config(cls, v):
|
||||
"""Valida configuração para agentes LLM"""
|
||||
if isinstance(v, dict):
|
||||
try:
|
||||
# Convert the dictionary to LLMConfig
|
||||
v = LLMConfig(**v)
|
||||
except Exception as e:
|
||||
raise ValueError(f"Invalid LLM configuration: {str(e)}")
|
||||
elif not isinstance(v, LLMConfig):
|
||||
raise ValueError("Invalid LLM configuration format")
|
||||
return v
|
||||
|
||||
@classmethod
|
||||
def _validate_composite_config(cls, v, agent_type):
|
||||
"""Valida configuração para agentes compostos (sequential, parallel, loop)"""
|
||||
if not isinstance(v, dict):
|
||||
raise ValueError(f'Configuration for {agent_type} agent must be a dictionary')
|
||||
|
||||
if "sub_agents" not in v:
|
||||
raise ValueError(f'{agent_type} agents must have sub_agents configuration')
|
||||
|
||||
if not isinstance(v["sub_agents"], list):
|
||||
raise ValueError("sub_agents must be a list")
|
||||
|
||||
if not v["sub_agents"]:
|
||||
raise ValueError(
|
||||
f'{agent_type} agents must have at least one sub-agent'
|
||||
)
|
||||
|
||||
# Validação específica para LoopAgent
|
||||
if agent_type == "loop":
|
||||
max_iterations = v.get("max_iterations", 5)
|
||||
if not isinstance(max_iterations, int) or max_iterations <= 0:
|
||||
raise ValueError("max_iterations must be a positive integer")
|
||||
|
||||
return v
|
||||
|
||||
@classmethod
|
||||
def _validate_task_config(cls, v):
|
||||
"""Valida configuração para agentes de task"""
|
||||
if not isinstance(v, dict):
|
||||
raise ValueError('Configuration for task agent must be a dictionary')
|
||||
|
||||
if "tasks" not in v:
|
||||
raise ValueError('Task agents must have tasks configuration')
|
||||
|
||||
if not isinstance(v["tasks"], list):
|
||||
raise ValueError("tasks must be a list")
|
||||
|
||||
if not v["tasks"]:
|
||||
raise ValueError('Task agents must have at least one task')
|
||||
|
||||
# Validar cada task individualmente
|
||||
for i, task in enumerate(v["tasks"]):
|
||||
if not isinstance(task, dict):
|
||||
raise ValueError(f"Task {i+1} must be a dictionary")
|
||||
|
||||
required_fields = ["agent_id", "description", "expected_output"]
|
||||
for field in required_fields:
|
||||
if field not in task:
|
||||
raise ValueError(f"Task {i+1} missing required field: {field}")
|
||||
|
||||
# Verificar se os campos não estão vazios
|
||||
if not task[field] or (isinstance(task[field], str) and not task[field].strip()):
|
||||
raise ValueError(f"Task {i+1} field '{field}' cannot be empty")
|
||||
|
||||
# Validar sub_agents se presente
|
||||
if "sub_agents" in v and v["sub_agents"] is not None:
|
||||
if not isinstance(v["sub_agents"], list):
|
||||
raise ValueError("sub_agents must be a list")
|
||||
|
||||
return v
|
||||
|
||||
@model_validator(mode='after')
|
||||
def validate_agent_consistency(self):
|
||||
"""Validação cruzada entre campos do agente"""
|
||||
|
||||
# Verificar consistência entre tipo e configurações
|
||||
if self.type == "llm":
|
||||
# LLM agents devem ter modelo
|
||||
if not self.model or (isinstance(self.model, str) and self.model.strip() == ""):
|
||||
raise ValueError("LLM agents must have a valid model")
|
||||
|
||||
# LLM agents devem ter API key (na config ou api_key_id)
|
||||
has_api_key = bool(self.api_key_id)
|
||||
if not has_api_key and self.config:
|
||||
config_dict = self.config if isinstance(self.config, dict) else self.config.__dict__
|
||||
has_api_key = bool(config_dict.get("api_key"))
|
||||
|
||||
if not has_api_key:
|
||||
raise ValueError("LLM agents must have an API key configured")
|
||||
|
||||
elif self.type in ["workflow", "task", "sequential", "parallel", "loop"]:
|
||||
# Orchestrator agents não devem ter modelo
|
||||
if self.model and isinstance(self.model, str) and self.model.strip():
|
||||
import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.warning(f"{self.type} agents don't need model configuration. Clearing model.")
|
||||
self.model = None
|
||||
|
||||
elif self.type == "a2a":
|
||||
# A2A agents devem ter agent_card_url
|
||||
if not self.agent_card_url:
|
||||
raise ValueError("A2A agents must have agent_card_url")
|
||||
|
||||
return self
|
||||
|
||||
|
||||
class AgentCreate(AgentBase):
|
||||
client_id: UUID
|
||||
|
||||
@model_validator(mode='after')
|
||||
def validate_creation_requirements(self):
|
||||
"""Validações específicas para criação de agentes"""
|
||||
|
||||
# Chamar validação da classe pai
|
||||
super().validate_agent_consistency()
|
||||
|
||||
# Validações específicas para criação
|
||||
if self.type == "llm":
|
||||
# Para criação, ser mais rigoroso com modelo
|
||||
if not self.model or len(self.model.strip()) < 3:
|
||||
raise ValueError(
|
||||
"LLM agents require a valid model identifier (minimum 3 characters). "
|
||||
"Examples: 'gpt-4', 'claude-3-sonnet', 'gemini-pro'"
|
||||
)
|
||||
|
||||
return self
|
||||
|
||||
|
||||
class Agent(AgentBase):
|
||||
id: UUID
|
||||
@ -237,17 +381,17 @@ class Agent(AgentBase):
|
||||
agent_card_url: Optional[str] = None
|
||||
folder_id: Optional[UUID4] = None
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
@validator("agent_card_url", pre=True)
|
||||
def set_agent_card_url(cls, v, values):
|
||||
@field_validator("agent_card_url", mode='before')
|
||||
@classmethod
|
||||
def set_agent_card_url(cls, v, info):
|
||||
if v:
|
||||
return v
|
||||
|
||||
values = info.data if hasattr(info, 'data') else {}
|
||||
if "id" in values:
|
||||
from os import getenv
|
||||
|
||||
return f"{getenv('API_URL', '')}/api/v1/a2a/{values['id']}/.well-known/agent.json"
|
||||
|
||||
return v
|
||||
@ -262,6 +406,7 @@ class ToolConfig(BaseModel):
|
||||
inputModes: List[str] = Field(default_factory=list)
|
||||
outputModes: List[str] = Field(default_factory=list)
|
||||
|
||||
|
||||
# Last edited by Arley Peter on 2025-05-17
|
||||
class MCPServerBase(BaseModel):
|
||||
name: str
|
||||
@ -272,6 +417,29 @@ class MCPServerBase(BaseModel):
|
||||
tools: Optional[List[ToolConfig]] = Field(default_factory=list)
|
||||
type: str = Field(default="official")
|
||||
|
||||
@field_validator("name")
|
||||
@classmethod
|
||||
def validate_name(cls, v):
|
||||
if not v or not v.strip():
|
||||
raise ValueError("MCP Server name cannot be empty")
|
||||
return v.strip()
|
||||
|
||||
@field_validator("config_type")
|
||||
@classmethod
|
||||
def validate_config_type(cls, v):
|
||||
valid_types = ["studio", "sse"]
|
||||
if v not in valid_types:
|
||||
raise ValueError(f"config_type must be one of: {valid_types}")
|
||||
return v
|
||||
|
||||
@field_validator("type")
|
||||
@classmethod
|
||||
def validate_type(cls, v):
|
||||
valid_types = ["official", "community"]
|
||||
if v not in valid_types:
|
||||
raise ValueError(f"type must be one of: {valid_types}")
|
||||
return v
|
||||
|
||||
|
||||
class MCPServerCreate(MCPServerBase):
|
||||
pass
|
||||
@ -282,8 +450,7 @@ class MCPServer(MCPServerBase):
|
||||
created_at: datetime
|
||||
updated_at: Optional[datetime] = None
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
class ToolBase(BaseModel):
|
||||
@ -292,6 +459,13 @@ class ToolBase(BaseModel):
|
||||
config_json: Dict[str, Any] = Field(default_factory=dict)
|
||||
environments: Dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
@field_validator("name")
|
||||
@classmethod
|
||||
def validate_name(cls, v):
|
||||
if not v or not v.strip():
|
||||
raise ValueError("Tool name cannot be empty")
|
||||
return v.strip()
|
||||
|
||||
|
||||
class ToolCreate(ToolBase):
|
||||
pass
|
||||
@ -302,14 +476,20 @@ class Tool(ToolBase):
|
||||
created_at: datetime
|
||||
updated_at: Optional[datetime] = None
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
class AgentFolderBase(BaseModel):
|
||||
name: str
|
||||
description: Optional[str] = None
|
||||
|
||||
@field_validator("name")
|
||||
@classmethod
|
||||
def validate_name(cls, v):
|
||||
if not v or not v.strip():
|
||||
raise ValueError("Folder name cannot be empty")
|
||||
return v.strip()
|
||||
|
||||
|
||||
class AgentFolderCreate(AgentFolderBase):
|
||||
client_id: UUID4
|
||||
@ -326,3 +506,86 @@ class AgentFolder(AgentFolderBase):
|
||||
updated_at: Optional[datetime] = None
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
class AgentTypeInfo(BaseModel):
|
||||
"""Informações sobre tipos de agente válidos"""
|
||||
type: str
|
||||
requires_model: bool
|
||||
requires_config: bool
|
||||
description: str
|
||||
|
||||
@classmethod
|
||||
def get_valid_types(cls) -> Dict[str, 'AgentTypeInfo']:
|
||||
"""Retorna informações sobre todos os tipos válidos de agente"""
|
||||
return {
|
||||
"llm": cls(
|
||||
type="llm",
|
||||
requires_model=True,
|
||||
requires_config=True,
|
||||
description="Large Language Model agent - requires model and API key"
|
||||
),
|
||||
"workflow": cls(
|
||||
type="workflow",
|
||||
requires_model=False,
|
||||
requires_config=True,
|
||||
description="Workflow orchestrator agent - uses LangGraph for complex flows"
|
||||
),
|
||||
"task": cls(
|
||||
type="task",
|
||||
requires_model=False,
|
||||
requires_config=True,
|
||||
description="Task management agent - coordinates multiple tasks"
|
||||
),
|
||||
"sequential": cls(
|
||||
type="sequential",
|
||||
requires_model=False,
|
||||
requires_config=True,
|
||||
description="Sequential execution agent - runs sub-agents in order"
|
||||
),
|
||||
"parallel": cls(
|
||||
type="parallel",
|
||||
requires_model=False,
|
||||
requires_config=True,
|
||||
description="Parallel execution agent - runs sub-agents concurrently"
|
||||
),
|
||||
"loop": cls(
|
||||
type="loop",
|
||||
requires_model=False,
|
||||
requires_config=True,
|
||||
description="Loop execution agent - repeats sub-agents with conditions"
|
||||
),
|
||||
"a2a": cls(
|
||||
type="a2a",
|
||||
requires_model=False,
|
||||
requires_config=False,
|
||||
description="Agent-to-Agent communication - external agent integration"
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
class ModelValidationResult(BaseModel):
|
||||
"""Resultado da validação de modelo"""
|
||||
is_valid: bool
|
||||
error_message: Optional[str] = None
|
||||
warnings: List[str] = Field(default_factory=list)
|
||||
|
||||
@classmethod
|
||||
def success(cls, warnings: List[str] = None) -> 'ModelValidationResult':
|
||||
return cls(is_valid=True, warnings=warnings or [])
|
||||
|
||||
@classmethod
|
||||
def failure(cls, error_message: str) -> 'ModelValidationResult':
|
||||
return cls(is_valid=False, error_message=error_message)
|
||||
|
||||
|
||||
class AgentValidationSummary(BaseModel):
|
||||
"""Resumo de validação de agente"""
|
||||
agent_id: Optional[UUID] = None
|
||||
agent_name: Optional[str] = None
|
||||
agent_type: str
|
||||
is_valid: bool
|
||||
model_validation: ModelValidationResult
|
||||
config_validation: ModelValidationResult
|
||||
general_errors: List[str] = Field(default_factory=list)
|
||||
warnings: List[str] = Field(default_factory=list)
|
@ -67,15 +67,39 @@ class AgentBuilder:
|
||||
if agent_tools_ids and isinstance(agent_tools_ids, list):
|
||||
for agent_tool_id in agent_tools_ids:
|
||||
sub_agent = get_agent(self.db, agent_tool_id)
|
||||
llm_agent, _ = await self.build_llm_agent(sub_agent)
|
||||
if llm_agent:
|
||||
agent_tools.append(AgentTool(agent=llm_agent))
|
||||
if sub_agent:
|
||||
# Verificar se o sub_agent é do tipo LLM antes de criar LlmAgent
|
||||
if sub_agent.type == "llm":
|
||||
llm_agent, _ = await self.build_llm_agent(sub_agent)
|
||||
if llm_agent:
|
||||
agent_tools.append(AgentTool(agent=llm_agent))
|
||||
else:
|
||||
logger.warning(f"Agent tool {agent_tool_id} is not of type 'llm', skipping")
|
||||
else:
|
||||
logger.warning(f"Agent tool {agent_tool_id} not found")
|
||||
return agent_tools
|
||||
|
||||
def _validate_llm_agent_model(self, agent) -> None:
|
||||
"""Validate that LLM agent has a proper model configuration."""
|
||||
if not hasattr(agent, 'model') or not agent.model:
|
||||
logger.error(f"LLM agent {agent.name} does not have a model configured")
|
||||
raise ValueError(f"LLM agent {agent.name} requires a model configuration")
|
||||
|
||||
if isinstance(agent.model, str) and agent.model.strip() == "":
|
||||
logger.error(f"LLM agent {agent.name} has an empty model string")
|
||||
raise ValueError(f"LLM agent {agent.name} has an empty model configuration")
|
||||
|
||||
logger.info(f"Model validation passed for agent {agent.name}: {agent.model}")
|
||||
|
||||
async def _create_llm_agent(
|
||||
self, agent, enabled_tools: List[str] = []
|
||||
) -> Tuple[LlmAgent, Optional[AsyncExitStack]]:
|
||||
"""Create an LLM agent from the agent data."""
|
||||
|
||||
self._validate_llm_agent_model(agent)
|
||||
|
||||
logger.info(f"Creating LLM agent: {agent.name} with model: {agent.model}")
|
||||
|
||||
# Get custom tools from the configuration
|
||||
custom_tools = []
|
||||
custom_tools = self.custom_tool_builder.build_tools(agent.config)
|
||||
@ -110,7 +134,7 @@ class AgentBuilder:
|
||||
current_day_of_week=current_day_of_week,
|
||||
current_date_iso=current_date_iso,
|
||||
current_time=current_time,
|
||||
)
|
||||
) if agent.instruction else ""
|
||||
|
||||
# add role on beginning of the prompt
|
||||
if agent.role:
|
||||
@ -170,21 +194,27 @@ class AgentBuilder:
|
||||
f"Agent {agent.name} does not have a configured API key"
|
||||
)
|
||||
|
||||
return (
|
||||
LlmAgent(
|
||||
if not agent.model or (isinstance(agent.model, str) and agent.model.strip() == ""):
|
||||
raise ValueError(f"Cannot create LiteLlm with empty model for agent {agent.name}")
|
||||
|
||||
try:
|
||||
llm_agent = LlmAgent(
|
||||
name=agent.name,
|
||||
model=LiteLlm(model=agent.model, api_key=api_key),
|
||||
instruction=formatted_prompt,
|
||||
description=agent.description,
|
||||
tools=all_tools,
|
||||
),
|
||||
mcp_exit_stack,
|
||||
)
|
||||
)
|
||||
logger.info(f"LLM agent created successfully: {agent.name}")
|
||||
return llm_agent, mcp_exit_stack
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating LLM agent {agent.name}: {str(e)}")
|
||||
raise ValueError(f"Error creating LLM agent {agent.name}: {str(e)}")
|
||||
|
||||
async def _get_sub_agents(
|
||||
self, sub_agent_ids: List[str]
|
||||
) -> List[Tuple[LlmAgent, Optional[AsyncExitStack]]]:
|
||||
"""Get and create LLM sub-agents."""
|
||||
) -> List[Tuple[BaseAgent, Optional[AsyncExitStack]]]:
|
||||
"""Get and create sub-agents with proper type validation."""
|
||||
sub_agents = []
|
||||
for sub_agent_id in sub_agent_ids:
|
||||
sub_agent_id_str = str(sub_agent_id)
|
||||
@ -197,39 +227,50 @@ class AgentBuilder:
|
||||
|
||||
logger.info(f"Sub-agent found: {agent.name} (type: {agent.type})")
|
||||
|
||||
if agent.type == "llm":
|
||||
sub_agent, exit_stack = await self._create_llm_agent(agent)
|
||||
elif agent.type == "a2a":
|
||||
sub_agent, exit_stack = await self.build_a2a_agent(agent)
|
||||
elif agent.type == "workflow":
|
||||
sub_agent, exit_stack = await self.build_workflow_agent(agent)
|
||||
elif agent.type == "task":
|
||||
sub_agent, exit_stack = await self.build_task_agent(agent)
|
||||
elif agent.type == "sequential":
|
||||
sub_agent, exit_stack = await self.build_composite_agent(agent)
|
||||
elif agent.type == "parallel":
|
||||
sub_agent, exit_stack = await self.build_composite_agent(agent)
|
||||
elif agent.type == "loop":
|
||||
sub_agent, exit_stack = await self.build_composite_agent(agent)
|
||||
else:
|
||||
raise ValueError(f"Invalid agent type: {agent.type}")
|
||||
try:
|
||||
if agent.type == "llm":
|
||||
# Verificar se tem modelo antes de criar
|
||||
if not agent.model or (isinstance(agent.model, str) and agent.model.strip() == ""):
|
||||
logger.error(f"LLM sub-agent {agent.name} does not have a model configured")
|
||||
raise ValueError(f"LLM sub-agent {agent.name} requires a model configuration")
|
||||
sub_agent, exit_stack = await self._create_llm_agent(agent)
|
||||
elif agent.type == "a2a":
|
||||
sub_agent, exit_stack = await self.build_a2a_agent(agent)
|
||||
elif agent.type == "workflow":
|
||||
# Workflow agents não precisam de modelo
|
||||
sub_agent, exit_stack = await self.build_workflow_agent(agent)
|
||||
elif agent.type == "task":
|
||||
sub_agent, exit_stack = await self.build_task_agent(agent)
|
||||
elif agent.type == "sequential":
|
||||
sub_agent, exit_stack = await self.build_composite_agent(agent)
|
||||
elif agent.type == "parallel":
|
||||
sub_agent, exit_stack = await self.build_composite_agent(agent)
|
||||
elif agent.type == "loop":
|
||||
sub_agent, exit_stack = await self.build_composite_agent(agent)
|
||||
else:
|
||||
raise ValueError(f"Invalid agent type: {agent.type}")
|
||||
|
||||
sub_agents.append(sub_agent)
|
||||
logger.info(f"Sub-agent added: {agent.name}")
|
||||
sub_agents.append((sub_agent, exit_stack))
|
||||
logger.info(f"Sub-agent added: {agent.name}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating sub-agent {agent.name}: {str(e)}")
|
||||
raise ValueError(f"Error creating sub-agent {agent.name}: {str(e)}")
|
||||
|
||||
logger.info(f"Sub-agents created: {len(sub_agents)}")
|
||||
logger.info(f"Sub-agents: {str(sub_agents)}")
|
||||
|
||||
return sub_agents
|
||||
|
||||
async def build_llm_agent(
|
||||
self, root_agent, enabled_tools: List[str] = []
|
||||
) -> Tuple[LlmAgent, Optional[AsyncExitStack]]:
|
||||
"""Build an LLM agent with its sub-agents."""
|
||||
logger.info("Creating LLM agent")
|
||||
logger.info(f"Creating LLM agent: {root_agent.name}")
|
||||
|
||||
if root_agent.type != "llm":
|
||||
raise ValueError(f"Expected LLM agent, got {root_agent.type}")
|
||||
|
||||
sub_agents = []
|
||||
if root_agent.config.get("sub_agents"):
|
||||
if root_agent.config and root_agent.config.get("sub_agents"):
|
||||
sub_agents_with_stacks = await self._get_sub_agents(
|
||||
root_agent.config.get("sub_agents")
|
||||
)
|
||||
@ -241,20 +282,21 @@ class AgentBuilder:
|
||||
if sub_agents:
|
||||
root_llm_agent.sub_agents = sub_agents
|
||||
|
||||
logger.info(f"LLM agent built successfully: {root_agent.name}")
|
||||
return root_llm_agent, exit_stack
|
||||
|
||||
async def build_a2a_agent(
|
||||
self, root_agent
|
||||
) -> Tuple[BaseAgent, Optional[AsyncExitStack]]:
|
||||
) -> Tuple[A2ACustomAgent, Optional[AsyncExitStack]]:
|
||||
"""Build an A2A agent with its sub-agents."""
|
||||
logger.info(f"Creating A2A agent from {root_agent.agent_card_url}")
|
||||
logger.info(f"Creating A2A agent from {root_agent.name}")
|
||||
|
||||
if not root_agent.agent_card_url:
|
||||
raise ValueError("agent_card_url is required for a2a agents")
|
||||
|
||||
try:
|
||||
sub_agents = []
|
||||
if root_agent.config.get("sub_agents"):
|
||||
if root_agent.config and root_agent.config.get("sub_agents"):
|
||||
sub_agents_with_stacks = await self._get_sub_agents(
|
||||
root_agent.config.get("sub_agents")
|
||||
)
|
||||
@ -288,6 +330,9 @@ class AgentBuilder:
|
||||
"""Build a workflow agent with its sub-agents."""
|
||||
logger.info(f"Creating Workflow agent from {root_agent.name}")
|
||||
|
||||
if root_agent.type != "workflow":
|
||||
raise ValueError(f"Expected workflow agent, got {root_agent.type}")
|
||||
|
||||
agent_config = root_agent.config or {}
|
||||
|
||||
if not agent_config.get("workflow"):
|
||||
@ -295,7 +340,7 @@ class AgentBuilder:
|
||||
|
||||
try:
|
||||
sub_agents = []
|
||||
if root_agent.config.get("sub_agents"):
|
||||
if root_agent.config and root_agent.config.get("sub_agents"):
|
||||
sub_agents_with_stacks = await self._get_sub_agents(
|
||||
root_agent.config.get("sub_agents")
|
||||
)
|
||||
@ -304,15 +349,20 @@ class AgentBuilder:
|
||||
config = root_agent.config or {}
|
||||
timeout = config.get("timeout", 300)
|
||||
|
||||
workflow_agent = WorkflowAgent(
|
||||
name=root_agent.name,
|
||||
flow_json=agent_config.get("workflow"),
|
||||
timeout=timeout,
|
||||
description=root_agent.description
|
||||
or f"Workflow Agent for {root_agent.name}",
|
||||
sub_agents=sub_agents,
|
||||
db=self.db,
|
||||
)
|
||||
kwargs = {
|
||||
"name": root_agent.name,
|
||||
"flow_json": agent_config.get("workflow"),
|
||||
"timeout": timeout,
|
||||
"description": root_agent.description or f"Workflow Agent for {root_agent.name}",
|
||||
"sub_agents": sub_agents,
|
||||
"db": self.db,
|
||||
}
|
||||
|
||||
# Se o root_agent tiver modelo, não passá-lo para o WorkflowAgent
|
||||
if hasattr(root_agent, 'model') and root_agent.model:
|
||||
logger.warning(f"Workflow agent {root_agent.name} has model '{root_agent.model}' configured, but workflow agents should not have models. Ignoring model.")
|
||||
|
||||
workflow_agent = WorkflowAgent(**kwargs)
|
||||
|
||||
logger.info(f"Workflow agent created successfully: {root_agent.name}")
|
||||
|
||||
@ -328,6 +378,9 @@ class AgentBuilder:
|
||||
"""Build a task agent with its sub-agents."""
|
||||
logger.info(f"Creating Task agent: {root_agent.name}")
|
||||
|
||||
if root_agent.type != "task":
|
||||
raise ValueError(f"Expected task agent, got {root_agent.type}")
|
||||
|
||||
agent_config = root_agent.config or {}
|
||||
|
||||
if not agent_config.get("tasks"):
|
||||
@ -336,7 +389,7 @@ class AgentBuilder:
|
||||
try:
|
||||
# Get sub-agents if there are any
|
||||
sub_agents = []
|
||||
if root_agent.config.get("sub_agents"):
|
||||
if root_agent.config and root_agent.config.get("sub_agents"):
|
||||
sub_agents_with_stacks = await self._get_sub_agents(
|
||||
root_agent.config.get("sub_agents")
|
||||
)
|
||||
@ -380,7 +433,11 @@ class AgentBuilder:
|
||||
f"Processing sub-agents for agent {root_agent.type} (ID: {root_agent.id}, Name: {root_agent.name})"
|
||||
)
|
||||
|
||||
if not root_agent.config.get("sub_agents"):
|
||||
valid_composite_types = ["sequential", "parallel", "loop"]
|
||||
if root_agent.type not in valid_composite_types:
|
||||
raise ValueError(f"Expected composite agent type ({valid_composite_types}), got {root_agent.type}")
|
||||
|
||||
if not root_agent.config or not root_agent.config.get("sub_agents"):
|
||||
logger.error(
|
||||
f"Sub_agents configuration not found or empty for agent {root_agent.name}"
|
||||
)
|
||||
@ -401,39 +458,51 @@ class AgentBuilder:
|
||||
sub_agents = [agent for agent, _ in sub_agents_with_stacks]
|
||||
logger.info(f"Extracted sub-agents: {[agent.name for agent in sub_agents]}")
|
||||
|
||||
if root_agent.type == "sequential":
|
||||
logger.info(f"Creating SequentialAgent with {len(sub_agents)} sub-agents")
|
||||
return (
|
||||
SequentialAgent(
|
||||
name=root_agent.name,
|
||||
sub_agents=sub_agents,
|
||||
description=root_agent.config.get("description", ""),
|
||||
),
|
||||
None,
|
||||
)
|
||||
elif root_agent.type == "parallel":
|
||||
logger.info(f"Creating ParallelAgent with {len(sub_agents)} sub-agents")
|
||||
return (
|
||||
ParallelAgent(
|
||||
name=root_agent.name,
|
||||
sub_agents=sub_agents,
|
||||
description=root_agent.config.get("description", ""),
|
||||
),
|
||||
None,
|
||||
)
|
||||
elif root_agent.type == "loop":
|
||||
logger.info(f"Creating LoopAgent with {len(sub_agents)} sub-agents")
|
||||
return (
|
||||
LoopAgent(
|
||||
name=root_agent.name,
|
||||
sub_agents=sub_agents,
|
||||
description=root_agent.config.get("description", ""),
|
||||
max_iterations=root_agent.config.get("max_iterations", 5),
|
||||
),
|
||||
None,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Invalid agent type: {root_agent.type}")
|
||||
if not sub_agents:
|
||||
raise ValueError(f"No valid sub-agents found for {root_agent.type} agent {root_agent.name}")
|
||||
|
||||
try:
|
||||
if root_agent.type == "sequential":
|
||||
logger.info(f"Creating SequentialAgent with {len(sub_agents)} sub-agents")
|
||||
return (
|
||||
SequentialAgent(
|
||||
name=root_agent.name,
|
||||
sub_agents=sub_agents,
|
||||
description=root_agent.description or root_agent.config.get("description", ""),
|
||||
),
|
||||
None,
|
||||
)
|
||||
elif root_agent.type == "parallel":
|
||||
logger.info(f"Creating ParallelAgent with {len(sub_agents)} sub-agents")
|
||||
return (
|
||||
ParallelAgent(
|
||||
name=root_agent.name,
|
||||
sub_agents=sub_agents,
|
||||
description=root_agent.description or root_agent.config.get("description", ""),
|
||||
),
|
||||
None,
|
||||
)
|
||||
elif root_agent.type == "loop":
|
||||
logger.info(f"Creating LoopAgent with {len(sub_agents)} sub-agents")
|
||||
max_iterations = root_agent.config.get("max_iterations", 5)
|
||||
if max_iterations <= 0:
|
||||
logger.warning(f"Invalid max_iterations ({max_iterations}) for LoopAgent, using default 5")
|
||||
max_iterations = 5
|
||||
return (
|
||||
LoopAgent(
|
||||
name=root_agent.name,
|
||||
sub_agents=sub_agents,
|
||||
description=root_agent.description or root_agent.config.get("description", ""),
|
||||
max_iterations=max_iterations,
|
||||
),
|
||||
None,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Invalid composite agent type: {root_agent.type}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating {root_agent.type} agent {root_agent.name}: {str(e)}")
|
||||
raise ValueError(f"Error creating {root_agent.type} agent {root_agent.name}: {str(e)}")
|
||||
|
||||
async def build_agent(self, root_agent, enabled_tools: List[str] = []) -> Tuple[
|
||||
LlmAgent
|
||||
@ -446,13 +515,29 @@ class AgentBuilder:
|
||||
Optional[AsyncExitStack],
|
||||
]:
|
||||
"""Build the appropriate agent based on the type of the root agent."""
|
||||
if root_agent.type == "llm":
|
||||
return await self.build_llm_agent(root_agent, enabled_tools)
|
||||
elif root_agent.type == "a2a":
|
||||
return await self.build_a2a_agent(root_agent)
|
||||
elif root_agent.type == "workflow":
|
||||
return await self.build_workflow_agent(root_agent)
|
||||
elif root_agent.type == "task":
|
||||
return await self.build_task_agent(root_agent)
|
||||
else:
|
||||
return await self.build_composite_agent(root_agent)
|
||||
|
||||
if not root_agent:
|
||||
raise ValueError("root_agent cannot be None")
|
||||
|
||||
if not hasattr(root_agent, 'type') or not root_agent.type:
|
||||
raise ValueError("root_agent must have a valid type")
|
||||
|
||||
logger.info(f"Building agent: {root_agent.name} (type: {root_agent.type})")
|
||||
|
||||
try:
|
||||
if root_agent.type == "llm":
|
||||
return await self.build_llm_agent(root_agent, enabled_tools)
|
||||
elif root_agent.type == "a2a":
|
||||
return await self.build_a2a_agent(root_agent)
|
||||
elif root_agent.type == "workflow":
|
||||
return await self.build_workflow_agent(root_agent)
|
||||
elif root_agent.type == "task":
|
||||
return await self.build_task_agent(root_agent)
|
||||
elif root_agent.type in ["sequential", "parallel", "loop"]:
|
||||
return await self.build_composite_agent(root_agent)
|
||||
else:
|
||||
raise ValueError(f"Unknown agent type: {root_agent.type}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error building agent {root_agent.name}: {str(e)}")
|
||||
raise
|
@ -458,7 +458,7 @@ async def run_agent_stream(
|
||||
|
||||
async for event in events_async:
|
||||
try:
|
||||
event_dict = event.dict()
|
||||
event_dict = event.model_dump()
|
||||
event_dict = convert_sets(event_dict)
|
||||
|
||||
if "content" in event_dict and event_dict["content"]:
|
||||
|
@ -30,6 +30,7 @@
|
||||
└──────────────────────────────────────────────────────────────────────────────┘
|
||||
"""
|
||||
|
||||
|
||||
from datetime import datetime
|
||||
from google.adk.agents import BaseAgent
|
||||
from google.adk.agents.invocation_context import InvocationContext
|
||||
@ -40,11 +41,14 @@ from typing import AsyncGenerator, Dict, Any, List, TypedDict
|
||||
import uuid
|
||||
|
||||
from src.services.agent_service import get_agent
|
||||
from src.utils.logger import setup_logger
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from langgraph.graph import StateGraph, END
|
||||
|
||||
logger = setup_logger(__name__)
|
||||
|
||||
|
||||
class State(TypedDict):
|
||||
content: List[Event]
|
||||
@ -63,6 +67,9 @@ class WorkflowAgent(BaseAgent):
|
||||
|
||||
This agent allows defining and executing complex workflows between multiple agents
|
||||
using LangGraph for orchestration.
|
||||
|
||||
IMPORTANT: Workflow agents are orchestrators and should NOT have a model configured.
|
||||
They delegate to sub-agents that have their own models.
|
||||
"""
|
||||
|
||||
# Field declarations for Pydantic
|
||||
@ -89,6 +96,21 @@ class WorkflowAgent(BaseAgent):
|
||||
sub_agents: List of sub-agents to be executed after the workflow agent
|
||||
db: Session
|
||||
"""
|
||||
|
||||
# Workflow agents não devem ter modelos
|
||||
if 'model' in kwargs:
|
||||
logger.warning(f"Removing model from workflow agent {name}. Workflow agents should not have models.")
|
||||
del kwargs['model']
|
||||
|
||||
if not flow_json:
|
||||
raise ValueError(f"Workflow agent {name} requires flow_json configuration")
|
||||
|
||||
if not isinstance(flow_json, dict):
|
||||
raise ValueError(f"Workflow agent {name} flow_json must be a dictionary")
|
||||
|
||||
if not flow_json.get('nodes'):
|
||||
raise ValueError(f"Workflow agent {name} flow_json must contain nodes")
|
||||
|
||||
# Initialize base class
|
||||
super().__init__(
|
||||
name=name,
|
||||
@ -98,9 +120,13 @@ class WorkflowAgent(BaseAgent):
|
||||
db=db,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
if hasattr(self, 'model'):
|
||||
logger.warning(f"Workflow agent {name} had a model attribute. Removing it.")
|
||||
delattr(self, 'model')
|
||||
|
||||
print(
|
||||
f"Workflow agent initialized with {len(flow_json.get('nodes', []))} nodes"
|
||||
logger.info(
|
||||
f"Workflow agent '{name}' initialized with {len(flow_json.get('nodes', []))} nodes"
|
||||
)
|
||||
|
||||
async def _create_node_functions(self, ctx: InvocationContext):
|
||||
@ -112,11 +138,12 @@ class WorkflowAgent(BaseAgent):
|
||||
node_id: str,
|
||||
node_data: Dict[str, Any],
|
||||
) -> AsyncGenerator[State, None]:
|
||||
print("\n🏁 INITIAL NODE")
|
||||
logger.info(f"🏁 INITIAL NODE: {node_id}")
|
||||
|
||||
content = state.get("content", [])
|
||||
|
||||
if not content:
|
||||
logger.warning("No content found in initial state")
|
||||
content = [
|
||||
Event(
|
||||
author=f"workflow-node:{node_id}",
|
||||
@ -128,9 +155,11 @@ class WorkflowAgent(BaseAgent):
|
||||
"status": "error",
|
||||
"node_outputs": {},
|
||||
"cycle_count": 0,
|
||||
"conversation_history": ctx.session.events,
|
||||
"conversation_history": ctx.session.events if ctx.session else [],
|
||||
"session_id": state.get("session_id", ""),
|
||||
}
|
||||
return
|
||||
|
||||
session_id = state.get("session_id", "")
|
||||
|
||||
# Store specific results for this node
|
||||
@ -149,7 +178,7 @@ class WorkflowAgent(BaseAgent):
|
||||
"node_outputs": node_outputs,
|
||||
"cycle_count": 0,
|
||||
"session_id": session_id,
|
||||
"conversation_history": ctx.session.events,
|
||||
"conversation_history": ctx.session.events if ctx.session else [],
|
||||
}
|
||||
|
||||
# Generic function for agent nodes
|
||||
@ -163,7 +192,7 @@ class WorkflowAgent(BaseAgent):
|
||||
|
||||
# Increment cycle counter
|
||||
cycle_count = state.get("cycle_count", 0) + 1
|
||||
print(f"\n👤 AGENT: {agent_name} (Cycle {cycle_count})")
|
||||
logger.info(f"👤 AGENT: {agent_name} (Cycle {cycle_count})")
|
||||
|
||||
content = state.get("content", [])
|
||||
session_id = state.get("session_id", "")
|
||||
@ -171,14 +200,13 @@ class WorkflowAgent(BaseAgent):
|
||||
# Get conversation history
|
||||
conversation_history = state.get("conversation_history", [])
|
||||
|
||||
agent = get_agent(self.db, agent_id)
|
||||
|
||||
if not agent:
|
||||
if not agent_id:
|
||||
logger.error(f"Agent node {node_id} does not have a valid agent_id")
|
||||
yield {
|
||||
"content": [
|
||||
Event(
|
||||
author=f"workflow-node:{node_id}",
|
||||
content=Content(parts=[Part(text="Agent not found")]),
|
||||
content=Content(parts=[Part(text="Agent ID not configured")]),
|
||||
)
|
||||
],
|
||||
"session_id": session_id,
|
||||
@ -189,44 +217,84 @@ class WorkflowAgent(BaseAgent):
|
||||
}
|
||||
return
|
||||
|
||||
# Import moved to inside the function to avoid circular import
|
||||
from src.services.adk.agent_builder import AgentBuilder
|
||||
agent = get_agent(self.db, agent_id)
|
||||
|
||||
agent_builder = AgentBuilder(self.db)
|
||||
root_agent, exit_stack = await agent_builder.build_agent(agent)
|
||||
if not agent:
|
||||
logger.error(f"Agent not found for ID: {agent_id}")
|
||||
yield {
|
||||
"content": [
|
||||
Event(
|
||||
author=f"workflow-node:{node_id}",
|
||||
content=Content(parts=[Part(text=f"Agent not found: {agent_id}")]),
|
||||
)
|
||||
],
|
||||
"session_id": session_id,
|
||||
"status": "error",
|
||||
"node_outputs": {},
|
||||
"cycle_count": cycle_count,
|
||||
"conversation_history": conversation_history,
|
||||
}
|
||||
return
|
||||
|
||||
new_content = []
|
||||
async for event in root_agent.run_async(ctx):
|
||||
conversation_history.append(event)
|
||||
|
||||
modified_event = Event(
|
||||
author=f"workflow-node:{node_id}", content=event.content
|
||||
)
|
||||
new_content.append(modified_event)
|
||||
try:
|
||||
# Import moved to inside the function to avoid circular import
|
||||
from src.services.adk.agent_builder import AgentBuilder
|
||||
|
||||
agent_builder = AgentBuilder(self.db)
|
||||
root_agent, exit_stack = await agent_builder.build_agent(agent)
|
||||
|
||||
print(f"New content: {new_content}")
|
||||
new_content = []
|
||||
async for event in root_agent.run_async(ctx):
|
||||
conversation_history.append(event)
|
||||
|
||||
modified_event = Event(
|
||||
author=f"workflow-node:{node_id}", content=event.content
|
||||
)
|
||||
new_content.append(modified_event)
|
||||
|
||||
node_outputs = state.get("node_outputs", {})
|
||||
node_outputs[node_id] = {
|
||||
"processed_by": agent_name,
|
||||
"agent_content": new_content,
|
||||
"cycle": cycle_count,
|
||||
}
|
||||
logger.debug(f"Agent {agent_name} generated {len(new_content)} events")
|
||||
|
||||
content = content + new_content
|
||||
node_outputs = state.get("node_outputs", {})
|
||||
node_outputs[node_id] = {
|
||||
"processed_by": agent_name,
|
||||
"agent_id": agent_id,
|
||||
"agent_content": new_content,
|
||||
"cycle": cycle_count,
|
||||
"processed_at": datetime.now().isoformat(),
|
||||
}
|
||||
|
||||
yield {
|
||||
"content": content,
|
||||
"status": "processed_by_agent",
|
||||
"node_outputs": node_outputs,
|
||||
"cycle_count": cycle_count,
|
||||
"conversation_history": conversation_history,
|
||||
"session_id": session_id,
|
||||
}
|
||||
content = content + new_content
|
||||
|
||||
if exit_stack:
|
||||
await exit_stack.aclose()
|
||||
yield {
|
||||
"content": content,
|
||||
"status": "processed_by_agent",
|
||||
"node_outputs": node_outputs,
|
||||
"cycle_count": cycle_count,
|
||||
"conversation_history": conversation_history,
|
||||
"session_id": session_id,
|
||||
}
|
||||
|
||||
if exit_stack:
|
||||
try:
|
||||
await exit_stack.aclose()
|
||||
except Exception as e:
|
||||
logger.warning(f"Error closing exit stack for agent {agent_name}: {e}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error executing agent {agent_name}: {str(e)}")
|
||||
yield {
|
||||
"content": [
|
||||
Event(
|
||||
author=f"workflow-node:{node_id}",
|
||||
content=Content(parts=[Part(text=f"Error executing agent: {str(e)}")]),
|
||||
)
|
||||
],
|
||||
"session_id": session_id,
|
||||
"status": "agent_error",
|
||||
"node_outputs": state.get("node_outputs", {}),
|
||||
"cycle_count": cycle_count,
|
||||
"conversation_history": conversation_history,
|
||||
}
|
||||
|
||||
# Function for condition nodes
|
||||
async def condition_node_function(
|
||||
@ -236,7 +304,7 @@ class WorkflowAgent(BaseAgent):
|
||||
conditions = node_data.get("conditions", [])
|
||||
cycle_count = state.get("cycle_count", 0)
|
||||
|
||||
print(f"\n🔄 CONDITION: {label} (Cycle {cycle_count})")
|
||||
logger.info(f"🔄 CONDITION: {label} (Cycle {cycle_count})")
|
||||
|
||||
content = state.get("content", [])
|
||||
conversation_history = state.get("conversation_history", [])
|
||||
@ -245,16 +313,17 @@ class WorkflowAgent(BaseAgent):
|
||||
if content and len(content) > 0:
|
||||
for event in reversed(content):
|
||||
if (
|
||||
event.author != "agent"
|
||||
or not hasattr(event.content, "parts")
|
||||
or not event.content.parts
|
||||
hasattr(event, 'author') and
|
||||
event.author != "user" and
|
||||
hasattr(event, 'content') and
|
||||
hasattr(event.content, "parts") and
|
||||
event.content.parts
|
||||
):
|
||||
latest_event = event
|
||||
break
|
||||
|
||||
if latest_event:
|
||||
print(
|
||||
f"Evaluating condition only for the most recent event: '{latest_event}'"
|
||||
)
|
||||
logger.debug(f"Evaluating condition for latest event from: {latest_event.author}")
|
||||
|
||||
# Use only the most recent event for condition evaluation
|
||||
evaluation_state = state.copy()
|
||||
@ -273,25 +342,24 @@ class WorkflowAgent(BaseAgent):
|
||||
operator = condition_data.get("operator")
|
||||
expected_value = condition_data.get("value")
|
||||
|
||||
print(
|
||||
f" Checking if {field} {operator} '{expected_value}' (current value: '{evaluation_state.get(field, '')}')"
|
||||
logger.debug(
|
||||
f"Checking condition: {field} {operator} '{expected_value}'"
|
||||
)
|
||||
|
||||
if self._evaluate_condition(condition, evaluation_state):
|
||||
conditions_met.append(condition_id)
|
||||
condition_details.append(
|
||||
f"{field} {operator} '{expected_value}' ✅"
|
||||
)
|
||||
print(f" ✅ Condition {condition_id} met!")
|
||||
logger.info(f"✅ Condition {condition_id} met!")
|
||||
else:
|
||||
condition_details.append(
|
||||
f"{field} {operator} '{expected_value}' ❌"
|
||||
)
|
||||
|
||||
# Check if the cycle reached the limit (extra security)
|
||||
if cycle_count >= 10:
|
||||
print(
|
||||
f"⚠️ ATTENTION: Cycle limit reached ({cycle_count}). Forcing termination."
|
||||
)
|
||||
max_cycles = 10 # Poderia vir da configuração
|
||||
if cycle_count >= max_cycles:
|
||||
logger.warning(f"Cycle limit reached ({cycle_count}). Forcing termination.")
|
||||
|
||||
condition_content = [
|
||||
Event(
|
||||
@ -314,10 +382,10 @@ class WorkflowAgent(BaseAgent):
|
||||
node_outputs = state.get("node_outputs", {})
|
||||
node_outputs[node_id] = {
|
||||
"condition_evaluated": label,
|
||||
"content_evaluated": content,
|
||||
"conditions_met": conditions_met,
|
||||
"condition_details": condition_details,
|
||||
"cycle": cycle_count,
|
||||
"evaluated_at": datetime.now().isoformat(),
|
||||
}
|
||||
|
||||
# Prepare a more descriptive message about the conditions
|
||||
@ -334,7 +402,8 @@ class WorkflowAgent(BaseAgent):
|
||||
)
|
||||
]
|
||||
),
|
||||
) ]
|
||||
)
|
||||
]
|
||||
content = content + condition_content
|
||||
|
||||
yield {
|
||||
@ -353,7 +422,7 @@ class WorkflowAgent(BaseAgent):
|
||||
message_type = message_data.get("type", "text")
|
||||
message_content = message_data.get("content", "")
|
||||
|
||||
print(f"\n💬 MESSAGE-NODE: {message_content}")
|
||||
logger.info(f"💬 MESSAGE-NODE: {message_content}")
|
||||
|
||||
content = state.get("content", [])
|
||||
session_id = state.get("session_id", "")
|
||||
@ -371,6 +440,8 @@ class WorkflowAgent(BaseAgent):
|
||||
node_outputs[node_id] = {
|
||||
"message_type": message_type,
|
||||
"message_content": message_content,
|
||||
"label": label,
|
||||
"processed_at": datetime.now().isoformat(),
|
||||
}
|
||||
|
||||
yield {
|
||||
@ -378,7 +449,8 @@ class WorkflowAgent(BaseAgent):
|
||||
"status": "message_added",
|
||||
"node_outputs": node_outputs,
|
||||
"cycle_count": state.get("cycle_count", 0),
|
||||
"conversation_history": conversation_history, "session_id": session_id,
|
||||
"conversation_history": conversation_history,
|
||||
"session_id": session_id,
|
||||
}
|
||||
|
||||
async def delay_node_function(
|
||||
@ -389,6 +461,10 @@ class WorkflowAgent(BaseAgent):
|
||||
delay_unit = delay_data.get("unit", "seconds")
|
||||
delay_description = delay_data.get("description", "")
|
||||
|
||||
if delay_value <= 0:
|
||||
logger.warning(f"Invalid delay value: {delay_value}. Using 1 second.")
|
||||
delay_value = 1
|
||||
|
||||
# Convert to seconds based on unit
|
||||
delay_seconds = delay_value
|
||||
if delay_unit == "minutes":
|
||||
@ -397,7 +473,7 @@ class WorkflowAgent(BaseAgent):
|
||||
delay_seconds = delay_value * 3600
|
||||
|
||||
label = node_data.get("label", "delay_node")
|
||||
print(f"\n⏱️ DELAY-NODE: {delay_value} {delay_unit} - {delay_description}")
|
||||
logger.info(f"⏱️ DELAY-NODE: {delay_value} {delay_unit} ({delay_seconds}s) - {delay_description}")
|
||||
|
||||
content = state.get("content", [])
|
||||
session_id = state.get("session_id", "")
|
||||
@ -409,13 +485,17 @@ class WorkflowAgent(BaseAgent):
|
||||
"delay_value": delay_value,
|
||||
"delay_unit": delay_unit,
|
||||
"delay_seconds": delay_seconds,
|
||||
"delay_description": delay_description,
|
||||
"delay_start_time": datetime.now().isoformat(),
|
||||
}
|
||||
|
||||
# Actually perform the delay
|
||||
import asyncio
|
||||
await asyncio.sleep(delay_seconds)
|
||||
|
||||
try:
|
||||
await asyncio.sleep(delay_seconds)
|
||||
except asyncio.CancelledError:
|
||||
logger.warning(f"Delay in node {node_id} was cancelled")
|
||||
# Continue execution even if delay was cancelled
|
||||
|
||||
# Update node outputs with completion information
|
||||
node_outputs[node_id]["delay_end_time"] = datetime.now().isoformat()
|
||||
@ -424,7 +504,8 @@ class WorkflowAgent(BaseAgent):
|
||||
yield {
|
||||
"content": content,
|
||||
"status": "delay_completed",
|
||||
"node_outputs": node_outputs, "cycle_count": state.get("cycle_count", 0),
|
||||
"node_outputs": node_outputs,
|
||||
"cycle_count": state.get("cycle_count", 0),
|
||||
"conversation_history": conversation_history,
|
||||
"session_id": session_id,
|
||||
}
|
||||
@ -452,7 +533,7 @@ class WorkflowAgent(BaseAgent):
|
||||
|
||||
result = self._process_condition(operator, actual_value, expected_value)
|
||||
|
||||
print(f" Check '{operator}': {result}")
|
||||
logger.debug(f"Condition check '{operator}': {result}")
|
||||
return result
|
||||
|
||||
return False
|
||||
@ -488,7 +569,7 @@ class WorkflowAgent(BaseAgent):
|
||||
|
||||
if extracted_texts:
|
||||
joined_text = " ".join(extracted_texts)
|
||||
print(f" Extracted text from events: '{joined_text[:100]}...'")
|
||||
logger.debug(f"Extracted text from events: '{joined_text[:100]}...'")
|
||||
return joined_text
|
||||
|
||||
return ""
|
||||
@ -524,6 +605,7 @@ class WorkflowAgent(BaseAgent):
|
||||
elif operator in ["matches", "not_matches"]:
|
||||
return self._check_regex(operator, actual_str, expected_str)
|
||||
|
||||
logger.warning(f"Unknown operator: {operator}")
|
||||
return False
|
||||
|
||||
def _check_definition(self, operator, actual_value):
|
||||
@ -563,8 +645,8 @@ class WorkflowAgent(BaseAgent):
|
||||
else: # less_than_or_equal
|
||||
return actual_num <= expected_num
|
||||
except (ValueError, TypeError):
|
||||
print(
|
||||
f" Error converting values for numeric comparison: '{actual_str[:100]}...' and '{expected_str}'"
|
||||
logger.warning(
|
||||
f"Error converting values for numeric comparison: '{actual_str[:100]}...' and '{expected_str}'"
|
||||
)
|
||||
return False
|
||||
|
||||
@ -579,7 +661,7 @@ class WorkflowAgent(BaseAgent):
|
||||
else: # not_matches
|
||||
return not bool(pattern.search(actual_str))
|
||||
except re.error:
|
||||
print(f" Error in regular expression: '{expected_str}'")
|
||||
logger.warning(f"Error in regular expression: '{expected_str}'")
|
||||
return (
|
||||
operator == "not_matches"
|
||||
) # Return True for not_matches, False for matches
|
||||
@ -589,8 +671,8 @@ class WorkflowAgent(BaseAgent):
|
||||
expected_lower = expected_str.lower()
|
||||
actual_lower = actual_str.lower()
|
||||
|
||||
print(
|
||||
f" Comparison '{operator}' without case distinction: '{expected_lower}' in '{actual_lower[:100]}...'"
|
||||
logger.debug(
|
||||
f"Comparison '{operator}' case insensitive: '{expected_lower}' in '{actual_lower[:100]}...'"
|
||||
)
|
||||
|
||||
if operator == "contains":
|
||||
@ -627,14 +709,13 @@ class WorkflowAgent(BaseAgent):
|
||||
# Routing function for each specific node
|
||||
def create_router_for_node(node_id: str):
|
||||
def router(state: State) -> str:
|
||||
print(f"Routing from node: {node_id}")
|
||||
logger.debug(f"Routing from node: {node_id}")
|
||||
|
||||
# Check if the cycle limit has been reached
|
||||
cycle_count = state.get("cycle_count", 0)
|
||||
if cycle_count >= 10:
|
||||
print(
|
||||
f"⚠️ Cycle limit ({cycle_count}) reached. Finalizing the flow."
|
||||
)
|
||||
max_cycles = 10 # Configurável
|
||||
if cycle_count >= max_cycles:
|
||||
logger.warning(f"Cycle limit ({cycle_count}) reached. Finalizing the flow.")
|
||||
return END
|
||||
|
||||
# If it's a condition node, evaluate the conditions
|
||||
@ -648,32 +729,29 @@ class WorkflowAgent(BaseAgent):
|
||||
if conditions_met:
|
||||
any_condition_met = True
|
||||
condition_id = conditions_met[0]
|
||||
print(
|
||||
f"Using stored condition evaluation result: Condition {condition_id} met."
|
||||
)
|
||||
logger.debug(f"Using stored condition result: Condition {condition_id} met.")
|
||||
if (
|
||||
node_id in edges_map
|
||||
and condition_id in edges_map[node_id]
|
||||
):
|
||||
return edges_map[node_id][condition_id]
|
||||
else:
|
||||
print(
|
||||
"Using stored condition evaluation result: No conditions met."
|
||||
)
|
||||
logger.debug("Using stored condition result: No conditions met.")
|
||||
else:
|
||||
# Evaluate conditions
|
||||
for condition in conditions:
|
||||
condition_id = condition.get("id")
|
||||
|
||||
# Get latest event for evaluation, ignoring condition node informational events
|
||||
content = state.get("content", [])
|
||||
|
||||
# Filter out events generated by condition nodes or informational messages
|
||||
# Filter out events generated by condition nodes or that contain evaluation results
|
||||
filtered_content = []
|
||||
for event in content:
|
||||
# Ignore events from condition nodes or that contain evaluation results
|
||||
if not hasattr(event, "author") or not (
|
||||
event.author.startswith("Condition")
|
||||
or "Condition evaluated:" in str(event)
|
||||
event.author.startswith("workflow-node:") and
|
||||
"Condition evaluated:" in str(event)
|
||||
):
|
||||
filtered_content.append(event)
|
||||
|
||||
@ -687,9 +765,7 @@ class WorkflowAgent(BaseAgent):
|
||||
|
||||
if is_condition_met:
|
||||
any_condition_met = True
|
||||
print(
|
||||
f"Condition {condition_id} met. Moving to the next node."
|
||||
)
|
||||
logger.debug(f"Condition {condition_id} met. Moving to next node.")
|
||||
|
||||
# Find the connection that uses this condition_id as a handle
|
||||
if (
|
||||
@ -698,9 +774,7 @@ class WorkflowAgent(BaseAgent):
|
||||
):
|
||||
return edges_map[node_id][condition_id]
|
||||
else:
|
||||
print(
|
||||
f"Condition {condition_id} not met. Continuing evaluation or using default path."
|
||||
)
|
||||
logger.debug(f"Condition {condition_id} not met.")
|
||||
|
||||
# If no condition is met, use the bottom-handle if available
|
||||
if not any_condition_met:
|
||||
@ -708,14 +782,10 @@ class WorkflowAgent(BaseAgent):
|
||||
node_id in edges_map
|
||||
and "bottom-handle" in edges_map[node_id]
|
||||
):
|
||||
print(
|
||||
"No condition met. Using default path (bottom-handle)."
|
||||
)
|
||||
logger.debug("No condition met. Using default path (bottom-handle).")
|
||||
return edges_map[node_id]["bottom-handle"]
|
||||
else:
|
||||
print(
|
||||
"No condition met and no default path. Closing the flow."
|
||||
)
|
||||
logger.debug("No condition met and no default path. Closing the flow.")
|
||||
return END
|
||||
|
||||
# For regular nodes, simply follow the first available connection
|
||||
@ -731,7 +801,7 @@ class WorkflowAgent(BaseAgent):
|
||||
return edges_map[node_id][first_handle]
|
||||
|
||||
# If there is no output connection, close the flow
|
||||
print(f"No output connection from node {node_id}. Closing the flow.")
|
||||
logger.debug(f"No output connection from node {node_id}. Closing the flow.")
|
||||
return END
|
||||
|
||||
return router
|
||||
@ -745,6 +815,9 @@ class WorkflowAgent(BaseAgent):
|
||||
# Extract nodes from the flow
|
||||
nodes = flow_data.get("nodes", [])
|
||||
|
||||
if not nodes:
|
||||
raise ValueError("Flow data must contain at least one node")
|
||||
|
||||
# Initialize StateGraph
|
||||
graph_builder = StateGraph(State)
|
||||
|
||||
@ -754,34 +827,60 @@ class WorkflowAgent(BaseAgent):
|
||||
# Dictionary to store specific functions for each node
|
||||
node_specific_functions = {}
|
||||
|
||||
valid_node_types = set(node_functions.keys())
|
||||
|
||||
# Add nodes to the graph
|
||||
for node in nodes:
|
||||
node_id = node.get("id")
|
||||
node_type = node.get("type")
|
||||
node_data = node.get("data", {})
|
||||
|
||||
if node_type in node_functions:
|
||||
# Create a specific function for this node
|
||||
def create_node_function(node_type, node_id, node_data):
|
||||
async def node_function(state):
|
||||
# Consume the asynchronous generator and return the last result
|
||||
result = None
|
||||
if not node_id:
|
||||
logger.warning(f"Skipping node without ID: {node}")
|
||||
continue
|
||||
|
||||
if node_type not in valid_node_types:
|
||||
logger.warning(f"Unknown node type '{node_type}' for node {node_id}. Skipping.")
|
||||
continue
|
||||
|
||||
# Create a specific function for this node
|
||||
def create_node_function(node_type, node_id, node_data):
|
||||
async def node_function(state):
|
||||
# Consume the asynchronous generator and return the last result
|
||||
result = None
|
||||
try:
|
||||
async for item in node_functions[node_type](
|
||||
state, node_id, node_data
|
||||
):
|
||||
result = item
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.error(f"Error in node {node_id} ({node_type}): {str(e)}")
|
||||
# Return error state
|
||||
return {
|
||||
"content": [
|
||||
Event(
|
||||
author=f"workflow-node:{node_id}",
|
||||
content=Content(parts=[Part(text=f"Node error: {str(e)}")]),
|
||||
)
|
||||
],
|
||||
"status": "node_error",
|
||||
"node_outputs": state.get("node_outputs", {}),
|
||||
"cycle_count": state.get("cycle_count", 0),
|
||||
"conversation_history": state.get("conversation_history", []),
|
||||
"session_id": state.get("session_id", ""),
|
||||
}
|
||||
|
||||
return node_function
|
||||
return node_function
|
||||
|
||||
# Add specific function to the dictionary
|
||||
node_specific_functions[node_id] = create_node_function(
|
||||
node_type, node_id, node_data
|
||||
)
|
||||
# Add specific function to the dictionary
|
||||
node_specific_functions[node_id] = create_node_function(
|
||||
node_type, node_id, node_data
|
||||
)
|
||||
|
||||
# Add node to the graph
|
||||
print(f"Adding node {node_id} of type {node_type}")
|
||||
graph_builder.add_node(node_id, node_specific_functions[node_id])
|
||||
# Add node to the graph
|
||||
logger.debug(f"Adding node {node_id} of type {node_type}")
|
||||
graph_builder.add_node(node_id, node_specific_functions[node_id])
|
||||
|
||||
# Create function to generate specific routers
|
||||
create_router = self._create_flow_router(flow_data)
|
||||
@ -808,8 +907,8 @@ class WorkflowAgent(BaseAgent):
|
||||
node_router = create_router(node_id)
|
||||
|
||||
# Add conditional connections
|
||||
print(f"Adding conditional connections for node {node_id}")
|
||||
print(f"Possible destinations: {edge_destinations}")
|
||||
logger.debug(f"Adding conditional connections for node {node_id}")
|
||||
logger.debug(f"Possible destinations: {list(edge_destinations.keys())}")
|
||||
|
||||
graph_builder.add_conditional_edges(
|
||||
node_id, node_router, edge_destinations
|
||||
@ -825,35 +924,56 @@ class WorkflowAgent(BaseAgent):
|
||||
# If there is no start-node, use the first node found
|
||||
if not entry_point and nodes:
|
||||
entry_point = nodes[0].get("id")
|
||||
logger.warning(f"No start-node found, using first node as entry point: {entry_point}")
|
||||
|
||||
# Define the entry point
|
||||
if entry_point:
|
||||
print(f"Defining entry point: {entry_point}")
|
||||
logger.info(f"Setting entry point: {entry_point}")
|
||||
graph_builder.set_entry_point(entry_point)
|
||||
else:
|
||||
raise ValueError("No valid entry point found for workflow")
|
||||
|
||||
# Compile the graph
|
||||
return graph_builder.compile()
|
||||
try:
|
||||
compiled_graph = graph_builder.compile()
|
||||
logger.info("Workflow graph compiled successfully")
|
||||
return compiled_graph
|
||||
except Exception as e:
|
||||
logger.error(f"Error compiling workflow graph: {str(e)}")
|
||||
raise ValueError(f"Error compiling workflow graph: {str(e)}")
|
||||
|
||||
async def _run_async_impl(
|
||||
self, ctx: InvocationContext
|
||||
) -> AsyncGenerator[Event, None]:
|
||||
"""Implementation of the workflow agent executing the defined workflow and returning results."""
|
||||
|
||||
if hasattr(self, 'model') and self.model:
|
||||
logger.error(f"Workflow agent {self.name} should not have a model configured")
|
||||
raise ValueError(f"Workflow agent {self.name} is an orchestrator and should not have a model. Models should be configured on sub-agents.")
|
||||
|
||||
try:
|
||||
logger.info(f"Starting workflow execution for agent: {self.name}")
|
||||
logger.debug(f"Context session ID: {ctx.session.id if ctx.session else 'No session'}")
|
||||
|
||||
user_message = await self._extract_user_message(ctx)
|
||||
session_id = self._get_session_id(ctx)
|
||||
|
||||
if not self.flow_json:
|
||||
raise ValueError("Workflow agent has no flow_json configured")
|
||||
|
||||
graph = await self._create_graph(ctx, self.flow_json)
|
||||
initial_state = await self._prepare_initial_state(
|
||||
ctx, user_message, session_id
|
||||
)
|
||||
|
||||
print("\n🚀 Starting workflow execution:")
|
||||
print(f"Initial content: {user_message[:100]}...")
|
||||
logger.info(f"🚀 Starting workflow execution with initial message: {user_message[:100]}...")
|
||||
|
||||
# Iterar sobre o AsyncGenerator em vez de usar await
|
||||
async for event in self._execute_workflow(ctx, graph, initial_state):
|
||||
yield event
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in workflow execution: {str(e)}", exc_info=True)
|
||||
yield await self._handle_workflow_error(e)
|
||||
|
||||
async def _extract_user_message(self, ctx: InvocationContext) -> str:
|
||||
@ -861,24 +981,36 @@ class WorkflowAgent(BaseAgent):
|
||||
# Try to find message in session events
|
||||
if ctx.session and hasattr(ctx.session, "events") and ctx.session.events:
|
||||
for event in reversed(ctx.session.events):
|
||||
if event.author == "user" and event.content and event.content.parts:
|
||||
print("Message found in session events")
|
||||
if (
|
||||
hasattr(event, 'author') and
|
||||
event.author == "user" and
|
||||
hasattr(event, 'content') and
|
||||
event.content and
|
||||
hasattr(event.content, 'parts') and
|
||||
event.content.parts
|
||||
):
|
||||
logger.debug("User message found in session events")
|
||||
return event.content.parts[0].text
|
||||
|
||||
# Try to find message in session state
|
||||
if ctx.session and ctx.session.state:
|
||||
if ctx.session and hasattr(ctx.session, 'state') and ctx.session.state:
|
||||
if "user_message" in ctx.session.state:
|
||||
return ctx.session.state["user_message"]
|
||||
elif "message" in ctx.session.state:
|
||||
return ctx.session.state["message"]
|
||||
|
||||
return ""
|
||||
logger.warning("No user message found in context")
|
||||
return "No user message provided"
|
||||
|
||||
def _get_session_id(self, ctx: InvocationContext) -> str:
|
||||
"""Gets or generates a session ID."""
|
||||
if ctx.session and hasattr(ctx.session, "id"):
|
||||
if ctx.session and hasattr(ctx.session, "id") and ctx.session.id:
|
||||
return str(ctx.session.id)
|
||||
return str(uuid.uuid4())
|
||||
|
||||
# Generate a new session ID
|
||||
new_session_id = str(uuid.uuid4())
|
||||
logger.debug(f"Generated new session ID: {new_session_id}")
|
||||
return new_session_id
|
||||
|
||||
async def _prepare_initial_state(
|
||||
self, ctx: InvocationContext, user_message: str, session_id: str
|
||||
@ -889,9 +1021,13 @@ class WorkflowAgent(BaseAgent):
|
||||
content=Content(parts=[Part(text=user_message)]),
|
||||
)
|
||||
|
||||
conversation_history = ctx.session.events or [user_event]
|
||||
conversation_history = []
|
||||
if ctx.session and hasattr(ctx.session, 'events') and ctx.session.events:
|
||||
conversation_history = ctx.session.events.copy()
|
||||
else:
|
||||
conversation_history = [user_event]
|
||||
|
||||
return State(
|
||||
initial_state = State(
|
||||
content=[user_event],
|
||||
status="started",
|
||||
session_id=session_id,
|
||||
@ -899,34 +1035,61 @@ class WorkflowAgent(BaseAgent):
|
||||
node_outputs={},
|
||||
conversation_history=conversation_history,
|
||||
)
|
||||
|
||||
logger.debug(f"Initial state prepared with {len(conversation_history)} history events")
|
||||
return initial_state
|
||||
|
||||
async def _execute_workflow(
|
||||
self, ctx: InvocationContext, graph: StateGraph, initial_state: State
|
||||
) -> AsyncGenerator[Event, None]:
|
||||
"""Executes the workflow graph and yields events."""
|
||||
sent_events = 0
|
||||
total_iterations = 0
|
||||
max_iterations = 100
|
||||
|
||||
async for state in graph.astream(initial_state, {"recursion_limit": 100}):
|
||||
for node_state in state.values():
|
||||
content = node_state.get("content", [])
|
||||
for event in content[sent_events:]:
|
||||
if event.author != "user":
|
||||
try:
|
||||
async for state in graph.astream(initial_state, {"recursion_limit": max_iterations}):
|
||||
total_iterations += 1
|
||||
|
||||
if total_iterations > max_iterations:
|
||||
logger.warning(f"Maximum iterations ({max_iterations}) reached, stopping workflow")
|
||||
break
|
||||
|
||||
for node_state in state.values():
|
||||
content = node_state.get("content", [])
|
||||
|
||||
# Yield new events that haven't been sent yet
|
||||
for event in content[sent_events:]:
|
||||
if hasattr(event, 'author') and event.author != "user":
|
||||
yield event
|
||||
|
||||
sent_events = len(content)
|
||||
|
||||
logger.info(f"Workflow completed after {total_iterations} iterations")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error during workflow execution: {str(e)}")
|
||||
yield await self._handle_workflow_error(e)
|
||||
|
||||
if self.sub_agents:
|
||||
logger.info(f"Executing {len(self.sub_agents)} sub-agents")
|
||||
for sub_agent in self.sub_agents:
|
||||
try:
|
||||
async for event in sub_agent.run_async(ctx):
|
||||
yield event
|
||||
sent_events = len(content)
|
||||
|
||||
# Execute sub-agents if any
|
||||
for sub_agent in self.sub_agents:
|
||||
async for event in sub_agent.run_async(ctx):
|
||||
yield event
|
||||
except Exception as e:
|
||||
logger.error(f"Error executing sub-agent {sub_agent.name}: {str(e)}")
|
||||
yield await self._handle_workflow_error(e)
|
||||
|
||||
async def _handle_workflow_error(self, error: Exception) -> Event:
|
||||
"""Creates an error event for workflow execution errors."""
|
||||
error_msg = f"Error executing the workflow agent: {str(error)}"
|
||||
print(error_msg)
|
||||
error_msg = f"Error executing workflow agent '{self.name}': {str(error)}"
|
||||
logger.error(error_msg)
|
||||
|
||||
return Event(
|
||||
author=f"workflow-error:{self.name}",
|
||||
content=Content(
|
||||
role="agent",
|
||||
parts=[Part(text=error_msg)],
|
||||
),
|
||||
)
|
||||
)
|
Loading…
Reference in New Issue
Block a user