structure saas with tools
This commit is contained in:
@@ -0,0 +1,327 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
"""Classes and functions for working with agent engines."""
|
||||
|
||||
from typing import Dict, Iterable, Optional, Sequence, Union
|
||||
|
||||
from google.cloud.aiplatform import base
|
||||
from google.cloud.aiplatform import initializer
|
||||
from google.cloud.aiplatform import utils as aip_utils
|
||||
from google.cloud.aiplatform_v1 import types as aip_types
|
||||
|
||||
# We just want to re-export certain classes
|
||||
# pylint: disable=g-multiple-import,g-importing-member
|
||||
from vertexai.agent_engines._agent_engines import (
|
||||
AgentEngine,
|
||||
Cloneable,
|
||||
ModuleAgent,
|
||||
OperationRegistrable,
|
||||
Queryable,
|
||||
StreamQueryable,
|
||||
)
|
||||
from vertexai.agent_engines.templates.ag2 import (
|
||||
AG2Agent,
|
||||
)
|
||||
from vertexai.agent_engines.templates.langchain import (
|
||||
LangchainAgent,
|
||||
)
|
||||
from vertexai.agent_engines.templates.langgraph import (
|
||||
LanggraphAgent,
|
||||
)
|
||||
|
||||
|
||||
_LOGGER = base.Logger(__name__)
|
||||
|
||||
|
||||
def get(resource_name: str) -> AgentEngine:
|
||||
"""Retrieves an Agent Engine resource.
|
||||
|
||||
Args:
|
||||
resource_name (str):
|
||||
Required. A fully-qualified resource name or ID such as
|
||||
"projects/123/locations/us-central1/reasoningEngines/456" or
|
||||
"456" when project and location are initialized or passed.
|
||||
"""
|
||||
return AgentEngine(resource_name)
|
||||
|
||||
|
||||
def create(
|
||||
agent_engine: Optional[Union[Queryable, OperationRegistrable]] = None,
|
||||
*,
|
||||
requirements: Optional[Union[str, Sequence[str]]] = None,
|
||||
display_name: Optional[str] = None,
|
||||
description: Optional[str] = None,
|
||||
gcs_dir_name: Optional[str] = None,
|
||||
extra_packages: Optional[Sequence[str]] = None,
|
||||
env_vars: Optional[
|
||||
Union[Sequence[str], Dict[str, Union[str, aip_types.SecretRef]]]
|
||||
] = None,
|
||||
) -> AgentEngine:
|
||||
"""Creates a new Agent Engine.
|
||||
|
||||
The Agent Engine will be an instance of the `agent_engine` that
|
||||
was passed in, running remotely on Vertex AI.
|
||||
|
||||
Sample ``src_dir`` contents (e.g. ``./user_src_dir``):
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
user_src_dir/
|
||||
|-- main.py
|
||||
|-- requirements.txt
|
||||
|-- user_code/
|
||||
| |-- utils.py
|
||||
| |-- ...
|
||||
|-- ...
|
||||
|
||||
To build an Agent Engine with the above files, run:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
remote_agent = agent_engines.create(
|
||||
agent_engine=local_agent,
|
||||
requirements=[
|
||||
# I.e. the PyPI dependencies listed in requirements.txt
|
||||
"google-cloud-aiplatform==1.25.0",
|
||||
"langchain==0.0.242",
|
||||
...
|
||||
],
|
||||
extra_packages=[
|
||||
"./user_src_dir/main.py", # a single file
|
||||
"./user_src_dir/user_code", # a directory
|
||||
...
|
||||
],
|
||||
)
|
||||
|
||||
Args:
|
||||
agent_engine (AgentEngineInterface):
|
||||
Required. The Agent Engine to be created.
|
||||
requirements (Union[str, Sequence[str]]):
|
||||
Optional. The set of PyPI dependencies needed. It can either be
|
||||
the path to a single file (requirements.txt), or an ordered list
|
||||
of strings corresponding to each line of the requirements file.
|
||||
display_name (str):
|
||||
Optional. The user-defined name of the Agent Engine.
|
||||
The name can be up to 128 characters long and can comprise any
|
||||
UTF-8 character.
|
||||
description (str):
|
||||
Optional. The description of the Agent Engine.
|
||||
gcs_dir_name (str):
|
||||
Optional. The GCS bucket directory under `staging_bucket` to
|
||||
use for staging the artifacts needed.
|
||||
extra_packages (Sequence[str]):
|
||||
Optional. The set of extra user-provided packages (if any).
|
||||
env_vars (Union[Sequence[str], Dict[str, Union[str, SecretRef]]]):
|
||||
Optional. The environment variables to be set when running the
|
||||
Agent Engine. If it is a list of strings, each string should be
|
||||
a valid key to `os.environ`. If it is a dictionary, the keys are
|
||||
the environment variable names, and the values are the
|
||||
corresponding values.
|
||||
|
||||
Returns:
|
||||
AgentEngine: The Agent Engine that was created.
|
||||
|
||||
Raises:
|
||||
ValueError: If the `project` was not set using `vertexai.init`.
|
||||
ValueError: If the `location` was not set using `vertexai.init`.
|
||||
ValueError: If the `staging_bucket` was not set using vertexai.init.
|
||||
ValueError: If the `staging_bucket` does not start with "gs://".
|
||||
FileNotFoundError: If `extra_packages` includes a file or directory
|
||||
that does not exist.
|
||||
IOError: If requirements is a string that corresponds to a
|
||||
nonexistent file.
|
||||
"""
|
||||
return AgentEngine.create(
|
||||
agent_engine=agent_engine,
|
||||
requirements=requirements,
|
||||
display_name=display_name,
|
||||
description=description,
|
||||
gcs_dir_name=gcs_dir_name,
|
||||
extra_packages=extra_packages,
|
||||
env_vars=env_vars,
|
||||
)
|
||||
|
||||
|
||||
def list(*, filter: str = "") -> Iterable[AgentEngine]:
|
||||
"""List all instances of Agent Engine matching the filter.
|
||||
|
||||
Example Usage:
|
||||
|
||||
.. code-block:: python
|
||||
import vertexai
|
||||
from vertexai import agent_engines
|
||||
|
||||
vertexai.init(project="my_project", location="us-central1")
|
||||
agent_engines.list(filter='display_name="My Custom Agent"')
|
||||
|
||||
Args:
|
||||
filter (str):
|
||||
Optional. An expression for filtering the results of the request.
|
||||
For field names both snake_case and camelCase are supported.
|
||||
|
||||
Returns:
|
||||
Iterable[AgentEngine]: An iterable of Agent Engines matching the filter.
|
||||
"""
|
||||
api_client = initializer.global_config.create_client(
|
||||
client_class=aip_utils.AgentEngineClientWithOverride,
|
||||
)
|
||||
for agent in api_client.list_reasoning_engines(
|
||||
request=aip_types.ListReasoningEnginesRequest(
|
||||
parent=initializer.global_config.common_location_path(),
|
||||
filter=filter,
|
||||
)
|
||||
):
|
||||
yield AgentEngine(agent.name)
|
||||
|
||||
|
||||
def delete(
|
||||
resource_name: str,
|
||||
*,
|
||||
force: bool = False,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
"""Delete an Agent Engine resource.
|
||||
|
||||
Args:
|
||||
resource_name (str):
|
||||
Required. The name of the Agent Engine to be deleted. Format:
|
||||
`projects/{project}/locations/{location}/reasoningEngines/{resource_id}`
|
||||
force (bool):
|
||||
Optional. If set to True, child resources will also be deleted.
|
||||
Otherwise, the request will fail with FAILED_PRECONDITION error
|
||||
when the Agent Engine has undeleted child resources. Defaults to
|
||||
False.
|
||||
**kwargs (dict[str, Any]):
|
||||
Optional. Additional keyword arguments to pass to the
|
||||
delete_reasoning_engine method.
|
||||
"""
|
||||
api_client = initializer.global_config.create_client(
|
||||
client_class=aip_utils.AgentEngineClientWithOverride,
|
||||
)
|
||||
_LOGGER.info(f"Deleting AgentEngine resource: {resource_name}")
|
||||
operation_future = api_client.delete_reasoning_engine(
|
||||
request=aip_types.DeleteReasoningEngineRequest(
|
||||
name=resource_name,
|
||||
force=force,
|
||||
**(kwargs or {}),
|
||||
)
|
||||
)
|
||||
_LOGGER.info(f"Delete AgentEngine backing LRO: {operation_future.operation.name}")
|
||||
operation_future.result()
|
||||
_LOGGER.info(f"AgentEngine resource deleted: {resource_name}")
|
||||
|
||||
|
||||
def update(
|
||||
resource_name: str,
|
||||
*,
|
||||
agent_engine: Optional[Union[Queryable, OperationRegistrable]] = None,
|
||||
requirements: Optional[Union[str, Sequence[str]]] = None,
|
||||
display_name: Optional[str] = None,
|
||||
description: Optional[str] = None,
|
||||
gcs_dir_name: Optional[str] = None,
|
||||
extra_packages: Optional[Sequence[str]] = None,
|
||||
env_vars: Optional[
|
||||
Union[Sequence[str], Dict[str, Union[str, aip_types.SecretRef]]]
|
||||
] = None,
|
||||
) -> "AgentEngine":
|
||||
"""Updates an existing Agent Engine.
|
||||
|
||||
This method updates the configuration of a deployed Agent Engine, identified
|
||||
by its resource name. Unlike the `create` function which requires an
|
||||
`agent_engine` object, all arguments in this method are optional. This
|
||||
method allows you to modify individual aspects of the configuration by
|
||||
providing any of the optional arguments.
|
||||
|
||||
Args:
|
||||
resource_name (str):
|
||||
Required. The name of the Agent Engine to be updated. Format:
|
||||
`projects/{project}/locations/{location}/reasoningEngines/{resource_id}`.
|
||||
agent_engine (AgentEngineInterface):
|
||||
Optional. The instance to be used as the updated Agent Engine. If it
|
||||
is not specified, the existing instance will be used.
|
||||
requirements (Union[str, Sequence[str]]):
|
||||
Optional. The set of PyPI dependencies needed. It can either be
|
||||
the path to a single file (requirements.txt), or an ordered list
|
||||
of strings corresponding to each line of the requirements file.
|
||||
If it is not specified, the existing requirements will be used.
|
||||
If it is set to an empty string or list, the existing
|
||||
requirements will be removed.
|
||||
display_name (str):
|
||||
Optional. The user-defined name of the Agent Engine.
|
||||
The name can be up to 128 characters long and can comprise any
|
||||
UTF-8 character.
|
||||
description (str):
|
||||
Optional. The description of the Agent Engine.
|
||||
gcs_dir_name (str):
|
||||
Optional. The GCS bucket directory under `staging_bucket` to
|
||||
use for staging the artifacts needed.
|
||||
extra_packages (Sequence[str]):
|
||||
Optional. The set of extra user-provided packages (if any). If
|
||||
it is not specified, the existing extra packages will be used.
|
||||
If it is set to an empty list, the existing extra packages will
|
||||
be removed.
|
||||
env_vars (Union[Sequence[str], Dict[str, Union[str, SecretRef]]]):
|
||||
Optional. The environment variables to be set when running the
|
||||
Agent Engine. If it is a list of strings, each string should be
|
||||
a valid key to `os.environ`. If it is a dictionary, the keys are
|
||||
the environment variable names, and the values are the
|
||||
corresponding values.
|
||||
|
||||
Returns:
|
||||
AgentEngine: The Agent Engine that was updated.
|
||||
|
||||
Raises:
|
||||
ValueError: If the `staging_bucket` was not set using vertexai.init.
|
||||
ValueError: If the `staging_bucket` does not start with "gs://".
|
||||
FileNotFoundError: If `extra_packages` includes a file or directory
|
||||
that does not exist.
|
||||
ValueError: if none of `display_name`, `description`,
|
||||
`requirements`, `extra_packages`, or `agent_engine` were
|
||||
specified.
|
||||
IOError: If requirements is a string that corresponds to a
|
||||
nonexistent file.
|
||||
"""
|
||||
agent = get(resource_name)
|
||||
return agent.update(
|
||||
agent_engine=agent_engine,
|
||||
requirements=requirements,
|
||||
display_name=display_name,
|
||||
description=description,
|
||||
gcs_dir_name=gcs_dir_name,
|
||||
extra_packages=extra_packages,
|
||||
env_vars=env_vars,
|
||||
)
|
||||
|
||||
|
||||
__all__ = (
|
||||
# Resources
|
||||
"AgentEngine",
|
||||
# Protocols
|
||||
"Cloneable",
|
||||
"OperationRegistrable",
|
||||
"Queryable",
|
||||
"StreamQueryable",
|
||||
# Methods
|
||||
"create",
|
||||
"delete",
|
||||
"get",
|
||||
"list",
|
||||
"update",
|
||||
# Templates
|
||||
"ModuleAgent",
|
||||
"LangchainAgent",
|
||||
"LanggraphAgent",
|
||||
"AG2Agent",
|
||||
)
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,769 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2023 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import dataclasses
|
||||
import inspect
|
||||
import json
|
||||
import sys
|
||||
import types
|
||||
import typing
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
Iterable,
|
||||
List,
|
||||
Mapping,
|
||||
Optional,
|
||||
Sequence,
|
||||
Set,
|
||||
TypedDict,
|
||||
Union,
|
||||
)
|
||||
from importlib import metadata as importlib_metadata
|
||||
|
||||
import proto
|
||||
|
||||
from google.cloud.aiplatform import base
|
||||
from google.api import httpbody_pb2
|
||||
from google.protobuf import struct_pb2
|
||||
from google.protobuf import json_format
|
||||
|
||||
try:
|
||||
# For LangChain templates, they might not import langchain_core and get
|
||||
# PydanticUserError: `query` is not fully defined; you should define
|
||||
# `RunnableConfig`, then call `query.model_rebuild()`.
|
||||
import langchain_core.runnables.config
|
||||
|
||||
RunnableConfig = langchain_core.runnables.config.RunnableConfig
|
||||
except ImportError:
|
||||
RunnableConfig = Any
|
||||
|
||||
try:
|
||||
import packaging
|
||||
|
||||
SpecifierSet = packaging.specifiers.SpecifierSet
|
||||
except AttributeError:
|
||||
SpecifierSet = Any
|
||||
|
||||
try:
|
||||
_BUILTIN_MODULE_NAMES: Sequence[str] = sys.builtin_module_names
|
||||
except AttributeError:
|
||||
_BUILTIN_MODULE_NAMES: Sequence[str] = []
|
||||
|
||||
try:
|
||||
# sys.stdlib_module_names is available from Python 3.10 onwards.
|
||||
_STDLIB_MODULE_NAMES: frozenset = sys.stdlib_module_names
|
||||
except AttributeError:
|
||||
_STDLIB_MODULE_NAMES: frozenset = frozenset()
|
||||
|
||||
try:
|
||||
_PACKAGE_DISTRIBUTIONS: Mapping[
|
||||
str, Sequence[str]
|
||||
] = importlib_metadata.packages_distributions()
|
||||
except AttributeError:
|
||||
_PACKAGE_DISTRIBUTIONS: Mapping[str, Sequence[str]] = {}
|
||||
|
||||
try:
|
||||
from autogen.agentchat import chat
|
||||
|
||||
AutogenChatResult = chat.ChatResult
|
||||
except ImportError:
|
||||
AutogenChatResult = Any
|
||||
|
||||
try:
|
||||
from autogen.io import run_response
|
||||
|
||||
AutogenRunResponse = run_response.RunResponse
|
||||
except ImportError:
|
||||
AutogenRunResponse = Any
|
||||
|
||||
JsonDict = Dict[str, Any]
|
||||
|
||||
|
||||
class _RequirementsValidationActions(TypedDict):
|
||||
append: Set[str]
|
||||
|
||||
|
||||
class _RequirementsValidationWarnings(TypedDict):
|
||||
missing: Set[str]
|
||||
incompatible: Set[str]
|
||||
|
||||
|
||||
class _RequirementsValidationResult(TypedDict):
|
||||
warnings: _RequirementsValidationWarnings
|
||||
actions: _RequirementsValidationActions
|
||||
|
||||
|
||||
LOGGER = base.Logger("vertexai.agent_engines")
|
||||
|
||||
_BASE_MODULES = set(_BUILTIN_MODULE_NAMES + tuple(_STDLIB_MODULE_NAMES))
|
||||
_DEFAULT_REQUIRED_PACKAGES = frozenset(["cloudpickle", "pydantic"])
|
||||
_ACTIONS_KEY = "actions"
|
||||
_ACTION_APPEND = "append"
|
||||
_WARNINGS_KEY = "warnings"
|
||||
_WARNING_MISSING = "missing"
|
||||
_WARNING_INCOMPATIBLE = "incompatible"
|
||||
|
||||
|
||||
def to_proto(
|
||||
obj: Union[JsonDict, proto.Message],
|
||||
message: Optional[proto.Message] = None,
|
||||
) -> proto.Message:
|
||||
"""Parses a JSON-like object into a message.
|
||||
|
||||
If the object is already a message, this will return the object as-is. If
|
||||
the object is a JSON Dict, this will parse and merge the object into the
|
||||
message.
|
||||
|
||||
Args:
|
||||
obj (Union[dict[str, Any], proto.Message]):
|
||||
Required. The object to convert to a proto message.
|
||||
message (proto.Message):
|
||||
Optional. A protocol buffer message to merge the obj into. It
|
||||
defaults to Struct() if unspecified.
|
||||
|
||||
Returns:
|
||||
proto.Message: The same message passed as argument.
|
||||
"""
|
||||
if message is None:
|
||||
message = struct_pb2.Struct()
|
||||
if isinstance(obj, (proto.Message, struct_pb2.Struct)):
|
||||
return obj
|
||||
try:
|
||||
json_format.ParseDict(obj, message._pb)
|
||||
except AttributeError:
|
||||
json_format.ParseDict(obj, message)
|
||||
return message
|
||||
|
||||
|
||||
def to_dict(message: proto.Message) -> JsonDict:
|
||||
"""Converts the contents of the protobuf message to JSON format.
|
||||
|
||||
Args:
|
||||
message (proto.Message):
|
||||
Required. The proto message to be converted to a JSON dictionary.
|
||||
|
||||
Returns:
|
||||
dict[str, Any]: A dictionary containing the contents of the proto.
|
||||
"""
|
||||
try:
|
||||
# Best effort attempt to convert the message into a JSON dictionary.
|
||||
result: JsonDict = json.loads(
|
||||
json_format.MessageToJson(
|
||||
message._pb,
|
||||
preserving_proto_field_name=True,
|
||||
)
|
||||
)
|
||||
except AttributeError:
|
||||
result: JsonDict = json.loads(
|
||||
json_format.MessageToJson(
|
||||
message,
|
||||
preserving_proto_field_name=True,
|
||||
)
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
def _dataclass_to_dict_or_raise(obj: Any) -> JsonDict:
|
||||
"""Converts a dataclass to a JSON dictionary.
|
||||
|
||||
Args:
|
||||
obj (Any):
|
||||
Required. The dataclass to be converted to a JSON dictionary.
|
||||
|
||||
Returns:
|
||||
dict[str, Any]: A dictionary containing the contents of the dataclass.
|
||||
|
||||
Raises:
|
||||
TypeError: If the object is not a dataclass.
|
||||
"""
|
||||
if not dataclasses.is_dataclass(obj):
|
||||
raise TypeError(f"Object is not a dataclass: {obj}")
|
||||
return json.loads(json.dumps(dataclasses.asdict(obj)))
|
||||
|
||||
|
||||
def _autogen_run_response_protocol_to_dict(
|
||||
obj: AutogenRunResponse,
|
||||
) -> JsonDict:
|
||||
"""Converts an AutogenRunResponse object into a JSON-serializable dictionary.
|
||||
|
||||
This function takes a `RunResponseProtocol` object and transforms its
|
||||
relevant attributes into a dictionary format suitable for JSON conversion.
|
||||
|
||||
The `RunResponseProtocol` defines the structure of the response object,
|
||||
which typically includes:
|
||||
|
||||
* **summary** (`Optional[str]`):
|
||||
A textual summary of the run.
|
||||
* **messages** (`Iterable[Message]`):
|
||||
A sequence of messages exchanged during the run.
|
||||
Each message is expected to be a JSON-serializable dictionary (`Dict[str,
|
||||
Any]`).
|
||||
* **events** (`Iterable[BaseEvent]`):
|
||||
A sequence of events that occurred during the run.
|
||||
Note: The `process()` method, if present, is called before conversion,
|
||||
which typically clears this event queue.
|
||||
* **context_variables** (`Optional[dict[str, Any]]`):
|
||||
A dictionary containing contextual variables from the run.
|
||||
* **last_speaker** (`Optional[Agent]`):
|
||||
The agent that produced the last message.
|
||||
The `Agent` object has attributes like `name` (Optional[str]) and
|
||||
`description` (Optional[str]).
|
||||
* **cost** (`Optional[Cost]`):
|
||||
Information about the computational cost of the run.
|
||||
The `Cost` object inherits from `pydantic.BaseModel` and is converted
|
||||
to JSON using its `model_dump_json()` method.
|
||||
* **process** (`Optional[Callable[[], None]]`):
|
||||
An optional function (like a console event processor) that is called
|
||||
before the conversion takes place.
|
||||
Executing this method often clears the `events` queue.
|
||||
|
||||
For a detailed definition of `RunResponseProtocol` and its components, refer
|
||||
to: https://github.com/ag2ai/ag2/blob/main/autogen/io/run_response.py
|
||||
|
||||
Args:
|
||||
obj (AutogenRunResponse): The AutogenRunResponse object to convert. This
|
||||
object must conform to the `RunResponseProtocol`.
|
||||
|
||||
Returns:
|
||||
JsonDict: A dictionary representation of the AutogenRunResponse, ready
|
||||
to be serialized into JSON. The dictionary includes keys like
|
||||
'summary', 'messages', 'context_variables', 'last_speaker_name',
|
||||
and 'cost'.
|
||||
"""
|
||||
if hasattr(obj, "process"):
|
||||
obj.process()
|
||||
|
||||
last_speaker = None
|
||||
if getattr(obj, "last_speaker", None) is not None:
|
||||
last_speaker = {
|
||||
"name": getattr(obj.last_speaker, "name", None),
|
||||
"description": getattr(obj.last_speaker, "description", None),
|
||||
}
|
||||
|
||||
cost = None
|
||||
if getattr(obj, "cost", None) is not None:
|
||||
if hasattr(obj.cost, "model_dump_json"):
|
||||
cost = json.loads(obj.cost.model_dump_json())
|
||||
else:
|
||||
cost = str(obj.cost)
|
||||
|
||||
result = {
|
||||
"summary": getattr(obj, "summary", None),
|
||||
"messages": list(getattr(obj, "messages", [])),
|
||||
"context_variables": getattr(obj, "context_variables", None),
|
||||
"last_speaker": last_speaker,
|
||||
"cost": cost,
|
||||
}
|
||||
return json.loads(json.dumps(result))
|
||||
|
||||
|
||||
def to_json_serializable_autogen_object(
|
||||
obj: Union[
|
||||
AutogenChatResult,
|
||||
AutogenRunResponse,
|
||||
]
|
||||
) -> JsonDict:
|
||||
"""Converts an Autogen object to a JSON serializable object.
|
||||
|
||||
In `ag2<=0.8.4`, `.run()` will return a `ChatResult` object.
|
||||
In `ag2>=0.8.5`, `.run()` will return a `RunResponse` object.
|
||||
|
||||
Args:
|
||||
obj (Union[AutogenChatResult, AutogenRunResponse]):
|
||||
Required. The Autogen object to be converted to a JSON serializable
|
||||
object.
|
||||
|
||||
Returns:
|
||||
JsonDict: A JSON serializable object.
|
||||
"""
|
||||
if isinstance(obj, AutogenChatResult):
|
||||
return _dataclass_to_dict_or_raise(obj)
|
||||
return _autogen_run_response_protocol_to_dict(obj)
|
||||
|
||||
|
||||
def yield_parsed_json(body: httpbody_pb2.HttpBody) -> Iterable[Any]:
|
||||
"""Converts the contents of the httpbody message to JSON format.
|
||||
|
||||
Args:
|
||||
body (httpbody_pb2.HttpBody):
|
||||
Required. The httpbody body to be converted to a JSON.
|
||||
|
||||
Yields:
|
||||
Any: A JSON object or the original body if it is not JSON or None.
|
||||
"""
|
||||
content_type = getattr(body, "content_type", None)
|
||||
data = getattr(body, "data", None)
|
||||
|
||||
if content_type is None or data is None or "application/json" not in content_type:
|
||||
yield body
|
||||
return
|
||||
|
||||
try:
|
||||
utf8_data = data.decode("utf-8")
|
||||
except Exception as e:
|
||||
LOGGER.warning(f"Failed to decode data: {data}. Exception: {e}")
|
||||
yield body
|
||||
return
|
||||
|
||||
if not utf8_data:
|
||||
yield None
|
||||
return
|
||||
|
||||
# Handle the case of multiple dictionaries delimited by newlines.
|
||||
for line in utf8_data.split("\n"):
|
||||
if line:
|
||||
try:
|
||||
line = json.loads(line)
|
||||
except Exception as e:
|
||||
LOGGER.warning(f"failed to parse json: {line}. Exception: {e}")
|
||||
yield line
|
||||
|
||||
|
||||
def parse_constraints(
|
||||
constraints: Sequence[str],
|
||||
) -> Mapping[str, "SpecifierSet"]:
|
||||
"""Parses a list of constraints into a dict of requirements.
|
||||
|
||||
Args:
|
||||
constraints (list[str]):
|
||||
Required. The list of package requirements to parse. This is assumed
|
||||
to come from the `requirements.txt` file.
|
||||
|
||||
Returns:
|
||||
dict[str, SpecifierSet]: The specifiers for each package.
|
||||
"""
|
||||
requirements = _import_packaging_requirements_or_raise()
|
||||
result = {}
|
||||
for constraint in constraints:
|
||||
try:
|
||||
requirement = requirements.Requirement(constraint)
|
||||
except Exception as e:
|
||||
LOGGER.warning(f"Failed to parse constraint: {constraint}. Exception: {e}")
|
||||
continue
|
||||
result[requirement.name] = requirement.specifier or None
|
||||
return result
|
||||
|
||||
|
||||
def validate_requirements_or_warn(
|
||||
obj: Any,
|
||||
requirements: List[str],
|
||||
) -> Mapping[str, str]:
|
||||
"""Compiles the requirements into a list of requirements."""
|
||||
requirements = requirements.copy()
|
||||
try:
|
||||
current_requirements = scan_requirements(obj)
|
||||
LOGGER.info(f"Identified the following requirements: {current_requirements}")
|
||||
constraints = parse_constraints(requirements)
|
||||
missing_requirements = compare_requirements(current_requirements, constraints)
|
||||
for warning_type, warnings in missing_requirements.get(
|
||||
_WARNINGS_KEY, {}
|
||||
).items():
|
||||
if warnings:
|
||||
LOGGER.warning(
|
||||
f"The following requirements are {warning_type}: {warnings}"
|
||||
)
|
||||
for action_type, actions in missing_requirements.get(_ACTIONS_KEY, {}).items():
|
||||
if actions and action_type == _ACTION_APPEND:
|
||||
for action in actions:
|
||||
requirements.append(action)
|
||||
LOGGER.info(f"The following requirements are appended: {actions}")
|
||||
except Exception as e:
|
||||
LOGGER.warning(f"Failed to compile requirements: {e}")
|
||||
return requirements
|
||||
|
||||
|
||||
def compare_requirements(
|
||||
requirements: Mapping[str, str],
|
||||
constraints: Union[Sequence[str], Mapping[str, "SpecifierSet"]],
|
||||
*,
|
||||
required_packages: Optional[Sequence[str]] = None,
|
||||
) -> Mapping[str, Mapping[str, Any]]:
|
||||
"""Compares the requirements with the constraints.
|
||||
|
||||
Args:
|
||||
requirements (Mapping[str, str]):
|
||||
Required. The packages (and their versions) to compare with the constraints.
|
||||
This is assumed to be the result of `scan_requirements`.
|
||||
constraints (Union[Sequence[str], Mapping[str, SpecifierSet]]):
|
||||
Required. The package constraints to compare against. This is assumed
|
||||
to be the result of `parse_constraints`.
|
||||
required_packages (Sequence[str]):
|
||||
Optional. The set of packages that are required to be in the
|
||||
constraints. It defaults to the set of packages that are required
|
||||
for deployment on Agent Engine.
|
||||
|
||||
Returns:
|
||||
dict[str, dict[str, Any]]: The comparison result as a dictionary containing:
|
||||
* warnings:
|
||||
* missing: The set of packages that are not in the constraints.
|
||||
* incompatible: The set of packages that are in the constraints
|
||||
but have versions that are not in the constraint specifier.
|
||||
* actions:
|
||||
* append: The set of packages that are not in the constraints
|
||||
but should be appended to the constraints.
|
||||
"""
|
||||
packaging_version = _import_packaging_version_or_raise()
|
||||
if required_packages is None:
|
||||
required_packages = _DEFAULT_REQUIRED_PACKAGES
|
||||
result = _RequirementsValidationResult(
|
||||
warnings=_RequirementsValidationWarnings(missing=set(), incompatible=set()),
|
||||
actions=_RequirementsValidationActions(append=set()),
|
||||
)
|
||||
if isinstance(constraints, list):
|
||||
constraints = parse_constraints(constraints)
|
||||
for package, package_version in requirements.items():
|
||||
if package not in constraints:
|
||||
result[_WARNINGS_KEY][_WARNING_MISSING].add(package)
|
||||
if package in required_packages:
|
||||
result[_ACTIONS_KEY][_ACTION_APPEND].add(
|
||||
f"{package}=={package_version}"
|
||||
)
|
||||
continue
|
||||
if package_version:
|
||||
package_specifier = constraints[package]
|
||||
if not package_specifier:
|
||||
continue
|
||||
if packaging_version.Version(package_version) not in package_specifier:
|
||||
result[_WARNINGS_KEY][_WARNING_INCOMPATIBLE].add(
|
||||
f"{package}=={package_version} (required: {str(package_specifier)})"
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
def scan_requirements(
|
||||
obj: Any,
|
||||
ignore_modules: Optional[Sequence[str]] = None,
|
||||
package_distributions: Optional[Mapping[str, Sequence[str]]] = None,
|
||||
inspect_getmembers_kwargs: Optional[Mapping[str, Any]] = None,
|
||||
) -> Mapping[str, str]:
|
||||
"""Scans the object for modules and returns the requirements discovered.
|
||||
|
||||
This is not a comprehensive scan of the object, and only detects for common
|
||||
cases based on the members of the object returned by `dir(obj)`.
|
||||
|
||||
Args:
|
||||
obj (Any):
|
||||
Required. The object to scan for package requirements.
|
||||
ignore_modules (Sequence[str]):
|
||||
Optional. The set of modules to ignore. It defaults to the set of
|
||||
built-in and stdlib modules.
|
||||
package_distributions (Mapping[str, Sequence[str]]):
|
||||
Optional. The mapping of module names to the set of packages that
|
||||
contain them. It defaults to the set of packages from
|
||||
`importlib_metadata.packages_distributions()`.
|
||||
inspect_getmembers_kwargs (Mapping[str, Any]):
|
||||
Optional. The keyword arguments to pass to `inspect.getmembers`. It
|
||||
defaults to an empty dictionary.
|
||||
|
||||
Returns:
|
||||
Sequence[str]: The list of requirements that were discovered.
|
||||
"""
|
||||
if ignore_modules is None:
|
||||
ignore_modules = _BASE_MODULES
|
||||
if package_distributions is None:
|
||||
package_distributions = _PACKAGE_DISTRIBUTIONS
|
||||
modules_found = set(_DEFAULT_REQUIRED_PACKAGES)
|
||||
inspect_getmembers_kwargs = inspect_getmembers_kwargs or {}
|
||||
for _, attr in inspect.getmembers(obj, **inspect_getmembers_kwargs):
|
||||
if not attr or inspect.isbuiltin(attr) or not hasattr(attr, "__module__"):
|
||||
continue
|
||||
module_name = (attr.__module__ or "").split(".")[0]
|
||||
if module_name and module_name not in ignore_modules:
|
||||
for module in package_distributions.get(module_name, []):
|
||||
modules_found.add(module)
|
||||
return {module: importlib_metadata.version(module) for module in modules_found}
|
||||
|
||||
|
||||
def generate_schema(
|
||||
f: Callable[..., Any],
|
||||
*,
|
||||
schema_name: Optional[str] = None,
|
||||
descriptions: Mapping[str, str] = {},
|
||||
required: Sequence[str] = [],
|
||||
) -> JsonDict:
|
||||
"""Generates the OpenAPI Schema for a callable object.
|
||||
|
||||
Only positional and keyword arguments of the function `f` will be supported
|
||||
in the OpenAPI Schema that is generated. I.e. `*args` and `**kwargs` will
|
||||
not be present in the OpenAPI schema returned from this function. For those
|
||||
cases, you can either include it in the docstring for `f`, or modify the
|
||||
OpenAPI schema returned from this function to include additional arguments.
|
||||
|
||||
Args:
|
||||
f (Callable):
|
||||
Required. The function to generate an OpenAPI Schema for.
|
||||
schema_name (str):
|
||||
Optional. The name for the OpenAPI schema. If unspecified, the name
|
||||
of the Callable will be used.
|
||||
descriptions (Mapping[str, str]):
|
||||
Optional. A `{name: description}` mapping for annotating input
|
||||
arguments of the function with user-provided descriptions. It
|
||||
defaults to an empty dictionary (i.e. there will not be any
|
||||
description for any of the inputs).
|
||||
required (Sequence[str]):
|
||||
Optional. For the user to specify the set of required arguments in
|
||||
function calls to `f`. If specified, it will be automatically
|
||||
inferred from `f`.
|
||||
|
||||
Returns:
|
||||
dict[str, Any]: The OpenAPI Schema for the function `f` in JSON format.
|
||||
"""
|
||||
pydantic = _import_pydantic_or_raise()
|
||||
defaults = dict(inspect.signature(f).parameters)
|
||||
fields_dict = {
|
||||
name: (
|
||||
# 1. We infer the argument type here: use Any rather than None so
|
||||
# it will not try to auto-infer the type based on the default value.
|
||||
(param.annotation if param.annotation != inspect.Parameter.empty else Any),
|
||||
pydantic.Field(
|
||||
# 2. We do not support default values for now.
|
||||
# default=(
|
||||
# param.default if param.default != inspect.Parameter.empty
|
||||
# else None
|
||||
# ),
|
||||
# 3. We support user-provided descriptions.
|
||||
description=descriptions.get(name, None),
|
||||
),
|
||||
)
|
||||
for name, param in defaults.items()
|
||||
# We do not support *args or **kwargs
|
||||
if param.kind
|
||||
in (
|
||||
inspect.Parameter.POSITIONAL_OR_KEYWORD,
|
||||
inspect.Parameter.KEYWORD_ONLY,
|
||||
inspect.Parameter.POSITIONAL_ONLY,
|
||||
)
|
||||
}
|
||||
parameters = pydantic.create_model(f.__name__, **fields_dict).schema()
|
||||
# Postprocessing
|
||||
# 4. Suppress unnecessary title generation:
|
||||
# * https://github.com/pydantic/pydantic/issues/1051
|
||||
# * http://cl/586221780
|
||||
parameters.pop("title", "")
|
||||
for name, function_arg in parameters.get("properties", {}).items():
|
||||
function_arg.pop("title", "")
|
||||
annotation = defaults[name].annotation
|
||||
# 5. Nullable fields:
|
||||
# * https://github.com/pydantic/pydantic/issues/1270
|
||||
# * https://stackoverflow.com/a/58841311
|
||||
# * https://github.com/pydantic/pydantic/discussions/4872
|
||||
if typing.get_origin(annotation) is Union and type(None) in typing.get_args(
|
||||
annotation
|
||||
):
|
||||
# for "typing.Optional" arguments, function_arg might be a
|
||||
# dictionary like
|
||||
#
|
||||
# {'anyOf': [{'type': 'integer'}, {'type': 'null'}]
|
||||
for schema in function_arg.pop("anyOf", []):
|
||||
schema_type = schema.get("type")
|
||||
if schema_type and schema_type != "null":
|
||||
function_arg["type"] = schema_type
|
||||
break
|
||||
function_arg["nullable"] = True
|
||||
# 6. Annotate required fields.
|
||||
if required:
|
||||
# We use the user-provided "required" fields if specified.
|
||||
parameters["required"] = required
|
||||
else:
|
||||
# Otherwise we infer it from the function signature.
|
||||
parameters["required"] = [
|
||||
k
|
||||
for k in defaults
|
||||
if (
|
||||
defaults[k].default == inspect.Parameter.empty
|
||||
and defaults[k].kind
|
||||
in (
|
||||
inspect.Parameter.POSITIONAL_OR_KEYWORD,
|
||||
inspect.Parameter.KEYWORD_ONLY,
|
||||
inspect.Parameter.POSITIONAL_ONLY,
|
||||
)
|
||||
)
|
||||
]
|
||||
schema = dict(name=f.__name__, description=f.__doc__, parameters=parameters)
|
||||
if schema_name:
|
||||
schema["name"] = schema_name
|
||||
return schema
|
||||
|
||||
|
||||
def is_noop_or_proxy_tracer_provider(tracer_provider) -> bool:
|
||||
"""Returns True if the tracer_provider is Proxy or NoOp."""
|
||||
opentelemetry = _import_opentelemetry_or_warn()
|
||||
ProxyTracerProvider = opentelemetry.trace.ProxyTracerProvider
|
||||
NoOpTracerProvider = opentelemetry.trace.NoOpTracerProvider
|
||||
return isinstance(tracer_provider, (NoOpTracerProvider, ProxyTracerProvider))
|
||||
|
||||
|
||||
def _import_cloud_storage_or_raise() -> types.ModuleType:
|
||||
"""Tries to import the Cloud Storage module."""
|
||||
try:
|
||||
from google.cloud import storage
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"Cloud Storage is not installed. Please call "
|
||||
"'pip install google-cloud-aiplatform[agent_engines]'."
|
||||
) from e
|
||||
return storage
|
||||
|
||||
|
||||
def _import_cloudpickle_or_raise() -> types.ModuleType:
|
||||
"""Tries to import the cloudpickle module."""
|
||||
try:
|
||||
import cloudpickle # noqa:F401
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"cloudpickle is not installed. Please call "
|
||||
"'pip install google-cloud-aiplatform[agent_engines]'."
|
||||
) from e
|
||||
return cloudpickle
|
||||
|
||||
|
||||
def _import_pydantic_or_raise() -> types.ModuleType:
|
||||
"""Tries to import the pydantic module."""
|
||||
try:
|
||||
import pydantic
|
||||
|
||||
_ = pydantic.Field
|
||||
except AttributeError:
|
||||
from pydantic import v1 as pydantic
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"pydantic is not installed. Please call "
|
||||
"'pip install google-cloud-aiplatform[agent_engines]'."
|
||||
) from e
|
||||
return pydantic
|
||||
|
||||
|
||||
def _import_packaging_requirements_or_raise() -> types.ModuleType:
|
||||
"""Tries to import the packaging.requirements module."""
|
||||
try:
|
||||
from packaging import requirements
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"packaging.requirements is not installed. Please call "
|
||||
"'pip install google-cloud-aiplatform[agent_engines]'."
|
||||
) from e
|
||||
return requirements
|
||||
|
||||
|
||||
def _import_packaging_version_or_raise() -> types.ModuleType:
|
||||
"""Tries to import the packaging.requirements module."""
|
||||
try:
|
||||
from packaging import version
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"packaging.version is not installed. Please call "
|
||||
"'pip install google-cloud-aiplatform[agent_engines]'."
|
||||
) from e
|
||||
return version
|
||||
|
||||
|
||||
def _import_opentelemetry_or_warn() -> Optional[types.ModuleType]:
|
||||
"""Tries to import the opentelemetry module."""
|
||||
try:
|
||||
import opentelemetry # noqa:F401
|
||||
|
||||
return opentelemetry
|
||||
except ImportError:
|
||||
LOGGER.warning(
|
||||
"opentelemetry-sdk is not installed. Please call "
|
||||
"'pip install google-cloud-aiplatform[agent_engines]'."
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
def _import_opentelemetry_sdk_trace_or_warn() -> Optional[types.ModuleType]:
|
||||
"""Tries to import the opentelemetry.sdk.trace module."""
|
||||
try:
|
||||
import opentelemetry.sdk.trace # noqa:F401
|
||||
|
||||
return opentelemetry.sdk.trace
|
||||
except ImportError:
|
||||
LOGGER.warning(
|
||||
"opentelemetry-sdk is not installed. Please call "
|
||||
"'pip install google-cloud-aiplatform[agent_engines]'."
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
def _import_cloud_trace_v2_or_warn() -> Optional[types.ModuleType]:
|
||||
"""Tries to import the google.cloud.trace_v2 module."""
|
||||
try:
|
||||
import google.cloud.trace_v2
|
||||
|
||||
return google.cloud.trace_v2
|
||||
except ImportError:
|
||||
LOGGER.warning(
|
||||
"google-cloud-trace is not installed. Please call "
|
||||
"'pip install google-cloud-aiplatform[agent_engines]'."
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
def _import_cloud_trace_exporter_or_warn() -> Optional[types.ModuleType]:
|
||||
"""Tries to import the opentelemetry.exporter.cloud_trace module."""
|
||||
try:
|
||||
import opentelemetry.exporter.cloud_trace # noqa:F401
|
||||
|
||||
return opentelemetry.exporter.cloud_trace
|
||||
except ImportError:
|
||||
LOGGER.warning(
|
||||
"opentelemetry-exporter-gcp-trace is not installed. Please "
|
||||
"call 'pip install google-cloud-aiplatform[agent_engines]'."
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
def _import_openinference_langchain_or_warn() -> Optional[types.ModuleType]:
|
||||
"""Tries to import the openinference.instrumentation.langchain module."""
|
||||
try:
|
||||
import openinference.instrumentation.langchain # noqa:F401
|
||||
|
||||
return openinference.instrumentation.langchain
|
||||
except ImportError:
|
||||
LOGGER.warning(
|
||||
"openinference-instrumentation-langchain is not installed. Please "
|
||||
"call 'pip install google-cloud-aiplatform[langchain]'."
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
def _import_openinference_autogen_or_warn() -> Optional[types.ModuleType]:
|
||||
"""Tries to import the openinference.instrumentation.autogen module."""
|
||||
try:
|
||||
import openinference.instrumentation.autogen # noqa:F401
|
||||
|
||||
return openinference.instrumentation.autogen
|
||||
except ImportError:
|
||||
LOGGER.warning(
|
||||
"openinference-instrumentation-autogen is not installed. Please "
|
||||
"call 'pip install google-cloud-aiplatform[ag2]'."
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
def _import_autogen_tools_or_warn() -> Optional[types.ModuleType]:
|
||||
"""Tries to import the autogen.tools module."""
|
||||
try:
|
||||
from autogen import tools
|
||||
|
||||
return tools
|
||||
except ImportError:
|
||||
LOGGER.warning(
|
||||
"autogen.tools is not installed. Please "
|
||||
"call `pip install google-cloud-aiplatform[ag2]`."
|
||||
)
|
||||
return None
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,487 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
Mapping,
|
||||
Optional,
|
||||
Sequence,
|
||||
Union,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
try:
|
||||
from autogen import agentchat
|
||||
|
||||
ConversableAgent = agentchat.ConversableAgent
|
||||
ChatResult = agentchat.ChatResult
|
||||
except ImportError:
|
||||
ConversableAgent = Any
|
||||
|
||||
try:
|
||||
from opentelemetry.sdk import trace
|
||||
|
||||
TracerProvider = trace.TracerProvider
|
||||
SpanProcessor = trace.SpanProcessor
|
||||
SynchronousMultiSpanProcessor = trace.SynchronousMultiSpanProcessor
|
||||
except ImportError:
|
||||
TracerProvider = Any
|
||||
SpanProcessor = Any
|
||||
SynchronousMultiSpanProcessor = Any
|
||||
|
||||
|
||||
def _prepare_runnable_kwargs(
|
||||
runnable_kwargs: Mapping[str, Any],
|
||||
system_instruction: str,
|
||||
runnable_name: str,
|
||||
llm_config: Mapping[str, Any],
|
||||
) -> Mapping[str, Any]:
|
||||
"""Prepares the configuration for a runnable, applying defaults and enforcing constraints."""
|
||||
if runnable_kwargs is None:
|
||||
runnable_kwargs = {}
|
||||
|
||||
if (
|
||||
"human_input_mode" in runnable_kwargs
|
||||
and runnable_kwargs["human_input_mode"] != "NEVER"
|
||||
):
|
||||
from google.cloud.aiplatform import base
|
||||
|
||||
_LOGGER = base.Logger(__name__)
|
||||
_LOGGER.warning(
|
||||
f"human_input_mode={runnable_kwargs['human_input_mode']}"
|
||||
"is not supported. Will be enforced to 'NEVER'."
|
||||
)
|
||||
runnable_kwargs["human_input_mode"] = "NEVER"
|
||||
|
||||
if "system_message" not in runnable_kwargs and system_instruction:
|
||||
runnable_kwargs["system_message"] = system_instruction
|
||||
|
||||
if "name" not in runnable_kwargs:
|
||||
runnable_kwargs["name"] = runnable_name
|
||||
|
||||
if "llm_config" not in runnable_kwargs:
|
||||
runnable_kwargs["llm_config"] = llm_config
|
||||
|
||||
return runnable_kwargs
|
||||
|
||||
|
||||
def _default_runnable_builder(
|
||||
**runnable_kwargs: Any,
|
||||
) -> "ConversableAgent":
|
||||
from autogen import agentchat
|
||||
|
||||
return agentchat.ConversableAgent(**runnable_kwargs)
|
||||
|
||||
|
||||
def _default_instrumentor_builder(project_id: str):
|
||||
from vertexai.agent_engines import _utils
|
||||
|
||||
cloud_trace_exporter = _utils._import_cloud_trace_exporter_or_warn()
|
||||
cloud_trace_v2 = _utils._import_cloud_trace_v2_or_warn()
|
||||
openinference_autogen = _utils._import_openinference_autogen_or_warn()
|
||||
opentelemetry = _utils._import_opentelemetry_or_warn()
|
||||
opentelemetry_sdk_trace = _utils._import_opentelemetry_sdk_trace_or_warn()
|
||||
if all(
|
||||
(
|
||||
cloud_trace_exporter,
|
||||
cloud_trace_v2,
|
||||
openinference_autogen,
|
||||
opentelemetry,
|
||||
opentelemetry_sdk_trace,
|
||||
)
|
||||
):
|
||||
import google.auth
|
||||
|
||||
credentials, _ = google.auth.default()
|
||||
span_exporter = cloud_trace_exporter.CloudTraceSpanExporter(
|
||||
project_id=project_id,
|
||||
client=cloud_trace_v2.TraceServiceClient(
|
||||
credentials=credentials.with_quota_project(project_id),
|
||||
),
|
||||
)
|
||||
span_processor: SpanProcessor = (
|
||||
opentelemetry_sdk_trace.export.SimpleSpanProcessor(
|
||||
span_exporter=span_exporter,
|
||||
)
|
||||
)
|
||||
tracer_provider: TracerProvider = opentelemetry.trace.get_tracer_provider()
|
||||
# Get the appropriate tracer provider:
|
||||
# 1. If _TRACER_PROVIDER is already set, use that.
|
||||
# 2. Otherwise, if the OTEL_PYTHON_TRACER_PROVIDER environment
|
||||
# variable is set, use that.
|
||||
# 3. As a final fallback, use _PROXY_TRACER_PROVIDER.
|
||||
# If none of the above is set, we log a warning, and
|
||||
# create a tracer provider.
|
||||
if not tracer_provider:
|
||||
from google.cloud.aiplatform import base
|
||||
|
||||
_LOGGER = base.Logger(__name__)
|
||||
_LOGGER.warning(
|
||||
"No tracer provider. By default, "
|
||||
"we should get one of the following providers: "
|
||||
"OTEL_PYTHON_TRACER_PROVIDER, _TRACER_PROVIDER, "
|
||||
"or _PROXY_TRACER_PROVIDER."
|
||||
)
|
||||
tracer_provider = opentelemetry_sdk_trace.TracerProvider()
|
||||
opentelemetry.trace.set_tracer_provider(tracer_provider)
|
||||
# Avoids AttributeError:
|
||||
# 'ProxyTracerProvider' and 'NoOpTracerProvider' objects has no
|
||||
# attribute 'add_span_processor'.
|
||||
if _utils.is_noop_or_proxy_tracer_provider(tracer_provider):
|
||||
tracer_provider = opentelemetry_sdk_trace.TracerProvider()
|
||||
opentelemetry.trace.set_tracer_provider(tracer_provider)
|
||||
# Avoids OpenTelemetry client already exists error.
|
||||
_override_active_span_processor(
|
||||
tracer_provider,
|
||||
opentelemetry_sdk_trace.SynchronousMultiSpanProcessor(),
|
||||
)
|
||||
tracer_provider.add_span_processor(span_processor)
|
||||
# Keep the instrumentation up-to-date.
|
||||
# When creating multiple AG2Agents,
|
||||
# we need to keep the instrumentation up-to-date.
|
||||
# We deliberately override the instrument each time,
|
||||
# so that if different agents end up using different
|
||||
# instrumentations, we guarantee that the user is always
|
||||
# working with the most recent agent's instrumentation.
|
||||
instrumentor = openinference_autogen.AutogenInstrumentor()
|
||||
instrumentor.uninstrument()
|
||||
instrumentor.instrument()
|
||||
return instrumentor
|
||||
else:
|
||||
from google.cloud.aiplatform import base
|
||||
|
||||
_LOGGER = base.Logger(__name__)
|
||||
_LOGGER.warning(
|
||||
"enable_tracing=True but proceeding with tracing disabled "
|
||||
"because not all packages for tracing have been installed"
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
def _validate_callable_parameters_are_annotated(callable: Callable):
|
||||
"""Validates that the parameters of the callable have type annotations.
|
||||
|
||||
This ensures that they can be used for constructing AG2 tools that are
|
||||
usable with Gemini function calling.
|
||||
"""
|
||||
import inspect
|
||||
|
||||
parameters = dict(inspect.signature(callable).parameters)
|
||||
for name, parameter in parameters.items():
|
||||
if parameter.annotation == inspect.Parameter.empty:
|
||||
raise TypeError(
|
||||
f"Callable={callable.__name__} has untyped input_arg={name}. "
|
||||
f"Please specify a type when defining it, e.g. `{name}: str`."
|
||||
)
|
||||
|
||||
|
||||
def _validate_tools(tools: Sequence[Callable[..., Any]]):
|
||||
"""Validates that the tools are usable for tool calling."""
|
||||
for tool in tools:
|
||||
if isinstance(tool, Callable):
|
||||
_validate_callable_parameters_are_annotated(tool)
|
||||
|
||||
|
||||
def _override_active_span_processor(
|
||||
tracer_provider: "TracerProvider",
|
||||
active_span_processor: "SynchronousMultiSpanProcessor",
|
||||
):
|
||||
"""Overrides the active span processor.
|
||||
|
||||
When working with multiple AG2Agents in the same environment,
|
||||
it's crucial to manage trace exports carefully.
|
||||
Each agent needs its own span processor tied to a unique project ID.
|
||||
While we add a new span processor for each agent, this can lead to
|
||||
unexpected behavior.
|
||||
For instance, with two agents linked to different projects, traces from the
|
||||
second agent might be sent to both projects.
|
||||
To prevent this and guarantee traces go to the correct project, we overwrite
|
||||
the active span processor whenever a new AG2Agent is created.
|
||||
|
||||
Args:
|
||||
tracer_provider (TracerProvider):
|
||||
The tracer provider to use for the project.
|
||||
active_span_processor (SynchronousMultiSpanProcessor):
|
||||
The active span processor overrides the tracer provider's
|
||||
active span processor.
|
||||
"""
|
||||
if tracer_provider._active_span_processor:
|
||||
tracer_provider._active_span_processor.shutdown()
|
||||
tracer_provider._active_span_processor = active_span_processor
|
||||
|
||||
|
||||
class AG2Agent:
|
||||
"""An AG2 Agent."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: str,
|
||||
runnable_name: str,
|
||||
*,
|
||||
api_type: Optional[str] = None,
|
||||
llm_config: Optional[Mapping[str, Any]] = None,
|
||||
system_instruction: Optional[str] = None,
|
||||
runnable_kwargs: Optional[Mapping[str, Any]] = None,
|
||||
runnable_builder: Optional[Callable[..., "ConversableAgent"]] = None,
|
||||
tools: Optional[Sequence[Callable[..., Any]]] = None,
|
||||
enable_tracing: bool = False,
|
||||
instrumentor_builder: Optional[Callable[..., Any]] = None,
|
||||
):
|
||||
"""Initializes the AG2 Agent.
|
||||
|
||||
Under-the-hood, assuming .set_up() is called, this will correspond to
|
||||
```python
|
||||
# runnable_builder
|
||||
runnable = runnable_builder(
|
||||
llm_config=llm_config,
|
||||
system_message=system_instruction,
|
||||
**runnable_kwargs,
|
||||
)
|
||||
```
|
||||
|
||||
When everything is based on their default values, this corresponds to
|
||||
```python
|
||||
# llm_config
|
||||
llm_config = {
|
||||
"config_list": [{
|
||||
"project_id": initializer.global_config.project,
|
||||
"location": initializer.global_config.location,
|
||||
"model": "gemini-1.0-pro-001",
|
||||
"api_type": "google",
|
||||
}]
|
||||
}
|
||||
|
||||
# runnable_builder
|
||||
runnable = ConversableAgent(
|
||||
llm_config=llm_config,
|
||||
name="Default AG2 Agent"
|
||||
system_message="You are a helpful AI Assistant.",
|
||||
human_input_mode="NEVER",
|
||||
)
|
||||
```
|
||||
|
||||
By default, if `llm_config` is not specified, a default configuration
|
||||
will be created using the provided `model` and `api_type`.
|
||||
|
||||
If `runnable_builder` is not specified, a default runnable builder will
|
||||
be used, configured with the `system_instruction`, `runnable_name` and
|
||||
`runnable_kwargs`.
|
||||
|
||||
Args:
|
||||
model (str):
|
||||
Required. The name of the model (e.g. "gemini-1.0-pro").
|
||||
Used to create a default `llm_config` if one is not provided.
|
||||
This parameter is ignored if `llm_config` is provided.
|
||||
runnable_name (str):
|
||||
Required. The name of the runnable.
|
||||
This name is used as the default `runnable_kwargs["name"]`
|
||||
unless `runnable_kwargs` already contains a "name", in which
|
||||
case the provided `runnable_kwargs["name"]` will be used.
|
||||
api_type (str):
|
||||
Optional. The API type to use for the language model.
|
||||
Used to create a default `llm_config` if one is not provided.
|
||||
This parameter is ignored if `llm_config` is provided.
|
||||
llm_config (Mapping[str, Any]):
|
||||
Optional. Configuration dictionary for the language model.
|
||||
If provided, this configuration will be used directly.
|
||||
Otherwise, a default `llm_config` will be created using `model`
|
||||
and `api_type`. This `llm_config` is used as the default
|
||||
`runnable_kwargs["llm_config"]` unless `runnable_kwargs` already
|
||||
contains a "llm_config", in which case the provided
|
||||
`runnable_kwargs["llm_config"]` will be used.
|
||||
system_instruction (str):
|
||||
Optional. The system instruction for the agent.
|
||||
This instruction is used as the default
|
||||
`runnable_kwargs["system_message"]` unless `runnable_kwargs`
|
||||
already contains a "system_message", in which case the provided
|
||||
`runnable_kwargs["system_message"]` will be used.
|
||||
runnable_kwargs (Mapping[str, Any]):
|
||||
Optional. Additional keyword arguments for the constructor of
|
||||
the runnable. Details of the kwargs can be found in
|
||||
https://docs.ag2.ai/docs/api-reference/autogen/ConversableAgent.
|
||||
`runnable_kwargs` only supports `human_input_mode="NEVER"`.
|
||||
Other `human_input_mode` values will trigger a warning.
|
||||
runnable_builder (Callable[..., "ConversableAgent"]):
|
||||
Optional. Callable that returns a new runnable. This can be used
|
||||
for customizing the orchestration logic of the Agent.
|
||||
If not provided, a default runnable builder will be used.
|
||||
tools (Sequence[Callable[..., Any]]):
|
||||
Optional. The tools for the agent to be able to use. All input
|
||||
callables (e.g. function or class method) will be converted
|
||||
to a AG2 tool . Defaults to None.
|
||||
enable_tracing (bool):
|
||||
Optional. Whether to enable tracing in Cloud Trace. Defaults to
|
||||
False.
|
||||
instrumentor_builder (Callable[..., Any]):
|
||||
Optional. Callable that returns a new instrumentor. This can be
|
||||
used for customizing the instrumentation logic of the Agent.
|
||||
If not provided, a default instrumentor builder will be used.
|
||||
This parameter is ignored if `enable_tracing` is False.
|
||||
"""
|
||||
from google.cloud.aiplatform import initializer
|
||||
|
||||
self._tmpl_attrs: dict[str, Any] = {
|
||||
"project": initializer.global_config.project,
|
||||
"location": initializer.global_config.location,
|
||||
"model_name": model,
|
||||
"api_type": api_type or "google",
|
||||
"system_instruction": system_instruction,
|
||||
"runnable_name": runnable_name,
|
||||
"tools": [],
|
||||
"ag2_tool_objects": [],
|
||||
"runnable": None,
|
||||
"runnable_builder": runnable_builder,
|
||||
"instrumentor": None,
|
||||
"instrumentor_builder": instrumentor_builder,
|
||||
"enable_tracing": enable_tracing,
|
||||
}
|
||||
self._tmpl_attrs["llm_config"] = llm_config or {
|
||||
"config_list": [
|
||||
{
|
||||
"project_id": self._tmpl_attrs.get("project"),
|
||||
"location": self._tmpl_attrs.get("location"),
|
||||
"model": self._tmpl_attrs.get("model_name"),
|
||||
"api_type": self._tmpl_attrs.get("api_type"),
|
||||
}
|
||||
]
|
||||
}
|
||||
self._tmpl_attrs["runnable_kwargs"] = _prepare_runnable_kwargs(
|
||||
runnable_kwargs=runnable_kwargs,
|
||||
llm_config=self._tmpl_attrs.get("llm_config"),
|
||||
system_instruction=self._tmpl_attrs.get("system_instruction"),
|
||||
runnable_name=self._tmpl_attrs.get("runnable_name"),
|
||||
)
|
||||
if tools:
|
||||
# We validate tools at initialization for actionable feedback before
|
||||
# they are deployed.
|
||||
_validate_tools(tools)
|
||||
self._tmpl_attrs["tools"] = tools
|
||||
|
||||
def set_up(self):
|
||||
"""Sets up the agent for execution of queries at runtime.
|
||||
|
||||
It initializes the runnable, binds the runnable with tools.
|
||||
|
||||
This method should not be called for an object that being passed to
|
||||
the ReasoningEngine service for deployment, as it initializes clients
|
||||
that can not be serialized.
|
||||
"""
|
||||
if self._tmpl_attrs.get("enable_tracing"):
|
||||
instrumentor_builder = (
|
||||
self._tmpl_attrs.get("instrumentor_builder")
|
||||
or _default_instrumentor_builder
|
||||
)
|
||||
self._tmpl_attrs["instrumentor"] = instrumentor_builder(
|
||||
project_id=self._tmpl_attrs.get("project")
|
||||
)
|
||||
|
||||
# Set up tools.
|
||||
tools = self._tmpl_attrs.get("tools")
|
||||
ag2_tool_objects = self._tmpl_attrs.get("ag2_tool_objects")
|
||||
if tools and not ag2_tool_objects:
|
||||
from vertexai.agent_engines import _utils
|
||||
|
||||
autogen_tools = _utils._import_autogen_tools_or_warn()
|
||||
if autogen_tools:
|
||||
for tool in tools:
|
||||
ag2_tool_objects.append(autogen_tools.Tool(func_or_tool=tool))
|
||||
|
||||
# Set up runnable.
|
||||
runnable_builder = (
|
||||
self._tmpl_attrs.get("runnable_builder") or _default_runnable_builder
|
||||
)
|
||||
self._tmpl_attrs["runnable"] = runnable_builder(
|
||||
**self._tmpl_attrs.get("runnable_kwargs")
|
||||
)
|
||||
|
||||
def clone(self) -> "AG2Agent":
|
||||
"""Returns a clone of the AG2Agent."""
|
||||
import copy
|
||||
|
||||
return AG2Agent(
|
||||
model=self._tmpl_attrs.get("model_name"),
|
||||
api_type=self._tmpl_attrs.get("api_type"),
|
||||
llm_config=copy.deepcopy(self._tmpl_attrs.get("llm_config")),
|
||||
system_instruction=self._tmpl_attrs.get("system_instruction"),
|
||||
runnable_name=self._tmpl_attrs.get("runnable_name"),
|
||||
tools=copy.deepcopy(self._tmpl_attrs.get("tools")),
|
||||
runnable_kwargs=copy.deepcopy(self._tmpl_attrs.get("runnable_kwargs")),
|
||||
runnable_builder=self._tmpl_attrs.get("runnable_builder"),
|
||||
enable_tracing=self._tmpl_attrs.get("enable_tracing"),
|
||||
instrumentor_builder=self._tmpl_attrs.get("instrumentor_builder"),
|
||||
)
|
||||
|
||||
def query(
|
||||
self,
|
||||
*,
|
||||
input: Union[str, Mapping[str, Any]],
|
||||
max_turns: Optional[int] = None,
|
||||
**kwargs: Any,
|
||||
) -> Dict[str, Any]:
|
||||
"""Queries the Agent with the given input.
|
||||
|
||||
Args:
|
||||
input (Union[str, Mapping[str, Any]]):
|
||||
Required. The input to be passed to the Agent.
|
||||
max_turns (int):
|
||||
Optional. The maximum number of turns to run the agent for.
|
||||
If not provided, the agent will run indefinitely.
|
||||
If `max_turns` is a `float`, it will be converted to `int`
|
||||
through rounding.
|
||||
**kwargs:
|
||||
Optional. Any additional keyword arguments to be passed to the
|
||||
`.run()` method of the corresponding runnable.
|
||||
Details of the kwargs can be found in
|
||||
https://docs.ag2.ai/docs/api-reference/autogen/ConversableAgent#run.
|
||||
The `user_input` parameter defaults to `False`, and should not
|
||||
be passed through `kwargs`.
|
||||
|
||||
Returns:
|
||||
The output of querying the Agent with the given input.
|
||||
"""
|
||||
if isinstance(input, str):
|
||||
input = {"content": input}
|
||||
|
||||
if max_turns and isinstance(max_turns, float):
|
||||
# Supporting auto-conversion float to int.
|
||||
max_turns = round(max_turns)
|
||||
|
||||
if "user_input" in kwargs:
|
||||
from google.cloud.aiplatform import base
|
||||
|
||||
_LOGGER = base.Logger(__name__)
|
||||
_LOGGER.warning(
|
||||
"The `user_input` parameter should not be passed through"
|
||||
"kwargs. The `user_input` defaults to `False`."
|
||||
)
|
||||
kwargs.pop("user_input")
|
||||
|
||||
if not self._tmpl_attrs.get("runnable"):
|
||||
self.set_up()
|
||||
|
||||
response = self._tmpl_attrs.get("runnable").run(
|
||||
message=input,
|
||||
user_input=False,
|
||||
tools=self._tmpl_attrs.get("ag2_tool_objects"),
|
||||
max_turns=max_turns,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
from vertexai.agent_engines import _utils
|
||||
|
||||
return _utils.to_json_serializable_autogen_object(response)
|
||||
@@ -0,0 +1,673 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2024 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
Iterable,
|
||||
Mapping,
|
||||
Optional,
|
||||
Sequence,
|
||||
Union,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
try:
|
||||
from langchain_core import runnables
|
||||
from langchain_core import tools as lc_tools
|
||||
from langchain_core.language_models import base as lc_language_models
|
||||
|
||||
BaseTool = lc_tools.BaseTool
|
||||
BaseLanguageModel = lc_language_models.BaseLanguageModel
|
||||
GetSessionHistoryCallable = runnables.history.GetSessionHistoryCallable
|
||||
RunnableConfig = runnables.RunnableConfig
|
||||
RunnableSerializable = runnables.RunnableSerializable
|
||||
except ImportError:
|
||||
BaseTool = Any
|
||||
BaseLanguageModel = Any
|
||||
GetSessionHistoryCallable = Any
|
||||
RunnableConfig = Any
|
||||
RunnableSerializable = Any
|
||||
|
||||
try:
|
||||
from langchain_google_vertexai.functions_utils import _ToolsType
|
||||
|
||||
_ToolLike = _ToolsType
|
||||
except ImportError:
|
||||
_ToolLike = Any
|
||||
|
||||
try:
|
||||
from opentelemetry.sdk import trace
|
||||
|
||||
TracerProvider = trace.TracerProvider
|
||||
SpanProcessor = trace.SpanProcessor
|
||||
SynchronousMultiSpanProcessor = trace.SynchronousMultiSpanProcessor
|
||||
except ImportError:
|
||||
TracerProvider = Any
|
||||
SpanProcessor = Any
|
||||
SynchronousMultiSpanProcessor = Any
|
||||
|
||||
|
||||
def _default_runnable_kwargs(has_history: bool) -> Mapping[str, Any]:
|
||||
# https://github.com/langchain-ai/langchain/blob/5784dfed001730530637793bea1795d9d5a7c244/libs/core/langchain_core/runnables/history.py#L237-L241
|
||||
runnable_kwargs = {
|
||||
# input_messages_key (str): Must be specified if the underlying
|
||||
# agent accepts a dict as input.
|
||||
"input_messages_key": "input",
|
||||
# output_messages_key (str): Must be specified if the underlying
|
||||
# agent returns a dict as output.
|
||||
"output_messages_key": "output",
|
||||
}
|
||||
if has_history:
|
||||
# history_messages_key (str): Must be specified if the underlying
|
||||
# agent accepts a dict as input and a separate key for historical
|
||||
# messages.
|
||||
runnable_kwargs["history_messages_key"] = "history"
|
||||
return runnable_kwargs
|
||||
|
||||
|
||||
def _default_output_parser():
|
||||
try:
|
||||
from langchain.agents.output_parsers.tools import ToolsAgentOutputParser
|
||||
except (ModuleNotFoundError, ImportError):
|
||||
# Fallback to an older version if needed.
|
||||
from langchain.agents.output_parsers.openai_tools import (
|
||||
OpenAIToolsAgentOutputParser as ToolsAgentOutputParser,
|
||||
)
|
||||
|
||||
return ToolsAgentOutputParser()
|
||||
|
||||
|
||||
def _default_model_builder(
|
||||
model_name: str,
|
||||
*,
|
||||
project: str,
|
||||
location: str,
|
||||
model_kwargs: Optional[Mapping[str, Any]] = None,
|
||||
) -> "BaseLanguageModel":
|
||||
import vertexai
|
||||
from google.cloud.aiplatform import initializer
|
||||
from langchain_google_vertexai import ChatVertexAI
|
||||
|
||||
model_kwargs = model_kwargs or {}
|
||||
current_project = initializer.global_config.project
|
||||
current_location = initializer.global_config.location
|
||||
vertexai.init(project=project, location=location)
|
||||
model = ChatVertexAI(model_name=model_name, **model_kwargs)
|
||||
vertexai.init(project=current_project, location=current_location)
|
||||
return model
|
||||
|
||||
|
||||
def _default_runnable_builder(
|
||||
model: "BaseLanguageModel",
|
||||
*,
|
||||
system_instruction: Optional[str] = None,
|
||||
tools: Optional[Sequence["_ToolLike"]] = None,
|
||||
prompt: Optional["RunnableSerializable"] = None,
|
||||
output_parser: Optional["RunnableSerializable"] = None,
|
||||
chat_history: Optional["GetSessionHistoryCallable"] = None,
|
||||
model_tool_kwargs: Optional[Mapping[str, Any]] = None,
|
||||
agent_executor_kwargs: Optional[Mapping[str, Any]] = None,
|
||||
runnable_kwargs: Optional[Mapping[str, Any]] = None,
|
||||
) -> "RunnableSerializable":
|
||||
from langchain_core import tools as lc_tools
|
||||
from langchain.agents import AgentExecutor
|
||||
from langchain.tools.base import StructuredTool
|
||||
|
||||
# The prompt template and runnable_kwargs needs to be customized depending
|
||||
# on whether the user intends for the agent to have history. The way the
|
||||
# user would reflect that is by setting chat_history (which defaults to
|
||||
# None).
|
||||
has_history: bool = chat_history is not None
|
||||
prompt = prompt or _default_prompt(
|
||||
has_history=has_history,
|
||||
system_instruction=system_instruction,
|
||||
)
|
||||
output_parser = output_parser or _default_output_parser()
|
||||
model_tool_kwargs = model_tool_kwargs or {}
|
||||
agent_executor_kwargs = agent_executor_kwargs or {}
|
||||
runnable_kwargs = runnable_kwargs or _default_runnable_kwargs(has_history)
|
||||
if tools:
|
||||
model = model.bind_tools(tools=tools, **model_tool_kwargs)
|
||||
else:
|
||||
tools = []
|
||||
agent_executor = AgentExecutor(
|
||||
agent=prompt | model | output_parser,
|
||||
tools=[
|
||||
tool
|
||||
if isinstance(tool, lc_tools.BaseTool)
|
||||
else StructuredTool.from_function(tool)
|
||||
for tool in tools
|
||||
if isinstance(tool, (Callable, lc_tools.BaseTool))
|
||||
],
|
||||
**agent_executor_kwargs,
|
||||
)
|
||||
if has_history:
|
||||
from langchain_core.runnables.history import RunnableWithMessageHistory
|
||||
|
||||
return RunnableWithMessageHistory(
|
||||
runnable=agent_executor,
|
||||
get_session_history=chat_history,
|
||||
**runnable_kwargs,
|
||||
)
|
||||
return agent_executor
|
||||
|
||||
|
||||
def _default_instrumentor_builder(project_id: str):
|
||||
from vertexai.agent_engines import _utils
|
||||
|
||||
cloud_trace_exporter = _utils._import_cloud_trace_exporter_or_warn()
|
||||
cloud_trace_v2 = _utils._import_cloud_trace_v2_or_warn()
|
||||
openinference_langchain = _utils._import_openinference_langchain_or_warn()
|
||||
opentelemetry = _utils._import_opentelemetry_or_warn()
|
||||
opentelemetry_sdk_trace = _utils._import_opentelemetry_sdk_trace_or_warn()
|
||||
if all(
|
||||
(
|
||||
cloud_trace_exporter,
|
||||
cloud_trace_v2,
|
||||
openinference_langchain,
|
||||
opentelemetry,
|
||||
opentelemetry_sdk_trace,
|
||||
)
|
||||
):
|
||||
import google.auth
|
||||
|
||||
credentials, _ = google.auth.default()
|
||||
span_exporter = cloud_trace_exporter.CloudTraceSpanExporter(
|
||||
project_id=project_id,
|
||||
client=cloud_trace_v2.TraceServiceClient(
|
||||
credentials=credentials.with_quota_project(project_id),
|
||||
),
|
||||
)
|
||||
span_processor: SpanProcessor = (
|
||||
opentelemetry_sdk_trace.export.SimpleSpanProcessor(
|
||||
span_exporter=span_exporter,
|
||||
)
|
||||
)
|
||||
tracer_provider: TracerProvider = opentelemetry.trace.get_tracer_provider()
|
||||
# Get the appropriate tracer provider:
|
||||
# 1. If _TRACER_PROVIDER is already set, use that.
|
||||
# 2. Otherwise, if the OTEL_PYTHON_TRACER_PROVIDER environment
|
||||
# variable is set, use that.
|
||||
# 3. As a final fallback, use _PROXY_TRACER_PROVIDER.
|
||||
# If none of the above is set, we log a warning, and
|
||||
# create a tracer provider.
|
||||
if not tracer_provider:
|
||||
from google.cloud.aiplatform import base
|
||||
|
||||
_LOGGER = base.Logger(__name__)
|
||||
_LOGGER.warning(
|
||||
"No tracer provider. By default, "
|
||||
"we should get one of the following providers: "
|
||||
"OTEL_PYTHON_TRACER_PROVIDER, _TRACER_PROVIDER, "
|
||||
"or _PROXY_TRACER_PROVIDER."
|
||||
)
|
||||
tracer_provider = opentelemetry_sdk_trace.TracerProvider()
|
||||
opentelemetry.trace.set_tracer_provider(tracer_provider)
|
||||
# Avoids AttributeError:
|
||||
# 'ProxyTracerProvider' and 'NoOpTracerProvider' objects has no
|
||||
# attribute 'add_span_processor'.
|
||||
if _utils.is_noop_or_proxy_tracer_provider(tracer_provider):
|
||||
tracer_provider = opentelemetry_sdk_trace.TracerProvider()
|
||||
opentelemetry.trace.set_tracer_provider(tracer_provider)
|
||||
# Avoids OpenTelemetry client already exists error.
|
||||
_override_active_span_processor(
|
||||
tracer_provider,
|
||||
opentelemetry_sdk_trace.SynchronousMultiSpanProcessor(),
|
||||
)
|
||||
tracer_provider.add_span_processor(span_processor)
|
||||
# Keep the instrumentation up-to-date.
|
||||
# When creating multiple LangchainAgents,
|
||||
# we need to keep the instrumentation up-to-date.
|
||||
# We deliberately override the instrument each time,
|
||||
# so that if different agents end up using different
|
||||
# instrumentations, we guarantee that the user is always
|
||||
# working with the most recent agent's instrumentation.
|
||||
instrumentor = openinference_langchain.LangChainInstrumentor()
|
||||
if instrumentor.is_instrumented_by_opentelemetry:
|
||||
instrumentor.uninstrument()
|
||||
instrumentor.instrument()
|
||||
return instrumentor
|
||||
else:
|
||||
from google.cloud.aiplatform import base
|
||||
|
||||
_LOGGER = base.Logger(__name__)
|
||||
_LOGGER.warning(
|
||||
"enable_tracing=True but proceeding with tracing disabled "
|
||||
"because not all packages for tracing have been installed"
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
def _default_prompt(
|
||||
has_history: bool,
|
||||
system_instruction: Optional[str] = None,
|
||||
) -> "RunnableSerializable":
|
||||
from langchain_core import prompts
|
||||
|
||||
try:
|
||||
from langchain.agents.format_scratchpad.tools import format_to_tool_messages
|
||||
except (ModuleNotFoundError, ImportError):
|
||||
# Fallback to an older version if needed.
|
||||
from langchain.agents.format_scratchpad.openai_tools import (
|
||||
format_to_openai_tool_messages as format_to_tool_messages,
|
||||
)
|
||||
|
||||
system_instructions = []
|
||||
if system_instruction:
|
||||
system_instructions = [("system", system_instruction)]
|
||||
|
||||
if has_history:
|
||||
return {
|
||||
"history": lambda x: x["history"],
|
||||
"input": lambda x: x["input"],
|
||||
"agent_scratchpad": (
|
||||
lambda x: format_to_tool_messages(x["intermediate_steps"])
|
||||
),
|
||||
} | prompts.ChatPromptTemplate.from_messages(
|
||||
system_instructions
|
||||
+ [
|
||||
prompts.MessagesPlaceholder(variable_name="history"),
|
||||
("user", "{input}"),
|
||||
prompts.MessagesPlaceholder(variable_name="agent_scratchpad"),
|
||||
]
|
||||
)
|
||||
else:
|
||||
return {
|
||||
"input": lambda x: x["input"],
|
||||
"agent_scratchpad": (
|
||||
lambda x: format_to_tool_messages(x["intermediate_steps"])
|
||||
),
|
||||
} | prompts.ChatPromptTemplate.from_messages(
|
||||
system_instructions
|
||||
+ [
|
||||
("user", "{input}"),
|
||||
prompts.MessagesPlaceholder(variable_name="agent_scratchpad"),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def _validate_callable_parameters_are_annotated(callable: Callable):
|
||||
"""Validates that the parameters of the callable have type annotations.
|
||||
|
||||
This ensures that they can be used for constructing LangChain tools that are
|
||||
usable with Gemini function calling.
|
||||
"""
|
||||
import inspect
|
||||
|
||||
parameters = dict(inspect.signature(callable).parameters)
|
||||
for name, parameter in parameters.items():
|
||||
if parameter.annotation == inspect.Parameter.empty:
|
||||
raise TypeError(
|
||||
f"Callable={callable.__name__} has untyped input_arg={name}. "
|
||||
f"Please specify a type when defining it, e.g. `{name}: str`."
|
||||
)
|
||||
|
||||
|
||||
def _validate_tools(tools: Sequence["_ToolLike"]):
|
||||
"""Validates that the tools are usable for tool calling."""
|
||||
for tool in tools:
|
||||
if isinstance(tool, Callable):
|
||||
_validate_callable_parameters_are_annotated(tool)
|
||||
|
||||
|
||||
def _override_active_span_processor(
|
||||
tracer_provider: "TracerProvider",
|
||||
active_span_processor: "SynchronousMultiSpanProcessor",
|
||||
):
|
||||
"""Overrides the active span processor.
|
||||
|
||||
When working with multiple LangchainAgents in the same environment,
|
||||
it's crucial to manage trace exports carefully.
|
||||
Each agent needs its own span processor tied to a unique project ID.
|
||||
While we add a new span processor for each agent, this can lead to
|
||||
unexpected behavior.
|
||||
For instance, with two agents linked to different projects, traces from the
|
||||
second agent might be sent to both projects.
|
||||
To prevent this and guarantee traces go to the correct project, we overwrite
|
||||
the active span processor whenever a new LangchainAgent is created.
|
||||
|
||||
Args:
|
||||
tracer_provider (TracerProvider):
|
||||
The tracer provider to use for the project.
|
||||
active_span_processor (SynchronousMultiSpanProcessor):
|
||||
The active span processor overrides the tracer provider's
|
||||
active span processor.
|
||||
"""
|
||||
if tracer_provider._active_span_processor:
|
||||
tracer_provider._active_span_processor.shutdown()
|
||||
tracer_provider._active_span_processor = active_span_processor
|
||||
|
||||
|
||||
class LangchainAgent:
|
||||
"""A Langchain Agent.
|
||||
|
||||
See https://cloud.google.com/vertex-ai/generative-ai/docs/reasoning-engine/develop
|
||||
for details.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: str,
|
||||
*,
|
||||
system_instruction: Optional[str] = None,
|
||||
prompt: Optional["RunnableSerializable"] = None,
|
||||
tools: Optional[Sequence["_ToolLike"]] = None,
|
||||
output_parser: Optional["RunnableSerializable"] = None,
|
||||
chat_history: Optional["GetSessionHistoryCallable"] = None,
|
||||
model_kwargs: Optional[Mapping[str, Any]] = None,
|
||||
model_tool_kwargs: Optional[Mapping[str, Any]] = None,
|
||||
agent_executor_kwargs: Optional[Mapping[str, Any]] = None,
|
||||
runnable_kwargs: Optional[Mapping[str, Any]] = None,
|
||||
model_builder: Optional[Callable] = None,
|
||||
runnable_builder: Optional[Callable] = None,
|
||||
enable_tracing: bool = False,
|
||||
instrumentor_builder: Optional[Callable[..., Any]] = None,
|
||||
):
|
||||
"""Initializes the LangchainAgent.
|
||||
|
||||
Under-the-hood, assuming .set_up() is called, this will correspond to
|
||||
|
||||
```
|
||||
model = model_builder(model_name=model, model_kwargs=model_kwargs)
|
||||
runnable = runnable_builder(
|
||||
prompt=prompt,
|
||||
model=model,
|
||||
tools=tools,
|
||||
output_parser=output_parser,
|
||||
chat_history=chat_history,
|
||||
agent_executor_kwargs=agent_executor_kwargs,
|
||||
runnable_kwargs=runnable_kwargs,
|
||||
)
|
||||
```
|
||||
|
||||
When everything is based on their default values, this corresponds to
|
||||
```
|
||||
# model_builder
|
||||
from langchain_google_vertexai import ChatVertexAI
|
||||
llm = ChatVertexAI(model_name=model, **model_kwargs)
|
||||
|
||||
# runnable_builder
|
||||
from langchain import agents
|
||||
from langchain_core.runnables.history import RunnableWithMessageHistory
|
||||
llm_with_tools = llm.bind_tools(tools=tools, **model_tool_kwargs)
|
||||
agent_executor = agents.AgentExecutor(
|
||||
agent=prompt | llm_with_tools | output_parser,
|
||||
tools=tools,
|
||||
**agent_executor_kwargs,
|
||||
)
|
||||
runnable = RunnableWithMessageHistory(
|
||||
runnable=agent_executor,
|
||||
get_session_history=chat_history,
|
||||
**runnable_kwargs,
|
||||
)
|
||||
```
|
||||
|
||||
Args:
|
||||
model (str):
|
||||
Optional. The name of the model (e.g. "gemini-1.0-pro").
|
||||
system_instruction (str):
|
||||
Optional. The system instruction to use for the agent. This
|
||||
argument should not be specified if `prompt` is specified.
|
||||
prompt (langchain_core.runnables.RunnableSerializable):
|
||||
Optional. The prompt template for the model. Defaults to a
|
||||
ChatPromptTemplate.
|
||||
tools (Sequence[langchain_core.tools.BaseTool, Callable]):
|
||||
Optional. The tools for the agent to be able to use. All input
|
||||
callables (e.g. function or class method) will be converted
|
||||
to a langchain.tools.base.StructuredTool. Defaults to None.
|
||||
output_parser (langchain_core.runnables.RunnableSerializable):
|
||||
Optional. The output parser for the model. Defaults to an
|
||||
output parser that works with Gemini function-calling.
|
||||
chat_history (langchain_core.runnables.history.GetSessionHistoryCallable):
|
||||
Optional. Callable that returns a new BaseChatMessageHistory.
|
||||
Defaults to None, i.e. chat_history is not preserved.
|
||||
model_kwargs (Mapping[str, Any]):
|
||||
Optional. Additional keyword arguments for the constructor of
|
||||
chat_models.ChatVertexAI. An example would be
|
||||
```
|
||||
{
|
||||
# temperature (float): Sampling temperature, it controls the
|
||||
# degree of randomness in token selection.
|
||||
"temperature": 0.28,
|
||||
# max_output_tokens (int): Token limit determines the
|
||||
# maximum amount of text output from one prompt.
|
||||
"max_output_tokens": 1000,
|
||||
# top_p (float): Tokens are selected from most probable to
|
||||
# least, until the sum of their probabilities equals the
|
||||
# top_p value.
|
||||
"top_p": 0.95,
|
||||
# top_k (int): How the model selects tokens for output, the
|
||||
# next token is selected from among the top_k most probable
|
||||
# tokens.
|
||||
"top_k": 40,
|
||||
}
|
||||
```
|
||||
model_tool_kwargs (Mapping[str, Any]):
|
||||
Optional. Additional keyword arguments when binding tools to the
|
||||
model using `model.bind_tools()`.
|
||||
agent_executor_kwargs (Mapping[str, Any]):
|
||||
Optional. Additional keyword arguments for the constructor of
|
||||
langchain.agents.AgentExecutor. An example would be
|
||||
```
|
||||
{
|
||||
# Whether to return the agent's trajectory of intermediate
|
||||
# steps at the end in addition to the final output.
|
||||
"return_intermediate_steps": False,
|
||||
# The maximum number of steps to take before ending the
|
||||
# execution loop.
|
||||
"max_iterations": 15,
|
||||
# The method to use for early stopping if the agent never
|
||||
# returns `AgentFinish`. Either 'force' or 'generate'.
|
||||
"early_stopping_method": "force",
|
||||
# How to handle errors raised by the agent's output parser.
|
||||
# Defaults to `False`, which raises the error.
|
||||
"handle_parsing_errors": False,
|
||||
}
|
||||
```
|
||||
runnable_kwargs (Mapping[str, Any]):
|
||||
Optional. Additional keyword arguments for the constructor of
|
||||
langchain.runnables.history.RunnableWithMessageHistory if
|
||||
chat_history is specified. If chat_history is None, this will be
|
||||
ignored.
|
||||
model_builder (Callable):
|
||||
Optional. Callable that returns a new language model. Defaults
|
||||
to a a callable that returns ChatVertexAI based on `model`,
|
||||
`model_kwargs` and the parameters in `vertexai.init`.
|
||||
runnable_builder (Callable):
|
||||
Optional. Callable that returns a new runnable. This can be used
|
||||
for customizing the orchestration logic of the Agent based on
|
||||
the model returned by `model_builder` and the rest of the input
|
||||
arguments.
|
||||
enable_tracing (bool):
|
||||
Optional. Whether to enable tracing in Cloud Trace. Defaults to
|
||||
False.
|
||||
instrumentor_builder (Callable[..., Any]):
|
||||
Optional. Callable that returns a new instrumentor. This can be
|
||||
used for customizing the instrumentation logic of the Agent.
|
||||
If not provided, a default instrumentor builder will be used.
|
||||
This parameter is ignored if `enable_tracing` is False.
|
||||
|
||||
Raises:
|
||||
ValueError: If both `prompt` and `system_instruction` are specified.
|
||||
TypeError: If there is an invalid tool (e.g. function with an input
|
||||
that did not specify its type).
|
||||
"""
|
||||
from google.cloud.aiplatform import initializer
|
||||
|
||||
self._tmpl_attrs: dict[str, Any] = {
|
||||
"project": initializer.global_config.project,
|
||||
"location": initializer.global_config.location,
|
||||
"tools": [],
|
||||
"model_name": model,
|
||||
"system_instruction": system_instruction,
|
||||
"prompt": prompt,
|
||||
"output_parser": output_parser,
|
||||
"chat_history": chat_history,
|
||||
"model_kwargs": model_kwargs,
|
||||
"model_tool_kwargs": model_tool_kwargs,
|
||||
"agent_executor_kwargs": agent_executor_kwargs,
|
||||
"runnable_kwargs": runnable_kwargs,
|
||||
"model_builder": model_builder,
|
||||
"runnable_builder": runnable_builder,
|
||||
"enable_tracing": enable_tracing,
|
||||
"model": None,
|
||||
"runnable": None,
|
||||
"instrumentor": None,
|
||||
"instrumentor_builder": instrumentor_builder,
|
||||
}
|
||||
if tools:
|
||||
# We validate tools at initialization for actionable feedback before
|
||||
# they are deployed.
|
||||
_validate_tools(tools)
|
||||
self._tmpl_attrs["tools"] = tools
|
||||
if prompt and system_instruction:
|
||||
raise ValueError(
|
||||
"Only one of `prompt` or `system_instruction` should be specified. "
|
||||
"Consider incorporating the system instruction into the prompt "
|
||||
"rather than passing it separately as an argument."
|
||||
)
|
||||
|
||||
def set_up(self):
|
||||
"""Sets up the agent for execution of queries at runtime.
|
||||
|
||||
It initializes the model, binds the model with tools, and connects it
|
||||
with the prompt template and output parser.
|
||||
|
||||
This method should not be called for an object being passed to the
|
||||
service for deployment, as it might initialize clients that can not be
|
||||
serialized.
|
||||
"""
|
||||
if self._tmpl_attrs.get("enable_tracing"):
|
||||
instrumentor_builder = (
|
||||
self._tmpl_attrs.get("instrumentor_builder")
|
||||
or _default_instrumentor_builder
|
||||
)
|
||||
self._tmpl_attrs["instrumentor"] = instrumentor_builder(
|
||||
project_id=self._tmpl_attrs.get("project")
|
||||
)
|
||||
model_builder = self._tmpl_attrs.get("model_builder") or _default_model_builder
|
||||
self._tmpl_attrs["model"] = model_builder(
|
||||
model_name=self._tmpl_attrs.get("model_name"),
|
||||
model_kwargs=self._tmpl_attrs.get("model_kwargs"),
|
||||
project=self._tmpl_attrs.get("project"),
|
||||
location=self._tmpl_attrs.get("location"),
|
||||
)
|
||||
runnable_builder = (
|
||||
self._tmpl_attrs.get("runnable_builder") or _default_runnable_builder
|
||||
)
|
||||
self._tmpl_attrs["runnable"] = runnable_builder(
|
||||
prompt=self._tmpl_attrs.get("prompt"),
|
||||
model=self._tmpl_attrs.get("model"),
|
||||
tools=self._tmpl_attrs.get("tools"),
|
||||
system_instruction=self._tmpl_attrs.get("system_instruction"),
|
||||
output_parser=self._tmpl_attrs.get("output_parser"),
|
||||
chat_history=self._tmpl_attrs.get("chat_history"),
|
||||
model_tool_kwargs=self._tmpl_attrs.get("model_tool_kwargs"),
|
||||
agent_executor_kwargs=self._tmpl_attrs.get("agent_executor_kwargs"),
|
||||
runnable_kwargs=self._tmpl_attrs.get("runnable_kwargs"),
|
||||
)
|
||||
|
||||
def clone(self) -> "LangchainAgent":
|
||||
"""Returns a clone of the LangchainAgent."""
|
||||
import copy
|
||||
|
||||
return LangchainAgent(
|
||||
model=self._tmpl_attrs.get("model_name"),
|
||||
system_instruction=self._tmpl_attrs.get("system_instruction"),
|
||||
prompt=copy.deepcopy(self._tmpl_attrs.get("prompt")),
|
||||
tools=copy.deepcopy(self._tmpl_attrs.get("tools")),
|
||||
output_parser=copy.deepcopy(self._tmpl_attrs.get("output_parser")),
|
||||
chat_history=copy.deepcopy(self._tmpl_attrs.get("chat_history")),
|
||||
model_kwargs=copy.deepcopy(self._tmpl_attrs.get("model_kwargs")),
|
||||
model_tool_kwargs=copy.deepcopy(self._tmpl_attrs.get("model_tool_kwargs")),
|
||||
agent_executor_kwargs=copy.deepcopy(
|
||||
self._tmpl_attrs.get("agent_executor_kwargs")
|
||||
),
|
||||
runnable_kwargs=copy.deepcopy(self._tmpl_attrs.get("runnable_kwargs")),
|
||||
model_builder=self._tmpl_attrs.get("model_builder"),
|
||||
runnable_builder=self._tmpl_attrs.get("runnable_builder"),
|
||||
enable_tracing=self._tmpl_attrs.get("enable_tracing"),
|
||||
instrumentor_builder=self._tmpl_attrs.get("instrumentor_builder"),
|
||||
)
|
||||
|
||||
def query(
|
||||
self,
|
||||
*,
|
||||
input: Union[str, Mapping[str, Any]],
|
||||
config: Optional["RunnableConfig"] = None,
|
||||
**kwargs: Any,
|
||||
) -> Dict[str, Any]:
|
||||
"""Queries the Agent with the given input and config.
|
||||
|
||||
Args:
|
||||
input (Union[str, Mapping[str, Any]]):
|
||||
Required. The input to be passed to the Agent.
|
||||
config (langchain_core.runnables.RunnableConfig):
|
||||
Optional. The config (if any) to be used for invoking the Agent.
|
||||
**kwargs:
|
||||
Optional. Any additional keyword arguments to be passed to the
|
||||
`.invoke()` method of the corresponding AgentExecutor.
|
||||
|
||||
Returns:
|
||||
The output of querying the Agent with the given input and config.
|
||||
"""
|
||||
from langchain.load import dump as langchain_load_dump
|
||||
|
||||
if isinstance(input, str):
|
||||
input = {"input": input}
|
||||
if not self._tmpl_attrs.get("runnable"):
|
||||
self.set_up()
|
||||
return langchain_load_dump.dumpd(
|
||||
self._tmpl_attrs.get("runnable").invoke(
|
||||
input=input, config=config, **kwargs
|
||||
)
|
||||
)
|
||||
|
||||
def stream_query(
|
||||
self,
|
||||
*,
|
||||
input: Union[str, Mapping[str, Any]],
|
||||
config: Optional["RunnableConfig"] = None,
|
||||
**kwargs,
|
||||
) -> Iterable[Any]:
|
||||
"""Stream queries the Agent with the given input and config.
|
||||
|
||||
Args:
|
||||
input (Union[str, Mapping[str, Any]]):
|
||||
Required. The input to be passed to the Agent.
|
||||
config (langchain_core.runnables.RunnableConfig):
|
||||
Optional. The config (if any) to be used for invoking the Agent.
|
||||
**kwargs:
|
||||
Optional. Any additional keyword arguments to be passed to the
|
||||
`.invoke()` method of the corresponding AgentExecutor.
|
||||
|
||||
Yields:
|
||||
The output of querying the Agent with the given input and config.
|
||||
"""
|
||||
from langchain.load import dump as langchain_load_dump
|
||||
|
||||
if isinstance(input, str):
|
||||
input = {"input": input}
|
||||
if not self._tmpl_attrs.get("runnable"):
|
||||
self.set_up()
|
||||
for chunk in self._tmpl_attrs.get("runnable").stream(
|
||||
input=input,
|
||||
config=config,
|
||||
**kwargs,
|
||||
):
|
||||
yield langchain_load_dump.dumpd(chunk)
|
||||
@@ -0,0 +1,692 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2024 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
Iterable,
|
||||
Mapping,
|
||||
Optional,
|
||||
Sequence,
|
||||
Union,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
try:
|
||||
from langchain_core import runnables
|
||||
from langchain_core import tools as lc_tools
|
||||
from langchain_core.language_models import base as lc_language_models
|
||||
|
||||
BaseTool = lc_tools.BaseTool
|
||||
BaseLanguageModel = lc_language_models.BaseLanguageModel
|
||||
RunnableConfig = runnables.RunnableConfig
|
||||
RunnableSerializable = runnables.RunnableSerializable
|
||||
except ImportError:
|
||||
BaseTool = Any
|
||||
BaseLanguageModel = Any
|
||||
RunnableConfig = Any
|
||||
RunnableSerializable = Any
|
||||
|
||||
try:
|
||||
from langchain_google_vertexai.functions_utils import _ToolsType
|
||||
|
||||
_ToolLike = _ToolsType
|
||||
except ImportError:
|
||||
_ToolLike = Any
|
||||
|
||||
try:
|
||||
from opentelemetry.sdk import trace
|
||||
|
||||
TracerProvider = trace.TracerProvider
|
||||
SpanProcessor = trace.SpanProcessor
|
||||
SynchronousMultiSpanProcessor = trace.SynchronousMultiSpanProcessor
|
||||
except ImportError:
|
||||
TracerProvider = Any
|
||||
SpanProcessor = Any
|
||||
SynchronousMultiSpanProcessor = Any
|
||||
|
||||
try:
|
||||
from langgraph_checkpoint.checkpoint import base
|
||||
|
||||
BaseCheckpointSaver = base.BaseCheckpointSaver
|
||||
except ImportError:
|
||||
try:
|
||||
from langgraph.checkpoint import base
|
||||
|
||||
BaseCheckpointSaver = base.BaseCheckpointSaver
|
||||
except ImportError:
|
||||
BaseCheckpointSaver = Any
|
||||
|
||||
|
||||
def _default_model_builder(
|
||||
model_name: str,
|
||||
*,
|
||||
project: str,
|
||||
location: str,
|
||||
model_kwargs: Optional[Mapping[str, Any]] = None,
|
||||
) -> "BaseLanguageModel":
|
||||
"""Default callable for building a language model.
|
||||
|
||||
Args:
|
||||
model_name (str):
|
||||
Required. The name of the model (e.g. "gemini-1.0-pro").
|
||||
project (str):
|
||||
Required. The Google Cloud project ID.
|
||||
location (str):
|
||||
Required. The Google Cloud location.
|
||||
model_kwargs (Mapping[str, Any]):
|
||||
Optional. Additional keyword arguments for the constructor of
|
||||
chat_models.ChatVertexAI.
|
||||
|
||||
Returns:
|
||||
BaseLanguageModel: The language model.
|
||||
"""
|
||||
import vertexai
|
||||
from google.cloud.aiplatform import initializer
|
||||
from langchain_google_vertexai import ChatVertexAI
|
||||
|
||||
model_kwargs = model_kwargs or {}
|
||||
current_project = initializer.global_config.project
|
||||
current_location = initializer.global_config.location
|
||||
vertexai.init(project=project, location=location)
|
||||
model = ChatVertexAI(model_name=model_name, **model_kwargs)
|
||||
vertexai.init(project=current_project, location=current_location)
|
||||
return model
|
||||
|
||||
|
||||
def _default_runnable_builder(
|
||||
model: "BaseLanguageModel",
|
||||
*,
|
||||
tools: Optional[Sequence["_ToolLike"]] = None,
|
||||
checkpointer: Optional[Any] = None,
|
||||
model_tool_kwargs: Optional[Mapping[str, Any]] = None,
|
||||
runnable_kwargs: Optional[Mapping[str, Any]] = None,
|
||||
) -> "RunnableSerializable":
|
||||
"""Default callable for building a runnable.
|
||||
|
||||
Args:
|
||||
model (BaseLanguageModel):
|
||||
Required. The language model.
|
||||
tools (Optional[Sequence[_ToolLike]]):
|
||||
Optional. The tools for the agent to be able to use.
|
||||
checkpointer (Optional[Checkpointer]):
|
||||
Optional. The checkpointer for the agent.
|
||||
model_tool_kwargs (Optional[Mapping[str, Any]]):
|
||||
Optional. Additional keyword arguments when binding tools to the model.
|
||||
runnable_kwargs (Optional[Mapping[str, Any]]):
|
||||
Optional. Additional keyword arguments for the runnable.
|
||||
|
||||
Returns:
|
||||
RunnableSerializable: The runnable.
|
||||
"""
|
||||
from langgraph import prebuilt as langgraph_prebuilt
|
||||
|
||||
model_tool_kwargs = model_tool_kwargs or {}
|
||||
runnable_kwargs = runnable_kwargs or {}
|
||||
if tools:
|
||||
model = model.bind_tools(tools=tools, **model_tool_kwargs)
|
||||
else:
|
||||
tools = []
|
||||
if checkpointer:
|
||||
if "checkpointer" in runnable_kwargs:
|
||||
from google.cloud.aiplatform import base
|
||||
|
||||
base.Logger(__name__).warning(
|
||||
"checkpointer is being specified in both checkpointer_builder "
|
||||
"and runnable_kwargs. Please specify it in only one of them. "
|
||||
"Overriding the checkpointer in runnable_kwargs."
|
||||
)
|
||||
runnable_kwargs["checkpointer"] = checkpointer
|
||||
return langgraph_prebuilt.create_react_agent(
|
||||
model,
|
||||
tools=tools,
|
||||
**runnable_kwargs,
|
||||
)
|
||||
|
||||
|
||||
def _default_instrumentor_builder(project_id: str):
|
||||
from vertexai.agent_engines import _utils
|
||||
|
||||
cloud_trace_exporter = _utils._import_cloud_trace_exporter_or_warn()
|
||||
cloud_trace_v2 = _utils._import_cloud_trace_v2_or_warn()
|
||||
openinference_langchain = _utils._import_openinference_langchain_or_warn()
|
||||
opentelemetry = _utils._import_opentelemetry_or_warn()
|
||||
opentelemetry_sdk_trace = _utils._import_opentelemetry_sdk_trace_or_warn()
|
||||
if all(
|
||||
(
|
||||
cloud_trace_exporter,
|
||||
cloud_trace_v2,
|
||||
openinference_langchain,
|
||||
opentelemetry,
|
||||
opentelemetry_sdk_trace,
|
||||
)
|
||||
):
|
||||
import google.auth
|
||||
|
||||
credentials, _ = google.auth.default()
|
||||
span_exporter = cloud_trace_exporter.CloudTraceSpanExporter(
|
||||
project_id=project_id,
|
||||
client=cloud_trace_v2.TraceServiceClient(
|
||||
credentials=credentials.with_quota_project(project_id),
|
||||
),
|
||||
)
|
||||
span_processor: SpanProcessor = (
|
||||
opentelemetry_sdk_trace.export.SimpleSpanProcessor(
|
||||
span_exporter=span_exporter,
|
||||
)
|
||||
)
|
||||
tracer_provider: TracerProvider = opentelemetry.trace.get_tracer_provider()
|
||||
# Get the appropriate tracer provider:
|
||||
# 1. If _TRACER_PROVIDER is already set, use that.
|
||||
# 2. Otherwise, if the OTEL_PYTHON_TRACER_PROVIDER environment
|
||||
# variable is set, use that.
|
||||
# 3. As a final fallback, use _PROXY_TRACER_PROVIDER.
|
||||
# If none of the above is set, we log a warning, and
|
||||
# create a tracer provider.
|
||||
if not tracer_provider:
|
||||
from google.cloud.aiplatform import base
|
||||
|
||||
base.Logger(__name__).warning(
|
||||
"No tracer provider. By default, "
|
||||
"we should get one of the following providers: "
|
||||
"OTEL_PYTHON_TRACER_PROVIDER, _TRACER_PROVIDER, "
|
||||
"or _PROXY_TRACER_PROVIDER."
|
||||
)
|
||||
tracer_provider = opentelemetry_sdk_trace.TracerProvider()
|
||||
opentelemetry.trace.set_tracer_provider(tracer_provider)
|
||||
# Avoids AttributeError:
|
||||
# 'ProxyTracerProvider' and 'NoOpTracerProvider' objects has no
|
||||
# attribute 'add_span_processor'.
|
||||
if _utils.is_noop_or_proxy_tracer_provider(tracer_provider):
|
||||
tracer_provider = opentelemetry_sdk_trace.TracerProvider()
|
||||
opentelemetry.trace.set_tracer_provider(tracer_provider)
|
||||
# Avoids OpenTelemetry client already exists error.
|
||||
_override_active_span_processor(
|
||||
tracer_provider,
|
||||
opentelemetry_sdk_trace.SynchronousMultiSpanProcessor(),
|
||||
)
|
||||
tracer_provider.add_span_processor(span_processor)
|
||||
# Keep the instrumentation up-to-date.
|
||||
# When creating multiple LangchainAgents,
|
||||
# we need to keep the instrumentation up-to-date.
|
||||
# We deliberately override the instrument each time,
|
||||
# so that if different agents end up using different
|
||||
# instrumentations, we guarantee that the user is always
|
||||
# working with the most recent agent's instrumentation.
|
||||
instrumentor = openinference_langchain.LangChainInstrumentor()
|
||||
if instrumentor.is_instrumented_by_opentelemetry:
|
||||
instrumentor.uninstrument()
|
||||
instrumentor.instrument()
|
||||
return instrumentor
|
||||
else:
|
||||
from google.cloud.aiplatform import base
|
||||
|
||||
_LOGGER = base.Logger(__name__)
|
||||
_LOGGER.warning(
|
||||
"enable_tracing=True but proceeding with tracing disabled "
|
||||
"because not all packages for tracing have been installed"
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
def _validate_callable_parameters_are_annotated(callable: Callable):
|
||||
"""Validates that the parameters of the callable have type annotations.
|
||||
|
||||
This ensures that they can be used for constructing LangChain tools that are
|
||||
usable with Gemini function calling.
|
||||
|
||||
Args:
|
||||
callable (Callable): The callable to validate.
|
||||
|
||||
Raises:
|
||||
TypeError: If any parameter is not annotated.
|
||||
"""
|
||||
import inspect
|
||||
|
||||
parameters = dict(inspect.signature(callable).parameters)
|
||||
for name, parameter in parameters.items():
|
||||
if parameter.annotation == inspect.Parameter.empty:
|
||||
raise TypeError(
|
||||
f"Callable={callable.__name__} has untyped input_arg={name}. "
|
||||
f"Please specify a type when defining it, e.g. `{name}: str`."
|
||||
)
|
||||
|
||||
|
||||
def _validate_tools(tools: Sequence["_ToolLike"]):
|
||||
"""Validates that the tools are usable for tool calling.
|
||||
|
||||
Args:
|
||||
tools (Sequence[_ToolLike]): The tools to validate.
|
||||
|
||||
Raises:
|
||||
TypeError: If any tool is a callable with untyped parameters.
|
||||
"""
|
||||
for tool in tools:
|
||||
if isinstance(tool, Callable):
|
||||
_validate_callable_parameters_are_annotated(tool)
|
||||
|
||||
|
||||
def _override_active_span_processor(
|
||||
tracer_provider: "TracerProvider",
|
||||
active_span_processor: "SynchronousMultiSpanProcessor",
|
||||
):
|
||||
"""Overrides the active span processor.
|
||||
|
||||
When working with multiple LangchainAgents in the same environment,
|
||||
it's crucial to manage trace exports carefully.
|
||||
Each agent needs its own span processor tied to a unique project ID.
|
||||
While we add a new span processor for each agent, this can lead to
|
||||
unexpected behavior.
|
||||
For instance, with two agents linked to different projects, traces from the
|
||||
second agent might be sent to both projects.
|
||||
To prevent this and guarantee traces go to the correct project, we overwrite
|
||||
the active span processor whenever a new LangchainAgent is created.
|
||||
|
||||
Args:
|
||||
tracer_provider (TracerProvider):
|
||||
The tracer provider to use for the project.
|
||||
active_span_processor (SynchronousMultiSpanProcessor):
|
||||
The active span processor overrides the tracer provider's
|
||||
active span processor.
|
||||
"""
|
||||
if tracer_provider._active_span_processor:
|
||||
tracer_provider._active_span_processor.shutdown()
|
||||
tracer_provider._active_span_processor = active_span_processor
|
||||
|
||||
|
||||
class LanggraphAgent:
|
||||
"""A LangGraph Agent."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: str,
|
||||
*,
|
||||
tools: Optional[Sequence["_ToolLike"]] = None,
|
||||
model_kwargs: Optional[Mapping[str, Any]] = None,
|
||||
model_tool_kwargs: Optional[Mapping[str, Any]] = None,
|
||||
model_builder: Optional[Callable[..., "BaseLanguageModel"]] = None,
|
||||
runnable_kwargs: Optional[Mapping[str, Any]] = None,
|
||||
runnable_builder: Optional[Callable[..., "RunnableSerializable"]] = None,
|
||||
checkpointer_kwargs: Optional[Mapping[str, Any]] = None,
|
||||
checkpointer_builder: Optional[Callable[..., "BaseCheckpointSaver"]] = None,
|
||||
enable_tracing: bool = False,
|
||||
instrumentor_builder: Optional[Callable[..., Any]] = None,
|
||||
):
|
||||
"""Initializes the LangGraph Agent.
|
||||
|
||||
Under-the-hood, assuming .set_up() is called, this will correspond to
|
||||
```python
|
||||
model = model_builder(model_name=model, model_kwargs=model_kwargs)
|
||||
runnable = runnable_builder(
|
||||
model=model,
|
||||
tools=tools,
|
||||
model_tool_kwargs=model_tool_kwargs,
|
||||
runnable_kwargs=runnable_kwargs,
|
||||
)
|
||||
```
|
||||
|
||||
When everything is based on their default values, this corresponds to
|
||||
```python
|
||||
# model_builder
|
||||
from langchain_google_vertexai import ChatVertexAI
|
||||
llm = ChatVertexAI(model_name=model, **model_kwargs)
|
||||
|
||||
# runnable_builder
|
||||
from langgraph.prebuilt import create_react_agent
|
||||
llm_with_tools = llm.bind_tools(tools=tools, **model_tool_kwargs)
|
||||
runnable = create_react_agent(
|
||||
llm_with_tools,
|
||||
tools=tools,
|
||||
**runnable_kwargs,
|
||||
)
|
||||
```
|
||||
|
||||
By default, no checkpointer is used (i.e. there is no state history). To
|
||||
enable checkpointing, provide a `checkpointer_builder` function that
|
||||
returns a checkpointer instance.
|
||||
|
||||
**Example using Spanner:**
|
||||
```python
|
||||
def checkpointer_builder(instance_id, database_id, project_id, **kwargs):
|
||||
from langchain_google_spanner import SpannerCheckpointSaver
|
||||
|
||||
checkpointer = SpannerCheckpointSaver(instance_id, database_id, project_id)
|
||||
with checkpointer.cursor() as cur:
|
||||
cur.execute("DROP TABLE IF EXISTS checkpoints")
|
||||
cur.execute("DROP TABLE IF EXISTS checkpoint_writes")
|
||||
checkpointer.setup()
|
||||
|
||||
return checkpointer
|
||||
```
|
||||
|
||||
**Example using an in-memory checkpointer:**
|
||||
```python
|
||||
def checkpointer_builder(**kwargs):
|
||||
from langgraph.checkpoint.memory import MemorySaver
|
||||
|
||||
return MemorySaver()
|
||||
```
|
||||
|
||||
The `checkpointer_builder` function will be called with any keyword
|
||||
arguments passed to the agent's constructor. Ensure your
|
||||
`checkpointer_builder` function accepts `**kwargs` to handle these
|
||||
arguments, even if unused.
|
||||
|
||||
Args:
|
||||
model (str):
|
||||
Optional. The name of the model (e.g. "gemini-1.0-pro").
|
||||
tools (Sequence[langchain_core.tools.BaseTool, Callable]):
|
||||
Optional. The tools for the agent to be able to use. All input
|
||||
callables (e.g. function or class method) will be converted
|
||||
to a langchain.tools.base.StructuredTool. Defaults to None.
|
||||
model_kwargs (Mapping[str, Any]):
|
||||
Optional. Additional keyword arguments for the constructor of
|
||||
chat_models.ChatVertexAI. An example would be
|
||||
```
|
||||
{
|
||||
# temperature (float): Sampling temperature, it controls the
|
||||
# degree of randomness in token selection.
|
||||
"temperature": 0.28,
|
||||
# max_output_tokens (int): Token limit determines the
|
||||
# maximum amount of text output from one prompt.
|
||||
"max_output_tokens": 1000,
|
||||
# top_p (float): Tokens are selected from most probable to
|
||||
# least, until the sum of their probabilities equals the
|
||||
# top_p value.
|
||||
"top_p": 0.95,
|
||||
# top_k (int): How the model selects tokens for output, the
|
||||
# next token is selected from among the top_k most probable
|
||||
# tokens.
|
||||
"top_k": 40,
|
||||
}
|
||||
```
|
||||
model_tool_kwargs (Mapping[str, Any]):
|
||||
Optional. Additional keyword arguments when binding tools to the
|
||||
model using `model.bind_tools()`.
|
||||
model_builder (Callable[..., "BaseLanguageModel"]):
|
||||
Optional. Callable that returns a new language model. Defaults
|
||||
to a a callable that returns ChatVertexAI based on `model`,
|
||||
`model_kwargs` and the parameters in `vertexai.init`.
|
||||
runnable_kwargs (Mapping[str, Any]):
|
||||
Optional. Additional keyword arguments for the constructor of
|
||||
langchain.runnables.history.RunnableWithMessageHistory if
|
||||
chat_history is specified. If chat_history is None, this will be
|
||||
ignored.
|
||||
runnable_builder (Callable[..., "RunnableSerializable"]):
|
||||
Optional. Callable that returns a new runnable. This can be used
|
||||
for customizing the orchestration logic of the Agent based on
|
||||
the model returned by `model_builder` and the rest of the input
|
||||
arguments.
|
||||
checkpointer_kwargs (Mapping[str, Any]):
|
||||
Optional. Additional keyword arguments for the constructor of
|
||||
the checkpointer returned by `checkpointer_builder`.
|
||||
checkpointer_builder (Callable[..., "BaseCheckpointSaver"]):
|
||||
Optional. Callable that returns a checkpointer. This can be used
|
||||
for defining the checkpointer of the Agent. Defaults to None.
|
||||
enable_tracing (bool):
|
||||
Optional. Whether to enable tracing in Cloud Trace. Defaults to
|
||||
False.
|
||||
instrumentor_builder (Callable[..., Any]):
|
||||
Optional. Callable that returns a new instrumentor. This can be
|
||||
used for customizing the instrumentation logic of the Agent.
|
||||
If not provided, a default instrumentor builder will be used.
|
||||
This parameter is ignored if `enable_tracing` is False.
|
||||
|
||||
Raises:
|
||||
TypeError: If there is an invalid tool (e.g. function with an input
|
||||
that did not specify its type).
|
||||
"""
|
||||
from google.cloud.aiplatform import initializer
|
||||
|
||||
self._tmpl_attrs: dict[str, Any] = {
|
||||
"project": initializer.global_config.project,
|
||||
"location": initializer.global_config.location,
|
||||
"tools": [],
|
||||
"model_name": model,
|
||||
"model_kwargs": model_kwargs,
|
||||
"model_tool_kwargs": model_tool_kwargs,
|
||||
"runnable_kwargs": runnable_kwargs,
|
||||
"checkpointer_kwargs": checkpointer_kwargs,
|
||||
"model": None,
|
||||
"model_builder": model_builder,
|
||||
"runnable": None,
|
||||
"runnable_builder": runnable_builder,
|
||||
"checkpointer": None,
|
||||
"checkpointer_builder": checkpointer_builder,
|
||||
"enable_tracing": enable_tracing,
|
||||
"instrumentor": None,
|
||||
"instrumentor_builder": instrumentor_builder,
|
||||
}
|
||||
if tools:
|
||||
# We validate tools at initialization for actionable feedback before
|
||||
# they are deployed.
|
||||
_validate_tools(tools)
|
||||
self._tmpl_attrs["tools"] = tools
|
||||
|
||||
def set_up(self):
|
||||
"""Sets up the agent for execution of queries at runtime.
|
||||
|
||||
It initializes the model, binds the model with tools, and connects it
|
||||
with the prompt template and output parser.
|
||||
|
||||
This method should not be called for an object that being passed to
|
||||
the ReasoningEngine service for deployment, as it initializes clients
|
||||
that can not be serialized.
|
||||
"""
|
||||
if self._tmpl_attrs.get("enable_tracing"):
|
||||
instrumentor_builder = (
|
||||
self._tmpl_attrs.get("instrumentor_builder")
|
||||
or _default_instrumentor_builder
|
||||
)
|
||||
self._tmpl_attrs["instrumentor"] = instrumentor_builder(
|
||||
project_id=self._tmpl_attrs.get("project")
|
||||
)
|
||||
model_builder = self._tmpl_attrs.get("model_builder") or _default_model_builder
|
||||
self._tmpl_attrs["model"] = model_builder(
|
||||
model_name=self._tmpl_attrs.get("model_name"),
|
||||
model_kwargs=self._tmpl_attrs.get("model_kwargs"),
|
||||
project=self._tmpl_attrs.get("project"),
|
||||
location=self._tmpl_attrs.get("location"),
|
||||
)
|
||||
checkpointer_builder = self._tmpl_attrs.get("checkpointer_builder")
|
||||
if checkpointer_builder:
|
||||
checkpointer_kwargs = self._tmpl_attrs.get("checkpointer_kwargs") or {}
|
||||
self._tmpl_attrs["checkpointer"] = checkpointer_builder(
|
||||
**checkpointer_kwargs
|
||||
)
|
||||
runnable_builder = (
|
||||
self._tmpl_attrs.get("runnable_builder") or _default_runnable_builder
|
||||
)
|
||||
self._tmpl_attrs["runnable"] = runnable_builder(
|
||||
model=self._tmpl_attrs.get("model"),
|
||||
tools=self._tmpl_attrs.get("tools"),
|
||||
checkpointer=self._tmpl_attrs.get("checkpointer"),
|
||||
model_tool_kwargs=self._tmpl_attrs.get("model_tool_kwargs"),
|
||||
runnable_kwargs=self._tmpl_attrs.get("runnable_kwargs"),
|
||||
)
|
||||
|
||||
def clone(self) -> "LanggraphAgent":
|
||||
"""Returns a clone of the LanggraphAgent."""
|
||||
import copy
|
||||
|
||||
return LanggraphAgent(
|
||||
model=self._tmpl_attrs.get("model_name"),
|
||||
tools=copy.deepcopy(self._tmpl_attrs.get("tools")),
|
||||
model_kwargs=copy.deepcopy(self._tmpl_attrs.get("model_kwargs")),
|
||||
model_tool_kwargs=copy.deepcopy(self._tmpl_attrs.get("model_tool_kwargs")),
|
||||
runnable_kwargs=copy.deepcopy(self._tmpl_attrs.get("runnable_kwargs")),
|
||||
checkpointer_kwargs=copy.deepcopy(
|
||||
self._tmpl_attrs.get("checkpointer_kwargs")
|
||||
),
|
||||
model_builder=self._tmpl_attrs.get("model_builder"),
|
||||
runnable_builder=self._tmpl_attrs.get("runnable_builder"),
|
||||
checkpointer_builder=self._tmpl_attrs.get("checkpointer_builder"),
|
||||
enable_tracing=self._tmpl_attrs.get("enable_tracing"),
|
||||
instrumentor_builder=self._tmpl_attrs.get("instrumentor_builder"),
|
||||
)
|
||||
|
||||
def query(
|
||||
self,
|
||||
*,
|
||||
input: Union[str, Mapping[str, Any]],
|
||||
config: Optional["RunnableConfig"] = None,
|
||||
**kwargs: Any,
|
||||
) -> Dict[str, Any]:
|
||||
"""Queries the Agent with the given input and config.
|
||||
|
||||
Args:
|
||||
input (Union[str, Mapping[str, Any]]):
|
||||
Required. The input to be passed to the Agent.
|
||||
config (langchain_core.runnables.RunnableConfig):
|
||||
Optional. The config (if any) to be used for invoking the Agent.
|
||||
**kwargs:
|
||||
Optional. Any additional keyword arguments to be passed to the
|
||||
`.invoke()` method of the corresponding AgentExecutor.
|
||||
|
||||
Returns:
|
||||
The output of querying the Agent with the given input and config.
|
||||
"""
|
||||
from langchain.load import dump as langchain_load_dump
|
||||
|
||||
if isinstance(input, str):
|
||||
input = {"input": input}
|
||||
if not self._tmpl_attrs.get("runnable"):
|
||||
self.set_up()
|
||||
return langchain_load_dump.dumpd(
|
||||
self._tmpl_attrs.get("runnable").invoke(
|
||||
input=input, config=config, **kwargs
|
||||
)
|
||||
)
|
||||
|
||||
def stream_query(
|
||||
self,
|
||||
*,
|
||||
input: Union[str, Mapping[str, Any]],
|
||||
config: Optional["RunnableConfig"] = None,
|
||||
**kwargs,
|
||||
) -> Iterable[Any]:
|
||||
"""Stream queries the Agent with the given input and config.
|
||||
|
||||
Args:
|
||||
input (Union[str, Mapping[str, Any]]):
|
||||
Required. The input to be passed to the Agent.
|
||||
config (langchain_core.runnables.RunnableConfig):
|
||||
Optional. The config (if any) to be used for invoking the Agent.
|
||||
**kwargs:
|
||||
Optional. Any additional keyword arguments to be passed to the
|
||||
`.invoke()` method of the corresponding AgentExecutor.
|
||||
|
||||
Yields:
|
||||
The output of querying the Agent with the given input and config.
|
||||
"""
|
||||
from langchain.load import dump as langchain_load_dump
|
||||
|
||||
if isinstance(input, str):
|
||||
input = {"input": input}
|
||||
if not self._tmpl_attrs.get("runnable"):
|
||||
self.set_up()
|
||||
for chunk in self._tmpl_attrs.get("runnable").stream(
|
||||
input=input,
|
||||
config=config,
|
||||
**kwargs,
|
||||
):
|
||||
yield langchain_load_dump.dumpd(chunk)
|
||||
|
||||
def get_state_history(
|
||||
self,
|
||||
config: Optional["RunnableConfig"] = None,
|
||||
**kwargs: Any,
|
||||
) -> Iterable[Any]:
|
||||
"""Gets the state history of the Agent.
|
||||
|
||||
Args:
|
||||
config (Optional[RunnableConfig]):
|
||||
Optional. The config for invoking the Agent.
|
||||
**kwargs:
|
||||
Optional. Additional keyword arguments for the `.invoke()` method.
|
||||
|
||||
Yields:
|
||||
Dict[str, Any]: The state history of the Agent.
|
||||
"""
|
||||
if not self._tmpl_attrs.get("runnable"):
|
||||
self.set_up()
|
||||
for state_snapshot in self._tmpl_attrs.get("runnable").get_state_history(
|
||||
config=config,
|
||||
**kwargs,
|
||||
):
|
||||
yield state_snapshot._asdict()
|
||||
|
||||
def get_state(
|
||||
self,
|
||||
config: Optional["RunnableConfig"] = None,
|
||||
**kwargs: Any,
|
||||
) -> Dict[str, Any]:
|
||||
"""Gets the current state of the Agent.
|
||||
|
||||
Args:
|
||||
config (Optional[RunnableConfig]):
|
||||
Optional. The config for invoking the Agent.
|
||||
**kwargs:
|
||||
Optional. Additional keyword arguments for the `.invoke()` method.
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: The current state of the Agent.
|
||||
"""
|
||||
if not self._tmpl_attrs.get("runnable"):
|
||||
self.set_up()
|
||||
return (
|
||||
self._tmpl_attrs.get("runnable")
|
||||
.get_state(config=config, **kwargs)
|
||||
._asdict()
|
||||
)
|
||||
|
||||
def update_state(
|
||||
self,
|
||||
config: Optional["RunnableConfig"] = None,
|
||||
**kwargs: Any,
|
||||
) -> Dict[str, Any]:
|
||||
"""Updates the state of the Agent.
|
||||
|
||||
Args:
|
||||
config (Optional[RunnableConfig]):
|
||||
Optional. The config for invoking the Agent.
|
||||
**kwargs:
|
||||
Optional. Additional keyword arguments for the `.invoke()` method.
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: The updated state of the Agent.
|
||||
"""
|
||||
if not self._tmpl_attrs.get("runnable"):
|
||||
self.set_up()
|
||||
return self._tmpl_attrs.get("runnable").update_state(config=config, **kwargs)
|
||||
|
||||
def register_operations(self) -> Mapping[str, Sequence[str]]:
|
||||
"""Registers the operations of the Agent.
|
||||
|
||||
This mapping defines how different operation modes (e.g., "", "stream")
|
||||
are implemented by specific methods of the Agent. The "default" mode,
|
||||
represented by the empty string ``, is associated with the `query` API,
|
||||
while the "stream" mode is associated with the `stream_query` API.
|
||||
|
||||
Returns:
|
||||
Mapping[str, Sequence[str]]: A mapping of operation modes to a list
|
||||
of method names that implement those operation modes.
|
||||
"""
|
||||
return {
|
||||
"": ["query", "get_state", "update_state"],
|
||||
"stream": ["stream_query", "get_state_history"],
|
||||
}
|
||||
Reference in New Issue
Block a user