From 74454170a32b65648974b236a9531ed20b29554b Mon Sep 17 00:00:00 2001 From: "Xiang (Sean) Zhou" Date: Mon, 19 May 2025 12:28:19 -0700 Subject: [PATCH] refactor: refactor and refine LangChainTool PiperOrigin-RevId: 760726719 --- src/google/adk/tools/langchain_tool.py | 145 ++++++++++++++++--------- 1 file changed, 96 insertions(+), 49 deletions(-) diff --git a/src/google/adk/tools/langchain_tool.py b/src/google/adk/tools/langchain_tool.py index b275926..b36c3f5 100644 --- a/src/google/adk/tools/langchain_tool.py +++ b/src/google/adk/tools/langchain_tool.py @@ -13,10 +13,12 @@ # limitations under the License. from typing import Any -from typing import Callable +from typing import Optional +from typing import Union from google.genai import types -from pydantic import model_validator +from langchain.agents import Tool +from langchain_core.tools import BaseTool from typing_extensions import override from . import _automatic_function_calling_util @@ -24,63 +26,108 @@ from .function_tool import FunctionTool class LangchainTool(FunctionTool): - """Use this class to wrap a langchain tool. + """Adapter class that wraps a Langchain tool for use with ADK. - If the original tool name and description are not suitable, you can override - them in the constructor. + This adapter converts Langchain tools into a format compatible with Google's + generative AI function calling interface. It preserves the tool's name, + description, and functionality while adapting its schema. + + The original tool's name and description can be overridden if needed. + + Args: + tool: A Langchain tool to wrap (BaseTool or a tool with a .run method) + name: Optional override for the tool's name + description: Optional override for the tool's description + + Examples: + ```python + from langchain.tools import DuckDuckGoSearchTool + from google.genai.tools import LangchainTool + + search_tool = DuckDuckGoSearchTool() + wrapped_tool = LangchainTool(search_tool) + ``` """ - tool: Any + _langchain_tool: Union[BaseTool, object] """The wrapped langchain tool.""" - def __init__(self, tool: Any): - super().__init__(tool._run) - self.tool = tool - if tool.name: - self.name = tool.name - if tool.description: - self.description = tool.description + def __init__( + self, + tool: Union[BaseTool, object], + name: Optional[str] = None, + description: Optional[str] = None, + ): + # Check if the tool has a 'run' method + if not hasattr(tool, 'run') and not hasattr(tool, '_run'): + raise ValueError("Langchain tool must have a 'run' or '_run' method") - @model_validator(mode='before') - @classmethod - def populate_name(cls, data: Any) -> Any: - # Override this to not use function's signature name as it's - # mostly "run" or "invoke" for thir-party tools. - return data + # Determine which function to use + func = tool._run if hasattr(tool, '_run') else tool.run + super().__init__(func) + + self._langchain_tool = tool + + # Set name: priority is 1) explicitly provided name, 2) tool's name, 3) default + if name is not None: + self.name = name + elif hasattr(tool, 'name') and tool.name: + self.name = tool.name + # else: keep default from FunctionTool + + # Set description: similar priority + if description is not None: + self.description = description + elif hasattr(tool, 'description') and tool.description: + self.description = tool.description + # else: keep default from FunctionTool @override def _get_declaration(self) -> types.FunctionDeclaration: - """Build the function declaration for the tool.""" - from langchain.agents import Tool - from langchain_core.tools import BaseTool + """Build the function declaration for the tool. + + Returns: + A FunctionDeclaration object that describes the tool's interface. + + Raises: + ValueError: If the tool schema cannot be correctly parsed. + """ + try: + # There are two types of tools: + # 1. BaseTool: the tool is defined in langchain_core.tools. + # 2. Other tools: the tool doesn't inherit any class but follow some + # conventions, like having a "run" method. + # Handle BaseTool type (preferred Langchain approach) + if isinstance(self._langchain_tool, BaseTool): + tool_wrapper = Tool( + name=self.name, + func=self.func, + description=self.description, + ) + + # Add schema if available + if ( + hasattr(self._langchain_tool, 'args_schema') + and self._langchain_tool.args_schema + ): + tool_wrapper.args_schema = self._langchain_tool.args_schema + + return _automatic_function_calling_util.build_function_declaration_for_langchain( + False, + self.name, + self.description, + tool_wrapper.func, + getattr(tool_wrapper, 'args', None), + ) - # There are two types of tools: - # 1. BaseTool: the tool is defined in langchain.tools. - # 2. Other tools: the tool doesn't inherit any class but follow some - # conventions, like having a "run" method. - if isinstance(self.tool, BaseTool): - tool_wrapper = Tool( - name=self.name, - func=self.func, - description=self.description, - ) - if self.tool.args_schema: - tool_wrapper.args_schema = self.tool.args_schema - function_declaration = _automatic_function_calling_util.build_function_declaration_for_langchain( - False, - self.name, - self.description, - tool_wrapper.func, - tool_wrapper.args, - ) - return function_declaration - else: # Need to provide a way to override the function names and descriptions # as the original function names are mostly ".run" and the descriptions - # may not meet users' needs. - function_declaration = ( - _automatic_function_calling_util.build_function_declaration( - func=self.tool.run, - ) + # may not meet users' needs + return _automatic_function_calling_util.build_function_declaration( + func=self._langchain_tool.run, ) - return function_declaration + + except Exception as e: + raise ValueError( + f'Failed to build function declaration for Langchain tool: {e}' + ) from e