diff --git a/src/google/adk/cli/agent_graph.py b/src/google/adk/cli/agent_graph.py index 18e1eae..4a0e53e 100644 --- a/src/google/adk/cli/agent_graph.py +++ b/src/google/adk/cli/agent_graph.py @@ -19,7 +19,7 @@ from typing import Union import graphviz -from ..agents import BaseAgent +from ..agents import BaseAgent, SequentialAgent, LoopAgent, ParallelAgent from ..agents.llm_agent import LlmAgent from ..tools.agent_tool import AgentTool from ..tools.base_tool import BaseTool @@ -35,14 +35,34 @@ else: retrieval_tool_module_loaded = True -async def build_graph(graph, agent: BaseAgent, highlight_pairs): +async def build_graph(graph: graphviz.Digraph, agent: BaseAgent, highlight_pairs, parent_agent=None): + """ + Build a graph of the agent and its sub-agents. + Args: + graph: The graph to build on. + agent: The agent to build the graph for. + highlight_pairs: A list of pairs of nodes to highlight. + parent_agent: The parent agent of the current agent. This is specifically used when building Workflow Agents to directly connect a node to nodes inside a Workflow Agent. + + Returns: + None + """ dark_green = '#0F5223' light_green = '#69CB87' light_gray = '#cccccc' + white = '#ffffff' def get_node_name(tool_or_agent: Union[BaseAgent, BaseTool]): if isinstance(tool_or_agent, BaseAgent): - return tool_or_agent.name + # Added Workflow Agent checks for different agent types + if isinstance(tool_or_agent, SequentialAgent): + return tool_or_agent.name + f" (Sequential Agent)" + elif isinstance(tool_or_agent, LoopAgent): + return tool_or_agent.name + f" (Loop Agent)" + elif isinstance(tool_or_agent, ParallelAgent): + return tool_or_agent.name + f" (Parallel Agent)" + else: + return tool_or_agent.name elif isinstance(tool_or_agent, BaseTool): return tool_or_agent.name else: @@ -73,6 +93,7 @@ async def build_graph(graph, agent: BaseAgent, highlight_pairs): def get_node_shape(tool_or_agent: Union[BaseAgent, BaseTool]): if isinstance(tool_or_agent, BaseAgent): return 'ellipse' + elif retrieval_tool_module_loaded and isinstance( tool_or_agent, BaseRetrievalTool ): @@ -88,33 +109,120 @@ async def build_graph(graph, agent: BaseAgent, highlight_pairs): tool_or_agent, ) return 'cylinder' + + def should_build_agent_cluster(tool_or_agent: Union[BaseAgent, BaseTool]): + if isinstance(tool_or_agent, BaseAgent): + if isinstance(tool_or_agent, SequentialAgent): + return True + elif isinstance(tool_or_agent, LoopAgent): + return True + elif isinstance(tool_or_agent, ParallelAgent): + return True + else: + return False + elif retrieval_tool_module_loaded and isinstance( + tool_or_agent, BaseRetrievalTool + ): + return False + elif isinstance(tool_or_agent, FunctionTool): + return False + elif isinstance(tool_or_agent, BaseTool): + return False + else: + logger.warning( + 'Unsupported tool, type: %s, obj: %s', + type(tool_or_agent), + tool_or_agent, + ) + return False + + def build_cluster(child: graphviz.Digraph, agent: BaseAgent, name: str): + if isinstance(agent, LoopAgent): + # Draw the edge from the parent agent to the first sub-agent + draw_edge(parent_agent.name, agent.sub_agents[0].name) + length = len(agent.sub_agents) + currLength = 0 + # Draw the edges between the sub-agents + for sub_agent_int_sequential in agent.sub_agents: + build_graph(child, sub_agent_int_sequential, highlight_pairs) + # Draw the edge between the current sub-agent and the next one + # If it's the last sub-agent, draw an edge to the first one to indicating a loop + draw_edge(agent.sub_agents[currLength].name, agent.sub_agents[0 if currLength == length - 1 else currLength+1 ].name) + currLength += 1 + elif isinstance(agent, SequentialAgent): + # Draw the edge from the parent agent to the first sub-agent + draw_edge(parent_agent.name, agent.sub_agents[0].name) + length = len(agent.sub_agents) + currLength = 0 + + # Draw the edges between the sub-agents + for sub_agent_int_sequential in agent.sub_agents: + build_graph(child, sub_agent_int_sequential, highlight_pairs) + # Draw the edge between the current sub-agent and the next one + # If it's the last sub-agent, don't draw an edge to avoid a loop + draw_edge(agent.sub_agents[currLength].name, agent.sub_agents[currLength+1].name) if currLength != length - 1 else None + currLength += 1 + + elif isinstance(agent, ParallelAgent): + # Draw the edge from the parent agent to every sub-agent + for sub_agent in agent.sub_agents: + build_graph(child, sub_agent, highlight_pairs) + draw_edge(parent_agent.name, sub_agent.name) + else: + for sub_agent in agent.sub_agents: + build_graph(child, sub_agent, highlight_pairs) + draw_edge(agent.name, sub_agent.name) + + child.attr( + label=name, + style='rounded', + color=white, + fontcolor=light_gray, + ) def draw_node(tool_or_agent: Union[BaseAgent, BaseTool]): name = get_node_name(tool_or_agent) shape = get_node_shape(tool_or_agent) caption = get_node_caption(tool_or_agent) + asCluster = should_build_agent_cluster(tool_or_agent) + child = None if highlight_pairs: for highlight_tuple in highlight_pairs: if name in highlight_tuple: - graph.node( - name, - caption, - style='filled,rounded', - fillcolor=dark_green, - color=dark_green, - shape=shape, - fontcolor=light_gray, - ) + # if in highlight, draw highlight node + if asCluster: + cluster = graphviz.Digraph(name='cluster_' + name) # adding "cluster_" to the name makes the graph render as a cluster subgraph + build_cluster(cluster, agent, name) + graph.subgraph(cluster) + else: + graph.node( + name, + caption, + style='filled,rounded', + fillcolor=dark_green, + color=dark_green, + shape=shape, + fontcolor=light_gray, + ) return - # if not in highlight, draw non-highliht node - graph.node( + # if not in highlight, draw non-highlight node + if asCluster: + + cluster = graphviz.Digraph(name='cluster_' + name) # adding "cluster_" to the name makes the graph render as a cluster subgraph + build_cluster(cluster, agent, name) + graph.subgraph(cluster) + + else: + graph.node( name, caption, shape=shape, style='rounded', color=light_gray, fontcolor=light_gray, - ) + ) + + return def draw_edge(from_name, to_name): if highlight_pairs: @@ -126,12 +234,18 @@ async def build_graph(graph, agent: BaseAgent, highlight_pairs): graph.edge(from_name, to_name, color=light_green, dir='back') return # if no need to highlight, color gray - graph.edge(from_name, to_name, arrowhead='none', color=light_gray) + if (should_build_agent_cluster(agent)): + + graph.edge(from_name, to_name, color=light_gray, ) + else: + graph.edge(from_name, to_name, arrowhead='none', color=light_gray) draw_node(agent) for sub_agent in agent.sub_agents: - await build_graph(graph, sub_agent, highlight_pairs) - draw_edge(agent.name, sub_agent.name) + + build_graph(graph, sub_agent, highlight_pairs, agent) + if (not should_build_agent_cluster(sub_agent) and not should_build_agent_cluster(agent)): # This is to avoid making a node for a Workflow Agent + draw_edge(agent.name, sub_agent.name) if isinstance(agent, LlmAgent): for tool in await agent.canonical_tools(): draw_node(tool)