Changes for 0.1.0 release

This commit is contained in:
hangfei
2025-04-09 04:24:34 +00:00
parent 9827820143
commit 363e10619a
25 changed files with 553 additions and 99 deletions
+31 -5
View File
@@ -12,6 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import logging
from typing import Union
import graphviz
@@ -21,7 +24,15 @@ from ..agents.llm_agent import LlmAgent
from ..tools.agent_tool import AgentTool
from ..tools.base_tool import BaseTool
from ..tools.function_tool import FunctionTool
from ..tools.retrieval.base_retrieval_tool import BaseRetrievalTool
logger = logging.getLogger(__name__)
try:
from ..tools.retrieval.base_retrieval_tool import BaseRetrievalTool
except ModuleNotFoundError:
retrieval_tool_module_loaded = False
else:
retrieval_tool_module_loaded = True
def build_graph(graph, agent: BaseAgent, highlight_pairs):
@@ -38,9 +49,12 @@ def build_graph(graph, agent: BaseAgent, highlight_pairs):
raise ValueError(f'Unsupported tool type: {tool_or_agent}')
def get_node_caption(tool_or_agent: Union[BaseAgent, BaseTool]):
if isinstance(tool_or_agent, BaseAgent):
return '🤖 ' + tool_or_agent.name
elif isinstance(tool_or_agent, BaseRetrievalTool):
elif retrieval_tool_module_loaded and isinstance(
tool_or_agent, BaseRetrievalTool
):
return '🔎 ' + tool_or_agent.name
elif isinstance(tool_or_agent, FunctionTool):
return '🔧 ' + tool_or_agent.name
@@ -49,19 +63,31 @@ def build_graph(graph, agent: BaseAgent, highlight_pairs):
elif isinstance(tool_or_agent, BaseTool):
return '🔧 ' + tool_or_agent.name
else:
raise ValueError(f'Unsupported tool type: {type(tool)}')
logger.warning(
'Unsupported tool, type: %s, obj: %s',
type(tool_or_agent),
tool_or_agent,
)
return f'❓ Unsupported tool type: {type(tool_or_agent)}'
def get_node_shape(tool_or_agent: Union[BaseAgent, BaseTool]):
if isinstance(tool_or_agent, BaseAgent):
return 'ellipse'
elif isinstance(tool_or_agent, BaseRetrievalTool):
elif retrieval_tool_module_loaded and isinstance(
tool_or_agent, BaseRetrievalTool
):
return 'cylinder'
elif isinstance(tool_or_agent, FunctionTool):
return 'box'
elif isinstance(tool_or_agent, BaseTool):
return 'box'
else:
raise ValueError(f'Unsupported tool type: {type(tool_or_agent)}')
logger.warning(
'Unsupported tool, type: %s, obj: %s',
type(tool_or_agent),
tool_or_agent,
)
return 'cylinder'
def draw_node(tool_or_agent: Union[BaseAgent, BaseTool]):
name = get_node_name(tool_or_agent)