structure saas with tools

This commit is contained in:
Davidson Gomes
2025-04-25 15:30:54 -03:00
commit 1aef473937
16434 changed files with 6584257 additions and 0 deletions

View File

@@ -0,0 +1,327 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""Classes and functions for working with agent engines."""
from typing import Dict, Iterable, Optional, Sequence, Union
from google.cloud.aiplatform import base
from google.cloud.aiplatform import initializer
from google.cloud.aiplatform import utils as aip_utils
from google.cloud.aiplatform_v1 import types as aip_types
# We just want to re-export certain classes
# pylint: disable=g-multiple-import,g-importing-member
from vertexai.agent_engines._agent_engines import (
AgentEngine,
Cloneable,
ModuleAgent,
OperationRegistrable,
Queryable,
StreamQueryable,
)
from vertexai.agent_engines.templates.ag2 import (
AG2Agent,
)
from vertexai.agent_engines.templates.langchain import (
LangchainAgent,
)
from vertexai.agent_engines.templates.langgraph import (
LanggraphAgent,
)
_LOGGER = base.Logger(__name__)
def get(resource_name: str) -> AgentEngine:
"""Retrieves an Agent Engine resource.
Args:
resource_name (str):
Required. A fully-qualified resource name or ID such as
"projects/123/locations/us-central1/reasoningEngines/456" or
"456" when project and location are initialized or passed.
"""
return AgentEngine(resource_name)
def create(
agent_engine: Optional[Union[Queryable, OperationRegistrable]] = None,
*,
requirements: Optional[Union[str, Sequence[str]]] = None,
display_name: Optional[str] = None,
description: Optional[str] = None,
gcs_dir_name: Optional[str] = None,
extra_packages: Optional[Sequence[str]] = None,
env_vars: Optional[
Union[Sequence[str], Dict[str, Union[str, aip_types.SecretRef]]]
] = None,
) -> AgentEngine:
"""Creates a new Agent Engine.
The Agent Engine will be an instance of the `agent_engine` that
was passed in, running remotely on Vertex AI.
Sample ``src_dir`` contents (e.g. ``./user_src_dir``):
.. code-block:: python
user_src_dir/
|-- main.py
|-- requirements.txt
|-- user_code/
| |-- utils.py
| |-- ...
|-- ...
To build an Agent Engine with the above files, run:
.. code-block:: python
remote_agent = agent_engines.create(
agent_engine=local_agent,
requirements=[
# I.e. the PyPI dependencies listed in requirements.txt
"google-cloud-aiplatform==1.25.0",
"langchain==0.0.242",
...
],
extra_packages=[
"./user_src_dir/main.py", # a single file
"./user_src_dir/user_code", # a directory
...
],
)
Args:
agent_engine (AgentEngineInterface):
Required. The Agent Engine to be created.
requirements (Union[str, Sequence[str]]):
Optional. The set of PyPI dependencies needed. It can either be
the path to a single file (requirements.txt), or an ordered list
of strings corresponding to each line of the requirements file.
display_name (str):
Optional. The user-defined name of the Agent Engine.
The name can be up to 128 characters long and can comprise any
UTF-8 character.
description (str):
Optional. The description of the Agent Engine.
gcs_dir_name (str):
Optional. The GCS bucket directory under `staging_bucket` to
use for staging the artifacts needed.
extra_packages (Sequence[str]):
Optional. The set of extra user-provided packages (if any).
env_vars (Union[Sequence[str], Dict[str, Union[str, SecretRef]]]):
Optional. The environment variables to be set when running the
Agent Engine. If it is a list of strings, each string should be
a valid key to `os.environ`. If it is a dictionary, the keys are
the environment variable names, and the values are the
corresponding values.
Returns:
AgentEngine: The Agent Engine that was created.
Raises:
ValueError: If the `project` was not set using `vertexai.init`.
ValueError: If the `location` was not set using `vertexai.init`.
ValueError: If the `staging_bucket` was not set using vertexai.init.
ValueError: If the `staging_bucket` does not start with "gs://".
FileNotFoundError: If `extra_packages` includes a file or directory
that does not exist.
IOError: If requirements is a string that corresponds to a
nonexistent file.
"""
return AgentEngine.create(
agent_engine=agent_engine,
requirements=requirements,
display_name=display_name,
description=description,
gcs_dir_name=gcs_dir_name,
extra_packages=extra_packages,
env_vars=env_vars,
)
def list(*, filter: str = "") -> Iterable[AgentEngine]:
"""List all instances of Agent Engine matching the filter.
Example Usage:
.. code-block:: python
import vertexai
from vertexai import agent_engines
vertexai.init(project="my_project", location="us-central1")
agent_engines.list(filter='display_name="My Custom Agent"')
Args:
filter (str):
Optional. An expression for filtering the results of the request.
For field names both snake_case and camelCase are supported.
Returns:
Iterable[AgentEngine]: An iterable of Agent Engines matching the filter.
"""
api_client = initializer.global_config.create_client(
client_class=aip_utils.AgentEngineClientWithOverride,
)
for agent in api_client.list_reasoning_engines(
request=aip_types.ListReasoningEnginesRequest(
parent=initializer.global_config.common_location_path(),
filter=filter,
)
):
yield AgentEngine(agent.name)
def delete(
resource_name: str,
*,
force: bool = False,
**kwargs,
) -> None:
"""Delete an Agent Engine resource.
Args:
resource_name (str):
Required. The name of the Agent Engine to be deleted. Format:
`projects/{project}/locations/{location}/reasoningEngines/{resource_id}`
force (bool):
Optional. If set to True, child resources will also be deleted.
Otherwise, the request will fail with FAILED_PRECONDITION error
when the Agent Engine has undeleted child resources. Defaults to
False.
**kwargs (dict[str, Any]):
Optional. Additional keyword arguments to pass to the
delete_reasoning_engine method.
"""
api_client = initializer.global_config.create_client(
client_class=aip_utils.AgentEngineClientWithOverride,
)
_LOGGER.info(f"Deleting AgentEngine resource: {resource_name}")
operation_future = api_client.delete_reasoning_engine(
request=aip_types.DeleteReasoningEngineRequest(
name=resource_name,
force=force,
**(kwargs or {}),
)
)
_LOGGER.info(f"Delete AgentEngine backing LRO: {operation_future.operation.name}")
operation_future.result()
_LOGGER.info(f"AgentEngine resource deleted: {resource_name}")
def update(
resource_name: str,
*,
agent_engine: Optional[Union[Queryable, OperationRegistrable]] = None,
requirements: Optional[Union[str, Sequence[str]]] = None,
display_name: Optional[str] = None,
description: Optional[str] = None,
gcs_dir_name: Optional[str] = None,
extra_packages: Optional[Sequence[str]] = None,
env_vars: Optional[
Union[Sequence[str], Dict[str, Union[str, aip_types.SecretRef]]]
] = None,
) -> "AgentEngine":
"""Updates an existing Agent Engine.
This method updates the configuration of a deployed Agent Engine, identified
by its resource name. Unlike the `create` function which requires an
`agent_engine` object, all arguments in this method are optional. This
method allows you to modify individual aspects of the configuration by
providing any of the optional arguments.
Args:
resource_name (str):
Required. The name of the Agent Engine to be updated. Format:
`projects/{project}/locations/{location}/reasoningEngines/{resource_id}`.
agent_engine (AgentEngineInterface):
Optional. The instance to be used as the updated Agent Engine. If it
is not specified, the existing instance will be used.
requirements (Union[str, Sequence[str]]):
Optional. The set of PyPI dependencies needed. It can either be
the path to a single file (requirements.txt), or an ordered list
of strings corresponding to each line of the requirements file.
If it is not specified, the existing requirements will be used.
If it is set to an empty string or list, the existing
requirements will be removed.
display_name (str):
Optional. The user-defined name of the Agent Engine.
The name can be up to 128 characters long and can comprise any
UTF-8 character.
description (str):
Optional. The description of the Agent Engine.
gcs_dir_name (str):
Optional. The GCS bucket directory under `staging_bucket` to
use for staging the artifacts needed.
extra_packages (Sequence[str]):
Optional. The set of extra user-provided packages (if any). If
it is not specified, the existing extra packages will be used.
If it is set to an empty list, the existing extra packages will
be removed.
env_vars (Union[Sequence[str], Dict[str, Union[str, SecretRef]]]):
Optional. The environment variables to be set when running the
Agent Engine. If it is a list of strings, each string should be
a valid key to `os.environ`. If it is a dictionary, the keys are
the environment variable names, and the values are the
corresponding values.
Returns:
AgentEngine: The Agent Engine that was updated.
Raises:
ValueError: If the `staging_bucket` was not set using vertexai.init.
ValueError: If the `staging_bucket` does not start with "gs://".
FileNotFoundError: If `extra_packages` includes a file or directory
that does not exist.
ValueError: if none of `display_name`, `description`,
`requirements`, `extra_packages`, or `agent_engine` were
specified.
IOError: If requirements is a string that corresponds to a
nonexistent file.
"""
agent = get(resource_name)
return agent.update(
agent_engine=agent_engine,
requirements=requirements,
display_name=display_name,
description=description,
gcs_dir_name=gcs_dir_name,
extra_packages=extra_packages,
env_vars=env_vars,
)
__all__ = (
# Resources
"AgentEngine",
# Protocols
"Cloneable",
"OperationRegistrable",
"Queryable",
"StreamQueryable",
# Methods
"create",
"delete",
"get",
"list",
"update",
# Templates
"ModuleAgent",
"LangchainAgent",
"LanggraphAgent",
"AG2Agent",
)

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,769 @@
# -*- coding: utf-8 -*-
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import dataclasses
import inspect
import json
import sys
import types
import typing
from typing import (
Any,
Callable,
Dict,
Iterable,
List,
Mapping,
Optional,
Sequence,
Set,
TypedDict,
Union,
)
from importlib import metadata as importlib_metadata
import proto
from google.cloud.aiplatform import base
from google.api import httpbody_pb2
from google.protobuf import struct_pb2
from google.protobuf import json_format
try:
# For LangChain templates, they might not import langchain_core and get
# PydanticUserError: `query` is not fully defined; you should define
# `RunnableConfig`, then call `query.model_rebuild()`.
import langchain_core.runnables.config
RunnableConfig = langchain_core.runnables.config.RunnableConfig
except ImportError:
RunnableConfig = Any
try:
import packaging
SpecifierSet = packaging.specifiers.SpecifierSet
except AttributeError:
SpecifierSet = Any
try:
_BUILTIN_MODULE_NAMES: Sequence[str] = sys.builtin_module_names
except AttributeError:
_BUILTIN_MODULE_NAMES: Sequence[str] = []
try:
# sys.stdlib_module_names is available from Python 3.10 onwards.
_STDLIB_MODULE_NAMES: frozenset = sys.stdlib_module_names
except AttributeError:
_STDLIB_MODULE_NAMES: frozenset = frozenset()
try:
_PACKAGE_DISTRIBUTIONS: Mapping[
str, Sequence[str]
] = importlib_metadata.packages_distributions()
except AttributeError:
_PACKAGE_DISTRIBUTIONS: Mapping[str, Sequence[str]] = {}
try:
from autogen.agentchat import chat
AutogenChatResult = chat.ChatResult
except ImportError:
AutogenChatResult = Any
try:
from autogen.io import run_response
AutogenRunResponse = run_response.RunResponse
except ImportError:
AutogenRunResponse = Any
JsonDict = Dict[str, Any]
class _RequirementsValidationActions(TypedDict):
append: Set[str]
class _RequirementsValidationWarnings(TypedDict):
missing: Set[str]
incompatible: Set[str]
class _RequirementsValidationResult(TypedDict):
warnings: _RequirementsValidationWarnings
actions: _RequirementsValidationActions
LOGGER = base.Logger("vertexai.agent_engines")
_BASE_MODULES = set(_BUILTIN_MODULE_NAMES + tuple(_STDLIB_MODULE_NAMES))
_DEFAULT_REQUIRED_PACKAGES = frozenset(["cloudpickle", "pydantic"])
_ACTIONS_KEY = "actions"
_ACTION_APPEND = "append"
_WARNINGS_KEY = "warnings"
_WARNING_MISSING = "missing"
_WARNING_INCOMPATIBLE = "incompatible"
def to_proto(
obj: Union[JsonDict, proto.Message],
message: Optional[proto.Message] = None,
) -> proto.Message:
"""Parses a JSON-like object into a message.
If the object is already a message, this will return the object as-is. If
the object is a JSON Dict, this will parse and merge the object into the
message.
Args:
obj (Union[dict[str, Any], proto.Message]):
Required. The object to convert to a proto message.
message (proto.Message):
Optional. A protocol buffer message to merge the obj into. It
defaults to Struct() if unspecified.
Returns:
proto.Message: The same message passed as argument.
"""
if message is None:
message = struct_pb2.Struct()
if isinstance(obj, (proto.Message, struct_pb2.Struct)):
return obj
try:
json_format.ParseDict(obj, message._pb)
except AttributeError:
json_format.ParseDict(obj, message)
return message
def to_dict(message: proto.Message) -> JsonDict:
"""Converts the contents of the protobuf message to JSON format.
Args:
message (proto.Message):
Required. The proto message to be converted to a JSON dictionary.
Returns:
dict[str, Any]: A dictionary containing the contents of the proto.
"""
try:
# Best effort attempt to convert the message into a JSON dictionary.
result: JsonDict = json.loads(
json_format.MessageToJson(
message._pb,
preserving_proto_field_name=True,
)
)
except AttributeError:
result: JsonDict = json.loads(
json_format.MessageToJson(
message,
preserving_proto_field_name=True,
)
)
return result
def _dataclass_to_dict_or_raise(obj: Any) -> JsonDict:
"""Converts a dataclass to a JSON dictionary.
Args:
obj (Any):
Required. The dataclass to be converted to a JSON dictionary.
Returns:
dict[str, Any]: A dictionary containing the contents of the dataclass.
Raises:
TypeError: If the object is not a dataclass.
"""
if not dataclasses.is_dataclass(obj):
raise TypeError(f"Object is not a dataclass: {obj}")
return json.loads(json.dumps(dataclasses.asdict(obj)))
def _autogen_run_response_protocol_to_dict(
obj: AutogenRunResponse,
) -> JsonDict:
"""Converts an AutogenRunResponse object into a JSON-serializable dictionary.
This function takes a `RunResponseProtocol` object and transforms its
relevant attributes into a dictionary format suitable for JSON conversion.
The `RunResponseProtocol` defines the structure of the response object,
which typically includes:
* **summary** (`Optional[str]`):
A textual summary of the run.
* **messages** (`Iterable[Message]`):
A sequence of messages exchanged during the run.
Each message is expected to be a JSON-serializable dictionary (`Dict[str,
Any]`).
* **events** (`Iterable[BaseEvent]`):
A sequence of events that occurred during the run.
Note: The `process()` method, if present, is called before conversion,
which typically clears this event queue.
* **context_variables** (`Optional[dict[str, Any]]`):
A dictionary containing contextual variables from the run.
* **last_speaker** (`Optional[Agent]`):
The agent that produced the last message.
The `Agent` object has attributes like `name` (Optional[str]) and
`description` (Optional[str]).
* **cost** (`Optional[Cost]`):
Information about the computational cost of the run.
The `Cost` object inherits from `pydantic.BaseModel` and is converted
to JSON using its `model_dump_json()` method.
* **process** (`Optional[Callable[[], None]]`):
An optional function (like a console event processor) that is called
before the conversion takes place.
Executing this method often clears the `events` queue.
For a detailed definition of `RunResponseProtocol` and its components, refer
to: https://github.com/ag2ai/ag2/blob/main/autogen/io/run_response.py
Args:
obj (AutogenRunResponse): The AutogenRunResponse object to convert. This
object must conform to the `RunResponseProtocol`.
Returns:
JsonDict: A dictionary representation of the AutogenRunResponse, ready
to be serialized into JSON. The dictionary includes keys like
'summary', 'messages', 'context_variables', 'last_speaker_name',
and 'cost'.
"""
if hasattr(obj, "process"):
obj.process()
last_speaker = None
if getattr(obj, "last_speaker", None) is not None:
last_speaker = {
"name": getattr(obj.last_speaker, "name", None),
"description": getattr(obj.last_speaker, "description", None),
}
cost = None
if getattr(obj, "cost", None) is not None:
if hasattr(obj.cost, "model_dump_json"):
cost = json.loads(obj.cost.model_dump_json())
else:
cost = str(obj.cost)
result = {
"summary": getattr(obj, "summary", None),
"messages": list(getattr(obj, "messages", [])),
"context_variables": getattr(obj, "context_variables", None),
"last_speaker": last_speaker,
"cost": cost,
}
return json.loads(json.dumps(result))
def to_json_serializable_autogen_object(
obj: Union[
AutogenChatResult,
AutogenRunResponse,
]
) -> JsonDict:
"""Converts an Autogen object to a JSON serializable object.
In `ag2<=0.8.4`, `.run()` will return a `ChatResult` object.
In `ag2>=0.8.5`, `.run()` will return a `RunResponse` object.
Args:
obj (Union[AutogenChatResult, AutogenRunResponse]):
Required. The Autogen object to be converted to a JSON serializable
object.
Returns:
JsonDict: A JSON serializable object.
"""
if isinstance(obj, AutogenChatResult):
return _dataclass_to_dict_or_raise(obj)
return _autogen_run_response_protocol_to_dict(obj)
def yield_parsed_json(body: httpbody_pb2.HttpBody) -> Iterable[Any]:
"""Converts the contents of the httpbody message to JSON format.
Args:
body (httpbody_pb2.HttpBody):
Required. The httpbody body to be converted to a JSON.
Yields:
Any: A JSON object or the original body if it is not JSON or None.
"""
content_type = getattr(body, "content_type", None)
data = getattr(body, "data", None)
if content_type is None or data is None or "application/json" not in content_type:
yield body
return
try:
utf8_data = data.decode("utf-8")
except Exception as e:
LOGGER.warning(f"Failed to decode data: {data}. Exception: {e}")
yield body
return
if not utf8_data:
yield None
return
# Handle the case of multiple dictionaries delimited by newlines.
for line in utf8_data.split("\n"):
if line:
try:
line = json.loads(line)
except Exception as e:
LOGGER.warning(f"failed to parse json: {line}. Exception: {e}")
yield line
def parse_constraints(
constraints: Sequence[str],
) -> Mapping[str, "SpecifierSet"]:
"""Parses a list of constraints into a dict of requirements.
Args:
constraints (list[str]):
Required. The list of package requirements to parse. This is assumed
to come from the `requirements.txt` file.
Returns:
dict[str, SpecifierSet]: The specifiers for each package.
"""
requirements = _import_packaging_requirements_or_raise()
result = {}
for constraint in constraints:
try:
requirement = requirements.Requirement(constraint)
except Exception as e:
LOGGER.warning(f"Failed to parse constraint: {constraint}. Exception: {e}")
continue
result[requirement.name] = requirement.specifier or None
return result
def validate_requirements_or_warn(
obj: Any,
requirements: List[str],
) -> Mapping[str, str]:
"""Compiles the requirements into a list of requirements."""
requirements = requirements.copy()
try:
current_requirements = scan_requirements(obj)
LOGGER.info(f"Identified the following requirements: {current_requirements}")
constraints = parse_constraints(requirements)
missing_requirements = compare_requirements(current_requirements, constraints)
for warning_type, warnings in missing_requirements.get(
_WARNINGS_KEY, {}
).items():
if warnings:
LOGGER.warning(
f"The following requirements are {warning_type}: {warnings}"
)
for action_type, actions in missing_requirements.get(_ACTIONS_KEY, {}).items():
if actions and action_type == _ACTION_APPEND:
for action in actions:
requirements.append(action)
LOGGER.info(f"The following requirements are appended: {actions}")
except Exception as e:
LOGGER.warning(f"Failed to compile requirements: {e}")
return requirements
def compare_requirements(
requirements: Mapping[str, str],
constraints: Union[Sequence[str], Mapping[str, "SpecifierSet"]],
*,
required_packages: Optional[Sequence[str]] = None,
) -> Mapping[str, Mapping[str, Any]]:
"""Compares the requirements with the constraints.
Args:
requirements (Mapping[str, str]):
Required. The packages (and their versions) to compare with the constraints.
This is assumed to be the result of `scan_requirements`.
constraints (Union[Sequence[str], Mapping[str, SpecifierSet]]):
Required. The package constraints to compare against. This is assumed
to be the result of `parse_constraints`.
required_packages (Sequence[str]):
Optional. The set of packages that are required to be in the
constraints. It defaults to the set of packages that are required
for deployment on Agent Engine.
Returns:
dict[str, dict[str, Any]]: The comparison result as a dictionary containing:
* warnings:
* missing: The set of packages that are not in the constraints.
* incompatible: The set of packages that are in the constraints
but have versions that are not in the constraint specifier.
* actions:
* append: The set of packages that are not in the constraints
but should be appended to the constraints.
"""
packaging_version = _import_packaging_version_or_raise()
if required_packages is None:
required_packages = _DEFAULT_REQUIRED_PACKAGES
result = _RequirementsValidationResult(
warnings=_RequirementsValidationWarnings(missing=set(), incompatible=set()),
actions=_RequirementsValidationActions(append=set()),
)
if isinstance(constraints, list):
constraints = parse_constraints(constraints)
for package, package_version in requirements.items():
if package not in constraints:
result[_WARNINGS_KEY][_WARNING_MISSING].add(package)
if package in required_packages:
result[_ACTIONS_KEY][_ACTION_APPEND].add(
f"{package}=={package_version}"
)
continue
if package_version:
package_specifier = constraints[package]
if not package_specifier:
continue
if packaging_version.Version(package_version) not in package_specifier:
result[_WARNINGS_KEY][_WARNING_INCOMPATIBLE].add(
f"{package}=={package_version} (required: {str(package_specifier)})"
)
return result
def scan_requirements(
obj: Any,
ignore_modules: Optional[Sequence[str]] = None,
package_distributions: Optional[Mapping[str, Sequence[str]]] = None,
inspect_getmembers_kwargs: Optional[Mapping[str, Any]] = None,
) -> Mapping[str, str]:
"""Scans the object for modules and returns the requirements discovered.
This is not a comprehensive scan of the object, and only detects for common
cases based on the members of the object returned by `dir(obj)`.
Args:
obj (Any):
Required. The object to scan for package requirements.
ignore_modules (Sequence[str]):
Optional. The set of modules to ignore. It defaults to the set of
built-in and stdlib modules.
package_distributions (Mapping[str, Sequence[str]]):
Optional. The mapping of module names to the set of packages that
contain them. It defaults to the set of packages from
`importlib_metadata.packages_distributions()`.
inspect_getmembers_kwargs (Mapping[str, Any]):
Optional. The keyword arguments to pass to `inspect.getmembers`. It
defaults to an empty dictionary.
Returns:
Sequence[str]: The list of requirements that were discovered.
"""
if ignore_modules is None:
ignore_modules = _BASE_MODULES
if package_distributions is None:
package_distributions = _PACKAGE_DISTRIBUTIONS
modules_found = set(_DEFAULT_REQUIRED_PACKAGES)
inspect_getmembers_kwargs = inspect_getmembers_kwargs or {}
for _, attr in inspect.getmembers(obj, **inspect_getmembers_kwargs):
if not attr or inspect.isbuiltin(attr) or not hasattr(attr, "__module__"):
continue
module_name = (attr.__module__ or "").split(".")[0]
if module_name and module_name not in ignore_modules:
for module in package_distributions.get(module_name, []):
modules_found.add(module)
return {module: importlib_metadata.version(module) for module in modules_found}
def generate_schema(
f: Callable[..., Any],
*,
schema_name: Optional[str] = None,
descriptions: Mapping[str, str] = {},
required: Sequence[str] = [],
) -> JsonDict:
"""Generates the OpenAPI Schema for a callable object.
Only positional and keyword arguments of the function `f` will be supported
in the OpenAPI Schema that is generated. I.e. `*args` and `**kwargs` will
not be present in the OpenAPI schema returned from this function. For those
cases, you can either include it in the docstring for `f`, or modify the
OpenAPI schema returned from this function to include additional arguments.
Args:
f (Callable):
Required. The function to generate an OpenAPI Schema for.
schema_name (str):
Optional. The name for the OpenAPI schema. If unspecified, the name
of the Callable will be used.
descriptions (Mapping[str, str]):
Optional. A `{name: description}` mapping for annotating input
arguments of the function with user-provided descriptions. It
defaults to an empty dictionary (i.e. there will not be any
description for any of the inputs).
required (Sequence[str]):
Optional. For the user to specify the set of required arguments in
function calls to `f`. If specified, it will be automatically
inferred from `f`.
Returns:
dict[str, Any]: The OpenAPI Schema for the function `f` in JSON format.
"""
pydantic = _import_pydantic_or_raise()
defaults = dict(inspect.signature(f).parameters)
fields_dict = {
name: (
# 1. We infer the argument type here: use Any rather than None so
# it will not try to auto-infer the type based on the default value.
(param.annotation if param.annotation != inspect.Parameter.empty else Any),
pydantic.Field(
# 2. We do not support default values for now.
# default=(
# param.default if param.default != inspect.Parameter.empty
# else None
# ),
# 3. We support user-provided descriptions.
description=descriptions.get(name, None),
),
)
for name, param in defaults.items()
# We do not support *args or **kwargs
if param.kind
in (
inspect.Parameter.POSITIONAL_OR_KEYWORD,
inspect.Parameter.KEYWORD_ONLY,
inspect.Parameter.POSITIONAL_ONLY,
)
}
parameters = pydantic.create_model(f.__name__, **fields_dict).schema()
# Postprocessing
# 4. Suppress unnecessary title generation:
# * https://github.com/pydantic/pydantic/issues/1051
# * http://cl/586221780
parameters.pop("title", "")
for name, function_arg in parameters.get("properties", {}).items():
function_arg.pop("title", "")
annotation = defaults[name].annotation
# 5. Nullable fields:
# * https://github.com/pydantic/pydantic/issues/1270
# * https://stackoverflow.com/a/58841311
# * https://github.com/pydantic/pydantic/discussions/4872
if typing.get_origin(annotation) is Union and type(None) in typing.get_args(
annotation
):
# for "typing.Optional" arguments, function_arg might be a
# dictionary like
#
# {'anyOf': [{'type': 'integer'}, {'type': 'null'}]
for schema in function_arg.pop("anyOf", []):
schema_type = schema.get("type")
if schema_type and schema_type != "null":
function_arg["type"] = schema_type
break
function_arg["nullable"] = True
# 6. Annotate required fields.
if required:
# We use the user-provided "required" fields if specified.
parameters["required"] = required
else:
# Otherwise we infer it from the function signature.
parameters["required"] = [
k
for k in defaults
if (
defaults[k].default == inspect.Parameter.empty
and defaults[k].kind
in (
inspect.Parameter.POSITIONAL_OR_KEYWORD,
inspect.Parameter.KEYWORD_ONLY,
inspect.Parameter.POSITIONAL_ONLY,
)
)
]
schema = dict(name=f.__name__, description=f.__doc__, parameters=parameters)
if schema_name:
schema["name"] = schema_name
return schema
def is_noop_or_proxy_tracer_provider(tracer_provider) -> bool:
"""Returns True if the tracer_provider is Proxy or NoOp."""
opentelemetry = _import_opentelemetry_or_warn()
ProxyTracerProvider = opentelemetry.trace.ProxyTracerProvider
NoOpTracerProvider = opentelemetry.trace.NoOpTracerProvider
return isinstance(tracer_provider, (NoOpTracerProvider, ProxyTracerProvider))
def _import_cloud_storage_or_raise() -> types.ModuleType:
"""Tries to import the Cloud Storage module."""
try:
from google.cloud import storage
except ImportError as e:
raise ImportError(
"Cloud Storage is not installed. Please call "
"'pip install google-cloud-aiplatform[agent_engines]'."
) from e
return storage
def _import_cloudpickle_or_raise() -> types.ModuleType:
"""Tries to import the cloudpickle module."""
try:
import cloudpickle # noqa:F401
except ImportError as e:
raise ImportError(
"cloudpickle is not installed. Please call "
"'pip install google-cloud-aiplatform[agent_engines]'."
) from e
return cloudpickle
def _import_pydantic_or_raise() -> types.ModuleType:
"""Tries to import the pydantic module."""
try:
import pydantic
_ = pydantic.Field
except AttributeError:
from pydantic import v1 as pydantic
except ImportError as e:
raise ImportError(
"pydantic is not installed. Please call "
"'pip install google-cloud-aiplatform[agent_engines]'."
) from e
return pydantic
def _import_packaging_requirements_or_raise() -> types.ModuleType:
"""Tries to import the packaging.requirements module."""
try:
from packaging import requirements
except ImportError as e:
raise ImportError(
"packaging.requirements is not installed. Please call "
"'pip install google-cloud-aiplatform[agent_engines]'."
) from e
return requirements
def _import_packaging_version_or_raise() -> types.ModuleType:
"""Tries to import the packaging.requirements module."""
try:
from packaging import version
except ImportError as e:
raise ImportError(
"packaging.version is not installed. Please call "
"'pip install google-cloud-aiplatform[agent_engines]'."
) from e
return version
def _import_opentelemetry_or_warn() -> Optional[types.ModuleType]:
"""Tries to import the opentelemetry module."""
try:
import opentelemetry # noqa:F401
return opentelemetry
except ImportError:
LOGGER.warning(
"opentelemetry-sdk is not installed. Please call "
"'pip install google-cloud-aiplatform[agent_engines]'."
)
return None
def _import_opentelemetry_sdk_trace_or_warn() -> Optional[types.ModuleType]:
"""Tries to import the opentelemetry.sdk.trace module."""
try:
import opentelemetry.sdk.trace # noqa:F401
return opentelemetry.sdk.trace
except ImportError:
LOGGER.warning(
"opentelemetry-sdk is not installed. Please call "
"'pip install google-cloud-aiplatform[agent_engines]'."
)
return None
def _import_cloud_trace_v2_or_warn() -> Optional[types.ModuleType]:
"""Tries to import the google.cloud.trace_v2 module."""
try:
import google.cloud.trace_v2
return google.cloud.trace_v2
except ImportError:
LOGGER.warning(
"google-cloud-trace is not installed. Please call "
"'pip install google-cloud-aiplatform[agent_engines]'."
)
return None
def _import_cloud_trace_exporter_or_warn() -> Optional[types.ModuleType]:
"""Tries to import the opentelemetry.exporter.cloud_trace module."""
try:
import opentelemetry.exporter.cloud_trace # noqa:F401
return opentelemetry.exporter.cloud_trace
except ImportError:
LOGGER.warning(
"opentelemetry-exporter-gcp-trace is not installed. Please "
"call 'pip install google-cloud-aiplatform[agent_engines]'."
)
return None
def _import_openinference_langchain_or_warn() -> Optional[types.ModuleType]:
"""Tries to import the openinference.instrumentation.langchain module."""
try:
import openinference.instrumentation.langchain # noqa:F401
return openinference.instrumentation.langchain
except ImportError:
LOGGER.warning(
"openinference-instrumentation-langchain is not installed. Please "
"call 'pip install google-cloud-aiplatform[langchain]'."
)
return None
def _import_openinference_autogen_or_warn() -> Optional[types.ModuleType]:
"""Tries to import the openinference.instrumentation.autogen module."""
try:
import openinference.instrumentation.autogen # noqa:F401
return openinference.instrumentation.autogen
except ImportError:
LOGGER.warning(
"openinference-instrumentation-autogen is not installed. Please "
"call 'pip install google-cloud-aiplatform[ag2]'."
)
return None
def _import_autogen_tools_or_warn() -> Optional[types.ModuleType]:
"""Tries to import the autogen.tools module."""
try:
from autogen import tools
return tools
except ImportError:
LOGGER.warning(
"autogen.tools is not installed. Please "
"call `pip install google-cloud-aiplatform[ag2]`."
)
return None

View File

@@ -0,0 +1,487 @@
# -*- coding: utf-8 -*-
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
Mapping,
Optional,
Sequence,
Union,
)
if TYPE_CHECKING:
try:
from autogen import agentchat
ConversableAgent = agentchat.ConversableAgent
ChatResult = agentchat.ChatResult
except ImportError:
ConversableAgent = Any
try:
from opentelemetry.sdk import trace
TracerProvider = trace.TracerProvider
SpanProcessor = trace.SpanProcessor
SynchronousMultiSpanProcessor = trace.SynchronousMultiSpanProcessor
except ImportError:
TracerProvider = Any
SpanProcessor = Any
SynchronousMultiSpanProcessor = Any
def _prepare_runnable_kwargs(
runnable_kwargs: Mapping[str, Any],
system_instruction: str,
runnable_name: str,
llm_config: Mapping[str, Any],
) -> Mapping[str, Any]:
"""Prepares the configuration for a runnable, applying defaults and enforcing constraints."""
if runnable_kwargs is None:
runnable_kwargs = {}
if (
"human_input_mode" in runnable_kwargs
and runnable_kwargs["human_input_mode"] != "NEVER"
):
from google.cloud.aiplatform import base
_LOGGER = base.Logger(__name__)
_LOGGER.warning(
f"human_input_mode={runnable_kwargs['human_input_mode']}"
"is not supported. Will be enforced to 'NEVER'."
)
runnable_kwargs["human_input_mode"] = "NEVER"
if "system_message" not in runnable_kwargs and system_instruction:
runnable_kwargs["system_message"] = system_instruction
if "name" not in runnable_kwargs:
runnable_kwargs["name"] = runnable_name
if "llm_config" not in runnable_kwargs:
runnable_kwargs["llm_config"] = llm_config
return runnable_kwargs
def _default_runnable_builder(
**runnable_kwargs: Any,
) -> "ConversableAgent":
from autogen import agentchat
return agentchat.ConversableAgent(**runnable_kwargs)
def _default_instrumentor_builder(project_id: str):
from vertexai.agent_engines import _utils
cloud_trace_exporter = _utils._import_cloud_trace_exporter_or_warn()
cloud_trace_v2 = _utils._import_cloud_trace_v2_or_warn()
openinference_autogen = _utils._import_openinference_autogen_or_warn()
opentelemetry = _utils._import_opentelemetry_or_warn()
opentelemetry_sdk_trace = _utils._import_opentelemetry_sdk_trace_or_warn()
if all(
(
cloud_trace_exporter,
cloud_trace_v2,
openinference_autogen,
opentelemetry,
opentelemetry_sdk_trace,
)
):
import google.auth
credentials, _ = google.auth.default()
span_exporter = cloud_trace_exporter.CloudTraceSpanExporter(
project_id=project_id,
client=cloud_trace_v2.TraceServiceClient(
credentials=credentials.with_quota_project(project_id),
),
)
span_processor: SpanProcessor = (
opentelemetry_sdk_trace.export.SimpleSpanProcessor(
span_exporter=span_exporter,
)
)
tracer_provider: TracerProvider = opentelemetry.trace.get_tracer_provider()
# Get the appropriate tracer provider:
# 1. If _TRACER_PROVIDER is already set, use that.
# 2. Otherwise, if the OTEL_PYTHON_TRACER_PROVIDER environment
# variable is set, use that.
# 3. As a final fallback, use _PROXY_TRACER_PROVIDER.
# If none of the above is set, we log a warning, and
# create a tracer provider.
if not tracer_provider:
from google.cloud.aiplatform import base
_LOGGER = base.Logger(__name__)
_LOGGER.warning(
"No tracer provider. By default, "
"we should get one of the following providers: "
"OTEL_PYTHON_TRACER_PROVIDER, _TRACER_PROVIDER, "
"or _PROXY_TRACER_PROVIDER."
)
tracer_provider = opentelemetry_sdk_trace.TracerProvider()
opentelemetry.trace.set_tracer_provider(tracer_provider)
# Avoids AttributeError:
# 'ProxyTracerProvider' and 'NoOpTracerProvider' objects has no
# attribute 'add_span_processor'.
if _utils.is_noop_or_proxy_tracer_provider(tracer_provider):
tracer_provider = opentelemetry_sdk_trace.TracerProvider()
opentelemetry.trace.set_tracer_provider(tracer_provider)
# Avoids OpenTelemetry client already exists error.
_override_active_span_processor(
tracer_provider,
opentelemetry_sdk_trace.SynchronousMultiSpanProcessor(),
)
tracer_provider.add_span_processor(span_processor)
# Keep the instrumentation up-to-date.
# When creating multiple AG2Agents,
# we need to keep the instrumentation up-to-date.
# We deliberately override the instrument each time,
# so that if different agents end up using different
# instrumentations, we guarantee that the user is always
# working with the most recent agent's instrumentation.
instrumentor = openinference_autogen.AutogenInstrumentor()
instrumentor.uninstrument()
instrumentor.instrument()
return instrumentor
else:
from google.cloud.aiplatform import base
_LOGGER = base.Logger(__name__)
_LOGGER.warning(
"enable_tracing=True but proceeding with tracing disabled "
"because not all packages for tracing have been installed"
)
return None
def _validate_callable_parameters_are_annotated(callable: Callable):
"""Validates that the parameters of the callable have type annotations.
This ensures that they can be used for constructing AG2 tools that are
usable with Gemini function calling.
"""
import inspect
parameters = dict(inspect.signature(callable).parameters)
for name, parameter in parameters.items():
if parameter.annotation == inspect.Parameter.empty:
raise TypeError(
f"Callable={callable.__name__} has untyped input_arg={name}. "
f"Please specify a type when defining it, e.g. `{name}: str`."
)
def _validate_tools(tools: Sequence[Callable[..., Any]]):
"""Validates that the tools are usable for tool calling."""
for tool in tools:
if isinstance(tool, Callable):
_validate_callable_parameters_are_annotated(tool)
def _override_active_span_processor(
tracer_provider: "TracerProvider",
active_span_processor: "SynchronousMultiSpanProcessor",
):
"""Overrides the active span processor.
When working with multiple AG2Agents in the same environment,
it's crucial to manage trace exports carefully.
Each agent needs its own span processor tied to a unique project ID.
While we add a new span processor for each agent, this can lead to
unexpected behavior.
For instance, with two agents linked to different projects, traces from the
second agent might be sent to both projects.
To prevent this and guarantee traces go to the correct project, we overwrite
the active span processor whenever a new AG2Agent is created.
Args:
tracer_provider (TracerProvider):
The tracer provider to use for the project.
active_span_processor (SynchronousMultiSpanProcessor):
The active span processor overrides the tracer provider's
active span processor.
"""
if tracer_provider._active_span_processor:
tracer_provider._active_span_processor.shutdown()
tracer_provider._active_span_processor = active_span_processor
class AG2Agent:
"""An AG2 Agent."""
def __init__(
self,
model: str,
runnable_name: str,
*,
api_type: Optional[str] = None,
llm_config: Optional[Mapping[str, Any]] = None,
system_instruction: Optional[str] = None,
runnable_kwargs: Optional[Mapping[str, Any]] = None,
runnable_builder: Optional[Callable[..., "ConversableAgent"]] = None,
tools: Optional[Sequence[Callable[..., Any]]] = None,
enable_tracing: bool = False,
instrumentor_builder: Optional[Callable[..., Any]] = None,
):
"""Initializes the AG2 Agent.
Under-the-hood, assuming .set_up() is called, this will correspond to
```python
# runnable_builder
runnable = runnable_builder(
llm_config=llm_config,
system_message=system_instruction,
**runnable_kwargs,
)
```
When everything is based on their default values, this corresponds to
```python
# llm_config
llm_config = {
"config_list": [{
"project_id": initializer.global_config.project,
"location": initializer.global_config.location,
"model": "gemini-1.0-pro-001",
"api_type": "google",
}]
}
# runnable_builder
runnable = ConversableAgent(
llm_config=llm_config,
name="Default AG2 Agent"
system_message="You are a helpful AI Assistant.",
human_input_mode="NEVER",
)
```
By default, if `llm_config` is not specified, a default configuration
will be created using the provided `model` and `api_type`.
If `runnable_builder` is not specified, a default runnable builder will
be used, configured with the `system_instruction`, `runnable_name` and
`runnable_kwargs`.
Args:
model (str):
Required. The name of the model (e.g. "gemini-1.0-pro").
Used to create a default `llm_config` if one is not provided.
This parameter is ignored if `llm_config` is provided.
runnable_name (str):
Required. The name of the runnable.
This name is used as the default `runnable_kwargs["name"]`
unless `runnable_kwargs` already contains a "name", in which
case the provided `runnable_kwargs["name"]` will be used.
api_type (str):
Optional. The API type to use for the language model.
Used to create a default `llm_config` if one is not provided.
This parameter is ignored if `llm_config` is provided.
llm_config (Mapping[str, Any]):
Optional. Configuration dictionary for the language model.
If provided, this configuration will be used directly.
Otherwise, a default `llm_config` will be created using `model`
and `api_type`. This `llm_config` is used as the default
`runnable_kwargs["llm_config"]` unless `runnable_kwargs` already
contains a "llm_config", in which case the provided
`runnable_kwargs["llm_config"]` will be used.
system_instruction (str):
Optional. The system instruction for the agent.
This instruction is used as the default
`runnable_kwargs["system_message"]` unless `runnable_kwargs`
already contains a "system_message", in which case the provided
`runnable_kwargs["system_message"]` will be used.
runnable_kwargs (Mapping[str, Any]):
Optional. Additional keyword arguments for the constructor of
the runnable. Details of the kwargs can be found in
https://docs.ag2.ai/docs/api-reference/autogen/ConversableAgent.
`runnable_kwargs` only supports `human_input_mode="NEVER"`.
Other `human_input_mode` values will trigger a warning.
runnable_builder (Callable[..., "ConversableAgent"]):
Optional. Callable that returns a new runnable. This can be used
for customizing the orchestration logic of the Agent.
If not provided, a default runnable builder will be used.
tools (Sequence[Callable[..., Any]]):
Optional. The tools for the agent to be able to use. All input
callables (e.g. function or class method) will be converted
to a AG2 tool . Defaults to None.
enable_tracing (bool):
Optional. Whether to enable tracing in Cloud Trace. Defaults to
False.
instrumentor_builder (Callable[..., Any]):
Optional. Callable that returns a new instrumentor. This can be
used for customizing the instrumentation logic of the Agent.
If not provided, a default instrumentor builder will be used.
This parameter is ignored if `enable_tracing` is False.
"""
from google.cloud.aiplatform import initializer
self._tmpl_attrs: dict[str, Any] = {
"project": initializer.global_config.project,
"location": initializer.global_config.location,
"model_name": model,
"api_type": api_type or "google",
"system_instruction": system_instruction,
"runnable_name": runnable_name,
"tools": [],
"ag2_tool_objects": [],
"runnable": None,
"runnable_builder": runnable_builder,
"instrumentor": None,
"instrumentor_builder": instrumentor_builder,
"enable_tracing": enable_tracing,
}
self._tmpl_attrs["llm_config"] = llm_config or {
"config_list": [
{
"project_id": self._tmpl_attrs.get("project"),
"location": self._tmpl_attrs.get("location"),
"model": self._tmpl_attrs.get("model_name"),
"api_type": self._tmpl_attrs.get("api_type"),
}
]
}
self._tmpl_attrs["runnable_kwargs"] = _prepare_runnable_kwargs(
runnable_kwargs=runnable_kwargs,
llm_config=self._tmpl_attrs.get("llm_config"),
system_instruction=self._tmpl_attrs.get("system_instruction"),
runnable_name=self._tmpl_attrs.get("runnable_name"),
)
if tools:
# We validate tools at initialization for actionable feedback before
# they are deployed.
_validate_tools(tools)
self._tmpl_attrs["tools"] = tools
def set_up(self):
"""Sets up the agent for execution of queries at runtime.
It initializes the runnable, binds the runnable with tools.
This method should not be called for an object that being passed to
the ReasoningEngine service for deployment, as it initializes clients
that can not be serialized.
"""
if self._tmpl_attrs.get("enable_tracing"):
instrumentor_builder = (
self._tmpl_attrs.get("instrumentor_builder")
or _default_instrumentor_builder
)
self._tmpl_attrs["instrumentor"] = instrumentor_builder(
project_id=self._tmpl_attrs.get("project")
)
# Set up tools.
tools = self._tmpl_attrs.get("tools")
ag2_tool_objects = self._tmpl_attrs.get("ag2_tool_objects")
if tools and not ag2_tool_objects:
from vertexai.agent_engines import _utils
autogen_tools = _utils._import_autogen_tools_or_warn()
if autogen_tools:
for tool in tools:
ag2_tool_objects.append(autogen_tools.Tool(func_or_tool=tool))
# Set up runnable.
runnable_builder = (
self._tmpl_attrs.get("runnable_builder") or _default_runnable_builder
)
self._tmpl_attrs["runnable"] = runnable_builder(
**self._tmpl_attrs.get("runnable_kwargs")
)
def clone(self) -> "AG2Agent":
"""Returns a clone of the AG2Agent."""
import copy
return AG2Agent(
model=self._tmpl_attrs.get("model_name"),
api_type=self._tmpl_attrs.get("api_type"),
llm_config=copy.deepcopy(self._tmpl_attrs.get("llm_config")),
system_instruction=self._tmpl_attrs.get("system_instruction"),
runnable_name=self._tmpl_attrs.get("runnable_name"),
tools=copy.deepcopy(self._tmpl_attrs.get("tools")),
runnable_kwargs=copy.deepcopy(self._tmpl_attrs.get("runnable_kwargs")),
runnable_builder=self._tmpl_attrs.get("runnable_builder"),
enable_tracing=self._tmpl_attrs.get("enable_tracing"),
instrumentor_builder=self._tmpl_attrs.get("instrumentor_builder"),
)
def query(
self,
*,
input: Union[str, Mapping[str, Any]],
max_turns: Optional[int] = None,
**kwargs: Any,
) -> Dict[str, Any]:
"""Queries the Agent with the given input.
Args:
input (Union[str, Mapping[str, Any]]):
Required. The input to be passed to the Agent.
max_turns (int):
Optional. The maximum number of turns to run the agent for.
If not provided, the agent will run indefinitely.
If `max_turns` is a `float`, it will be converted to `int`
through rounding.
**kwargs:
Optional. Any additional keyword arguments to be passed to the
`.run()` method of the corresponding runnable.
Details of the kwargs can be found in
https://docs.ag2.ai/docs/api-reference/autogen/ConversableAgent#run.
The `user_input` parameter defaults to `False`, and should not
be passed through `kwargs`.
Returns:
The output of querying the Agent with the given input.
"""
if isinstance(input, str):
input = {"content": input}
if max_turns and isinstance(max_turns, float):
# Supporting auto-conversion float to int.
max_turns = round(max_turns)
if "user_input" in kwargs:
from google.cloud.aiplatform import base
_LOGGER = base.Logger(__name__)
_LOGGER.warning(
"The `user_input` parameter should not be passed through"
"kwargs. The `user_input` defaults to `False`."
)
kwargs.pop("user_input")
if not self._tmpl_attrs.get("runnable"):
self.set_up()
response = self._tmpl_attrs.get("runnable").run(
message=input,
user_input=False,
tools=self._tmpl_attrs.get("ag2_tool_objects"),
max_turns=max_turns,
**kwargs,
)
from vertexai.agent_engines import _utils
return _utils.to_json_serializable_autogen_object(response)

View File

@@ -0,0 +1,673 @@
# -*- coding: utf-8 -*-
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
Iterable,
Mapping,
Optional,
Sequence,
Union,
)
if TYPE_CHECKING:
try:
from langchain_core import runnables
from langchain_core import tools as lc_tools
from langchain_core.language_models import base as lc_language_models
BaseTool = lc_tools.BaseTool
BaseLanguageModel = lc_language_models.BaseLanguageModel
GetSessionHistoryCallable = runnables.history.GetSessionHistoryCallable
RunnableConfig = runnables.RunnableConfig
RunnableSerializable = runnables.RunnableSerializable
except ImportError:
BaseTool = Any
BaseLanguageModel = Any
GetSessionHistoryCallable = Any
RunnableConfig = Any
RunnableSerializable = Any
try:
from langchain_google_vertexai.functions_utils import _ToolsType
_ToolLike = _ToolsType
except ImportError:
_ToolLike = Any
try:
from opentelemetry.sdk import trace
TracerProvider = trace.TracerProvider
SpanProcessor = trace.SpanProcessor
SynchronousMultiSpanProcessor = trace.SynchronousMultiSpanProcessor
except ImportError:
TracerProvider = Any
SpanProcessor = Any
SynchronousMultiSpanProcessor = Any
def _default_runnable_kwargs(has_history: bool) -> Mapping[str, Any]:
# https://github.com/langchain-ai/langchain/blob/5784dfed001730530637793bea1795d9d5a7c244/libs/core/langchain_core/runnables/history.py#L237-L241
runnable_kwargs = {
# input_messages_key (str): Must be specified if the underlying
# agent accepts a dict as input.
"input_messages_key": "input",
# output_messages_key (str): Must be specified if the underlying
# agent returns a dict as output.
"output_messages_key": "output",
}
if has_history:
# history_messages_key (str): Must be specified if the underlying
# agent accepts a dict as input and a separate key for historical
# messages.
runnable_kwargs["history_messages_key"] = "history"
return runnable_kwargs
def _default_output_parser():
try:
from langchain.agents.output_parsers.tools import ToolsAgentOutputParser
except (ModuleNotFoundError, ImportError):
# Fallback to an older version if needed.
from langchain.agents.output_parsers.openai_tools import (
OpenAIToolsAgentOutputParser as ToolsAgentOutputParser,
)
return ToolsAgentOutputParser()
def _default_model_builder(
model_name: str,
*,
project: str,
location: str,
model_kwargs: Optional[Mapping[str, Any]] = None,
) -> "BaseLanguageModel":
import vertexai
from google.cloud.aiplatform import initializer
from langchain_google_vertexai import ChatVertexAI
model_kwargs = model_kwargs or {}
current_project = initializer.global_config.project
current_location = initializer.global_config.location
vertexai.init(project=project, location=location)
model = ChatVertexAI(model_name=model_name, **model_kwargs)
vertexai.init(project=current_project, location=current_location)
return model
def _default_runnable_builder(
model: "BaseLanguageModel",
*,
system_instruction: Optional[str] = None,
tools: Optional[Sequence["_ToolLike"]] = None,
prompt: Optional["RunnableSerializable"] = None,
output_parser: Optional["RunnableSerializable"] = None,
chat_history: Optional["GetSessionHistoryCallable"] = None,
model_tool_kwargs: Optional[Mapping[str, Any]] = None,
agent_executor_kwargs: Optional[Mapping[str, Any]] = None,
runnable_kwargs: Optional[Mapping[str, Any]] = None,
) -> "RunnableSerializable":
from langchain_core import tools as lc_tools
from langchain.agents import AgentExecutor
from langchain.tools.base import StructuredTool
# The prompt template and runnable_kwargs needs to be customized depending
# on whether the user intends for the agent to have history. The way the
# user would reflect that is by setting chat_history (which defaults to
# None).
has_history: bool = chat_history is not None
prompt = prompt or _default_prompt(
has_history=has_history,
system_instruction=system_instruction,
)
output_parser = output_parser or _default_output_parser()
model_tool_kwargs = model_tool_kwargs or {}
agent_executor_kwargs = agent_executor_kwargs or {}
runnable_kwargs = runnable_kwargs or _default_runnable_kwargs(has_history)
if tools:
model = model.bind_tools(tools=tools, **model_tool_kwargs)
else:
tools = []
agent_executor = AgentExecutor(
agent=prompt | model | output_parser,
tools=[
tool
if isinstance(tool, lc_tools.BaseTool)
else StructuredTool.from_function(tool)
for tool in tools
if isinstance(tool, (Callable, lc_tools.BaseTool))
],
**agent_executor_kwargs,
)
if has_history:
from langchain_core.runnables.history import RunnableWithMessageHistory
return RunnableWithMessageHistory(
runnable=agent_executor,
get_session_history=chat_history,
**runnable_kwargs,
)
return agent_executor
def _default_instrumentor_builder(project_id: str):
from vertexai.agent_engines import _utils
cloud_trace_exporter = _utils._import_cloud_trace_exporter_or_warn()
cloud_trace_v2 = _utils._import_cloud_trace_v2_or_warn()
openinference_langchain = _utils._import_openinference_langchain_or_warn()
opentelemetry = _utils._import_opentelemetry_or_warn()
opentelemetry_sdk_trace = _utils._import_opentelemetry_sdk_trace_or_warn()
if all(
(
cloud_trace_exporter,
cloud_trace_v2,
openinference_langchain,
opentelemetry,
opentelemetry_sdk_trace,
)
):
import google.auth
credentials, _ = google.auth.default()
span_exporter = cloud_trace_exporter.CloudTraceSpanExporter(
project_id=project_id,
client=cloud_trace_v2.TraceServiceClient(
credentials=credentials.with_quota_project(project_id),
),
)
span_processor: SpanProcessor = (
opentelemetry_sdk_trace.export.SimpleSpanProcessor(
span_exporter=span_exporter,
)
)
tracer_provider: TracerProvider = opentelemetry.trace.get_tracer_provider()
# Get the appropriate tracer provider:
# 1. If _TRACER_PROVIDER is already set, use that.
# 2. Otherwise, if the OTEL_PYTHON_TRACER_PROVIDER environment
# variable is set, use that.
# 3. As a final fallback, use _PROXY_TRACER_PROVIDER.
# If none of the above is set, we log a warning, and
# create a tracer provider.
if not tracer_provider:
from google.cloud.aiplatform import base
_LOGGER = base.Logger(__name__)
_LOGGER.warning(
"No tracer provider. By default, "
"we should get one of the following providers: "
"OTEL_PYTHON_TRACER_PROVIDER, _TRACER_PROVIDER, "
"or _PROXY_TRACER_PROVIDER."
)
tracer_provider = opentelemetry_sdk_trace.TracerProvider()
opentelemetry.trace.set_tracer_provider(tracer_provider)
# Avoids AttributeError:
# 'ProxyTracerProvider' and 'NoOpTracerProvider' objects has no
# attribute 'add_span_processor'.
if _utils.is_noop_or_proxy_tracer_provider(tracer_provider):
tracer_provider = opentelemetry_sdk_trace.TracerProvider()
opentelemetry.trace.set_tracer_provider(tracer_provider)
# Avoids OpenTelemetry client already exists error.
_override_active_span_processor(
tracer_provider,
opentelemetry_sdk_trace.SynchronousMultiSpanProcessor(),
)
tracer_provider.add_span_processor(span_processor)
# Keep the instrumentation up-to-date.
# When creating multiple LangchainAgents,
# we need to keep the instrumentation up-to-date.
# We deliberately override the instrument each time,
# so that if different agents end up using different
# instrumentations, we guarantee that the user is always
# working with the most recent agent's instrumentation.
instrumentor = openinference_langchain.LangChainInstrumentor()
if instrumentor.is_instrumented_by_opentelemetry:
instrumentor.uninstrument()
instrumentor.instrument()
return instrumentor
else:
from google.cloud.aiplatform import base
_LOGGER = base.Logger(__name__)
_LOGGER.warning(
"enable_tracing=True but proceeding with tracing disabled "
"because not all packages for tracing have been installed"
)
return None
def _default_prompt(
has_history: bool,
system_instruction: Optional[str] = None,
) -> "RunnableSerializable":
from langchain_core import prompts
try:
from langchain.agents.format_scratchpad.tools import format_to_tool_messages
except (ModuleNotFoundError, ImportError):
# Fallback to an older version if needed.
from langchain.agents.format_scratchpad.openai_tools import (
format_to_openai_tool_messages as format_to_tool_messages,
)
system_instructions = []
if system_instruction:
system_instructions = [("system", system_instruction)]
if has_history:
return {
"history": lambda x: x["history"],
"input": lambda x: x["input"],
"agent_scratchpad": (
lambda x: format_to_tool_messages(x["intermediate_steps"])
),
} | prompts.ChatPromptTemplate.from_messages(
system_instructions
+ [
prompts.MessagesPlaceholder(variable_name="history"),
("user", "{input}"),
prompts.MessagesPlaceholder(variable_name="agent_scratchpad"),
]
)
else:
return {
"input": lambda x: x["input"],
"agent_scratchpad": (
lambda x: format_to_tool_messages(x["intermediate_steps"])
),
} | prompts.ChatPromptTemplate.from_messages(
system_instructions
+ [
("user", "{input}"),
prompts.MessagesPlaceholder(variable_name="agent_scratchpad"),
]
)
def _validate_callable_parameters_are_annotated(callable: Callable):
"""Validates that the parameters of the callable have type annotations.
This ensures that they can be used for constructing LangChain tools that are
usable with Gemini function calling.
"""
import inspect
parameters = dict(inspect.signature(callable).parameters)
for name, parameter in parameters.items():
if parameter.annotation == inspect.Parameter.empty:
raise TypeError(
f"Callable={callable.__name__} has untyped input_arg={name}. "
f"Please specify a type when defining it, e.g. `{name}: str`."
)
def _validate_tools(tools: Sequence["_ToolLike"]):
"""Validates that the tools are usable for tool calling."""
for tool in tools:
if isinstance(tool, Callable):
_validate_callable_parameters_are_annotated(tool)
def _override_active_span_processor(
tracer_provider: "TracerProvider",
active_span_processor: "SynchronousMultiSpanProcessor",
):
"""Overrides the active span processor.
When working with multiple LangchainAgents in the same environment,
it's crucial to manage trace exports carefully.
Each agent needs its own span processor tied to a unique project ID.
While we add a new span processor for each agent, this can lead to
unexpected behavior.
For instance, with two agents linked to different projects, traces from the
second agent might be sent to both projects.
To prevent this and guarantee traces go to the correct project, we overwrite
the active span processor whenever a new LangchainAgent is created.
Args:
tracer_provider (TracerProvider):
The tracer provider to use for the project.
active_span_processor (SynchronousMultiSpanProcessor):
The active span processor overrides the tracer provider's
active span processor.
"""
if tracer_provider._active_span_processor:
tracer_provider._active_span_processor.shutdown()
tracer_provider._active_span_processor = active_span_processor
class LangchainAgent:
"""A Langchain Agent.
See https://cloud.google.com/vertex-ai/generative-ai/docs/reasoning-engine/develop
for details.
"""
def __init__(
self,
model: str,
*,
system_instruction: Optional[str] = None,
prompt: Optional["RunnableSerializable"] = None,
tools: Optional[Sequence["_ToolLike"]] = None,
output_parser: Optional["RunnableSerializable"] = None,
chat_history: Optional["GetSessionHistoryCallable"] = None,
model_kwargs: Optional[Mapping[str, Any]] = None,
model_tool_kwargs: Optional[Mapping[str, Any]] = None,
agent_executor_kwargs: Optional[Mapping[str, Any]] = None,
runnable_kwargs: Optional[Mapping[str, Any]] = None,
model_builder: Optional[Callable] = None,
runnable_builder: Optional[Callable] = None,
enable_tracing: bool = False,
instrumentor_builder: Optional[Callable[..., Any]] = None,
):
"""Initializes the LangchainAgent.
Under-the-hood, assuming .set_up() is called, this will correspond to
```
model = model_builder(model_name=model, model_kwargs=model_kwargs)
runnable = runnable_builder(
prompt=prompt,
model=model,
tools=tools,
output_parser=output_parser,
chat_history=chat_history,
agent_executor_kwargs=agent_executor_kwargs,
runnable_kwargs=runnable_kwargs,
)
```
When everything is based on their default values, this corresponds to
```
# model_builder
from langchain_google_vertexai import ChatVertexAI
llm = ChatVertexAI(model_name=model, **model_kwargs)
# runnable_builder
from langchain import agents
from langchain_core.runnables.history import RunnableWithMessageHistory
llm_with_tools = llm.bind_tools(tools=tools, **model_tool_kwargs)
agent_executor = agents.AgentExecutor(
agent=prompt | llm_with_tools | output_parser,
tools=tools,
**agent_executor_kwargs,
)
runnable = RunnableWithMessageHistory(
runnable=agent_executor,
get_session_history=chat_history,
**runnable_kwargs,
)
```
Args:
model (str):
Optional. The name of the model (e.g. "gemini-1.0-pro").
system_instruction (str):
Optional. The system instruction to use for the agent. This
argument should not be specified if `prompt` is specified.
prompt (langchain_core.runnables.RunnableSerializable):
Optional. The prompt template for the model. Defaults to a
ChatPromptTemplate.
tools (Sequence[langchain_core.tools.BaseTool, Callable]):
Optional. The tools for the agent to be able to use. All input
callables (e.g. function or class method) will be converted
to a langchain.tools.base.StructuredTool. Defaults to None.
output_parser (langchain_core.runnables.RunnableSerializable):
Optional. The output parser for the model. Defaults to an
output parser that works with Gemini function-calling.
chat_history (langchain_core.runnables.history.GetSessionHistoryCallable):
Optional. Callable that returns a new BaseChatMessageHistory.
Defaults to None, i.e. chat_history is not preserved.
model_kwargs (Mapping[str, Any]):
Optional. Additional keyword arguments for the constructor of
chat_models.ChatVertexAI. An example would be
```
{
# temperature (float): Sampling temperature, it controls the
# degree of randomness in token selection.
"temperature": 0.28,
# max_output_tokens (int): Token limit determines the
# maximum amount of text output from one prompt.
"max_output_tokens": 1000,
# top_p (float): Tokens are selected from most probable to
# least, until the sum of their probabilities equals the
# top_p value.
"top_p": 0.95,
# top_k (int): How the model selects tokens for output, the
# next token is selected from among the top_k most probable
# tokens.
"top_k": 40,
}
```
model_tool_kwargs (Mapping[str, Any]):
Optional. Additional keyword arguments when binding tools to the
model using `model.bind_tools()`.
agent_executor_kwargs (Mapping[str, Any]):
Optional. Additional keyword arguments for the constructor of
langchain.agents.AgentExecutor. An example would be
```
{
# Whether to return the agent's trajectory of intermediate
# steps at the end in addition to the final output.
"return_intermediate_steps": False,
# The maximum number of steps to take before ending the
# execution loop.
"max_iterations": 15,
# The method to use for early stopping if the agent never
# returns `AgentFinish`. Either 'force' or 'generate'.
"early_stopping_method": "force",
# How to handle errors raised by the agent's output parser.
# Defaults to `False`, which raises the error.
"handle_parsing_errors": False,
}
```
runnable_kwargs (Mapping[str, Any]):
Optional. Additional keyword arguments for the constructor of
langchain.runnables.history.RunnableWithMessageHistory if
chat_history is specified. If chat_history is None, this will be
ignored.
model_builder (Callable):
Optional. Callable that returns a new language model. Defaults
to a a callable that returns ChatVertexAI based on `model`,
`model_kwargs` and the parameters in `vertexai.init`.
runnable_builder (Callable):
Optional. Callable that returns a new runnable. This can be used
for customizing the orchestration logic of the Agent based on
the model returned by `model_builder` and the rest of the input
arguments.
enable_tracing (bool):
Optional. Whether to enable tracing in Cloud Trace. Defaults to
False.
instrumentor_builder (Callable[..., Any]):
Optional. Callable that returns a new instrumentor. This can be
used for customizing the instrumentation logic of the Agent.
If not provided, a default instrumentor builder will be used.
This parameter is ignored if `enable_tracing` is False.
Raises:
ValueError: If both `prompt` and `system_instruction` are specified.
TypeError: If there is an invalid tool (e.g. function with an input
that did not specify its type).
"""
from google.cloud.aiplatform import initializer
self._tmpl_attrs: dict[str, Any] = {
"project": initializer.global_config.project,
"location": initializer.global_config.location,
"tools": [],
"model_name": model,
"system_instruction": system_instruction,
"prompt": prompt,
"output_parser": output_parser,
"chat_history": chat_history,
"model_kwargs": model_kwargs,
"model_tool_kwargs": model_tool_kwargs,
"agent_executor_kwargs": agent_executor_kwargs,
"runnable_kwargs": runnable_kwargs,
"model_builder": model_builder,
"runnable_builder": runnable_builder,
"enable_tracing": enable_tracing,
"model": None,
"runnable": None,
"instrumentor": None,
"instrumentor_builder": instrumentor_builder,
}
if tools:
# We validate tools at initialization for actionable feedback before
# they are deployed.
_validate_tools(tools)
self._tmpl_attrs["tools"] = tools
if prompt and system_instruction:
raise ValueError(
"Only one of `prompt` or `system_instruction` should be specified. "
"Consider incorporating the system instruction into the prompt "
"rather than passing it separately as an argument."
)
def set_up(self):
"""Sets up the agent for execution of queries at runtime.
It initializes the model, binds the model with tools, and connects it
with the prompt template and output parser.
This method should not be called for an object being passed to the
service for deployment, as it might initialize clients that can not be
serialized.
"""
if self._tmpl_attrs.get("enable_tracing"):
instrumentor_builder = (
self._tmpl_attrs.get("instrumentor_builder")
or _default_instrumentor_builder
)
self._tmpl_attrs["instrumentor"] = instrumentor_builder(
project_id=self._tmpl_attrs.get("project")
)
model_builder = self._tmpl_attrs.get("model_builder") or _default_model_builder
self._tmpl_attrs["model"] = model_builder(
model_name=self._tmpl_attrs.get("model_name"),
model_kwargs=self._tmpl_attrs.get("model_kwargs"),
project=self._tmpl_attrs.get("project"),
location=self._tmpl_attrs.get("location"),
)
runnable_builder = (
self._tmpl_attrs.get("runnable_builder") or _default_runnable_builder
)
self._tmpl_attrs["runnable"] = runnable_builder(
prompt=self._tmpl_attrs.get("prompt"),
model=self._tmpl_attrs.get("model"),
tools=self._tmpl_attrs.get("tools"),
system_instruction=self._tmpl_attrs.get("system_instruction"),
output_parser=self._tmpl_attrs.get("output_parser"),
chat_history=self._tmpl_attrs.get("chat_history"),
model_tool_kwargs=self._tmpl_attrs.get("model_tool_kwargs"),
agent_executor_kwargs=self._tmpl_attrs.get("agent_executor_kwargs"),
runnable_kwargs=self._tmpl_attrs.get("runnable_kwargs"),
)
def clone(self) -> "LangchainAgent":
"""Returns a clone of the LangchainAgent."""
import copy
return LangchainAgent(
model=self._tmpl_attrs.get("model_name"),
system_instruction=self._tmpl_attrs.get("system_instruction"),
prompt=copy.deepcopy(self._tmpl_attrs.get("prompt")),
tools=copy.deepcopy(self._tmpl_attrs.get("tools")),
output_parser=copy.deepcopy(self._tmpl_attrs.get("output_parser")),
chat_history=copy.deepcopy(self._tmpl_attrs.get("chat_history")),
model_kwargs=copy.deepcopy(self._tmpl_attrs.get("model_kwargs")),
model_tool_kwargs=copy.deepcopy(self._tmpl_attrs.get("model_tool_kwargs")),
agent_executor_kwargs=copy.deepcopy(
self._tmpl_attrs.get("agent_executor_kwargs")
),
runnable_kwargs=copy.deepcopy(self._tmpl_attrs.get("runnable_kwargs")),
model_builder=self._tmpl_attrs.get("model_builder"),
runnable_builder=self._tmpl_attrs.get("runnable_builder"),
enable_tracing=self._tmpl_attrs.get("enable_tracing"),
instrumentor_builder=self._tmpl_attrs.get("instrumentor_builder"),
)
def query(
self,
*,
input: Union[str, Mapping[str, Any]],
config: Optional["RunnableConfig"] = None,
**kwargs: Any,
) -> Dict[str, Any]:
"""Queries the Agent with the given input and config.
Args:
input (Union[str, Mapping[str, Any]]):
Required. The input to be passed to the Agent.
config (langchain_core.runnables.RunnableConfig):
Optional. The config (if any) to be used for invoking the Agent.
**kwargs:
Optional. Any additional keyword arguments to be passed to the
`.invoke()` method of the corresponding AgentExecutor.
Returns:
The output of querying the Agent with the given input and config.
"""
from langchain.load import dump as langchain_load_dump
if isinstance(input, str):
input = {"input": input}
if not self._tmpl_attrs.get("runnable"):
self.set_up()
return langchain_load_dump.dumpd(
self._tmpl_attrs.get("runnable").invoke(
input=input, config=config, **kwargs
)
)
def stream_query(
self,
*,
input: Union[str, Mapping[str, Any]],
config: Optional["RunnableConfig"] = None,
**kwargs,
) -> Iterable[Any]:
"""Stream queries the Agent with the given input and config.
Args:
input (Union[str, Mapping[str, Any]]):
Required. The input to be passed to the Agent.
config (langchain_core.runnables.RunnableConfig):
Optional. The config (if any) to be used for invoking the Agent.
**kwargs:
Optional. Any additional keyword arguments to be passed to the
`.invoke()` method of the corresponding AgentExecutor.
Yields:
The output of querying the Agent with the given input and config.
"""
from langchain.load import dump as langchain_load_dump
if isinstance(input, str):
input = {"input": input}
if not self._tmpl_attrs.get("runnable"):
self.set_up()
for chunk in self._tmpl_attrs.get("runnable").stream(
input=input,
config=config,
**kwargs,
):
yield langchain_load_dump.dumpd(chunk)

View File

@@ -0,0 +1,692 @@
# -*- coding: utf-8 -*-
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
Iterable,
Mapping,
Optional,
Sequence,
Union,
)
if TYPE_CHECKING:
try:
from langchain_core import runnables
from langchain_core import tools as lc_tools
from langchain_core.language_models import base as lc_language_models
BaseTool = lc_tools.BaseTool
BaseLanguageModel = lc_language_models.BaseLanguageModel
RunnableConfig = runnables.RunnableConfig
RunnableSerializable = runnables.RunnableSerializable
except ImportError:
BaseTool = Any
BaseLanguageModel = Any
RunnableConfig = Any
RunnableSerializable = Any
try:
from langchain_google_vertexai.functions_utils import _ToolsType
_ToolLike = _ToolsType
except ImportError:
_ToolLike = Any
try:
from opentelemetry.sdk import trace
TracerProvider = trace.TracerProvider
SpanProcessor = trace.SpanProcessor
SynchronousMultiSpanProcessor = trace.SynchronousMultiSpanProcessor
except ImportError:
TracerProvider = Any
SpanProcessor = Any
SynchronousMultiSpanProcessor = Any
try:
from langgraph_checkpoint.checkpoint import base
BaseCheckpointSaver = base.BaseCheckpointSaver
except ImportError:
try:
from langgraph.checkpoint import base
BaseCheckpointSaver = base.BaseCheckpointSaver
except ImportError:
BaseCheckpointSaver = Any
def _default_model_builder(
model_name: str,
*,
project: str,
location: str,
model_kwargs: Optional[Mapping[str, Any]] = None,
) -> "BaseLanguageModel":
"""Default callable for building a language model.
Args:
model_name (str):
Required. The name of the model (e.g. "gemini-1.0-pro").
project (str):
Required. The Google Cloud project ID.
location (str):
Required. The Google Cloud location.
model_kwargs (Mapping[str, Any]):
Optional. Additional keyword arguments for the constructor of
chat_models.ChatVertexAI.
Returns:
BaseLanguageModel: The language model.
"""
import vertexai
from google.cloud.aiplatform import initializer
from langchain_google_vertexai import ChatVertexAI
model_kwargs = model_kwargs or {}
current_project = initializer.global_config.project
current_location = initializer.global_config.location
vertexai.init(project=project, location=location)
model = ChatVertexAI(model_name=model_name, **model_kwargs)
vertexai.init(project=current_project, location=current_location)
return model
def _default_runnable_builder(
model: "BaseLanguageModel",
*,
tools: Optional[Sequence["_ToolLike"]] = None,
checkpointer: Optional[Any] = None,
model_tool_kwargs: Optional[Mapping[str, Any]] = None,
runnable_kwargs: Optional[Mapping[str, Any]] = None,
) -> "RunnableSerializable":
"""Default callable for building a runnable.
Args:
model (BaseLanguageModel):
Required. The language model.
tools (Optional[Sequence[_ToolLike]]):
Optional. The tools for the agent to be able to use.
checkpointer (Optional[Checkpointer]):
Optional. The checkpointer for the agent.
model_tool_kwargs (Optional[Mapping[str, Any]]):
Optional. Additional keyword arguments when binding tools to the model.
runnable_kwargs (Optional[Mapping[str, Any]]):
Optional. Additional keyword arguments for the runnable.
Returns:
RunnableSerializable: The runnable.
"""
from langgraph import prebuilt as langgraph_prebuilt
model_tool_kwargs = model_tool_kwargs or {}
runnable_kwargs = runnable_kwargs or {}
if tools:
model = model.bind_tools(tools=tools, **model_tool_kwargs)
else:
tools = []
if checkpointer:
if "checkpointer" in runnable_kwargs:
from google.cloud.aiplatform import base
base.Logger(__name__).warning(
"checkpointer is being specified in both checkpointer_builder "
"and runnable_kwargs. Please specify it in only one of them. "
"Overriding the checkpointer in runnable_kwargs."
)
runnable_kwargs["checkpointer"] = checkpointer
return langgraph_prebuilt.create_react_agent(
model,
tools=tools,
**runnable_kwargs,
)
def _default_instrumentor_builder(project_id: str):
from vertexai.agent_engines import _utils
cloud_trace_exporter = _utils._import_cloud_trace_exporter_or_warn()
cloud_trace_v2 = _utils._import_cloud_trace_v2_or_warn()
openinference_langchain = _utils._import_openinference_langchain_or_warn()
opentelemetry = _utils._import_opentelemetry_or_warn()
opentelemetry_sdk_trace = _utils._import_opentelemetry_sdk_trace_or_warn()
if all(
(
cloud_trace_exporter,
cloud_trace_v2,
openinference_langchain,
opentelemetry,
opentelemetry_sdk_trace,
)
):
import google.auth
credentials, _ = google.auth.default()
span_exporter = cloud_trace_exporter.CloudTraceSpanExporter(
project_id=project_id,
client=cloud_trace_v2.TraceServiceClient(
credentials=credentials.with_quota_project(project_id),
),
)
span_processor: SpanProcessor = (
opentelemetry_sdk_trace.export.SimpleSpanProcessor(
span_exporter=span_exporter,
)
)
tracer_provider: TracerProvider = opentelemetry.trace.get_tracer_provider()
# Get the appropriate tracer provider:
# 1. If _TRACER_PROVIDER is already set, use that.
# 2. Otherwise, if the OTEL_PYTHON_TRACER_PROVIDER environment
# variable is set, use that.
# 3. As a final fallback, use _PROXY_TRACER_PROVIDER.
# If none of the above is set, we log a warning, and
# create a tracer provider.
if not tracer_provider:
from google.cloud.aiplatform import base
base.Logger(__name__).warning(
"No tracer provider. By default, "
"we should get one of the following providers: "
"OTEL_PYTHON_TRACER_PROVIDER, _TRACER_PROVIDER, "
"or _PROXY_TRACER_PROVIDER."
)
tracer_provider = opentelemetry_sdk_trace.TracerProvider()
opentelemetry.trace.set_tracer_provider(tracer_provider)
# Avoids AttributeError:
# 'ProxyTracerProvider' and 'NoOpTracerProvider' objects has no
# attribute 'add_span_processor'.
if _utils.is_noop_or_proxy_tracer_provider(tracer_provider):
tracer_provider = opentelemetry_sdk_trace.TracerProvider()
opentelemetry.trace.set_tracer_provider(tracer_provider)
# Avoids OpenTelemetry client already exists error.
_override_active_span_processor(
tracer_provider,
opentelemetry_sdk_trace.SynchronousMultiSpanProcessor(),
)
tracer_provider.add_span_processor(span_processor)
# Keep the instrumentation up-to-date.
# When creating multiple LangchainAgents,
# we need to keep the instrumentation up-to-date.
# We deliberately override the instrument each time,
# so that if different agents end up using different
# instrumentations, we guarantee that the user is always
# working with the most recent agent's instrumentation.
instrumentor = openinference_langchain.LangChainInstrumentor()
if instrumentor.is_instrumented_by_opentelemetry:
instrumentor.uninstrument()
instrumentor.instrument()
return instrumentor
else:
from google.cloud.aiplatform import base
_LOGGER = base.Logger(__name__)
_LOGGER.warning(
"enable_tracing=True but proceeding with tracing disabled "
"because not all packages for tracing have been installed"
)
return None
def _validate_callable_parameters_are_annotated(callable: Callable):
"""Validates that the parameters of the callable have type annotations.
This ensures that they can be used for constructing LangChain tools that are
usable with Gemini function calling.
Args:
callable (Callable): The callable to validate.
Raises:
TypeError: If any parameter is not annotated.
"""
import inspect
parameters = dict(inspect.signature(callable).parameters)
for name, parameter in parameters.items():
if parameter.annotation == inspect.Parameter.empty:
raise TypeError(
f"Callable={callable.__name__} has untyped input_arg={name}. "
f"Please specify a type when defining it, e.g. `{name}: str`."
)
def _validate_tools(tools: Sequence["_ToolLike"]):
"""Validates that the tools are usable for tool calling.
Args:
tools (Sequence[_ToolLike]): The tools to validate.
Raises:
TypeError: If any tool is a callable with untyped parameters.
"""
for tool in tools:
if isinstance(tool, Callable):
_validate_callable_parameters_are_annotated(tool)
def _override_active_span_processor(
tracer_provider: "TracerProvider",
active_span_processor: "SynchronousMultiSpanProcessor",
):
"""Overrides the active span processor.
When working with multiple LangchainAgents in the same environment,
it's crucial to manage trace exports carefully.
Each agent needs its own span processor tied to a unique project ID.
While we add a new span processor for each agent, this can lead to
unexpected behavior.
For instance, with two agents linked to different projects, traces from the
second agent might be sent to both projects.
To prevent this and guarantee traces go to the correct project, we overwrite
the active span processor whenever a new LangchainAgent is created.
Args:
tracer_provider (TracerProvider):
The tracer provider to use for the project.
active_span_processor (SynchronousMultiSpanProcessor):
The active span processor overrides the tracer provider's
active span processor.
"""
if tracer_provider._active_span_processor:
tracer_provider._active_span_processor.shutdown()
tracer_provider._active_span_processor = active_span_processor
class LanggraphAgent:
"""A LangGraph Agent."""
def __init__(
self,
model: str,
*,
tools: Optional[Sequence["_ToolLike"]] = None,
model_kwargs: Optional[Mapping[str, Any]] = None,
model_tool_kwargs: Optional[Mapping[str, Any]] = None,
model_builder: Optional[Callable[..., "BaseLanguageModel"]] = None,
runnable_kwargs: Optional[Mapping[str, Any]] = None,
runnable_builder: Optional[Callable[..., "RunnableSerializable"]] = None,
checkpointer_kwargs: Optional[Mapping[str, Any]] = None,
checkpointer_builder: Optional[Callable[..., "BaseCheckpointSaver"]] = None,
enable_tracing: bool = False,
instrumentor_builder: Optional[Callable[..., Any]] = None,
):
"""Initializes the LangGraph Agent.
Under-the-hood, assuming .set_up() is called, this will correspond to
```python
model = model_builder(model_name=model, model_kwargs=model_kwargs)
runnable = runnable_builder(
model=model,
tools=tools,
model_tool_kwargs=model_tool_kwargs,
runnable_kwargs=runnable_kwargs,
)
```
When everything is based on their default values, this corresponds to
```python
# model_builder
from langchain_google_vertexai import ChatVertexAI
llm = ChatVertexAI(model_name=model, **model_kwargs)
# runnable_builder
from langgraph.prebuilt import create_react_agent
llm_with_tools = llm.bind_tools(tools=tools, **model_tool_kwargs)
runnable = create_react_agent(
llm_with_tools,
tools=tools,
**runnable_kwargs,
)
```
By default, no checkpointer is used (i.e. there is no state history). To
enable checkpointing, provide a `checkpointer_builder` function that
returns a checkpointer instance.
**Example using Spanner:**
```python
def checkpointer_builder(instance_id, database_id, project_id, **kwargs):
from langchain_google_spanner import SpannerCheckpointSaver
checkpointer = SpannerCheckpointSaver(instance_id, database_id, project_id)
with checkpointer.cursor() as cur:
cur.execute("DROP TABLE IF EXISTS checkpoints")
cur.execute("DROP TABLE IF EXISTS checkpoint_writes")
checkpointer.setup()
return checkpointer
```
**Example using an in-memory checkpointer:**
```python
def checkpointer_builder(**kwargs):
from langgraph.checkpoint.memory import MemorySaver
return MemorySaver()
```
The `checkpointer_builder` function will be called with any keyword
arguments passed to the agent's constructor. Ensure your
`checkpointer_builder` function accepts `**kwargs` to handle these
arguments, even if unused.
Args:
model (str):
Optional. The name of the model (e.g. "gemini-1.0-pro").
tools (Sequence[langchain_core.tools.BaseTool, Callable]):
Optional. The tools for the agent to be able to use. All input
callables (e.g. function or class method) will be converted
to a langchain.tools.base.StructuredTool. Defaults to None.
model_kwargs (Mapping[str, Any]):
Optional. Additional keyword arguments for the constructor of
chat_models.ChatVertexAI. An example would be
```
{
# temperature (float): Sampling temperature, it controls the
# degree of randomness in token selection.
"temperature": 0.28,
# max_output_tokens (int): Token limit determines the
# maximum amount of text output from one prompt.
"max_output_tokens": 1000,
# top_p (float): Tokens are selected from most probable to
# least, until the sum of their probabilities equals the
# top_p value.
"top_p": 0.95,
# top_k (int): How the model selects tokens for output, the
# next token is selected from among the top_k most probable
# tokens.
"top_k": 40,
}
```
model_tool_kwargs (Mapping[str, Any]):
Optional. Additional keyword arguments when binding tools to the
model using `model.bind_tools()`.
model_builder (Callable[..., "BaseLanguageModel"]):
Optional. Callable that returns a new language model. Defaults
to a a callable that returns ChatVertexAI based on `model`,
`model_kwargs` and the parameters in `vertexai.init`.
runnable_kwargs (Mapping[str, Any]):
Optional. Additional keyword arguments for the constructor of
langchain.runnables.history.RunnableWithMessageHistory if
chat_history is specified. If chat_history is None, this will be
ignored.
runnable_builder (Callable[..., "RunnableSerializable"]):
Optional. Callable that returns a new runnable. This can be used
for customizing the orchestration logic of the Agent based on
the model returned by `model_builder` and the rest of the input
arguments.
checkpointer_kwargs (Mapping[str, Any]):
Optional. Additional keyword arguments for the constructor of
the checkpointer returned by `checkpointer_builder`.
checkpointer_builder (Callable[..., "BaseCheckpointSaver"]):
Optional. Callable that returns a checkpointer. This can be used
for defining the checkpointer of the Agent. Defaults to None.
enable_tracing (bool):
Optional. Whether to enable tracing in Cloud Trace. Defaults to
False.
instrumentor_builder (Callable[..., Any]):
Optional. Callable that returns a new instrumentor. This can be
used for customizing the instrumentation logic of the Agent.
If not provided, a default instrumentor builder will be used.
This parameter is ignored if `enable_tracing` is False.
Raises:
TypeError: If there is an invalid tool (e.g. function with an input
that did not specify its type).
"""
from google.cloud.aiplatform import initializer
self._tmpl_attrs: dict[str, Any] = {
"project": initializer.global_config.project,
"location": initializer.global_config.location,
"tools": [],
"model_name": model,
"model_kwargs": model_kwargs,
"model_tool_kwargs": model_tool_kwargs,
"runnable_kwargs": runnable_kwargs,
"checkpointer_kwargs": checkpointer_kwargs,
"model": None,
"model_builder": model_builder,
"runnable": None,
"runnable_builder": runnable_builder,
"checkpointer": None,
"checkpointer_builder": checkpointer_builder,
"enable_tracing": enable_tracing,
"instrumentor": None,
"instrumentor_builder": instrumentor_builder,
}
if tools:
# We validate tools at initialization for actionable feedback before
# they are deployed.
_validate_tools(tools)
self._tmpl_attrs["tools"] = tools
def set_up(self):
"""Sets up the agent for execution of queries at runtime.
It initializes the model, binds the model with tools, and connects it
with the prompt template and output parser.
This method should not be called for an object that being passed to
the ReasoningEngine service for deployment, as it initializes clients
that can not be serialized.
"""
if self._tmpl_attrs.get("enable_tracing"):
instrumentor_builder = (
self._tmpl_attrs.get("instrumentor_builder")
or _default_instrumentor_builder
)
self._tmpl_attrs["instrumentor"] = instrumentor_builder(
project_id=self._tmpl_attrs.get("project")
)
model_builder = self._tmpl_attrs.get("model_builder") or _default_model_builder
self._tmpl_attrs["model"] = model_builder(
model_name=self._tmpl_attrs.get("model_name"),
model_kwargs=self._tmpl_attrs.get("model_kwargs"),
project=self._tmpl_attrs.get("project"),
location=self._tmpl_attrs.get("location"),
)
checkpointer_builder = self._tmpl_attrs.get("checkpointer_builder")
if checkpointer_builder:
checkpointer_kwargs = self._tmpl_attrs.get("checkpointer_kwargs") or {}
self._tmpl_attrs["checkpointer"] = checkpointer_builder(
**checkpointer_kwargs
)
runnable_builder = (
self._tmpl_attrs.get("runnable_builder") or _default_runnable_builder
)
self._tmpl_attrs["runnable"] = runnable_builder(
model=self._tmpl_attrs.get("model"),
tools=self._tmpl_attrs.get("tools"),
checkpointer=self._tmpl_attrs.get("checkpointer"),
model_tool_kwargs=self._tmpl_attrs.get("model_tool_kwargs"),
runnable_kwargs=self._tmpl_attrs.get("runnable_kwargs"),
)
def clone(self) -> "LanggraphAgent":
"""Returns a clone of the LanggraphAgent."""
import copy
return LanggraphAgent(
model=self._tmpl_attrs.get("model_name"),
tools=copy.deepcopy(self._tmpl_attrs.get("tools")),
model_kwargs=copy.deepcopy(self._tmpl_attrs.get("model_kwargs")),
model_tool_kwargs=copy.deepcopy(self._tmpl_attrs.get("model_tool_kwargs")),
runnable_kwargs=copy.deepcopy(self._tmpl_attrs.get("runnable_kwargs")),
checkpointer_kwargs=copy.deepcopy(
self._tmpl_attrs.get("checkpointer_kwargs")
),
model_builder=self._tmpl_attrs.get("model_builder"),
runnable_builder=self._tmpl_attrs.get("runnable_builder"),
checkpointer_builder=self._tmpl_attrs.get("checkpointer_builder"),
enable_tracing=self._tmpl_attrs.get("enable_tracing"),
instrumentor_builder=self._tmpl_attrs.get("instrumentor_builder"),
)
def query(
self,
*,
input: Union[str, Mapping[str, Any]],
config: Optional["RunnableConfig"] = None,
**kwargs: Any,
) -> Dict[str, Any]:
"""Queries the Agent with the given input and config.
Args:
input (Union[str, Mapping[str, Any]]):
Required. The input to be passed to the Agent.
config (langchain_core.runnables.RunnableConfig):
Optional. The config (if any) to be used for invoking the Agent.
**kwargs:
Optional. Any additional keyword arguments to be passed to the
`.invoke()` method of the corresponding AgentExecutor.
Returns:
The output of querying the Agent with the given input and config.
"""
from langchain.load import dump as langchain_load_dump
if isinstance(input, str):
input = {"input": input}
if not self._tmpl_attrs.get("runnable"):
self.set_up()
return langchain_load_dump.dumpd(
self._tmpl_attrs.get("runnable").invoke(
input=input, config=config, **kwargs
)
)
def stream_query(
self,
*,
input: Union[str, Mapping[str, Any]],
config: Optional["RunnableConfig"] = None,
**kwargs,
) -> Iterable[Any]:
"""Stream queries the Agent with the given input and config.
Args:
input (Union[str, Mapping[str, Any]]):
Required. The input to be passed to the Agent.
config (langchain_core.runnables.RunnableConfig):
Optional. The config (if any) to be used for invoking the Agent.
**kwargs:
Optional. Any additional keyword arguments to be passed to the
`.invoke()` method of the corresponding AgentExecutor.
Yields:
The output of querying the Agent with the given input and config.
"""
from langchain.load import dump as langchain_load_dump
if isinstance(input, str):
input = {"input": input}
if not self._tmpl_attrs.get("runnable"):
self.set_up()
for chunk in self._tmpl_attrs.get("runnable").stream(
input=input,
config=config,
**kwargs,
):
yield langchain_load_dump.dumpd(chunk)
def get_state_history(
self,
config: Optional["RunnableConfig"] = None,
**kwargs: Any,
) -> Iterable[Any]:
"""Gets the state history of the Agent.
Args:
config (Optional[RunnableConfig]):
Optional. The config for invoking the Agent.
**kwargs:
Optional. Additional keyword arguments for the `.invoke()` method.
Yields:
Dict[str, Any]: The state history of the Agent.
"""
if not self._tmpl_attrs.get("runnable"):
self.set_up()
for state_snapshot in self._tmpl_attrs.get("runnable").get_state_history(
config=config,
**kwargs,
):
yield state_snapshot._asdict()
def get_state(
self,
config: Optional["RunnableConfig"] = None,
**kwargs: Any,
) -> Dict[str, Any]:
"""Gets the current state of the Agent.
Args:
config (Optional[RunnableConfig]):
Optional. The config for invoking the Agent.
**kwargs:
Optional. Additional keyword arguments for the `.invoke()` method.
Returns:
Dict[str, Any]: The current state of the Agent.
"""
if not self._tmpl_attrs.get("runnable"):
self.set_up()
return (
self._tmpl_attrs.get("runnable")
.get_state(config=config, **kwargs)
._asdict()
)
def update_state(
self,
config: Optional["RunnableConfig"] = None,
**kwargs: Any,
) -> Dict[str, Any]:
"""Updates the state of the Agent.
Args:
config (Optional[RunnableConfig]):
Optional. The config for invoking the Agent.
**kwargs:
Optional. Additional keyword arguments for the `.invoke()` method.
Returns:
Dict[str, Any]: The updated state of the Agent.
"""
if not self._tmpl_attrs.get("runnable"):
self.set_up()
return self._tmpl_attrs.get("runnable").update_state(config=config, **kwargs)
def register_operations(self) -> Mapping[str, Sequence[str]]:
"""Registers the operations of the Agent.
This mapping defines how different operation modes (e.g., "", "stream")
are implemented by specific methods of the Agent. The "default" mode,
represented by the empty string ``, is associated with the `query` API,
while the "stream" mode is associated with the `stream_query` API.
Returns:
Mapping[str, Sequence[str]]: A mapping of operation modes to a list
of method names that implement those operation modes.
"""
return {
"": ["query", "get_state", "update_state"],
"stream": ["stream_query", "get_state_history"],
}