Files
evo-ai/.venv/lib/python3.10/site-packages/vertexai/agent_engines/templates/ag2.py
2025-04-25 15:30:54 -03:00

488 lines
19 KiB
Python

# -*- coding: utf-8 -*-
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
Mapping,
Optional,
Sequence,
Union,
)
if TYPE_CHECKING:
try:
from autogen import agentchat
ConversableAgent = agentchat.ConversableAgent
ChatResult = agentchat.ChatResult
except ImportError:
ConversableAgent = Any
try:
from opentelemetry.sdk import trace
TracerProvider = trace.TracerProvider
SpanProcessor = trace.SpanProcessor
SynchronousMultiSpanProcessor = trace.SynchronousMultiSpanProcessor
except ImportError:
TracerProvider = Any
SpanProcessor = Any
SynchronousMultiSpanProcessor = Any
def _prepare_runnable_kwargs(
runnable_kwargs: Mapping[str, Any],
system_instruction: str,
runnable_name: str,
llm_config: Mapping[str, Any],
) -> Mapping[str, Any]:
"""Prepares the configuration for a runnable, applying defaults and enforcing constraints."""
if runnable_kwargs is None:
runnable_kwargs = {}
if (
"human_input_mode" in runnable_kwargs
and runnable_kwargs["human_input_mode"] != "NEVER"
):
from google.cloud.aiplatform import base
_LOGGER = base.Logger(__name__)
_LOGGER.warning(
f"human_input_mode={runnable_kwargs['human_input_mode']}"
"is not supported. Will be enforced to 'NEVER'."
)
runnable_kwargs["human_input_mode"] = "NEVER"
if "system_message" not in runnable_kwargs and system_instruction:
runnable_kwargs["system_message"] = system_instruction
if "name" not in runnable_kwargs:
runnable_kwargs["name"] = runnable_name
if "llm_config" not in runnable_kwargs:
runnable_kwargs["llm_config"] = llm_config
return runnable_kwargs
def _default_runnable_builder(
**runnable_kwargs: Any,
) -> "ConversableAgent":
from autogen import agentchat
return agentchat.ConversableAgent(**runnable_kwargs)
def _default_instrumentor_builder(project_id: str):
from vertexai.agent_engines import _utils
cloud_trace_exporter = _utils._import_cloud_trace_exporter_or_warn()
cloud_trace_v2 = _utils._import_cloud_trace_v2_or_warn()
openinference_autogen = _utils._import_openinference_autogen_or_warn()
opentelemetry = _utils._import_opentelemetry_or_warn()
opentelemetry_sdk_trace = _utils._import_opentelemetry_sdk_trace_or_warn()
if all(
(
cloud_trace_exporter,
cloud_trace_v2,
openinference_autogen,
opentelemetry,
opentelemetry_sdk_trace,
)
):
import google.auth
credentials, _ = google.auth.default()
span_exporter = cloud_trace_exporter.CloudTraceSpanExporter(
project_id=project_id,
client=cloud_trace_v2.TraceServiceClient(
credentials=credentials.with_quota_project(project_id),
),
)
span_processor: SpanProcessor = (
opentelemetry_sdk_trace.export.SimpleSpanProcessor(
span_exporter=span_exporter,
)
)
tracer_provider: TracerProvider = opentelemetry.trace.get_tracer_provider()
# Get the appropriate tracer provider:
# 1. If _TRACER_PROVIDER is already set, use that.
# 2. Otherwise, if the OTEL_PYTHON_TRACER_PROVIDER environment
# variable is set, use that.
# 3. As a final fallback, use _PROXY_TRACER_PROVIDER.
# If none of the above is set, we log a warning, and
# create a tracer provider.
if not tracer_provider:
from google.cloud.aiplatform import base
_LOGGER = base.Logger(__name__)
_LOGGER.warning(
"No tracer provider. By default, "
"we should get one of the following providers: "
"OTEL_PYTHON_TRACER_PROVIDER, _TRACER_PROVIDER, "
"or _PROXY_TRACER_PROVIDER."
)
tracer_provider = opentelemetry_sdk_trace.TracerProvider()
opentelemetry.trace.set_tracer_provider(tracer_provider)
# Avoids AttributeError:
# 'ProxyTracerProvider' and 'NoOpTracerProvider' objects has no
# attribute 'add_span_processor'.
if _utils.is_noop_or_proxy_tracer_provider(tracer_provider):
tracer_provider = opentelemetry_sdk_trace.TracerProvider()
opentelemetry.trace.set_tracer_provider(tracer_provider)
# Avoids OpenTelemetry client already exists error.
_override_active_span_processor(
tracer_provider,
opentelemetry_sdk_trace.SynchronousMultiSpanProcessor(),
)
tracer_provider.add_span_processor(span_processor)
# Keep the instrumentation up-to-date.
# When creating multiple AG2Agents,
# we need to keep the instrumentation up-to-date.
# We deliberately override the instrument each time,
# so that if different agents end up using different
# instrumentations, we guarantee that the user is always
# working with the most recent agent's instrumentation.
instrumentor = openinference_autogen.AutogenInstrumentor()
instrumentor.uninstrument()
instrumentor.instrument()
return instrumentor
else:
from google.cloud.aiplatform import base
_LOGGER = base.Logger(__name__)
_LOGGER.warning(
"enable_tracing=True but proceeding with tracing disabled "
"because not all packages for tracing have been installed"
)
return None
def _validate_callable_parameters_are_annotated(callable: Callable):
"""Validates that the parameters of the callable have type annotations.
This ensures that they can be used for constructing AG2 tools that are
usable with Gemini function calling.
"""
import inspect
parameters = dict(inspect.signature(callable).parameters)
for name, parameter in parameters.items():
if parameter.annotation == inspect.Parameter.empty:
raise TypeError(
f"Callable={callable.__name__} has untyped input_arg={name}. "
f"Please specify a type when defining it, e.g. `{name}: str`."
)
def _validate_tools(tools: Sequence[Callable[..., Any]]):
"""Validates that the tools are usable for tool calling."""
for tool in tools:
if isinstance(tool, Callable):
_validate_callable_parameters_are_annotated(tool)
def _override_active_span_processor(
tracer_provider: "TracerProvider",
active_span_processor: "SynchronousMultiSpanProcessor",
):
"""Overrides the active span processor.
When working with multiple AG2Agents in the same environment,
it's crucial to manage trace exports carefully.
Each agent needs its own span processor tied to a unique project ID.
While we add a new span processor for each agent, this can lead to
unexpected behavior.
For instance, with two agents linked to different projects, traces from the
second agent might be sent to both projects.
To prevent this and guarantee traces go to the correct project, we overwrite
the active span processor whenever a new AG2Agent is created.
Args:
tracer_provider (TracerProvider):
The tracer provider to use for the project.
active_span_processor (SynchronousMultiSpanProcessor):
The active span processor overrides the tracer provider's
active span processor.
"""
if tracer_provider._active_span_processor:
tracer_provider._active_span_processor.shutdown()
tracer_provider._active_span_processor = active_span_processor
class AG2Agent:
"""An AG2 Agent."""
def __init__(
self,
model: str,
runnable_name: str,
*,
api_type: Optional[str] = None,
llm_config: Optional[Mapping[str, Any]] = None,
system_instruction: Optional[str] = None,
runnable_kwargs: Optional[Mapping[str, Any]] = None,
runnable_builder: Optional[Callable[..., "ConversableAgent"]] = None,
tools: Optional[Sequence[Callable[..., Any]]] = None,
enable_tracing: bool = False,
instrumentor_builder: Optional[Callable[..., Any]] = None,
):
"""Initializes the AG2 Agent.
Under-the-hood, assuming .set_up() is called, this will correspond to
```python
# runnable_builder
runnable = runnable_builder(
llm_config=llm_config,
system_message=system_instruction,
**runnable_kwargs,
)
```
When everything is based on their default values, this corresponds to
```python
# llm_config
llm_config = {
"config_list": [{
"project_id": initializer.global_config.project,
"location": initializer.global_config.location,
"model": "gemini-1.0-pro-001",
"api_type": "google",
}]
}
# runnable_builder
runnable = ConversableAgent(
llm_config=llm_config,
name="Default AG2 Agent"
system_message="You are a helpful AI Assistant.",
human_input_mode="NEVER",
)
```
By default, if `llm_config` is not specified, a default configuration
will be created using the provided `model` and `api_type`.
If `runnable_builder` is not specified, a default runnable builder will
be used, configured with the `system_instruction`, `runnable_name` and
`runnable_kwargs`.
Args:
model (str):
Required. The name of the model (e.g. "gemini-1.0-pro").
Used to create a default `llm_config` if one is not provided.
This parameter is ignored if `llm_config` is provided.
runnable_name (str):
Required. The name of the runnable.
This name is used as the default `runnable_kwargs["name"]`
unless `runnable_kwargs` already contains a "name", in which
case the provided `runnable_kwargs["name"]` will be used.
api_type (str):
Optional. The API type to use for the language model.
Used to create a default `llm_config` if one is not provided.
This parameter is ignored if `llm_config` is provided.
llm_config (Mapping[str, Any]):
Optional. Configuration dictionary for the language model.
If provided, this configuration will be used directly.
Otherwise, a default `llm_config` will be created using `model`
and `api_type`. This `llm_config` is used as the default
`runnable_kwargs["llm_config"]` unless `runnable_kwargs` already
contains a "llm_config", in which case the provided
`runnable_kwargs["llm_config"]` will be used.
system_instruction (str):
Optional. The system instruction for the agent.
This instruction is used as the default
`runnable_kwargs["system_message"]` unless `runnable_kwargs`
already contains a "system_message", in which case the provided
`runnable_kwargs["system_message"]` will be used.
runnable_kwargs (Mapping[str, Any]):
Optional. Additional keyword arguments for the constructor of
the runnable. Details of the kwargs can be found in
https://docs.ag2.ai/docs/api-reference/autogen/ConversableAgent.
`runnable_kwargs` only supports `human_input_mode="NEVER"`.
Other `human_input_mode` values will trigger a warning.
runnable_builder (Callable[..., "ConversableAgent"]):
Optional. Callable that returns a new runnable. This can be used
for customizing the orchestration logic of the Agent.
If not provided, a default runnable builder will be used.
tools (Sequence[Callable[..., Any]]):
Optional. The tools for the agent to be able to use. All input
callables (e.g. function or class method) will be converted
to a AG2 tool . Defaults to None.
enable_tracing (bool):
Optional. Whether to enable tracing in Cloud Trace. Defaults to
False.
instrumentor_builder (Callable[..., Any]):
Optional. Callable that returns a new instrumentor. This can be
used for customizing the instrumentation logic of the Agent.
If not provided, a default instrumentor builder will be used.
This parameter is ignored if `enable_tracing` is False.
"""
from google.cloud.aiplatform import initializer
self._tmpl_attrs: dict[str, Any] = {
"project": initializer.global_config.project,
"location": initializer.global_config.location,
"model_name": model,
"api_type": api_type or "google",
"system_instruction": system_instruction,
"runnable_name": runnable_name,
"tools": [],
"ag2_tool_objects": [],
"runnable": None,
"runnable_builder": runnable_builder,
"instrumentor": None,
"instrumentor_builder": instrumentor_builder,
"enable_tracing": enable_tracing,
}
self._tmpl_attrs["llm_config"] = llm_config or {
"config_list": [
{
"project_id": self._tmpl_attrs.get("project"),
"location": self._tmpl_attrs.get("location"),
"model": self._tmpl_attrs.get("model_name"),
"api_type": self._tmpl_attrs.get("api_type"),
}
]
}
self._tmpl_attrs["runnable_kwargs"] = _prepare_runnable_kwargs(
runnable_kwargs=runnable_kwargs,
llm_config=self._tmpl_attrs.get("llm_config"),
system_instruction=self._tmpl_attrs.get("system_instruction"),
runnable_name=self._tmpl_attrs.get("runnable_name"),
)
if tools:
# We validate tools at initialization for actionable feedback before
# they are deployed.
_validate_tools(tools)
self._tmpl_attrs["tools"] = tools
def set_up(self):
"""Sets up the agent for execution of queries at runtime.
It initializes the runnable, binds the runnable with tools.
This method should not be called for an object that being passed to
the ReasoningEngine service for deployment, as it initializes clients
that can not be serialized.
"""
if self._tmpl_attrs.get("enable_tracing"):
instrumentor_builder = (
self._tmpl_attrs.get("instrumentor_builder")
or _default_instrumentor_builder
)
self._tmpl_attrs["instrumentor"] = instrumentor_builder(
project_id=self._tmpl_attrs.get("project")
)
# Set up tools.
tools = self._tmpl_attrs.get("tools")
ag2_tool_objects = self._tmpl_attrs.get("ag2_tool_objects")
if tools and not ag2_tool_objects:
from vertexai.agent_engines import _utils
autogen_tools = _utils._import_autogen_tools_or_warn()
if autogen_tools:
for tool in tools:
ag2_tool_objects.append(autogen_tools.Tool(func_or_tool=tool))
# Set up runnable.
runnable_builder = (
self._tmpl_attrs.get("runnable_builder") or _default_runnable_builder
)
self._tmpl_attrs["runnable"] = runnable_builder(
**self._tmpl_attrs.get("runnable_kwargs")
)
def clone(self) -> "AG2Agent":
"""Returns a clone of the AG2Agent."""
import copy
return AG2Agent(
model=self._tmpl_attrs.get("model_name"),
api_type=self._tmpl_attrs.get("api_type"),
llm_config=copy.deepcopy(self._tmpl_attrs.get("llm_config")),
system_instruction=self._tmpl_attrs.get("system_instruction"),
runnable_name=self._tmpl_attrs.get("runnable_name"),
tools=copy.deepcopy(self._tmpl_attrs.get("tools")),
runnable_kwargs=copy.deepcopy(self._tmpl_attrs.get("runnable_kwargs")),
runnable_builder=self._tmpl_attrs.get("runnable_builder"),
enable_tracing=self._tmpl_attrs.get("enable_tracing"),
instrumentor_builder=self._tmpl_attrs.get("instrumentor_builder"),
)
def query(
self,
*,
input: Union[str, Mapping[str, Any]],
max_turns: Optional[int] = None,
**kwargs: Any,
) -> Dict[str, Any]:
"""Queries the Agent with the given input.
Args:
input (Union[str, Mapping[str, Any]]):
Required. The input to be passed to the Agent.
max_turns (int):
Optional. The maximum number of turns to run the agent for.
If not provided, the agent will run indefinitely.
If `max_turns` is a `float`, it will be converted to `int`
through rounding.
**kwargs:
Optional. Any additional keyword arguments to be passed to the
`.run()` method of the corresponding runnable.
Details of the kwargs can be found in
https://docs.ag2.ai/docs/api-reference/autogen/ConversableAgent#run.
The `user_input` parameter defaults to `False`, and should not
be passed through `kwargs`.
Returns:
The output of querying the Agent with the given input.
"""
if isinstance(input, str):
input = {"content": input}
if max_turns and isinstance(max_turns, float):
# Supporting auto-conversion float to int.
max_turns = round(max_turns)
if "user_input" in kwargs:
from google.cloud.aiplatform import base
_LOGGER = base.Logger(__name__)
_LOGGER.warning(
"The `user_input` parameter should not be passed through"
"kwargs. The `user_input` defaults to `False`."
)
kwargs.pop("user_input")
if not self._tmpl_attrs.get("runnable"):
self.set_up()
response = self._tmpl_attrs.get("runnable").run(
message=input,
user_input=False,
tools=self._tmpl_attrs.get("ag2_tool_objects"),
max_turns=max_turns,
**kwargs,
)
from vertexai.agent_engines import _utils
return _utils.to_json_serializable_autogen_object(response)