refactor: refactor and refine LangChainTool

PiperOrigin-RevId: 760726719
This commit is contained in:
Xiang (Sean) Zhou 2025-05-19 12:28:19 -07:00 committed by Copybara-Service
parent ae7d19a4c6
commit 74454170a3

View File

@ -13,10 +13,12 @@
# limitations under the License. # limitations under the License.
from typing import Any from typing import Any
from typing import Callable from typing import Optional
from typing import Union
from google.genai import types from google.genai import types
from pydantic import model_validator from langchain.agents import Tool
from langchain_core.tools import BaseTool
from typing_extensions import override from typing_extensions import override
from . import _automatic_function_calling_util from . import _automatic_function_calling_util
@ -24,63 +26,108 @@ from .function_tool import FunctionTool
class LangchainTool(FunctionTool): class LangchainTool(FunctionTool):
"""Use this class to wrap a langchain tool. """Adapter class that wraps a Langchain tool for use with ADK.
If the original tool name and description are not suitable, you can override This adapter converts Langchain tools into a format compatible with Google's
them in the constructor. generative AI function calling interface. It preserves the tool's name,
description, and functionality while adapting its schema.
The original tool's name and description can be overridden if needed.
Args:
tool: A Langchain tool to wrap (BaseTool or a tool with a .run method)
name: Optional override for the tool's name
description: Optional override for the tool's description
Examples:
```python
from langchain.tools import DuckDuckGoSearchTool
from google.genai.tools import LangchainTool
search_tool = DuckDuckGoSearchTool()
wrapped_tool = LangchainTool(search_tool)
```
""" """
tool: Any _langchain_tool: Union[BaseTool, object]
"""The wrapped langchain tool.""" """The wrapped langchain tool."""
def __init__(self, tool: Any): def __init__(
super().__init__(tool._run) self,
self.tool = tool tool: Union[BaseTool, object],
if tool.name: name: Optional[str] = None,
self.name = tool.name description: Optional[str] = None,
if tool.description: ):
self.description = tool.description # Check if the tool has a 'run' method
if not hasattr(tool, 'run') and not hasattr(tool, '_run'):
raise ValueError("Langchain tool must have a 'run' or '_run' method")
@model_validator(mode='before') # Determine which function to use
@classmethod func = tool._run if hasattr(tool, '_run') else tool.run
def populate_name(cls, data: Any) -> Any: super().__init__(func)
# Override this to not use function's signature name as it's
# mostly "run" or "invoke" for thir-party tools. self._langchain_tool = tool
return data
# Set name: priority is 1) explicitly provided name, 2) tool's name, 3) default
if name is not None:
self.name = name
elif hasattr(tool, 'name') and tool.name:
self.name = tool.name
# else: keep default from FunctionTool
# Set description: similar priority
if description is not None:
self.description = description
elif hasattr(tool, 'description') and tool.description:
self.description = tool.description
# else: keep default from FunctionTool
@override @override
def _get_declaration(self) -> types.FunctionDeclaration: def _get_declaration(self) -> types.FunctionDeclaration:
"""Build the function declaration for the tool.""" """Build the function declaration for the tool.
from langchain.agents import Tool
from langchain_core.tools import BaseTool
Returns:
A FunctionDeclaration object that describes the tool's interface.
Raises:
ValueError: If the tool schema cannot be correctly parsed.
"""
try:
# There are two types of tools: # There are two types of tools:
# 1. BaseTool: the tool is defined in langchain.tools. # 1. BaseTool: the tool is defined in langchain_core.tools.
# 2. Other tools: the tool doesn't inherit any class but follow some # 2. Other tools: the tool doesn't inherit any class but follow some
# conventions, like having a "run" method. # conventions, like having a "run" method.
if isinstance(self.tool, BaseTool): # Handle BaseTool type (preferred Langchain approach)
if isinstance(self._langchain_tool, BaseTool):
tool_wrapper = Tool( tool_wrapper = Tool(
name=self.name, name=self.name,
func=self.func, func=self.func,
description=self.description, description=self.description,
) )
if self.tool.args_schema:
tool_wrapper.args_schema = self.tool.args_schema # Add schema if available
function_declaration = _automatic_function_calling_util.build_function_declaration_for_langchain( if (
hasattr(self._langchain_tool, 'args_schema')
and self._langchain_tool.args_schema
):
tool_wrapper.args_schema = self._langchain_tool.args_schema
return _automatic_function_calling_util.build_function_declaration_for_langchain(
False, False,
self.name, self.name,
self.description, self.description,
tool_wrapper.func, tool_wrapper.func,
tool_wrapper.args, getattr(tool_wrapper, 'args', None),
) )
return function_declaration
else:
# Need to provide a way to override the function names and descriptions # Need to provide a way to override the function names and descriptions
# as the original function names are mostly ".run" and the descriptions # as the original function names are mostly ".run" and the descriptions
# may not meet users' needs. # may not meet users' needs
function_declaration = ( return _automatic_function_calling_util.build_function_declaration(
_automatic_function_calling_util.build_function_declaration( func=self._langchain_tool.run,
func=self.tool.run,
) )
)
return function_declaration except Exception as e:
raise ValueError(
f'Failed to build function declaration for Langchain tool: {e}'
) from e