709 lines
27 KiB
Python
709 lines
27 KiB
Python
from datetime import datetime
|
|
from google.adk.agents import BaseAgent
|
|
from google.adk.agents.invocation_context import InvocationContext
|
|
from google.adk.events import Event
|
|
from google.genai.types import Content, Part
|
|
|
|
from typing import AsyncGenerator, Dict, Any, List, TypedDict
|
|
import uuid
|
|
|
|
from google.adk.runners import Runner
|
|
from src.services.agent_service import get_agent
|
|
|
|
from sqlalchemy.orm import Session
|
|
|
|
from langgraph.graph import StateGraph, END
|
|
|
|
|
|
class State(TypedDict):
|
|
content: List[Event]
|
|
status: str
|
|
session_id: str
|
|
# Additional fields to store any node outputs
|
|
node_outputs: Dict[str, Any]
|
|
# Cycle counter to prevent infinite loops
|
|
cycle_count: int
|
|
conversation_history: List[Event]
|
|
|
|
|
|
class WorkflowAgent(BaseAgent):
|
|
"""
|
|
Agent that implements workflow flows using LangGraph.
|
|
|
|
This agent allows defining and executing complex workflows between multiple agents
|
|
using LangGraph for orchestration.
|
|
"""
|
|
|
|
# Field declarations for Pydantic
|
|
flow_json: Dict[str, Any]
|
|
timeout: int
|
|
db: Session
|
|
|
|
def __init__(
|
|
self,
|
|
name: str,
|
|
flow_json: Dict[str, Any],
|
|
timeout: int = 300,
|
|
sub_agents: List[BaseAgent] = [],
|
|
db: Session = None,
|
|
**kwargs,
|
|
):
|
|
"""
|
|
Initializes the workflow agent.
|
|
|
|
Args:
|
|
name: Agent name
|
|
flow_json: Workflow definition in JSON format
|
|
timeout: Maximum execution time (seconds)
|
|
sub_agents: List of sub-agents to be executed after the workflow agent
|
|
db: Session
|
|
"""
|
|
# Initialize base class
|
|
super().__init__(
|
|
name=name,
|
|
flow_json=flow_json,
|
|
timeout=timeout,
|
|
sub_agents=sub_agents,
|
|
db=db,
|
|
**kwargs,
|
|
)
|
|
|
|
print(
|
|
f"Workflow agent initialized with {len(flow_json.get('nodes', []))} nodes"
|
|
)
|
|
|
|
async def _create_node_functions(self, ctx: InvocationContext):
|
|
"""Creates functions for each type of node in the flow."""
|
|
|
|
# Function for the initial node
|
|
async def start_node_function(
|
|
state: State,
|
|
node_id: str,
|
|
node_data: Dict[str, Any],
|
|
) -> AsyncGenerator[State, None]:
|
|
print("\n🏁 INITIAL NODE")
|
|
|
|
content = state.get("content", [])
|
|
|
|
if not content:
|
|
content = [
|
|
Event(
|
|
author="agent",
|
|
content=Content(parts=[Part(text="Content not found")]),
|
|
)
|
|
]
|
|
yield {
|
|
"content": content,
|
|
"status": "error",
|
|
"node_outputs": {},
|
|
"cycle_count": 0,
|
|
"conversation_history": ctx.session.events,
|
|
}
|
|
return
|
|
session_id = state.get("session_id", "")
|
|
|
|
# Store specific results for this node
|
|
node_outputs = state.get("node_outputs", {})
|
|
node_outputs[node_id] = {"started_at": datetime.now().isoformat()}
|
|
|
|
yield {
|
|
"content": content,
|
|
"status": "started",
|
|
"node_outputs": node_outputs,
|
|
"cycle_count": 0,
|
|
"session_id": session_id,
|
|
"conversation_history": ctx.session.events,
|
|
}
|
|
|
|
# Generic function for agent nodes
|
|
async def agent_node_function(
|
|
state: State, node_id: str, node_data: Dict[str, Any]
|
|
) -> AsyncGenerator[State, None]:
|
|
|
|
agent_config = node_data.get("agent", {})
|
|
agent_name = agent_config.get("name", "")
|
|
agent_id = agent_config.get("id", "")
|
|
|
|
# Increment cycle counter
|
|
cycle_count = state.get("cycle_count", 0) + 1
|
|
print(f"\n👤 AGENT: {agent_name} (Cycle {cycle_count})")
|
|
|
|
content = state.get("content", [])
|
|
session_id = state.get("session_id", "")
|
|
|
|
# Get conversation history
|
|
conversation_history = state.get("conversation_history", [])
|
|
|
|
agent = get_agent(self.db, agent_id)
|
|
|
|
if not agent:
|
|
yield {
|
|
"content": [
|
|
Event(
|
|
author="agent",
|
|
content=Content(parts=[Part(text="Agent not found")]),
|
|
)
|
|
],
|
|
"session_id": session_id,
|
|
"status": "error",
|
|
"node_outputs": {},
|
|
"cycle_count": cycle_count,
|
|
"conversation_history": conversation_history,
|
|
}
|
|
return
|
|
|
|
# Import moved to inside the function to avoid circular import
|
|
from src.services.agent_builder import AgentBuilder
|
|
|
|
agent_builder = AgentBuilder(self.db)
|
|
root_agent, exit_stack = await agent_builder.build_agent(agent)
|
|
|
|
new_content = []
|
|
async for event in root_agent.run_async(ctx):
|
|
conversation_history.append(event)
|
|
new_content.append(event)
|
|
|
|
print(f"New content: {str(new_content)}")
|
|
|
|
node_outputs = state.get("node_outputs", {})
|
|
node_outputs[node_id] = {
|
|
"processed_by": agent_name,
|
|
"agent_content": new_content,
|
|
"cycle": cycle_count,
|
|
}
|
|
|
|
content = content + new_content
|
|
|
|
yield {
|
|
"content": content,
|
|
"status": "processed_by_agent",
|
|
"node_outputs": node_outputs,
|
|
"cycle_count": cycle_count,
|
|
"conversation_history": conversation_history,
|
|
"session_id": session_id,
|
|
}
|
|
|
|
if exit_stack:
|
|
await exit_stack.aclose()
|
|
|
|
# Function for condition nodes
|
|
async def condition_node_function(
|
|
state: State, node_id: str, node_data: Dict[str, Any]
|
|
) -> AsyncGenerator[State, None]:
|
|
label = node_data.get("label", "No name condition")
|
|
conditions = node_data.get("conditions", [])
|
|
cycle_count = state.get("cycle_count", 0)
|
|
|
|
print(f"\n🔄 CONDITION: {label} (Cycle {cycle_count})")
|
|
|
|
content = state.get("content", [])
|
|
print(f"Evaluating condition for content: '{content}'")
|
|
|
|
session_id = state.get("session_id", "")
|
|
conversation_history = state.get("conversation_history", [])
|
|
|
|
# Check all conditions
|
|
conditions_met = []
|
|
condition_details = []
|
|
for condition in conditions:
|
|
condition_id = condition.get("id")
|
|
condition_data = condition.get("data", {})
|
|
field = condition_data.get("field")
|
|
operator = condition_data.get("operator")
|
|
expected_value = condition_data.get("value")
|
|
|
|
print(
|
|
f" Checking if {field} {operator} '{expected_value}' (current value: '{state.get(field, '')}')"
|
|
)
|
|
if self._evaluate_condition(condition, state):
|
|
conditions_met.append(condition_id)
|
|
condition_details.append(
|
|
f"{field} {operator} '{expected_value}' ✅"
|
|
)
|
|
print(f" ✅ Condition {condition_id} met!")
|
|
else:
|
|
condition_details.append(
|
|
f"{field} {operator} '{expected_value}' ❌"
|
|
)
|
|
|
|
# Check if the cycle reached the limit (extra security)
|
|
if cycle_count >= 10:
|
|
print(
|
|
f"⚠️ ATTENTION: Cycle limit reached ({cycle_count}). Forcing termination."
|
|
)
|
|
|
|
condition_content = [
|
|
Event(
|
|
author="agent",
|
|
content=Content(parts=[Part(text="Cycle limit reached")]),
|
|
)
|
|
]
|
|
content = content + condition_content
|
|
yield {
|
|
"content": content,
|
|
"status": "cycle_limit_reached",
|
|
"node_outputs": state.get("node_outputs", {}),
|
|
"cycle_count": cycle_count,
|
|
"conversation_history": conversation_history,
|
|
"session_id": session_id,
|
|
}
|
|
return
|
|
|
|
# Store specific results for this node
|
|
node_outputs = state.get("node_outputs", {})
|
|
node_outputs[node_id] = {
|
|
"condition_evaluated": label,
|
|
"content_evaluated": content,
|
|
"conditions_met": conditions_met,
|
|
"condition_details": condition_details,
|
|
"cycle": cycle_count,
|
|
}
|
|
|
|
# Prepare a more descriptive message about the conditions
|
|
conditions_result_text = "\n".join(condition_details)
|
|
condition_summary = f"TRUE" if conditions_met else "FALSE"
|
|
|
|
condition_content = [
|
|
Event(
|
|
author="agent",
|
|
content=Content(
|
|
parts=[
|
|
Part(
|
|
text=f"Condition evaluated: {label}\nResult: {condition_summary}\nDetails:\n{conditions_result_text}"
|
|
)
|
|
]
|
|
),
|
|
)
|
|
]
|
|
content = content + condition_content
|
|
|
|
yield {
|
|
"content": content,
|
|
"status": "condition_evaluated",
|
|
"node_outputs": node_outputs,
|
|
"cycle_count": cycle_count,
|
|
"conversation_history": conversation_history,
|
|
"session_id": session_id,
|
|
}
|
|
|
|
return {
|
|
"start-node": start_node_function,
|
|
"agent-node": agent_node_function,
|
|
"condition-node": condition_node_function,
|
|
}
|
|
|
|
def _evaluate_condition(self, condition: Dict[str, Any], state: State) -> bool:
|
|
"""Evaluates a condition against the current state."""
|
|
condition_type = condition.get("type")
|
|
condition_data = condition.get("data", {})
|
|
|
|
if condition_type == "previous-output":
|
|
field = condition_data.get("field")
|
|
operator = condition_data.get("operator")
|
|
expected_value = condition_data.get("value")
|
|
|
|
actual_value = state.get(field, "")
|
|
|
|
# Special treatment for when content is a list of Events
|
|
if field == "content" and isinstance(actual_value, list) and actual_value:
|
|
# Extract text from each event for comparison
|
|
extracted_texts = []
|
|
for event in actual_value:
|
|
if hasattr(event, "content") and hasattr(event.content, "parts"):
|
|
for part in event.content.parts:
|
|
if hasattr(part, "text") and part.text:
|
|
extracted_texts.append(part.text)
|
|
|
|
if extracted_texts:
|
|
actual_value = " ".join(extracted_texts)
|
|
print(f" Extracted text from events: '{actual_value[:100]}...'")
|
|
|
|
# Convert values to string for easier comparisons
|
|
if actual_value is not None:
|
|
actual_str = str(actual_value)
|
|
else:
|
|
actual_str = ""
|
|
|
|
if expected_value is not None:
|
|
expected_str = str(expected_value)
|
|
else:
|
|
expected_str = ""
|
|
|
|
# Checks for definition
|
|
if operator == "is_defined":
|
|
result = actual_value is not None and actual_value != ""
|
|
print(f" Check '{operator}': {result}")
|
|
return result
|
|
elif operator == "is_not_defined":
|
|
result = actual_value is None or actual_value == ""
|
|
print(f" Check '{operator}': {result}")
|
|
return result
|
|
|
|
# Checks for equality
|
|
elif operator == "equals":
|
|
result = actual_str == expected_str
|
|
print(f" Check '{operator}': {result}")
|
|
return result
|
|
elif operator == "not_equals":
|
|
result = actual_str != expected_str
|
|
print(f" Check '{operator}': {result}")
|
|
return result
|
|
|
|
# Checks for content
|
|
elif operator == "contains":
|
|
# Convert both to lowercase for case-insensitive comparison
|
|
expected_lower = expected_str.lower()
|
|
actual_lower = actual_str.lower()
|
|
print(
|
|
f" Comparison 'contains' without case distinction: '{expected_lower}' in '{actual_lower[:100]}...'"
|
|
)
|
|
result = expected_lower in actual_lower
|
|
print(f" Check '{operator}': {result}")
|
|
return result
|
|
elif operator == "not_contains":
|
|
expected_lower = expected_str.lower()
|
|
actual_lower = actual_str.lower()
|
|
print(
|
|
f" Comparison 'not_contains' without case distinction: '{expected_lower}' in '{actual_lower[:100]}...'"
|
|
)
|
|
result = expected_lower not in actual_lower
|
|
print(f" Check '{operator}': {result}")
|
|
return result
|
|
|
|
# Checks for start and end
|
|
elif operator == "starts_with":
|
|
result = actual_str.lower().startswith(expected_str.lower())
|
|
print(f" Check '{operator}': {result}")
|
|
return result
|
|
elif operator == "ends_with":
|
|
result = actual_str.lower().endswith(expected_str.lower())
|
|
print(f" Check '{operator}': {result}")
|
|
return result
|
|
|
|
# Numeric checks (attempting to convert to number)
|
|
elif operator in [
|
|
"greater_than",
|
|
"greater_than_or_equal",
|
|
"less_than",
|
|
"less_than_or_equal",
|
|
]:
|
|
try:
|
|
actual_num = float(actual_str) if actual_str else 0
|
|
expected_num = float(expected_str) if expected_str else 0
|
|
|
|
if operator == "greater_than":
|
|
result = actual_num > expected_num
|
|
elif operator == "greater_than_or_equal":
|
|
result = actual_num >= expected_num
|
|
elif operator == "less_than":
|
|
result = actual_num < expected_num
|
|
elif operator == "less_than_or_equal":
|
|
result = actual_num <= expected_num
|
|
print(f" Numeric check '{operator}': {result}")
|
|
return result
|
|
except (ValueError, TypeError):
|
|
# If it's not possible to convert to number, return false
|
|
print(
|
|
f" Error converting values for numeric comparison: '{actual_str[:100]}...' and '{expected_str}'"
|
|
)
|
|
return False
|
|
|
|
# Checks with regular expressions
|
|
elif operator == "matches":
|
|
import re
|
|
|
|
try:
|
|
pattern = re.compile(expected_str, re.IGNORECASE)
|
|
result = bool(pattern.search(actual_str))
|
|
print(f" Check '{operator}': {result}")
|
|
return result
|
|
except re.error:
|
|
print(f" Error in regular expression: '{expected_str}'")
|
|
return False
|
|
elif operator == "not_matches":
|
|
import re
|
|
|
|
try:
|
|
pattern = re.compile(expected_str, re.IGNORECASE)
|
|
result = not bool(pattern.search(actual_str))
|
|
print(f" Check '{operator}': {result}")
|
|
return result
|
|
except re.error:
|
|
print(f" Error in regular expression: '{expected_str}'")
|
|
return True # If the regex is invalid, we consider that there was no match
|
|
|
|
return False
|
|
|
|
def _create_flow_router(self, flow_data: Dict[str, Any]):
|
|
"""Creates a router based on the connections in flow.json."""
|
|
# Map connections to understand how nodes are connected
|
|
edges_map = {}
|
|
|
|
for edge in flow_data.get("edges", []):
|
|
source = edge.get("source")
|
|
target = edge.get("target")
|
|
source_handle = edge.get("sourceHandle", "default")
|
|
|
|
if source not in edges_map:
|
|
edges_map[source] = {}
|
|
|
|
# Store the destination for each specific handle
|
|
edges_map[source][source_handle] = target
|
|
|
|
# Map condition nodes and their conditions
|
|
condition_nodes = {}
|
|
for node in flow_data.get("nodes", []):
|
|
if node.get("type") == "condition-node":
|
|
node_id = node.get("id")
|
|
conditions = node.get("data", {}).get("conditions", [])
|
|
condition_nodes[node_id] = conditions
|
|
|
|
# Routing function for each specific node
|
|
def create_router_for_node(node_id: str):
|
|
def router(state: State) -> str:
|
|
print(f"Routing from node: {node_id}")
|
|
|
|
# Check if the cycle limit has been reached
|
|
cycle_count = state.get("cycle_count", 0)
|
|
if cycle_count >= 10:
|
|
print(
|
|
f"⚠️ Cycle limit ({cycle_count}) reached. Finalizing the flow."
|
|
)
|
|
return END
|
|
|
|
# If it's a condition node, evaluate the conditions
|
|
if node_id in condition_nodes:
|
|
conditions = condition_nodes[node_id]
|
|
|
|
for condition in conditions:
|
|
condition_id = condition.get("id")
|
|
|
|
# Check if the condition is met
|
|
is_condition_met = self._evaluate_condition(condition, state)
|
|
|
|
if is_condition_met:
|
|
print(
|
|
f"Condition {condition_id} met. Moving to the next node."
|
|
)
|
|
|
|
# Find the connection that uses this condition_id as a handle
|
|
if (
|
|
node_id in edges_map
|
|
and condition_id in edges_map[node_id]
|
|
):
|
|
return edges_map[node_id][condition_id]
|
|
else:
|
|
print(
|
|
f"Condition {condition_id} not met. Continuing evaluation or using default path."
|
|
)
|
|
|
|
# If no condition is met, use the bottom-handle if available
|
|
if node_id in edges_map and "bottom-handle" in edges_map[node_id]:
|
|
print("No condition met. Using default path (bottom-handle).")
|
|
return edges_map[node_id]["bottom-handle"]
|
|
else:
|
|
print("No condition met and no default path. Closing the flow.")
|
|
return END
|
|
|
|
# For regular nodes, simply follow the first available connection
|
|
if node_id in edges_map:
|
|
# Try to use the default handle or bottom-handle first
|
|
for handle in ["default", "bottom-handle"]:
|
|
if handle in edges_map[node_id]:
|
|
return edges_map[node_id][handle]
|
|
|
|
# If no specific handle is found, use the first available
|
|
if edges_map[node_id]:
|
|
first_handle = list(edges_map[node_id].keys())[0]
|
|
return edges_map[node_id][first_handle]
|
|
|
|
# If there is no output connection, close the flow
|
|
print(f"No output connection from node {node_id}. Closing the flow.")
|
|
return END
|
|
|
|
return router
|
|
|
|
return create_router_for_node
|
|
|
|
async def _create_graph(
|
|
self, ctx: InvocationContext, flow_data: Dict[str, Any]
|
|
) -> StateGraph:
|
|
"""Creates a StateGraph from the flow data."""
|
|
# Extract nodes from the flow
|
|
nodes = flow_data.get("nodes", [])
|
|
|
|
# Initialize StateGraph
|
|
graph_builder = StateGraph(State)
|
|
|
|
# Create functions for each node type
|
|
node_functions = await self._create_node_functions(ctx)
|
|
|
|
# Dictionary to store specific functions for each node
|
|
node_specific_functions = {}
|
|
|
|
# Add nodes to the graph
|
|
for node in nodes:
|
|
node_id = node.get("id")
|
|
node_type = node.get("type")
|
|
node_data = node.get("data", {})
|
|
|
|
if node_type in node_functions:
|
|
# Create a specific function for this node
|
|
def create_node_function(node_type, node_id, node_data):
|
|
async def node_function(state):
|
|
# Consume the asynchronous generator and return the last result
|
|
result = None
|
|
async for item in node_functions[node_type](
|
|
state, node_id, node_data
|
|
):
|
|
result = item
|
|
return result
|
|
|
|
return node_function
|
|
|
|
# Add specific function to the dictionary
|
|
node_specific_functions[node_id] = create_node_function(
|
|
node_type, node_id, node_data
|
|
)
|
|
|
|
# Add node to the graph
|
|
print(f"Adding node {node_id} of type {node_type}")
|
|
graph_builder.add_node(node_id, node_specific_functions[node_id])
|
|
|
|
# Create function to generate specific routers
|
|
create_router = self._create_flow_router(flow_data)
|
|
|
|
# Add conditional connections for each node
|
|
for node in nodes:
|
|
node_id = node.get("id")
|
|
|
|
if node_id in node_specific_functions:
|
|
# Create dictionary of possible destinations
|
|
edge_destinations = {}
|
|
|
|
# Map all possible destinations
|
|
for edge in flow_data.get("edges", []):
|
|
if edge.get("source") == node_id:
|
|
target = edge.get("target")
|
|
if target in node_specific_functions:
|
|
edge_destinations[target] = target
|
|
|
|
# Add END as a possible destination
|
|
edge_destinations[END] = END
|
|
|
|
# Create specific router for this node
|
|
node_router = create_router(node_id)
|
|
|
|
# Add conditional connections
|
|
print(f"Adding conditional connections for node {node_id}")
|
|
print(f"Possible destinations: {edge_destinations}")
|
|
|
|
graph_builder.add_conditional_edges(
|
|
node_id, node_router, edge_destinations
|
|
)
|
|
|
|
# Find the initial node (usually the start-node)
|
|
entry_point = None
|
|
for node in nodes:
|
|
if node.get("type") == "start-node":
|
|
entry_point = node.get("id")
|
|
break
|
|
|
|
# If there is no start-node, use the first node found
|
|
if not entry_point and nodes:
|
|
entry_point = nodes[0].get("id")
|
|
|
|
# Define the entry point
|
|
if entry_point:
|
|
print(f"Defining entry point: {entry_point}")
|
|
graph_builder.set_entry_point(entry_point)
|
|
|
|
# Compile the graph
|
|
return graph_builder.compile()
|
|
|
|
async def _run_async_impl(
|
|
self, ctx: InvocationContext
|
|
) -> AsyncGenerator[Event, None]:
|
|
"""
|
|
Implementation of the workflow agent.
|
|
|
|
This method follows the pattern of custom agent implementation,
|
|
executing the defined workflow and returning the results.
|
|
"""
|
|
|
|
try:
|
|
# 1. Extract the user message from the context
|
|
user_message = None
|
|
|
|
# Search for the user message in the session events
|
|
if ctx.session and hasattr(ctx.session, "events") and ctx.session.events:
|
|
for event in reversed(ctx.session.events):
|
|
if event.author == "user" and event.content and event.content.parts:
|
|
user_message = event.content.parts[0].text
|
|
print("Message found in session events")
|
|
break
|
|
|
|
# Check in the session state if the message was not found in the events
|
|
if not user_message and ctx.session and ctx.session.state:
|
|
if "user_message" in ctx.session.state:
|
|
user_message = ctx.session.state["user_message"]
|
|
elif "message" in ctx.session.state:
|
|
user_message = ctx.session.state["message"]
|
|
|
|
# 2. Use the session ID as a stable identifier
|
|
session_id = (
|
|
str(ctx.session.id)
|
|
if ctx.session and hasattr(ctx.session, "id")
|
|
else str(uuid.uuid4())
|
|
)
|
|
|
|
# 3. Create the workflow graph from the provided JSON
|
|
graph = await self._create_graph(ctx, self.flow_json)
|
|
|
|
# 4. Prepare the initial state
|
|
initial_state = State(
|
|
content=[
|
|
Event(
|
|
author="user",
|
|
content=Content(parts=[Part(text=user_message)]),
|
|
)
|
|
],
|
|
status="started",
|
|
session_id=session_id,
|
|
cycle_count=0,
|
|
node_outputs={},
|
|
conversation_history=ctx.session.events,
|
|
)
|
|
|
|
# 5. Execute the graph
|
|
print("\n🚀 Starting workflow execution:")
|
|
print(f"Initial content: {user_message[:100]}...")
|
|
|
|
# Execute the graph with a recursion limit to avoid infinite loops
|
|
result = await graph.ainvoke(initial_state, {"recursion_limit": 20})
|
|
|
|
# 6. Process and return the result
|
|
final_content = result.get("content", [])
|
|
print(f"\n✅ FINAL RESULT: {final_content[:100]}...")
|
|
|
|
for content in final_content:
|
|
if content.author != "user":
|
|
yield content
|
|
|
|
# Execute sub-agents
|
|
for sub_agent in self.sub_agents:
|
|
async for event in sub_agent.run_async(ctx):
|
|
yield event
|
|
|
|
except Exception as e:
|
|
# Handle any uncaught errors
|
|
error_msg = f"Error executing the workflow agent: {str(e)}"
|
|
print(error_msg)
|
|
yield Event(
|
|
author=self.name,
|
|
content=Content(
|
|
role="agent",
|
|
parts=[Part(text=error_msg)],
|
|
),
|
|
)
|