chore: Fix formatting

PiperOrigin-RevId: 766407362
This commit is contained in:
Selcuk Gun 2025-06-02 17:27:46 -07:00 committed by Copybara-Service
parent 174afb3975
commit eb2b9ef88f

View File

@ -19,7 +19,10 @@ from typing import Union
import graphviz
from ..agents import BaseAgent, SequentialAgent, LoopAgent, ParallelAgent
from ..agents import BaseAgent
from ..agents import LoopAgent
from ..agents import ParallelAgent
from ..agents import SequentialAgent
from ..agents.llm_agent import LlmAgent
from ..tools.agent_tool import AgentTool
from ..tools.base_tool import BaseTool
@ -35,7 +38,12 @@ else:
retrieval_tool_module_loaded = True
async def build_graph(graph: graphviz.Digraph, agent: BaseAgent, highlight_pairs, parent_agent=None):
async def build_graph(
graph: graphviz.Digraph,
agent: BaseAgent,
highlight_pairs,
parent_agent=None,
):
"""
Build a graph of the agent and its sub-agents.
Args:
@ -43,7 +51,7 @@ async def build_graph(graph: graphviz.Digraph, agent: BaseAgent, highlight_pairs
agent: The agent to build the graph for.
highlight_pairs: A list of pairs of nodes to highlight.
parent_agent: The parent agent of the current agent. This is specifically used when building Workflow Agents to directly connect a node to nodes inside a Workflow Agent.
Returns:
None
"""
@ -56,11 +64,11 @@ async def build_graph(graph: graphviz.Digraph, agent: BaseAgent, highlight_pairs
if isinstance(tool_or_agent, BaseAgent):
# Added Workflow Agent checks for different agent types
if isinstance(tool_or_agent, SequentialAgent):
return tool_or_agent.name + f" (Sequential Agent)"
return tool_or_agent.name + f' (Sequential Agent)'
elif isinstance(tool_or_agent, LoopAgent):
return tool_or_agent.name + f" (Loop Agent)"
return tool_or_agent.name + f' (Loop Agent)'
elif isinstance(tool_or_agent, ParallelAgent):
return tool_or_agent.name + f" (Parallel Agent)"
return tool_or_agent.name + f' (Parallel Agent)'
else:
return tool_or_agent.name
elif isinstance(tool_or_agent, BaseTool):
@ -93,7 +101,7 @@ async def build_graph(graph: graphviz.Digraph, agent: BaseAgent, highlight_pairs
def get_node_shape(tool_or_agent: Union[BaseAgent, BaseTool]):
if isinstance(tool_or_agent, BaseAgent):
return 'ellipse'
elif retrieval_tool_module_loaded and isinstance(
tool_or_agent, BaseRetrievalTool
):
@ -109,7 +117,7 @@ async def build_graph(graph: graphviz.Digraph, agent: BaseAgent, highlight_pairs
tool_or_agent,
)
return 'cylinder'
def should_build_agent_cluster(tool_or_agent: Union[BaseAgent, BaseTool]):
if isinstance(tool_or_agent, BaseAgent):
if isinstance(tool_or_agent, SequentialAgent):
@ -135,19 +143,24 @@ async def build_graph(graph: graphviz.Digraph, agent: BaseAgent, highlight_pairs
tool_or_agent,
)
return False
def build_cluster(child: graphviz.Digraph, agent: BaseAgent, name: str):
if isinstance(agent, LoopAgent):
if isinstance(agent, LoopAgent):
# Draw the edge from the parent agent to the first sub-agent
draw_edge(parent_agent.name, agent.sub_agents[0].name)
length = len(agent.sub_agents)
currLength = 0
# Draw the edges between the sub-agents
for sub_agent_int_sequential in agent.sub_agents:
for sub_agent_int_sequential in agent.sub_agents:
build_graph(child, sub_agent_int_sequential, highlight_pairs)
# Draw the edge between the current sub-agent and the next one
# If it's the last sub-agent, draw an edge to the first one to indicating a loop
draw_edge(agent.sub_agents[currLength].name, agent.sub_agents[0 if currLength == length - 1 else currLength+1 ].name)
draw_edge(
agent.sub_agents[currLength].name,
agent.sub_agents[
0 if currLength == length - 1 else currLength + 1
].name,
)
currLength += 1
elif isinstance(agent, SequentialAgent):
# Draw the edge from the parent agent to the first sub-agent
@ -160,10 +173,13 @@ async def build_graph(graph: graphviz.Digraph, agent: BaseAgent, highlight_pairs
build_graph(child, sub_agent_int_sequential, highlight_pairs)
# Draw the edge between the current sub-agent and the next one
# If it's the last sub-agent, don't draw an edge to avoid a loop
draw_edge(agent.sub_agents[currLength].name, agent.sub_agents[currLength+1].name) if currLength != length - 1 else None
draw_edge(
agent.sub_agents[currLength].name,
agent.sub_agents[currLength + 1].name,
) if currLength != length - 1 else None
currLength += 1
elif isinstance(agent, ParallelAgent):
elif isinstance(agent, ParallelAgent):
# Draw the edge from the parent agent to every sub-agent
for sub_agent in agent.sub_agents:
build_graph(child, sub_agent, highlight_pairs)
@ -191,7 +207,9 @@ async def build_graph(graph: graphviz.Digraph, agent: BaseAgent, highlight_pairs
if name in highlight_tuple:
# if in highlight, draw highlight node
if asCluster:
cluster = graphviz.Digraph(name='cluster_' + name) # adding "cluster_" to the name makes the graph render as a cluster subgraph
cluster = graphviz.Digraph(
name='cluster_' + name
) # adding "cluster_" to the name makes the graph render as a cluster subgraph
build_cluster(cluster, agent, name)
graph.subgraph(cluster)
else:
@ -207,19 +225,21 @@ async def build_graph(graph: graphviz.Digraph, agent: BaseAgent, highlight_pairs
return
# if not in highlight, draw non-highlight node
if asCluster:
cluster = graphviz.Digraph(name='cluster_' + name) # adding "cluster_" to the name makes the graph render as a cluster subgraph
cluster = graphviz.Digraph(
name='cluster_' + name
) # adding "cluster_" to the name makes the graph render as a cluster subgraph
build_cluster(cluster, agent, name)
graph.subgraph(cluster)
else:
graph.node(
name,
caption,
shape=shape,
style='rounded',
color=light_gray,
fontcolor=light_gray,
name,
caption,
shape=shape,
style='rounded',
color=light_gray,
fontcolor=light_gray,
)
return
@ -234,17 +254,25 @@ async def build_graph(graph: graphviz.Digraph, agent: BaseAgent, highlight_pairs
graph.edge(from_name, to_name, color=light_green, dir='back')
return
# if no need to highlight, color gray
if (should_build_agent_cluster(agent)):
graph.edge(from_name, to_name, color=light_gray, )
if should_build_agent_cluster(agent):
graph.edge(
from_name,
to_name,
color=light_gray,
)
else:
graph.edge(from_name, to_name, arrowhead='none', color=light_gray)
draw_node(agent)
for sub_agent in agent.sub_agents:
build_graph(graph, sub_agent, highlight_pairs, agent)
if (not should_build_agent_cluster(sub_agent) and not should_build_agent_cluster(agent)): # This is to avoid making a node for a Workflow Agent
if not should_build_agent_cluster(
sub_agent
) and not should_build_agent_cluster(
agent
): # This is to avoid making a node for a Workflow Agent
draw_edge(agent.name, sub_agent.name)
if isinstance(agent, LlmAgent):
for tool in await agent.canonical_tools():