refactor: refactor and refine LangChainTool

PiperOrigin-RevId: 760726719
This commit is contained in:
Xiang (Sean) Zhou 2025-05-19 12:28:19 -07:00 committed by Copybara-Service
parent ae7d19a4c6
commit 74454170a3

View File

@ -13,10 +13,12 @@
# limitations under the License.
from typing import Any
from typing import Callable
from typing import Optional
from typing import Union
from google.genai import types
from pydantic import model_validator
from langchain.agents import Tool
from langchain_core.tools import BaseTool
from typing_extensions import override
from . import _automatic_function_calling_util
@ -24,63 +26,108 @@ from .function_tool import FunctionTool
class LangchainTool(FunctionTool):
"""Use this class to wrap a langchain tool.
"""Adapter class that wraps a Langchain tool for use with ADK.
If the original tool name and description are not suitable, you can override
them in the constructor.
This adapter converts Langchain tools into a format compatible with Google's
generative AI function calling interface. It preserves the tool's name,
description, and functionality while adapting its schema.
The original tool's name and description can be overridden if needed.
Args:
tool: A Langchain tool to wrap (BaseTool or a tool with a .run method)
name: Optional override for the tool's name
description: Optional override for the tool's description
Examples:
```python
from langchain.tools import DuckDuckGoSearchTool
from google.genai.tools import LangchainTool
search_tool = DuckDuckGoSearchTool()
wrapped_tool = LangchainTool(search_tool)
```
"""
tool: Any
_langchain_tool: Union[BaseTool, object]
"""The wrapped langchain tool."""
def __init__(self, tool: Any):
super().__init__(tool._run)
self.tool = tool
if tool.name:
self.name = tool.name
if tool.description:
self.description = tool.description
def __init__(
self,
tool: Union[BaseTool, object],
name: Optional[str] = None,
description: Optional[str] = None,
):
# Check if the tool has a 'run' method
if not hasattr(tool, 'run') and not hasattr(tool, '_run'):
raise ValueError("Langchain tool must have a 'run' or '_run' method")
@model_validator(mode='before')
@classmethod
def populate_name(cls, data: Any) -> Any:
# Override this to not use function's signature name as it's
# mostly "run" or "invoke" for thir-party tools.
return data
# Determine which function to use
func = tool._run if hasattr(tool, '_run') else tool.run
super().__init__(func)
self._langchain_tool = tool
# Set name: priority is 1) explicitly provided name, 2) tool's name, 3) default
if name is not None:
self.name = name
elif hasattr(tool, 'name') and tool.name:
self.name = tool.name
# else: keep default from FunctionTool
# Set description: similar priority
if description is not None:
self.description = description
elif hasattr(tool, 'description') and tool.description:
self.description = tool.description
# else: keep default from FunctionTool
@override
def _get_declaration(self) -> types.FunctionDeclaration:
"""Build the function declaration for the tool."""
from langchain.agents import Tool
from langchain_core.tools import BaseTool
"""Build the function declaration for the tool.
Returns:
A FunctionDeclaration object that describes the tool's interface.
Raises:
ValueError: If the tool schema cannot be correctly parsed.
"""
try:
# There are two types of tools:
# 1. BaseTool: the tool is defined in langchain_core.tools.
# 2. Other tools: the tool doesn't inherit any class but follow some
# conventions, like having a "run" method.
# Handle BaseTool type (preferred Langchain approach)
if isinstance(self._langchain_tool, BaseTool):
tool_wrapper = Tool(
name=self.name,
func=self.func,
description=self.description,
)
# Add schema if available
if (
hasattr(self._langchain_tool, 'args_schema')
and self._langchain_tool.args_schema
):
tool_wrapper.args_schema = self._langchain_tool.args_schema
return _automatic_function_calling_util.build_function_declaration_for_langchain(
False,
self.name,
self.description,
tool_wrapper.func,
getattr(tool_wrapper, 'args', None),
)
# There are two types of tools:
# 1. BaseTool: the tool is defined in langchain.tools.
# 2. Other tools: the tool doesn't inherit any class but follow some
# conventions, like having a "run" method.
if isinstance(self.tool, BaseTool):
tool_wrapper = Tool(
name=self.name,
func=self.func,
description=self.description,
)
if self.tool.args_schema:
tool_wrapper.args_schema = self.tool.args_schema
function_declaration = _automatic_function_calling_util.build_function_declaration_for_langchain(
False,
self.name,
self.description,
tool_wrapper.func,
tool_wrapper.args,
)
return function_declaration
else:
# Need to provide a way to override the function names and descriptions
# as the original function names are mostly ".run" and the descriptions
# may not meet users' needs.
function_declaration = (
_automatic_function_calling_util.build_function_declaration(
func=self.tool.run,
)
# may not meet users' needs
return _automatic_function_calling_util.build_function_declaration(
func=self._langchain_tool.run,
)
return function_declaration
except Exception as e:
raise ValueError(
f'Failed to build function declaration for Langchain tool: {e}'
) from e