evo-ai/src/services/workflow_agent.py

709 lines
27 KiB
Python

from datetime import datetime
from google.adk.agents import BaseAgent
from google.adk.agents.invocation_context import InvocationContext
from google.adk.events import Event
from google.genai.types import Content, Part
from typing import AsyncGenerator, Dict, Any, List, TypedDict
import uuid
from google.adk.runners import Runner
from src.services.agent_service import get_agent
from sqlalchemy.orm import Session
from langgraph.graph import StateGraph, END
class State(TypedDict):
content: List[Event]
status: str
session_id: str
# Additional fields to store any node outputs
node_outputs: Dict[str, Any]
# Cycle counter to prevent infinite loops
cycle_count: int
conversation_history: List[Event]
class WorkflowAgent(BaseAgent):
"""
Agent that implements workflow flows using LangGraph.
This agent allows defining and executing complex workflows between multiple agents
using LangGraph for orchestration.
"""
# Field declarations for Pydantic
flow_json: Dict[str, Any]
timeout: int
db: Session
def __init__(
self,
name: str,
flow_json: Dict[str, Any],
timeout: int = 300,
sub_agents: List[BaseAgent] = [],
db: Session = None,
**kwargs,
):
"""
Initializes the workflow agent.
Args:
name: Agent name
flow_json: Workflow definition in JSON format
timeout: Maximum execution time (seconds)
sub_agents: List of sub-agents to be executed after the workflow agent
db: Session
"""
# Initialize base class
super().__init__(
name=name,
flow_json=flow_json,
timeout=timeout,
sub_agents=sub_agents,
db=db,
**kwargs,
)
print(
f"Workflow agent initialized with {len(flow_json.get('nodes', []))} nodes"
)
async def _create_node_functions(self, ctx: InvocationContext):
"""Creates functions for each type of node in the flow."""
# Function for the initial node
async def start_node_function(
state: State,
node_id: str,
node_data: Dict[str, Any],
) -> AsyncGenerator[State, None]:
print("\n🏁 INITIAL NODE")
content = state.get("content", [])
if not content:
content = [
Event(
author="agent",
content=Content(parts=[Part(text="Content not found")]),
)
]
yield {
"content": content,
"status": "error",
"node_outputs": {},
"cycle_count": 0,
"conversation_history": ctx.session.events,
}
return
session_id = state.get("session_id", "")
# Store specific results for this node
node_outputs = state.get("node_outputs", {})
node_outputs[node_id] = {"started_at": datetime.now().isoformat()}
yield {
"content": content,
"status": "started",
"node_outputs": node_outputs,
"cycle_count": 0,
"session_id": session_id,
"conversation_history": ctx.session.events,
}
# Generic function for agent nodes
async def agent_node_function(
state: State, node_id: str, node_data: Dict[str, Any]
) -> AsyncGenerator[State, None]:
agent_config = node_data.get("agent", {})
agent_name = agent_config.get("name", "")
agent_id = agent_config.get("id", "")
# Increment cycle counter
cycle_count = state.get("cycle_count", 0) + 1
print(f"\n👤 AGENT: {agent_name} (Cycle {cycle_count})")
content = state.get("content", [])
session_id = state.get("session_id", "")
# Get conversation history
conversation_history = state.get("conversation_history", [])
agent = get_agent(self.db, agent_id)
if not agent:
yield {
"content": [
Event(
author="agent",
content=Content(parts=[Part(text="Agent not found")]),
)
],
"session_id": session_id,
"status": "error",
"node_outputs": {},
"cycle_count": cycle_count,
"conversation_history": conversation_history,
}
return
# Import moved to inside the function to avoid circular import
from src.services.agent_builder import AgentBuilder
agent_builder = AgentBuilder(self.db)
root_agent, exit_stack = await agent_builder.build_agent(agent)
new_content = []
async for event in root_agent.run_async(ctx):
conversation_history.append(event)
new_content.append(event)
print(f"New content: {str(new_content)}")
node_outputs = state.get("node_outputs", {})
node_outputs[node_id] = {
"processed_by": agent_name,
"agent_content": new_content,
"cycle": cycle_count,
}
content = content + new_content
yield {
"content": content,
"status": "processed_by_agent",
"node_outputs": node_outputs,
"cycle_count": cycle_count,
"conversation_history": conversation_history,
"session_id": session_id,
}
if exit_stack:
await exit_stack.aclose()
# Function for condition nodes
async def condition_node_function(
state: State, node_id: str, node_data: Dict[str, Any]
) -> AsyncGenerator[State, None]:
label = node_data.get("label", "No name condition")
conditions = node_data.get("conditions", [])
cycle_count = state.get("cycle_count", 0)
print(f"\n🔄 CONDITION: {label} (Cycle {cycle_count})")
content = state.get("content", [])
print(f"Evaluating condition for content: '{content}'")
session_id = state.get("session_id", "")
conversation_history = state.get("conversation_history", [])
# Check all conditions
conditions_met = []
condition_details = []
for condition in conditions:
condition_id = condition.get("id")
condition_data = condition.get("data", {})
field = condition_data.get("field")
operator = condition_data.get("operator")
expected_value = condition_data.get("value")
print(
f" Checking if {field} {operator} '{expected_value}' (current value: '{state.get(field, '')}')"
)
if self._evaluate_condition(condition, state):
conditions_met.append(condition_id)
condition_details.append(
f"{field} {operator} '{expected_value}'"
)
print(f" ✅ Condition {condition_id} met!")
else:
condition_details.append(
f"{field} {operator} '{expected_value}'"
)
# Check if the cycle reached the limit (extra security)
if cycle_count >= 10:
print(
f"⚠️ ATTENTION: Cycle limit reached ({cycle_count}). Forcing termination."
)
condition_content = [
Event(
author="agent",
content=Content(parts=[Part(text="Cycle limit reached")]),
)
]
content = content + condition_content
yield {
"content": content,
"status": "cycle_limit_reached",
"node_outputs": state.get("node_outputs", {}),
"cycle_count": cycle_count,
"conversation_history": conversation_history,
"session_id": session_id,
}
return
# Store specific results for this node
node_outputs = state.get("node_outputs", {})
node_outputs[node_id] = {
"condition_evaluated": label,
"content_evaluated": content,
"conditions_met": conditions_met,
"condition_details": condition_details,
"cycle": cycle_count,
}
# Prepare a more descriptive message about the conditions
conditions_result_text = "\n".join(condition_details)
condition_summary = f"TRUE" if conditions_met else "FALSE"
condition_content = [
Event(
author="agent",
content=Content(
parts=[
Part(
text=f"Condition evaluated: {label}\nResult: {condition_summary}\nDetails:\n{conditions_result_text}"
)
]
),
)
]
content = content + condition_content
yield {
"content": content,
"status": "condition_evaluated",
"node_outputs": node_outputs,
"cycle_count": cycle_count,
"conversation_history": conversation_history,
"session_id": session_id,
}
return {
"start-node": start_node_function,
"agent-node": agent_node_function,
"condition-node": condition_node_function,
}
def _evaluate_condition(self, condition: Dict[str, Any], state: State) -> bool:
"""Evaluates a condition against the current state."""
condition_type = condition.get("type")
condition_data = condition.get("data", {})
if condition_type == "previous-output":
field = condition_data.get("field")
operator = condition_data.get("operator")
expected_value = condition_data.get("value")
actual_value = state.get(field, "")
# Special treatment for when content is a list of Events
if field == "content" and isinstance(actual_value, list) and actual_value:
# Extract text from each event for comparison
extracted_texts = []
for event in actual_value:
if hasattr(event, "content") and hasattr(event.content, "parts"):
for part in event.content.parts:
if hasattr(part, "text") and part.text:
extracted_texts.append(part.text)
if extracted_texts:
actual_value = " ".join(extracted_texts)
print(f" Extracted text from events: '{actual_value[:100]}...'")
# Convert values to string for easier comparisons
if actual_value is not None:
actual_str = str(actual_value)
else:
actual_str = ""
if expected_value is not None:
expected_str = str(expected_value)
else:
expected_str = ""
# Checks for definition
if operator == "is_defined":
result = actual_value is not None and actual_value != ""
print(f" Check '{operator}': {result}")
return result
elif operator == "is_not_defined":
result = actual_value is None or actual_value == ""
print(f" Check '{operator}': {result}")
return result
# Checks for equality
elif operator == "equals":
result = actual_str == expected_str
print(f" Check '{operator}': {result}")
return result
elif operator == "not_equals":
result = actual_str != expected_str
print(f" Check '{operator}': {result}")
return result
# Checks for content
elif operator == "contains":
# Convert both to lowercase for case-insensitive comparison
expected_lower = expected_str.lower()
actual_lower = actual_str.lower()
print(
f" Comparison 'contains' without case distinction: '{expected_lower}' in '{actual_lower[:100]}...'"
)
result = expected_lower in actual_lower
print(f" Check '{operator}': {result}")
return result
elif operator == "not_contains":
expected_lower = expected_str.lower()
actual_lower = actual_str.lower()
print(
f" Comparison 'not_contains' without case distinction: '{expected_lower}' in '{actual_lower[:100]}...'"
)
result = expected_lower not in actual_lower
print(f" Check '{operator}': {result}")
return result
# Checks for start and end
elif operator == "starts_with":
result = actual_str.lower().startswith(expected_str.lower())
print(f" Check '{operator}': {result}")
return result
elif operator == "ends_with":
result = actual_str.lower().endswith(expected_str.lower())
print(f" Check '{operator}': {result}")
return result
# Numeric checks (attempting to convert to number)
elif operator in [
"greater_than",
"greater_than_or_equal",
"less_than",
"less_than_or_equal",
]:
try:
actual_num = float(actual_str) if actual_str else 0
expected_num = float(expected_str) if expected_str else 0
if operator == "greater_than":
result = actual_num > expected_num
elif operator == "greater_than_or_equal":
result = actual_num >= expected_num
elif operator == "less_than":
result = actual_num < expected_num
elif operator == "less_than_or_equal":
result = actual_num <= expected_num
print(f" Numeric check '{operator}': {result}")
return result
except (ValueError, TypeError):
# If it's not possible to convert to number, return false
print(
f" Error converting values for numeric comparison: '{actual_str[:100]}...' and '{expected_str}'"
)
return False
# Checks with regular expressions
elif operator == "matches":
import re
try:
pattern = re.compile(expected_str, re.IGNORECASE)
result = bool(pattern.search(actual_str))
print(f" Check '{operator}': {result}")
return result
except re.error:
print(f" Error in regular expression: '{expected_str}'")
return False
elif operator == "not_matches":
import re
try:
pattern = re.compile(expected_str, re.IGNORECASE)
result = not bool(pattern.search(actual_str))
print(f" Check '{operator}': {result}")
return result
except re.error:
print(f" Error in regular expression: '{expected_str}'")
return True # If the regex is invalid, we consider that there was no match
return False
def _create_flow_router(self, flow_data: Dict[str, Any]):
"""Creates a router based on the connections in flow.json."""
# Map connections to understand how nodes are connected
edges_map = {}
for edge in flow_data.get("edges", []):
source = edge.get("source")
target = edge.get("target")
source_handle = edge.get("sourceHandle", "default")
if source not in edges_map:
edges_map[source] = {}
# Store the destination for each specific handle
edges_map[source][source_handle] = target
# Map condition nodes and their conditions
condition_nodes = {}
for node in flow_data.get("nodes", []):
if node.get("type") == "condition-node":
node_id = node.get("id")
conditions = node.get("data", {}).get("conditions", [])
condition_nodes[node_id] = conditions
# Routing function for each specific node
def create_router_for_node(node_id: str):
def router(state: State) -> str:
print(f"Routing from node: {node_id}")
# Check if the cycle limit has been reached
cycle_count = state.get("cycle_count", 0)
if cycle_count >= 10:
print(
f"⚠️ Cycle limit ({cycle_count}) reached. Finalizing the flow."
)
return END
# If it's a condition node, evaluate the conditions
if node_id in condition_nodes:
conditions = condition_nodes[node_id]
for condition in conditions:
condition_id = condition.get("id")
# Check if the condition is met
is_condition_met = self._evaluate_condition(condition, state)
if is_condition_met:
print(
f"Condition {condition_id} met. Moving to the next node."
)
# Find the connection that uses this condition_id as a handle
if (
node_id in edges_map
and condition_id in edges_map[node_id]
):
return edges_map[node_id][condition_id]
else:
print(
f"Condition {condition_id} not met. Continuing evaluation or using default path."
)
# If no condition is met, use the bottom-handle if available
if node_id in edges_map and "bottom-handle" in edges_map[node_id]:
print("No condition met. Using default path (bottom-handle).")
return edges_map[node_id]["bottom-handle"]
else:
print("No condition met and no default path. Closing the flow.")
return END
# For regular nodes, simply follow the first available connection
if node_id in edges_map:
# Try to use the default handle or bottom-handle first
for handle in ["default", "bottom-handle"]:
if handle in edges_map[node_id]:
return edges_map[node_id][handle]
# If no specific handle is found, use the first available
if edges_map[node_id]:
first_handle = list(edges_map[node_id].keys())[0]
return edges_map[node_id][first_handle]
# If there is no output connection, close the flow
print(f"No output connection from node {node_id}. Closing the flow.")
return END
return router
return create_router_for_node
async def _create_graph(
self, ctx: InvocationContext, flow_data: Dict[str, Any]
) -> StateGraph:
"""Creates a StateGraph from the flow data."""
# Extract nodes from the flow
nodes = flow_data.get("nodes", [])
# Initialize StateGraph
graph_builder = StateGraph(State)
# Create functions for each node type
node_functions = await self._create_node_functions(ctx)
# Dictionary to store specific functions for each node
node_specific_functions = {}
# Add nodes to the graph
for node in nodes:
node_id = node.get("id")
node_type = node.get("type")
node_data = node.get("data", {})
if node_type in node_functions:
# Create a specific function for this node
def create_node_function(node_type, node_id, node_data):
async def node_function(state):
# Consume the asynchronous generator and return the last result
result = None
async for item in node_functions[node_type](
state, node_id, node_data
):
result = item
return result
return node_function
# Add specific function to the dictionary
node_specific_functions[node_id] = create_node_function(
node_type, node_id, node_data
)
# Add node to the graph
print(f"Adding node {node_id} of type {node_type}")
graph_builder.add_node(node_id, node_specific_functions[node_id])
# Create function to generate specific routers
create_router = self._create_flow_router(flow_data)
# Add conditional connections for each node
for node in nodes:
node_id = node.get("id")
if node_id in node_specific_functions:
# Create dictionary of possible destinations
edge_destinations = {}
# Map all possible destinations
for edge in flow_data.get("edges", []):
if edge.get("source") == node_id:
target = edge.get("target")
if target in node_specific_functions:
edge_destinations[target] = target
# Add END as a possible destination
edge_destinations[END] = END
# Create specific router for this node
node_router = create_router(node_id)
# Add conditional connections
print(f"Adding conditional connections for node {node_id}")
print(f"Possible destinations: {edge_destinations}")
graph_builder.add_conditional_edges(
node_id, node_router, edge_destinations
)
# Find the initial node (usually the start-node)
entry_point = None
for node in nodes:
if node.get("type") == "start-node":
entry_point = node.get("id")
break
# If there is no start-node, use the first node found
if not entry_point and nodes:
entry_point = nodes[0].get("id")
# Define the entry point
if entry_point:
print(f"Defining entry point: {entry_point}")
graph_builder.set_entry_point(entry_point)
# Compile the graph
return graph_builder.compile()
async def _run_async_impl(
self, ctx: InvocationContext
) -> AsyncGenerator[Event, None]:
"""
Implementation of the workflow agent.
This method follows the pattern of custom agent implementation,
executing the defined workflow and returning the results.
"""
try:
# 1. Extract the user message from the context
user_message = None
# Search for the user message in the session events
if ctx.session and hasattr(ctx.session, "events") and ctx.session.events:
for event in reversed(ctx.session.events):
if event.author == "user" and event.content and event.content.parts:
user_message = event.content.parts[0].text
print("Message found in session events")
break
# Check in the session state if the message was not found in the events
if not user_message and ctx.session and ctx.session.state:
if "user_message" in ctx.session.state:
user_message = ctx.session.state["user_message"]
elif "message" in ctx.session.state:
user_message = ctx.session.state["message"]
# 2. Use the session ID as a stable identifier
session_id = (
str(ctx.session.id)
if ctx.session and hasattr(ctx.session, "id")
else str(uuid.uuid4())
)
# 3. Create the workflow graph from the provided JSON
graph = await self._create_graph(ctx, self.flow_json)
# 4. Prepare the initial state
initial_state = State(
content=[
Event(
author="user",
content=Content(parts=[Part(text=user_message)]),
)
],
status="started",
session_id=session_id,
cycle_count=0,
node_outputs={},
conversation_history=ctx.session.events,
)
# 5. Execute the graph
print("\n🚀 Starting workflow execution:")
print(f"Initial content: {user_message[:100]}...")
# Execute the graph with a recursion limit to avoid infinite loops
result = await graph.ainvoke(initial_state, {"recursion_limit": 20})
# 6. Process and return the result
final_content = result.get("content", [])
print(f"\n✅ FINAL RESULT: {final_content[:100]}...")
for content in final_content:
if content.author != "user":
yield content
# Execute sub-agents
for sub_agent in self.sub_agents:
async for event in sub_agent.run_async(ctx):
yield event
except Exception as e:
# Handle any uncaught errors
error_msg = f"Error executing the workflow agent: {str(e)}"
print(error_msg)
yield Event(
author=self.name,
content=Content(
role="agent",
parts=[Part(text=error_msg)],
),
)