structure saas with tools

This commit is contained in:
Davidson Gomes
2025-04-25 15:30:54 -03:00
commit 1aef473937
16434 changed files with 6584257 additions and 0 deletions

View File

@@ -0,0 +1,47 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""Classes for working with reasoning engines."""
# We just want to re-export certain classes
# pylint: disable=g-multiple-import,g-importing-member
from vertexai.reasoning_engines._reasoning_engines import (
Queryable,
ReasoningEngine,
)
from vertexai.preview.reasoning_engines.templates.adk import (
AdkApp,
)
from vertexai.preview.reasoning_engines.templates.ag2 import (
AG2Agent,
)
from vertexai.preview.reasoning_engines.templates.langchain import (
LangchainAgent,
)
from vertexai.preview.reasoning_engines.templates.langgraph import (
LanggraphAgent,
)
from vertexai.preview.reasoning_engines.templates.llama_index import (
LlamaIndexQueryPipelineAgent,
)
__all__ = (
"AdkApp",
"AG2Agent",
"LangchainAgent",
"LanggraphAgent",
"LlamaIndexQueryPipelineAgent",
"Queryable",
"ReasoningEngine",
)

View File

@@ -0,0 +1,651 @@
# -*- coding: utf-8 -*-
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional
if TYPE_CHECKING:
try:
from google.adk.events.event import Event
Event = Event
except (ImportError, AttributeError):
Event = Any
try:
from google.adk.agents import BaseAgent
BaseAgent = BaseAgent
except (ImportError, AttributeError):
BaseAgent = Any
try:
from google.adk.sessions import BaseSessionService
BaseSessionService = BaseSessionService
except (ImportError, AttributeError):
BaseSessionService = Any
try:
from google.adk.artifacts import BaseArtifactService
BaseArtifactService = BaseArtifactService
except (ImportError, AttributeError):
BaseArtifactService = Any
try:
from opentelemetry.sdk import trace
TracerProvider = trace.TracerProvider
SpanProcessor = trace.SpanProcessor
SynchronousMultiSpanProcessor = trace.SynchronousMultiSpanProcessor
except ImportError:
TracerProvider = Any
SpanProcessor = Any
SynchronousMultiSpanProcessor = Any
_DEFAULT_APP_NAME = "default-app-name"
_DEFAULT_USER_ID = "default-user-id"
class _ArtifactVersion:
def __init__(self, **kwargs):
self.version: Optional[str] = kwargs.get("version")
self.data = kwargs.get("data")
def dump(self) -> Dict[str, Any]:
result = {}
if self.version:
result["version"] = self.version
if self.data:
result["data"] = self.data
return result
class _Artifact:
def __init__(self, **kwargs):
self.file_name: Optional[str] = kwargs.get("file_name")
self.versions: List[_ArtifactVersion] = kwargs.get("versions")
def dump(self) -> Dict[str, Any]:
result = {}
if self.file_name:
result["file_name"] = self.file_name
if self.versions:
result["versions"] = [version.dump() for version in self.versions]
return result
class _Authorization:
def __init__(self, **kwargs):
self.access_token: Optional[str] = kwargs.get("access_token") or kwargs.get(
"accessToken"
)
class _StreamRunRequest:
"""Request object for `streaming_agent_run_with_events` method."""
def __init__(self, **kwargs):
from google.adk.events.event import Event
from google.genai import types
self.message: Optional[types.Content] = kwargs.get("message")
# The new message to be processed by the agent.
self.events: Optional[List[Event]] = kwargs.get("events")
# List of preceding events happened in the same session.
self.artifacts: Optional[List[_Artifact]] = kwargs.get("artifacts")
# List of artifacts belonging to the session.
self.authorizations: Dict[str, _Authorization] = kwargs.get(
"authorizations", {}
)
# The authorizations of the user, keyed by authorization ID.
self.user_id: Optional[str] = kwargs.get("user_id", _DEFAULT_USER_ID)
# The user ID.
class _StreamingRunResponse:
"""Response object for `streaming_agent_run_with_events` method.
It contains the generated events together with the belonging artifacts.
"""
def __init__(self, **kwargs):
self.events: Optional[List["Event"]] = kwargs.get("events")
# List of generated events.
self.artifacts: Optional[List[_Artifact]] = kwargs.get("artifacts")
# List of artifacts belonging to the session.
def dump(self) -> Dict[str, Any]:
result = {}
if self.events:
result["events"] = []
for event in self.events:
event_dict = event.model_dump(exclude_none=True)
event_dict["invocation_id"] = event_dict.get("invocation_id", "")
result["events"].append(event_dict)
if self.artifacts:
result["artifacts"] = [artifact.dump() for artifact in self.artifacts]
return result
def _default_instrumentor_builder(project_id: str):
from vertexai.agent_engines import _utils
cloud_trace_exporter = _utils._import_cloud_trace_exporter_or_warn()
cloud_trace_v2 = _utils._import_cloud_trace_v2_or_warn()
opentelemetry = _utils._import_opentelemetry_or_warn()
opentelemetry_sdk_trace = _utils._import_opentelemetry_sdk_trace_or_warn()
if all(
(
cloud_trace_exporter,
cloud_trace_v2,
opentelemetry,
opentelemetry_sdk_trace,
)
):
import google.auth
credentials, _ = google.auth.default()
span_exporter = cloud_trace_exporter.CloudTraceSpanExporter(
project_id=project_id,
client=cloud_trace_v2.TraceServiceClient(
credentials=credentials.with_quota_project(project_id),
),
)
span_processor = opentelemetry_sdk_trace.export.BatchSpanProcessor(
span_exporter=span_exporter,
)
tracer_provider = opentelemetry.trace.get_tracer_provider()
# Get the appropriate tracer provider:
# 1. If _TRACER_PROVIDER is already set, use that.
# 2. Otherwise, if the OTEL_PYTHON_TRACER_PROVIDER environment
# variable is set, use that.
# 3. As a final fallback, use _PROXY_TRACER_PROVIDER.
# If none of the above is set, we log a warning, and
# create a tracer provider.
if not tracer_provider:
from google.cloud.aiplatform import base
_LOGGER = base.Logger(__name__)
_LOGGER.warning(
"No tracer provider. By default, "
"we should get one of the following providers: "
"OTEL_PYTHON_TRACER_PROVIDER, _TRACER_PROVIDER, "
"or _PROXY_TRACER_PROVIDER."
)
tracer_provider = opentelemetry_sdk_trace.TracerProvider()
opentelemetry.trace.set_tracer_provider(tracer_provider)
# Avoids AttributeError:
# 'ProxyTracerProvider' and 'NoOpTracerProvider' objects has no
# attribute 'add_span_processor'.
if _utils.is_noop_or_proxy_tracer_provider(tracer_provider):
tracer_provider = opentelemetry_sdk_trace.TracerProvider()
opentelemetry.trace.set_tracer_provider(tracer_provider)
# Avoids OpenTelemetry client already exists error.
_override_active_span_processor(
tracer_provider,
opentelemetry_sdk_trace.SynchronousMultiSpanProcessor(),
)
tracer_provider.add_span_processor(span_processor)
return None
else:
from google.cloud.aiplatform import base
_LOGGER = base.Logger(__name__)
_LOGGER.warning(
"enable_tracing=True but proceeding with tracing disabled "
"because not all packages (i.e. `google-cloud-trace`, `opentelemetry-sdk`, "
"`opentelemetry-exporter-gcp-trace`) for tracing have been installed"
)
return None
def _override_active_span_processor(
tracer_provider: "TracerProvider",
active_span_processor: "SynchronousMultiSpanProcessor",
):
"""Overrides the active span processor.
When working with multiple LangchainAgents in the same environment,
it's crucial to manage trace exports carefully.
Each agent needs its own span processor tied to a unique project ID.
While we add a new span processor for each agent, this can lead to
unexpected behavior.
For instance, with two agents linked to different projects, traces from the
second agent might be sent to both projects.
To prevent this and guarantee traces go to the correct project, we overwrite
the active span processor whenever a new LangchainAgent is created.
Args:
tracer_provider (TracerProvider):
The tracer provider to use for the project.
active_span_processor (SynchronousMultiSpanProcessor):
The active span processor overrides the tracer provider's
active span processor.
"""
if tracer_provider._active_span_processor:
tracer_provider._active_span_processor.shutdown()
tracer_provider._active_span_processor = active_span_processor
class AdkApp:
def __init__(
self,
*,
agent: "BaseAgent",
enable_tracing: bool = False,
session_service_builder: Optional[Callable[..., "BaseSessionService"]] = None,
artifact_service_builder: Optional[Callable[..., "BaseArtifactService"]] = None,
env_vars: Optional[Dict[str, str]] = None,
):
"""An ADK Application."""
from google.cloud.aiplatform import initializer
self._tmpl_attrs: Dict[str, Any] = {
"project": initializer.global_config.project,
"location": initializer.global_config.location,
"agent": agent,
"enable_tracing": enable_tracing,
"session_service_builder": session_service_builder,
"artifact_service_builder": artifact_service_builder,
"app_name": _DEFAULT_APP_NAME,
"env_vars": env_vars or {},
}
def _init_session(
self,
session_service: "BaseSessionService",
artifact_service: "BaseArtifactService",
request: _StreamRunRequest,
):
"""Initializes the session, and returns the session id."""
from google.adk.events.event import Event
import random
session_state = None
if request.authorizations:
session_state = {}
for auth_id, auth in request.authorizations.items():
auth = _Authorization(**auth)
session_state[f"temp:{auth_id}"] = auth.access_token
session_id = f"temp_session_{random.randbytes(8).hex()}"
session = session_service.create_session(
app_name=self._tmpl_attrs.get("app_name"),
user_id=request.user_id,
session_id=session_id,
state=session_state,
)
if not session:
raise RuntimeError("Create session failed.")
if request.events:
for event in request.events:
session_service.append_event(session, Event(**event))
if request.artifacts:
for artifact in request.artifacts:
artifact = _Artifact(**artifact)
for version_data in sorted(
artifact.versions, key=lambda x: x["version"]
):
version_data = _ArtifactVersion(**version_data)
saved_version = artifact_service.save_artifact(
app_name=self._tmpl_attrs.get("app_name"),
user_id=request.user_id,
session_id=session_id,
filename=artifact.file_name,
artifact=version_data.data,
)
if saved_version != version_data.version:
from google.cloud.aiplatform import base
_LOGGER = base.Logger(__name__)
_LOGGER.debug(
"Artifact '%s' saved at version %s instead of %s",
artifact.file_name,
saved_version,
version_data.version,
)
return session
def _convert_response_events(
self,
user_id: str,
session_id: str,
events: List["Event"],
artifact_service: Optional["BaseArtifactService"],
) -> _StreamingRunResponse:
"""Converts the events to the streaming run response object."""
import collections
result = _StreamingRunResponse(events=events, artifacts=[])
# Save the generated artifacts into the result object.
artifact_versions = collections.defaultdict(list)
for event in events:
if event.actions and event.actions.artifact_delta:
for key, version in event.actions.artifact_delta.items():
artifact_versions[key].append(version)
for key, versions in artifact_versions.items():
result.artifacts.append(
_Artifact(
file_name=key,
versions=[
_ArtifactVersion(
version=version,
data=artifact_service.load_artifact(
app_name=self._tmpl_attrs.get("app_name"),
user_id=user_id,
session_id=session_id,
filename=key,
version=version,
),
)
for version in versions
],
)
)
return result.dump()
def clone(self):
"""Returns a clone of the ADK application."""
import copy
return AdkApp(
agent=copy.deepcopy(self._tmpl_attrs.get("agent")),
enable_tracing=self._tmpl_attrs.get("enable_tracing"),
session_service_builder=self._tmpl_attrs.get("session_service_builder"),
artifact_service_builder=self._tmpl_attrs.get("artifact_service_builder"),
env_vars=self._tmpl_attrs.get("env_vars"),
)
def set_up(self):
"""Sets up the ADK application."""
import os
from google.adk.runners import Runner
from google.adk.sessions.in_memory_session_service import InMemorySessionService
from google.adk.artifacts.in_memory_artifact_service import (
InMemoryArtifactService,
)
os.environ["GOOGLE_GENAI_USE_VERTEXAI"] = "1"
project = self._tmpl_attrs.get("project")
os.environ["GOOGLE_CLOUD_PROJECT"] = project
location = self._tmpl_attrs.get("location")
os.environ["GOOGLE_CLOUD_LOCATION"] = location
if self._tmpl_attrs.get("enable_tracing"):
self._tmpl_attrs["instrumentor"] = _default_instrumentor_builder(
project_id=project
)
for key, value in self._tmpl_attrs.get("env_vars").items():
os.environ[key] = value
artifact_service_builder = self._tmpl_attrs.get("artifact_service_builder")
if artifact_service_builder:
self._tmpl_attrs["artifact_service"] = artifact_service_builder()
else:
self._tmpl_attrs["artifact_service"] = InMemoryArtifactService()
session_service_builder = self._tmpl_attrs.get("session_service_builder")
if session_service_builder:
self._tmpl_attrs["session_service"] = session_service_builder()
elif "GOOGLE_CLOUD_AGENT_ENGINE_ID" in os.environ:
from google.adk.sessions.vertex_ai_session_service import (
VertexAiSessionService,
)
self._tmpl_attrs["session_service"] = VertexAiSessionService(
project=project,
location=location,
)
self._tmpl_attrs["app_name"] = os.environ.get(
"GOOGLE_CLOUD_AGENT_ENGINE_ID",
self._tmpl_attrs.get("app_name"),
)
else:
self._tmpl_attrs["session_service"] = InMemorySessionService()
self._tmpl_attrs["runner"] = Runner(
agent=self._tmpl_attrs.get("agent"),
session_service=self._tmpl_attrs.get("session_service"),
artifact_service=self._tmpl_attrs.get("artifact_service"),
app_name=self._tmpl_attrs.get("app_name"),
)
self._tmpl_attrs["in_memory_session_service"] = InMemorySessionService()
self._tmpl_attrs["in_memory_artifact_service"] = InMemoryArtifactService()
self._tmpl_attrs["in_memory_runner"] = Runner(
agent=self._tmpl_attrs.get("agent"),
session_service=self._tmpl_attrs.get("in_memory_session_service"),
artifact_service=self._tmpl_attrs.get("in_memory_artifact_service"),
app_name=self._tmpl_attrs.get("app_name"),
)
def stream_query(
self,
*,
message: str,
user_id: str,
session_id: Optional[str] = None,
**kwargs,
):
"""Streams responses from the ADK application in response to a message.
Args:
message (str):
Required. The message to stream responses for.
user_id (str):
Required. The ID of the user.
session_id (str):
Optional. The ID of the session. If not provided, a new
session will be created for the user.
**kwargs (dict[str, Any]):
Optional. Additional keyword arguments to pass to the
runner.
Yields:
The output of querying the ADK application.
"""
from google.genai import types
content = types.Content(role="user", parts=[types.Part(text=message)])
if not self._tmpl_attrs.get("runner"):
self.set_up()
if not session_id:
session = self.create_session(user_id=user_id)
session_id = session.id
for event in self._tmpl_attrs.get("runner").run(
user_id=user_id, session_id=session_id, new_message=content, **kwargs
):
yield event.model_dump(exclude_none=True)
def streaming_agent_run_with_events(self, request_json: str):
import json
from google.genai import types
request = _StreamRunRequest(**json.loads(request_json))
if not self._tmpl_attrs.get("in_memory_runner"):
self.set_up()
if not self._tmpl_attrs.get("artifact_service"):
self.set_up()
# Prepare the in-memory session.
if not self._tmpl_attrs.get("in_memory_artifact_service"):
self.set_up()
if not self._tmpl_attrs.get("in_memory_session_service"):
self.set_up()
session = self._init_session(
session_service=self._tmpl_attrs.get("in_memory_session_service"),
artifact_service=self._tmpl_attrs.get("in_memory_artifact_service"),
request=request,
)
if not session:
raise RuntimeError("Session initialization failed.")
# Run the agent.
for event in self._tmpl_attrs.get("in_memory_runner").run(
user_id=request.user_id,
session_id=session.id,
new_message=types.Content(**request.message),
):
yield self._convert_response_events(
user_id=request.user_id,
session_id=session.id,
events=[event],
artifact_service=self._tmpl_attrs.get("in_memory_artifact_service"),
)
self._tmpl_attrs.get("in_memory_session_service").delete_session(
app_name=self._tmpl_attrs.get("app_name"),
user_id=request.user_id,
session_id=session.id,
)
def get_session(
self,
*,
user_id: str,
session_id: str,
**kwargs,
):
"""Get a session for the given user.
Args:
user_id (str):
Required. The ID of the user.
session_id (str):
Required. The ID of the session.
**kwargs (dict[str, Any]):
Optional. Additional keyword arguments to pass to the
session service.
Returns:
Session: The session instance (if any). It returns None if the
session is not found.
Raises:
RuntimeError: If the session is not found.
"""
if not self._tmpl_attrs.get("session_service"):
self.set_up()
session = self._tmpl_attrs.get("session_service").get_session(
app_name=self._tmpl_attrs.get("app_name"),
user_id=user_id,
session_id=session_id,
**kwargs,
)
if not session:
raise RuntimeError(
"Session not found. Please create it using .create_session()"
)
return session
def list_sessions(self, *, user_id: str, **kwargs):
"""List sessions for the given user.
Args:
user_id (str):
Required. The ID of the user.
**kwargs (dict[str, Any]):
Optional. Additional keyword arguments to pass to the
session service.
Returns:
ListSessionsResponse: The list of sessions.
"""
if not self._tmpl_attrs.get("session_service"):
self.set_up()
return self._tmpl_attrs.get("session_service").list_sessions(
app_name=self._tmpl_attrs.get("app_name"),
user_id=user_id,
**kwargs,
)
def create_session(
self,
*,
user_id: str,
session_id: Optional[str] = None,
state: Optional[Dict[str, Any]] = None,
**kwargs,
):
"""Creates a new session.
Args:
user_id (str):
Required. The ID of the user.
session_id (str):
Optional. The ID of the session. If not provided, an ID
will be be generated for the session.
state (dict[str, Any]):
Optional. The initial state of the session.
**kwargs (dict[str, Any]):
Optional. Additional keyword arguments to pass to the
session service.
Returns:
Session: The newly created session instance.
"""
if not self._tmpl_attrs.get("session_service"):
self.set_up()
session = self._tmpl_attrs.get("session_service").create_session(
app_name=self._tmpl_attrs.get("app_name"),
user_id=user_id,
session_id=session_id,
state=state,
**kwargs,
)
return session
def delete_session(
self,
*,
user_id: str,
session_id: str,
**kwargs,
):
"""Deletes a session for the given user.
Args:
user_id (str):
Required. The ID of the user.
session_id (str):
Required. The ID of the session.
**kwargs (dict[str, Any]):
Optional. Additional keyword arguments to pass to the
session service.
"""
if not self._tmpl_attrs.get("session_service"):
self.set_up()
self._tmpl_attrs.get("session_service").delete_session(
app_name=self._tmpl_attrs.get("app_name"),
user_id=user_id,
session_id=session_id,
**kwargs,
)
def register_operations(self) -> Dict[str, List[str]]:
"""Registers the operations of the ADK application."""
return {
"": [
"get_session",
"list_sessions",
"create_session",
"delete_session",
],
"stream": ["stream_query", "streaming_agent_run_with_events"],
}

View File

@@ -0,0 +1,474 @@
# -*- coding: utf-8 -*-
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
Mapping,
Optional,
Sequence,
Union,
)
if TYPE_CHECKING:
try:
from autogen import agentchat
ConversableAgent = agentchat.ConversableAgent
ChatResult = agentchat.ChatResult
except ImportError:
ConversableAgent = Any
try:
from opentelemetry.sdk import trace
TracerProvider = trace.TracerProvider
SpanProcessor = trace.SpanProcessor
SynchronousMultiSpanProcessor = trace.SynchronousMultiSpanProcessor
except ImportError:
TracerProvider = Any
SpanProcessor = Any
SynchronousMultiSpanProcessor = Any
def _prepare_runnable_kwargs(
runnable_kwargs: Mapping[str, Any],
system_instruction: str,
runnable_name: str,
llm_config: Mapping[str, Any],
) -> Mapping[str, Any]:
"""Prepares the configuration for a runnable, applying defaults and enforcing constraints."""
if runnable_kwargs is None:
runnable_kwargs = {}
if (
"human_input_mode" in runnable_kwargs
and runnable_kwargs["human_input_mode"] != "NEVER"
):
from google.cloud.aiplatform import base
_LOGGER = base.Logger(__name__)
_LOGGER.warning(
f"human_input_mode={runnable_kwargs['human_input_mode']}"
"is not supported. Will be enforced to 'NEVER'."
)
runnable_kwargs["human_input_mode"] = "NEVER"
if "system_message" not in runnable_kwargs and system_instruction:
runnable_kwargs["system_message"] = system_instruction
if "name" not in runnable_kwargs:
runnable_kwargs["name"] = runnable_name
if "llm_config" not in runnable_kwargs:
runnable_kwargs["llm_config"] = llm_config
return runnable_kwargs
def _default_runnable_builder(
**runnable_kwargs: Any,
) -> "ConversableAgent":
from autogen import agentchat
return agentchat.ConversableAgent(**runnable_kwargs)
def _validate_callable_parameters_are_annotated(callable: Callable):
"""Validates that the parameters of the callable have type annotations.
This ensures that they can be used for constructing AG2 tools that are
usable with Gemini function calling.
"""
import inspect
parameters = dict(inspect.signature(callable).parameters)
for name, parameter in parameters.items():
if parameter.annotation == inspect.Parameter.empty:
raise TypeError(
f"Callable={callable.__name__} has untyped input_arg={name}. "
f"Please specify a type when defining it, e.g. `{name}: str`."
)
def _validate_tools(tools: Sequence[Callable[..., Any]]):
"""Validates that the tools are usable for tool calling."""
for tool in tools:
if isinstance(tool, Callable):
_validate_callable_parameters_are_annotated(tool)
def _override_active_span_processor(
tracer_provider: "TracerProvider",
active_span_processor: "SynchronousMultiSpanProcessor",
):
"""Overrides the active span processor.
When working with multiple AG2Agents in the same environment,
it's crucial to manage trace exports carefully.
Each agent needs its own span processor tied to a unique project ID.
While we add a new span processor for each agent, this can lead to
unexpected behavior.
For instance, with two agents linked to different projects, traces from the
second agent might be sent to both projects.
To prevent this and guarantee traces go to the correct project, we overwrite
the active span processor whenever a new AG2Agent is created.
Args:
tracer_provider (TracerProvider):
The tracer provider to use for the project.
active_span_processor (SynchronousMultiSpanProcessor):
The active span processor overrides the tracer provider's
active span processor.
"""
if tracer_provider._active_span_processor:
tracer_provider._active_span_processor.shutdown()
tracer_provider._active_span_processor = active_span_processor
class AG2Agent:
"""An AG2 Agent.
See https://cloud.google.com/vertex-ai/generative-ai/docs/agent-engine/develop/ag2
for details.
"""
def __init__(
self,
model: str,
runnable_name: str,
*,
api_type: Optional[str] = None,
llm_config: Optional[Mapping[str, Any]] = None,
system_instruction: Optional[str] = None,
runnable_kwargs: Optional[Mapping[str, Any]] = None,
runnable_builder: Optional[Callable[..., "ConversableAgent"]] = None,
tools: Optional[Sequence[Callable[..., Any]]] = None,
enable_tracing: bool = False,
):
"""Initializes the AG2 Agent.
Under-the-hood, assuming .set_up() is called, this will correspond to
```python
# runnable_builder
runnable = runnable_builder(
llm_config=llm_config,
system_message=system_instruction,
**runnable_kwargs,
)
```
When everything is based on their default values, this corresponds to
```python
# llm_config
llm_config = {
"config_list": [{
"project_id": initializer.global_config.project,
"location": initializer.global_config.location,
"model": "gemini-1.0-pro-001",
"api_type": "google",
}]
}
# runnable_builder
runnable = ConversableAgent(
llm_config=llm_config,
name="Default AG2 Agent"
system_message="You are a helpful AI Assistant.",
human_input_mode="NEVER",
)
```
By default, if `llm_config` is not specified, a default configuration
will be created using the provided `model` and `api_type`.
If `runnable_builder` is not specified, a default runnable builder will
be used, configured with the `system_instruction`, `runnable_name` and
`runnable_kwargs`.
Args:
model (str):
Required. The name of the model (e.g. "gemini-1.0-pro").
Used to create a default `llm_config` if one is not provided.
This parameter is ignored if `llm_config` is provided.
runnable_name (str):
Required. The name of the runnable.
This name is used as the default `runnable_kwargs["name"]`
unless `runnable_kwargs` already contains a "name", in which
case the provided `runnable_kwargs["name"]` will be used.
api_type (str):
Optional. The API type to use for the language model.
Used to create a default `llm_config` if one is not provided.
This parameter is ignored if `llm_config` is provided.
llm_config (Mapping[str, Any]):
Optional. Configuration dictionary for the language model.
If provided, this configuration will be used directly.
Otherwise, a default `llm_config` will be created using `model`
and `api_type`. This `llm_config` is used as the default
`runnable_kwargs["llm_config"]` unless `runnable_kwargs` already
contains a "llm_config", in which case the provided
`runnable_kwargs["llm_config"]` will be used.
system_instruction (str):
Optional. The system instruction for the agent.
This instruction is used as the default
`runnable_kwargs["system_message"]` unless `runnable_kwargs`
already contains a "system_message", in which case the provided
`runnable_kwargs["system_message"]` will be used.
runnable_kwargs (Mapping[str, Any]):
Optional. Additional keyword arguments for the constructor of
the runnable. Details of the kwargs can be found in
https://docs.ag2.ai/docs/api-reference/autogen/ConversableAgent.
`runnable_kwargs` only supports `human_input_mode="NEVER"`.
Other `human_input_mode` values will trigger a warning.
runnable_builder (Callable[..., "ConversableAgent"]):
Optional. Callable that returns a new runnable. This can be used
for customizing the orchestration logic of the Agent.
If not provided, a default runnable builder will be used.
tools (Sequence[Callable[..., Any]]):
Optional. The tools for the agent to be able to use. All input
callables (e.g. function or class method) will be converted
to a AG2 tool . Defaults to None.
enable_tracing (bool):
Optional. Whether to enable tracing in Cloud Trace. Defaults to
False.
"""
from google.cloud.aiplatform import initializer
# Set up llm config.
self._project = initializer.global_config.project
self._location = initializer.global_config.location
self._model_name = model or "gemini-1.0-pro-001"
self._api_type = api_type or "google"
self._llm_config = llm_config or {
"config_list": [
{
"project_id": self._project,
"location": self._location,
"model": self._model_name,
"api_type": self._api_type,
}
]
}
self._system_instruction = system_instruction
self._runnable_name = runnable_name
self._runnable_kwargs = _prepare_runnable_kwargs(
runnable_kwargs=runnable_kwargs,
llm_config=self._llm_config,
system_instruction=self._system_instruction,
runnable_name=self._runnable_name,
)
self._tools = []
if tools:
# We validate tools at initialization for actionable feedback before
# they are deployed.
_validate_tools(tools)
self._tools = tools
self._ag2_tool_objects = []
self._runnable = None
self._runnable_builder = runnable_builder
self._instrumentor = None
self._enable_tracing = enable_tracing
def set_up(self):
"""Sets up the agent for execution of queries at runtime.
It initializes the runnable, binds the runnable with tools.
This method should not be called for an object that being passed to
the ReasoningEngine service for deployment, as it initializes clients
that can not be serialized.
"""
if self._enable_tracing:
from vertexai.reasoning_engines import _utils
cloud_trace_exporter = _utils._import_cloud_trace_exporter_or_warn()
cloud_trace_v2 = _utils._import_cloud_trace_v2_or_warn()
openinference_autogen = _utils._import_openinference_autogen_or_warn()
opentelemetry = _utils._import_opentelemetry_or_warn()
opentelemetry_sdk_trace = _utils._import_opentelemetry_sdk_trace_or_warn()
if all(
(
cloud_trace_exporter,
cloud_trace_v2,
openinference_autogen,
opentelemetry,
opentelemetry_sdk_trace,
)
):
import google.auth
credentials, _ = google.auth.default()
span_exporter = cloud_trace_exporter.CloudTraceSpanExporter(
project_id=self._project,
client=cloud_trace_v2.TraceServiceClient(
credentials=credentials.with_quota_project(self._project),
),
)
span_processor: SpanProcessor = (
opentelemetry_sdk_trace.export.SimpleSpanProcessor(
span_exporter=span_exporter,
)
)
tracer_provider: TracerProvider = (
opentelemetry.trace.get_tracer_provider()
)
# Get the appropriate tracer provider:
# 1. If _TRACER_PROVIDER is already set, use that.
# 2. Otherwise, if the OTEL_PYTHON_TRACER_PROVIDER environment
# variable is set, use that.
# 3. As a final fallback, use _PROXY_TRACER_PROVIDER.
# If none of the above is set, we log a warning, and
# create a tracer provider.
if not tracer_provider:
from google.cloud.aiplatform import base
_LOGGER = base.Logger(__name__)
_LOGGER.warning(
"No tracer provider. By default, "
"we should get one of the following providers: "
"OTEL_PYTHON_TRACER_PROVIDER, _TRACER_PROVIDER, "
"or _PROXY_TRACER_PROVIDER."
)
tracer_provider = opentelemetry_sdk_trace.TracerProvider()
opentelemetry.trace.set_tracer_provider(tracer_provider)
# Avoids AttributeError:
# 'ProxyTracerProvider' and 'NoOpTracerProvider' objects has no
# attribute 'add_span_processor'.
if _utils.is_noop_or_proxy_tracer_provider(tracer_provider):
tracer_provider = opentelemetry_sdk_trace.TracerProvider()
opentelemetry.trace.set_tracer_provider(tracer_provider)
# Avoids OpenTelemetry client already exists error.
_override_active_span_processor(
tracer_provider,
opentelemetry_sdk_trace.SynchronousMultiSpanProcessor(),
)
tracer_provider.add_span_processor(span_processor)
# Keep the instrumentation up-to-date.
# When creating multiple AG2Agents,
# we need to keep the instrumentation up-to-date.
# We deliberately override the instrument each time,
# so that if different agents end up using different
# instrumentations, we guarantee that the user is always
# working with the most recent agent's instrumentation.
self._instrumentor = openinference_autogen.AutogenInstrumentor()
self._instrumentor.uninstrument()
self._instrumentor.instrument()
else:
from google.cloud.aiplatform import base
_LOGGER = base.Logger(__name__)
_LOGGER.warning(
"enable_tracing=True but proceeding with tracing disabled "
"because not all packages for tracing have been installed"
)
# Set up tools.
if self._tools and not self._ag2_tool_objects:
from vertexai.reasoning_engines import _utils
autogen_tools = _utils._import_autogen_tools_or_warn()
if autogen_tools:
for tool in self._tools:
self._ag2_tool_objects.append(autogen_tools.Tool(func_or_tool=tool))
# Set up runnable.
runnable_builder = self._runnable_builder or _default_runnable_builder
self._runnable = runnable_builder(
**self._runnable_kwargs,
)
def clone(self) -> "AG2Agent":
"""Returns a clone of the AG2Agent."""
import copy
return AG2Agent(
model=self._model_name,
api_type=self._api_type,
llm_config=copy.deepcopy(self._llm_config),
system_instruction=self._system_instruction,
runnable_name=self._runnable_name,
tools=copy.deepcopy(self._tools),
runnable_kwargs=copy.deepcopy(self._runnable_kwargs),
runnable_builder=self._runnable_builder,
enable_tracing=self._enable_tracing,
)
def query(
self,
*,
input: Union[str, Mapping[str, Any]],
max_turns: Optional[int] = None,
**kwargs: Any,
) -> Dict[str, Any]:
"""Queries the Agent with the given input.
Args:
input (Union[str, Mapping[str, Any]]):
Required. The input to be passed to the Agent.
max_turns (int):
Optional. The maximum number of turns to run the agent for.
If not provided, the agent will run indefinitely.
If `max_turns` is a `float`, it will be converted to `int`
through rounding.
**kwargs:
Optional. Any additional keyword arguments to be passed to the
`.run()` method of the corresponding runnable.
Details of the kwargs can be found in
https://docs.ag2.ai/docs/api-reference/autogen/ConversableAgent#run.
The `user_input` parameter defaults to `False`, and should not
be passed through `kwargs`.
Returns:
The output of querying the Agent with the given input.
"""
if isinstance(input, str):
input = {"content": input}
if max_turns and isinstance(max_turns, float):
# Supporting auto-conversion float to int.
max_turns = round(max_turns)
if "user_input" in kwargs:
from google.cloud.aiplatform import base
_LOGGER = base.Logger(__name__)
_LOGGER.warning(
"The `user_input` parameter should not be passed through"
"kwargs. The `user_input` defaults to `False`."
)
kwargs.pop("user_input")
if not self._runnable:
self.set_up()
from vertexai.reasoning_engines import _utils
# `.run()` will return a `ChatResult` object, which is a dataclass.
# We need to convert it to a JSON-serializable object.
# More details of `ChatResult` can be found in
# https://docs.ag2.ai/docs/api-reference/autogen/ChatResult.
return _utils.dataclass_to_dict(
self._runnable.run(
input,
user_input=False,
tools=self._ag2_tool_objects,
max_turns=max_turns,
**kwargs,
)
)

View File

@@ -0,0 +1,643 @@
# -*- coding: utf-8 -*-
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
Iterable,
Mapping,
Optional,
Sequence,
Union,
)
if TYPE_CHECKING:
try:
from langchain_core import runnables
from langchain_core import tools as lc_tools
from langchain_core.language_models import base as lc_language_models
BaseTool = lc_tools.BaseTool
BaseLanguageModel = lc_language_models.BaseLanguageModel
GetSessionHistoryCallable = runnables.history.GetSessionHistoryCallable
RunnableConfig = runnables.RunnableConfig
RunnableSerializable = runnables.RunnableSerializable
except ImportError:
BaseTool = Any
BaseLanguageModel = Any
GetSessionHistoryCallable = Any
RunnableConfig = Any
RunnableSerializable = Any
try:
from langchain_google_vertexai.functions_utils import _ToolsType
_ToolLike = _ToolsType
except ImportError:
_ToolLike = Any
try:
from opentelemetry.sdk import trace
TracerProvider = trace.TracerProvider
SpanProcessor = trace.SpanProcessor
SynchronousMultiSpanProcessor = trace.SynchronousMultiSpanProcessor
except ImportError:
TracerProvider = Any
SpanProcessor = Any
SynchronousMultiSpanProcessor = Any
def _default_runnable_kwargs(has_history: bool) -> Mapping[str, Any]:
# https://github.com/langchain-ai/langchain/blob/5784dfed001730530637793bea1795d9d5a7c244/libs/core/langchain_core/runnables/history.py#L237-L241
runnable_kwargs = {
# input_messages_key (str): Must be specified if the underlying
# agent accepts a dict as input.
"input_messages_key": "input",
# output_messages_key (str): Must be specified if the underlying
# agent returns a dict as output.
"output_messages_key": "output",
}
if has_history:
# history_messages_key (str): Must be specified if the underlying
# agent accepts a dict as input and a separate key for historical
# messages.
runnable_kwargs["history_messages_key"] = "history"
return runnable_kwargs
def _default_output_parser():
try:
from langchain.agents.output_parsers.tools import ToolsAgentOutputParser
except (ModuleNotFoundError, ImportError):
# Fallback to an older version if needed.
from langchain.agents.output_parsers.openai_tools import (
OpenAIToolsAgentOutputParser as ToolsAgentOutputParser,
)
return ToolsAgentOutputParser()
def _default_model_builder(
model_name: str,
*,
project: str,
location: str,
model_kwargs: Optional[Mapping[str, Any]] = None,
) -> "BaseLanguageModel":
import vertexai
from google.cloud.aiplatform import initializer
from langchain_google_vertexai import ChatVertexAI
model_kwargs = model_kwargs or {}
current_project = initializer.global_config.project
current_location = initializer.global_config.location
vertexai.init(project=project, location=location)
model = ChatVertexAI(model_name=model_name, **model_kwargs)
vertexai.init(project=current_project, location=current_location)
return model
def _default_runnable_builder(
model: "BaseLanguageModel",
*,
system_instruction: Optional[str] = None,
tools: Optional[Sequence["_ToolLike"]] = None,
prompt: Optional["RunnableSerializable"] = None,
output_parser: Optional["RunnableSerializable"] = None,
chat_history: Optional["GetSessionHistoryCallable"] = None,
model_tool_kwargs: Optional[Mapping[str, Any]] = None,
agent_executor_kwargs: Optional[Mapping[str, Any]] = None,
runnable_kwargs: Optional[Mapping[str, Any]] = None,
) -> "RunnableSerializable":
from langchain_core import tools as lc_tools
from langchain.agents import AgentExecutor
from langchain.tools.base import StructuredTool
# The prompt template and runnable_kwargs needs to be customized depending
# on whether the user intends for the agent to have history. The way the
# user would reflect that is by setting chat_history (which defaults to
# None).
has_history: bool = chat_history is not None
prompt = prompt or _default_prompt(
has_history=has_history,
system_instruction=system_instruction,
)
output_parser = output_parser or _default_output_parser()
model_tool_kwargs = model_tool_kwargs or {}
agent_executor_kwargs = agent_executor_kwargs or {}
runnable_kwargs = runnable_kwargs or _default_runnable_kwargs(has_history)
if tools:
model = model.bind_tools(tools=tools, **model_tool_kwargs)
else:
tools = []
agent_executor = AgentExecutor(
agent=prompt | model | output_parser,
tools=[
tool
if isinstance(tool, lc_tools.BaseTool)
else StructuredTool.from_function(tool)
for tool in tools
if isinstance(tool, (Callable, lc_tools.BaseTool))
],
**agent_executor_kwargs,
)
if has_history:
from langchain_core.runnables.history import RunnableWithMessageHistory
return RunnableWithMessageHistory(
runnable=agent_executor,
get_session_history=chat_history,
**runnable_kwargs,
)
return agent_executor
def _default_prompt(
has_history: bool,
system_instruction: Optional[str] = None,
) -> "RunnableSerializable":
from langchain_core import prompts
try:
from langchain.agents.format_scratchpad.tools import format_to_tool_messages
except (ModuleNotFoundError, ImportError):
# Fallback to an older version if needed.
from langchain.agents.format_scratchpad.openai_tools import (
format_to_openai_tool_messages as format_to_tool_messages,
)
system_instructions = []
if system_instruction:
system_instructions = [("system", system_instruction)]
if has_history:
return {
"history": lambda x: x["history"],
"input": lambda x: x["input"],
"agent_scratchpad": (
lambda x: format_to_tool_messages(x["intermediate_steps"])
),
} | prompts.ChatPromptTemplate.from_messages(
system_instructions
+ [
prompts.MessagesPlaceholder(variable_name="history"),
("user", "{input}"),
prompts.MessagesPlaceholder(variable_name="agent_scratchpad"),
]
)
else:
return {
"input": lambda x: x["input"],
"agent_scratchpad": (
lambda x: format_to_tool_messages(x["intermediate_steps"])
),
} | prompts.ChatPromptTemplate.from_messages(
system_instructions
+ [
("user", "{input}"),
prompts.MessagesPlaceholder(variable_name="agent_scratchpad"),
]
)
def _validate_callable_parameters_are_annotated(callable: Callable):
"""Validates that the parameters of the callable have type annotations.
This ensures that they can be used for constructing LangChain tools that are
usable with Gemini function calling.
"""
import inspect
parameters = dict(inspect.signature(callable).parameters)
for name, parameter in parameters.items():
if parameter.annotation == inspect.Parameter.empty:
raise TypeError(
f"Callable={callable.__name__} has untyped input_arg={name}. "
f"Please specify a type when defining it, e.g. `{name}: str`."
)
def _validate_tools(tools: Sequence["_ToolLike"]):
"""Validates that the tools are usable for tool calling."""
for tool in tools:
if isinstance(tool, Callable):
_validate_callable_parameters_are_annotated(tool)
def _override_active_span_processor(
tracer_provider: "TracerProvider",
active_span_processor: "SynchronousMultiSpanProcessor",
):
"""Overrides the active span processor.
When working with multiple LangchainAgents in the same environment,
it's crucial to manage trace exports carefully.
Each agent needs its own span processor tied to a unique project ID.
While we add a new span processor for each agent, this can lead to
unexpected behavior.
For instance, with two agents linked to different projects, traces from the
second agent might be sent to both projects.
To prevent this and guarantee traces go to the correct project, we overwrite
the active span processor whenever a new LangchainAgent is created.
Args:
tracer_provider (TracerProvider):
The tracer provider to use for the project.
active_span_processor (SynchronousMultiSpanProcessor):
The active span processor overrides the tracer provider's
active span processor.
"""
if tracer_provider._active_span_processor:
tracer_provider._active_span_processor.shutdown()
tracer_provider._active_span_processor = active_span_processor
class LangchainAgent:
"""A Langchain Agent.
See https://cloud.google.com/vertex-ai/generative-ai/docs/agent-engine/develop/langchain
for details.
"""
def __init__(
self,
model: str,
*,
system_instruction: Optional[str] = None,
prompt: Optional["RunnableSerializable"] = None,
tools: Optional[Sequence["_ToolLike"]] = None,
output_parser: Optional["RunnableSerializable"] = None,
chat_history: Optional["GetSessionHistoryCallable"] = None,
model_kwargs: Optional[Mapping[str, Any]] = None,
model_tool_kwargs: Optional[Mapping[str, Any]] = None,
agent_executor_kwargs: Optional[Mapping[str, Any]] = None,
runnable_kwargs: Optional[Mapping[str, Any]] = None,
model_builder: Optional[Callable] = None,
runnable_builder: Optional[Callable] = None,
enable_tracing: bool = False,
):
"""Initializes the LangchainAgent.
Under-the-hood, assuming .set_up() is called, this will correspond to
```
model = model_builder(model_name=model, model_kwargs=model_kwargs)
runnable = runnable_builder(
prompt=prompt,
model=model,
tools=tools,
output_parser=output_parser,
chat_history=chat_history,
agent_executor_kwargs=agent_executor_kwargs,
runnable_kwargs=runnable_kwargs,
)
```
When everything is based on their default values, this corresponds to
```
# model_builder
from langchain_google_vertexai import ChatVertexAI
llm = ChatVertexAI(model_name=model, **model_kwargs)
# runnable_builder
from langchain import agents
from langchain_core.runnables.history import RunnableWithMessageHistory
llm_with_tools = llm.bind_tools(tools=tools, **model_tool_kwargs)
agent_executor = agents.AgentExecutor(
agent=prompt | llm_with_tools | output_parser,
tools=tools,
**agent_executor_kwargs,
)
runnable = RunnableWithMessageHistory(
runnable=agent_executor,
get_session_history=chat_history,
**runnable_kwargs,
)
```
Args:
model (str):
Optional. The name of the model (e.g. "gemini-1.0-pro").
system_instruction (str):
Optional. The system instruction to use for the agent. This
argument should not be specified if `prompt` is specified.
prompt (langchain_core.runnables.RunnableSerializable):
Optional. The prompt template for the model. Defaults to a
ChatPromptTemplate.
tools (Sequence[langchain_core.tools.BaseTool, Callable]):
Optional. The tools for the agent to be able to use. All input
callables (e.g. function or class method) will be converted
to a langchain.tools.base.StructuredTool. Defaults to None.
output_parser (langchain_core.runnables.RunnableSerializable):
Optional. The output parser for the model. Defaults to an
output parser that works with Gemini function-calling.
chat_history (langchain_core.runnables.history.GetSessionHistoryCallable):
Optional. Callable that returns a new BaseChatMessageHistory.
Defaults to None, i.e. chat_history is not preserved.
model_kwargs (Mapping[str, Any]):
Optional. Additional keyword arguments for the constructor of
chat_models.ChatVertexAI. An example would be
```
{
# temperature (float): Sampling temperature, it controls the
# degree of randomness in token selection.
"temperature": 0.28,
# max_output_tokens (int): Token limit determines the
# maximum amount of text output from one prompt.
"max_output_tokens": 1000,
# top_p (float): Tokens are selected from most probable to
# least, until the sum of their probabilities equals the
# top_p value.
"top_p": 0.95,
# top_k (int): How the model selects tokens for output, the
# next token is selected from among the top_k most probable
# tokens.
"top_k": 40,
}
```
model_tool_kwargs (Mapping[str, Any]):
Optional. Additional keyword arguments when binding tools to the
model using `model.bind_tools()`.
agent_executor_kwargs (Mapping[str, Any]):
Optional. Additional keyword arguments for the constructor of
langchain.agents.AgentExecutor. An example would be
```
{
# Whether to return the agent's trajectory of intermediate
# steps at the end in addition to the final output.
"return_intermediate_steps": False,
# The maximum number of steps to take before ending the
# execution loop.
"max_iterations": 15,
# The method to use for early stopping if the agent never
# returns `AgentFinish`. Either 'force' or 'generate'.
"early_stopping_method": "force",
# How to handle errors raised by the agent's output parser.
# Defaults to `False`, which raises the error.
"handle_parsing_errors": False,
}
```
runnable_kwargs (Mapping[str, Any]):
Optional. Additional keyword arguments for the constructor of
langchain.runnables.history.RunnableWithMessageHistory if
chat_history is specified. If chat_history is None, this will be
ignored.
model_builder (Callable):
Optional. Callable that returns a new language model. Defaults
to a a callable that returns ChatVertexAI based on `model`,
`model_kwargs` and the parameters in `vertexai.init`.
runnable_builder (Callable):
Optional. Callable that returns a new runnable. This can be used
for customizing the orchestration logic of the Agent based on
the model returned by `model_builder` and the rest of the input
arguments.
enable_tracing (bool):
Optional. Whether to enable tracing in Cloud Trace. Defaults to
False.
Raises:
ValueError: If both `prompt` and `system_instruction` are specified.
TypeError: If there is an invalid tool (e.g. function with an input
that did not specify its type).
"""
from google.cloud.aiplatform import initializer
self._project = initializer.global_config.project
self._location = initializer.global_config.location
self._tools = []
if tools:
# We validate tools at initialization for actionable feedback before
# they are deployed.
_validate_tools(tools)
self._tools = tools
if prompt and system_instruction:
raise ValueError(
"Only one of `prompt` or `system_instruction` should be specified. "
"Consider incorporating the system instruction into the prompt "
"rather than passing it separately as an argument."
)
self._model_name = model
self._system_instruction = system_instruction
self._prompt = prompt
self._output_parser = output_parser
self._chat_history = chat_history
self._model_kwargs = model_kwargs
self._model_tool_kwargs = model_tool_kwargs
self._agent_executor_kwargs = agent_executor_kwargs
self._runnable_kwargs = runnable_kwargs
self._model = None
self._model_builder = model_builder
self._runnable = None
self._runnable_builder = runnable_builder
self._instrumentor = None
self._enable_tracing = enable_tracing
def set_up(self):
"""Sets up the agent for execution of queries at runtime.
It initializes the model, binds the model with tools, and connects it
with the prompt template and output parser.
This method should not be called for an object that being passed to
the ReasoningEngine service for deployment, as it initializes clients
that can not be serialized.
"""
if self._enable_tracing:
from vertexai.reasoning_engines import _utils
cloud_trace_exporter = _utils._import_cloud_trace_exporter_or_warn()
cloud_trace_v2 = _utils._import_cloud_trace_v2_or_warn()
openinference_langchain = _utils._import_openinference_langchain_or_warn()
opentelemetry = _utils._import_opentelemetry_or_warn()
opentelemetry_sdk_trace = _utils._import_opentelemetry_sdk_trace_or_warn()
if all(
(
cloud_trace_exporter,
cloud_trace_v2,
openinference_langchain,
opentelemetry,
opentelemetry_sdk_trace,
)
):
import google.auth
credentials, _ = google.auth.default()
span_exporter = cloud_trace_exporter.CloudTraceSpanExporter(
project_id=self._project,
client=cloud_trace_v2.TraceServiceClient(
credentials=credentials.with_quota_project(self._project),
),
)
span_processor: SpanProcessor = (
opentelemetry_sdk_trace.export.SimpleSpanProcessor(
span_exporter=span_exporter,
)
)
tracer_provider: TracerProvider = (
opentelemetry.trace.get_tracer_provider()
)
# Get the appropriate tracer provider:
# 1. If _TRACER_PROVIDER is already set, use that.
# 2. Otherwise, if the OTEL_PYTHON_TRACER_PROVIDER environment
# variable is set, use that.
# 3. As a final fallback, use _PROXY_TRACER_PROVIDER.
# If none of the above is set, we log a warning, and
# create a tracer provider.
if not tracer_provider:
from google.cloud.aiplatform import base
_LOGGER = base.Logger(__name__)
_LOGGER.warning(
"No tracer provider. By default, "
"we should get one of the following providers: "
"OTEL_PYTHON_TRACER_PROVIDER, _TRACER_PROVIDER, "
"or _PROXY_TRACER_PROVIDER."
)
tracer_provider = opentelemetry_sdk_trace.TracerProvider()
opentelemetry.trace.set_tracer_provider(tracer_provider)
# Avoids AttributeError:
# 'ProxyTracerProvider' and 'NoOpTracerProvider' objects has no
# attribute 'add_span_processor'.
if _utils.is_noop_or_proxy_tracer_provider(tracer_provider):
tracer_provider = opentelemetry_sdk_trace.TracerProvider()
opentelemetry.trace.set_tracer_provider(tracer_provider)
# Avoids OpenTelemetry client already exists error.
_override_active_span_processor(
tracer_provider,
opentelemetry_sdk_trace.SynchronousMultiSpanProcessor(),
)
tracer_provider.add_span_processor(span_processor)
# Keep the instrumentation up-to-date.
# When creating multiple LangchainAgents,
# we need to keep the instrumentation up-to-date.
# We deliberately override the instrument each time,
# so that if different agents end up using different
# instrumentations, we guarantee that the user is always
# working with the most recent agent's instrumentation.
self._instrumentor = openinference_langchain.LangChainInstrumentor()
if self._instrumentor.is_instrumented_by_opentelemetry:
self._instrumentor.uninstrument()
self._instrumentor.instrument()
else:
from google.cloud.aiplatform import base
_LOGGER = base.Logger(__name__)
_LOGGER.warning(
"enable_tracing=True but proceeding with tracing disabled "
"because not all packages for tracing have been installed"
)
model_builder = self._model_builder or _default_model_builder
self._model = model_builder(
model_name=self._model_name,
model_kwargs=self._model_kwargs,
project=self._project,
location=self._location,
)
runnable_builder = self._runnable_builder or _default_runnable_builder
self._runnable = runnable_builder(
prompt=self._prompt,
model=self._model,
tools=self._tools,
system_instruction=self._system_instruction,
output_parser=self._output_parser,
chat_history=self._chat_history,
model_tool_kwargs=self._model_tool_kwargs,
agent_executor_kwargs=self._agent_executor_kwargs,
runnable_kwargs=self._runnable_kwargs,
)
def clone(self) -> "LangchainAgent":
"""Returns a clone of the LangchainAgent."""
import copy
return LangchainAgent(
model=self._model_name,
system_instruction=self._system_instruction,
prompt=copy.deepcopy(self._prompt),
tools=copy.deepcopy(self._tools),
output_parser=copy.deepcopy(self._output_parser),
chat_history=copy.deepcopy(self._chat_history),
model_kwargs=copy.deepcopy(self._model_kwargs),
model_tool_kwargs=copy.deepcopy(self._model_tool_kwargs),
agent_executor_kwargs=copy.deepcopy(self._agent_executor_kwargs),
runnable_kwargs=copy.deepcopy(self._runnable_kwargs),
model_builder=self._model_builder,
runnable_builder=self._runnable_builder,
enable_tracing=self._enable_tracing,
)
def query(
self,
*,
input: Union[str, Mapping[str, Any]],
config: Optional["RunnableConfig"] = None,
**kwargs: Any,
) -> Dict[str, Any]:
"""Queries the Agent with the given input and config.
Args:
input (Union[str, Mapping[str, Any]]):
Required. The input to be passed to the Agent.
config (langchain_core.runnables.RunnableConfig):
Optional. The config (if any) to be used for invoking the Agent.
**kwargs:
Optional. Any additional keyword arguments to be passed to the
`.invoke()` method of the corresponding AgentExecutor.
Returns:
The output of querying the Agent with the given input and config.
"""
from langchain.load import dump as langchain_load_dump
if isinstance(input, str):
input = {"input": input}
if not self._runnable:
self.set_up()
return langchain_load_dump.dumpd(
self._runnable.invoke(input=input, config=config, **kwargs)
)
def stream_query(
self,
*,
input: Union[str, Mapping[str, Any]],
config: Optional["RunnableConfig"] = None,
**kwargs,
) -> Iterable[Any]:
"""Stream queries the Agent with the given input and config.
Args:
input (Union[str, Mapping[str, Any]]):
Required. The input to be passed to the Agent.
config (langchain_core.runnables.RunnableConfig):
Optional. The config (if any) to be used for invoking the Agent.
**kwargs:
Optional. Any additional keyword arguments to be passed to the
`.invoke()` method of the corresponding AgentExecutor.
Yields:
The output of querying the Agent with the given input and config.
"""
from langchain.load import dump as langchain_load_dump
if isinstance(input, str):
input = {"input": input}
if not self._runnable:
self.set_up()
for chunk in self._runnable.stream(input=input, config=config, **kwargs):
yield langchain_load_dump.dumpd(chunk)

View File

@@ -0,0 +1,658 @@
# -*- coding: utf-8 -*-
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
Iterable,
Mapping,
Optional,
Sequence,
Union,
)
if TYPE_CHECKING:
try:
from langchain_core import runnables
from langchain_core import tools as lc_tools
from langchain_core.language_models import base as lc_language_models
BaseTool = lc_tools.BaseTool
BaseLanguageModel = lc_language_models.BaseLanguageModel
RunnableConfig = runnables.RunnableConfig
RunnableSerializable = runnables.RunnableSerializable
except ImportError:
BaseTool = Any
BaseLanguageModel = Any
RunnableConfig = Any
RunnableSerializable = Any
try:
from langchain_google_vertexai.functions_utils import _ToolsType
_ToolLike = _ToolsType
except ImportError:
_ToolLike = Any
try:
from opentelemetry.sdk import trace
TracerProvider = trace.TracerProvider
SpanProcessor = trace.SpanProcessor
SynchronousMultiSpanProcessor = trace.SynchronousMultiSpanProcessor
except ImportError:
TracerProvider = Any
SpanProcessor = Any
SynchronousMultiSpanProcessor = Any
try:
from langgraph_checkpoint.checkpoint import base
BaseCheckpointSaver = base.BaseCheckpointSaver
except ImportError:
try:
from langgraph.checkpoint import base
BaseCheckpointSaver = base.BaseCheckpointSaver
except ImportError:
BaseCheckpointSaver = Any
def _default_model_builder(
model_name: str,
*,
project: str,
location: str,
model_kwargs: Optional[Mapping[str, Any]] = None,
) -> "BaseLanguageModel":
"""Default callable for building a language model.
Args:
model_name (str):
Required. The name of the model (e.g. "gemini-1.0-pro").
project (str):
Required. The Google Cloud project ID.
location (str):
Required. The Google Cloud location.
model_kwargs (Mapping[str, Any]):
Optional. Additional keyword arguments for the constructor of
chat_models.ChatVertexAI.
Returns:
BaseLanguageModel: The language model.
"""
import vertexai
from google.cloud.aiplatform import initializer
from langchain_google_vertexai import ChatVertexAI
model_kwargs = model_kwargs or {}
current_project = initializer.global_config.project
current_location = initializer.global_config.location
vertexai.init(project=project, location=location)
model = ChatVertexAI(model_name=model_name, **model_kwargs)
vertexai.init(project=current_project, location=current_location)
return model
def _default_runnable_builder(
model: "BaseLanguageModel",
*,
tools: Optional[Sequence["_ToolLike"]] = None,
checkpointer: Optional[Any] = None,
model_tool_kwargs: Optional[Mapping[str, Any]] = None,
runnable_kwargs: Optional[Mapping[str, Any]] = None,
) -> "RunnableSerializable":
"""Default callable for building a runnable.
Args:
model (BaseLanguageModel):
Required. The language model.
tools (Optional[Sequence[_ToolLike]]):
Optional. The tools for the agent to be able to use.
checkpointer (Optional[Checkpointer]):
Optional. The checkpointer for the agent.
model_tool_kwargs (Optional[Mapping[str, Any]]):
Optional. Additional keyword arguments when binding tools to the model.
runnable_kwargs (Optional[Mapping[str, Any]]):
Optional. Additional keyword arguments for the runnable.
Returns:
RunnableSerializable: The runnable.
"""
from langgraph import prebuilt as langgraph_prebuilt
model_tool_kwargs = model_tool_kwargs or {}
runnable_kwargs = runnable_kwargs or {}
if tools:
model = model.bind_tools(tools=tools, **model_tool_kwargs)
else:
tools = []
if checkpointer:
if "checkpointer" in runnable_kwargs:
from google.cloud.aiplatform import base
base.Logger(__name__).warning(
"checkpointer is being specified in both checkpointer_builder "
"and runnable_kwargs. Please specify it in only one of them. "
"Overriding the checkpointer in runnable_kwargs."
)
runnable_kwargs["checkpointer"] = checkpointer
return langgraph_prebuilt.create_react_agent(
model,
tools=tools,
**runnable_kwargs,
)
def _validate_callable_parameters_are_annotated(callable: Callable):
"""Validates that the parameters of the callable have type annotations.
This ensures that they can be used for constructing LangChain tools that are
usable with Gemini function calling.
Args:
callable (Callable): The callable to validate.
Raises:
TypeError: If any parameter is not annotated.
"""
import inspect
parameters = dict(inspect.signature(callable).parameters)
for name, parameter in parameters.items():
if parameter.annotation == inspect.Parameter.empty:
raise TypeError(
f"Callable={callable.__name__} has untyped input_arg={name}. "
f"Please specify a type when defining it, e.g. `{name}: str`."
)
def _validate_tools(tools: Sequence["_ToolLike"]):
"""Validates that the tools are usable for tool calling.
Args:
tools (Sequence[_ToolLike]): The tools to validate.
Raises:
TypeError: If any tool is a callable with untyped parameters.
"""
for tool in tools:
if isinstance(tool, Callable):
_validate_callable_parameters_are_annotated(tool)
def _override_active_span_processor(
tracer_provider: "TracerProvider",
active_span_processor: "SynchronousMultiSpanProcessor",
):
"""Overrides the active span processor.
When working with multiple LangchainAgents in the same environment,
it's crucial to manage trace exports carefully.
Each agent needs its own span processor tied to a unique project ID.
While we add a new span processor for each agent, this can lead to
unexpected behavior.
For instance, with two agents linked to different projects, traces from the
second agent might be sent to both projects.
To prevent this and guarantee traces go to the correct project, we overwrite
the active span processor whenever a new LangchainAgent is created.
Args:
tracer_provider (TracerProvider):
The tracer provider to use for the project.
active_span_processor (SynchronousMultiSpanProcessor):
The active span processor overrides the tracer provider's
active span processor.
"""
if tracer_provider._active_span_processor:
tracer_provider._active_span_processor.shutdown()
tracer_provider._active_span_processor = active_span_processor
class LanggraphAgent:
"""A LangGraph Agent.
See https://cloud.google.com/vertex-ai/generative-ai/docs/agent-engine/develop/langgraph
for details.
"""
def __init__(
self,
model: str,
*,
tools: Optional[Sequence["_ToolLike"]] = None,
model_kwargs: Optional[Mapping[str, Any]] = None,
model_tool_kwargs: Optional[Mapping[str, Any]] = None,
model_builder: Optional[Callable[..., "BaseLanguageModel"]] = None,
runnable_kwargs: Optional[Mapping[str, Any]] = None,
runnable_builder: Optional[Callable[..., "RunnableSerializable"]] = None,
checkpointer_kwargs: Optional[Mapping[str, Any]] = None,
checkpointer_builder: Optional[Callable[..., "BaseCheckpointSaver"]] = None,
enable_tracing: bool = False,
):
"""Initializes the LangGraph Agent.
Under-the-hood, assuming .set_up() is called, this will correspond to
```python
model = model_builder(model_name=model, model_kwargs=model_kwargs)
runnable = runnable_builder(
model=model,
tools=tools,
model_tool_kwargs=model_tool_kwargs,
runnable_kwargs=runnable_kwargs,
)
```
When everything is based on their default values, this corresponds to
```python
# model_builder
from langchain_google_vertexai import ChatVertexAI
llm = ChatVertexAI(model_name=model, **model_kwargs)
# runnable_builder
from langgraph.prebuilt import create_react_agent
llm_with_tools = llm.bind_tools(tools=tools, **model_tool_kwargs)
runnable = create_react_agent(
llm_with_tools,
tools=tools,
**runnable_kwargs,
)
```
By default, no checkpointer is used (i.e. there is no state history). To
enable checkpointing, provide a `checkpointer_builder` function that
returns a checkpointer instance.
**Example using Spanner:**
```python
def checkpointer_builder(instance_id, database_id, project_id, **kwargs):
from langchain_google_spanner import SpannerCheckpointSaver
checkpointer = SpannerCheckpointSaver(instance_id, database_id, project_id)
with checkpointer.cursor() as cur:
cur.execute("DROP TABLE IF EXISTS checkpoints")
cur.execute("DROP TABLE IF EXISTS checkpoint_writes")
checkpointer.setup()
return checkpointer
```
**Example using an in-memory checkpointer:**
```python
def checkpointer_builder(**kwargs):
from langgraph.checkpoint.memory import MemorySaver
return MemorySaver()
```
The `checkpointer_builder` function will be called with any keyword
arguments passed to the agent's constructor. Ensure your
`checkpointer_builder` function accepts `**kwargs` to handle these
arguments, even if unused.
Args:
model (str):
Optional. The name of the model (e.g. "gemini-1.0-pro").
tools (Sequence[langchain_core.tools.BaseTool, Callable]):
Optional. The tools for the agent to be able to use. All input
callables (e.g. function or class method) will be converted
to a langchain.tools.base.StructuredTool. Defaults to None.
model_kwargs (Mapping[str, Any]):
Optional. Additional keyword arguments for the constructor of
chat_models.ChatVertexAI. An example would be
```
{
# temperature (float): Sampling temperature, it controls the
# degree of randomness in token selection.
"temperature": 0.28,
# max_output_tokens (int): Token limit determines the
# maximum amount of text output from one prompt.
"max_output_tokens": 1000,
# top_p (float): Tokens are selected from most probable to
# least, until the sum of their probabilities equals the
# top_p value.
"top_p": 0.95,
# top_k (int): How the model selects tokens for output, the
# next token is selected from among the top_k most probable
# tokens.
"top_k": 40,
}
```
model_tool_kwargs (Mapping[str, Any]):
Optional. Additional keyword arguments when binding tools to the
model using `model.bind_tools()`.
model_builder (Callable[..., "BaseLanguageModel"]):
Optional. Callable that returns a new language model. Defaults
to a a callable that returns ChatVertexAI based on `model`,
`model_kwargs` and the parameters in `vertexai.init`.
runnable_kwargs (Mapping[str, Any]):
Optional. Additional keyword arguments for the constructor of
langchain.runnables.history.RunnableWithMessageHistory if
chat_history is specified. If chat_history is None, this will be
ignored.
runnable_builder (Callable[..., "RunnableSerializable"]):
Optional. Callable that returns a new runnable. This can be used
for customizing the orchestration logic of the Agent based on
the model returned by `model_builder` and the rest of the input
arguments.
checkpointer_kwargs (Mapping[str, Any]):
Optional. Additional keyword arguments for the constructor of
the checkpointer returned by `checkpointer_builder`.
checkpointer_builder (Callable[..., "BaseCheckpointSaver"]):
Optional. Callable that returns a checkpointer. This can be used
for defining the checkpointer of the Agent. Defaults to None.
enable_tracing (bool):
Optional. Whether to enable tracing in Cloud Trace. Defaults to
False.
Raises:
TypeError: If there is an invalid tool (e.g. function with an input
that did not specify its type).
"""
from google.cloud.aiplatform import initializer
self._project = initializer.global_config.project
self._location = initializer.global_config.location
self._tools = []
if tools:
# We validate tools at initialization for actionable feedback before
# they are deployed.
_validate_tools(tools)
self._tools = tools
self._model_name = model
self._model_kwargs = model_kwargs
self._model_tool_kwargs = model_tool_kwargs
self._runnable_kwargs = runnable_kwargs
self._checkpointer_kwargs = checkpointer_kwargs
self._model = None
self._model_builder = model_builder
self._runnable = None
self._runnable_builder = runnable_builder
self._checkpointer_builder = checkpointer_builder
self._instrumentor = None
self._enable_tracing = enable_tracing
def set_up(self):
"""Sets up the agent for execution of queries at runtime.
It initializes the model, binds the model with tools, and connects it
with the prompt template and output parser.
This method should not be called for an object that being passed to
the ReasoningEngine service for deployment, as it initializes clients
that can not be serialized.
"""
if self._enable_tracing:
from vertexai.reasoning_engines import _utils
cloud_trace_exporter = _utils._import_cloud_trace_exporter_or_warn()
cloud_trace_v2 = _utils._import_cloud_trace_v2_or_warn()
openinference_langchain = _utils._import_openinference_langchain_or_warn()
opentelemetry = _utils._import_opentelemetry_or_warn()
opentelemetry_sdk_trace = _utils._import_opentelemetry_sdk_trace_or_warn()
if all(
(
cloud_trace_exporter,
cloud_trace_v2,
openinference_langchain,
opentelemetry,
opentelemetry_sdk_trace,
)
):
import google.auth
credentials, _ = google.auth.default()
span_exporter = cloud_trace_exporter.CloudTraceSpanExporter(
project_id=self._project,
client=cloud_trace_v2.TraceServiceClient(
credentials=credentials.with_quota_project(self._project),
),
)
span_processor: SpanProcessor = (
opentelemetry_sdk_trace.export.SimpleSpanProcessor(
span_exporter=span_exporter,
)
)
tracer_provider: TracerProvider = (
opentelemetry.trace.get_tracer_provider()
)
# Get the appropriate tracer provider:
# 1. If _TRACER_PROVIDER is already set, use that.
# 2. Otherwise, if the OTEL_PYTHON_TRACER_PROVIDER environment
# variable is set, use that.
# 3. As a final fallback, use _PROXY_TRACER_PROVIDER.
# If none of the above is set, we log a warning, and
# create a tracer provider.
if not tracer_provider:
from google.cloud.aiplatform import base
base.Logger(__name__).warning(
"No tracer provider. By default, "
"we should get one of the following providers: "
"OTEL_PYTHON_TRACER_PROVIDER, _TRACER_PROVIDER, "
"or _PROXY_TRACER_PROVIDER."
)
tracer_provider = opentelemetry_sdk_trace.TracerProvider()
opentelemetry.trace.set_tracer_provider(tracer_provider)
# Avoids AttributeError:
# 'ProxyTracerProvider' and 'NoOpTracerProvider' objects has no
# attribute 'add_span_processor'.
if _utils.is_noop_or_proxy_tracer_provider(tracer_provider):
tracer_provider = opentelemetry_sdk_trace.TracerProvider()
opentelemetry.trace.set_tracer_provider(tracer_provider)
# Avoids OpenTelemetry client already exists error.
_override_active_span_processor(
tracer_provider,
opentelemetry_sdk_trace.SynchronousMultiSpanProcessor(),
)
tracer_provider.add_span_processor(span_processor)
# Keep the instrumentation up-to-date.
# When creating multiple LangchainAgents,
# we need to keep the instrumentation up-to-date.
# We deliberately override the instrument each time,
# so that if different agents end up using different
# instrumentations, we guarantee that the user is always
# working with the most recent agent's instrumentation.
self._instrumentor = openinference_langchain.LangChainInstrumentor()
if self._instrumentor.is_instrumented_by_opentelemetry:
self._instrumentor.uninstrument()
self._instrumentor.instrument()
else:
from google.cloud.aiplatform import base
_LOGGER = base.Logger(__name__)
_LOGGER.warning(
"enable_tracing=True but proceeding with tracing disabled "
"because not all packages for tracing have been installed"
)
model_builder = self._model_builder or _default_model_builder
self._model = model_builder(
model_name=self._model_name,
model_kwargs=self._model_kwargs,
project=self._project,
location=self._location,
)
self._checkpointer = None
if self._checkpointer_builder:
checkpointer_kwargs = self._checkpointer_kwargs or {}
self._checkpointer = self._checkpointer_builder(
**checkpointer_kwargs,
)
runnable_builder = self._runnable_builder or _default_runnable_builder
self._runnable = runnable_builder(
model=self._model,
tools=self._tools,
checkpointer=self._checkpointer,
model_tool_kwargs=self._model_tool_kwargs,
runnable_kwargs=self._runnable_kwargs,
)
def clone(self) -> "LanggraphAgent":
"""Returns a clone of the LanggraphAgent."""
import copy
return LanggraphAgent(
model=self._model_name,
tools=copy.deepcopy(self._tools),
model_kwargs=copy.deepcopy(self._model_kwargs),
model_tool_kwargs=copy.deepcopy(self._model_tool_kwargs),
runnable_kwargs=copy.deepcopy(self._runnable_kwargs),
checkpointer_kwargs=copy.deepcopy(self._checkpointer_kwargs),
model_builder=self._model_builder,
runnable_builder=self._runnable_builder,
checkpointer_builder=self._checkpointer_builder,
enable_tracing=self._enable_tracing,
)
def query(
self,
*,
input: Union[str, Mapping[str, Any]],
config: Optional["RunnableConfig"] = None,
**kwargs: Any,
) -> Dict[str, Any]:
"""Queries the Agent with the given input and config.
Args:
input (Union[str, Mapping[str, Any]]):
Required. The input to be passed to the Agent.
config (langchain_core.runnables.RunnableConfig):
Optional. The config (if any) to be used for invoking the Agent.
**kwargs:
Optional. Any additional keyword arguments to be passed to the
`.invoke()` method of the corresponding AgentExecutor.
Returns:
The output of querying the Agent with the given input and config.
"""
from langchain.load import dump as langchain_load_dump
if isinstance(input, str):
input = {"input": input}
if not self._runnable:
self.set_up()
return langchain_load_dump.dumpd(
self._runnable.invoke(input=input, config=config, **kwargs)
)
def stream_query(
self,
*,
input: Union[str, Mapping[str, Any]],
config: Optional["RunnableConfig"] = None,
**kwargs,
) -> Iterable[Any]:
"""Stream queries the Agent with the given input and config.
Args:
input (Union[str, Mapping[str, Any]]):
Required. The input to be passed to the Agent.
config (langchain_core.runnables.RunnableConfig):
Optional. The config (if any) to be used for invoking the Agent.
**kwargs:
Optional. Any additional keyword arguments to be passed to the
`.invoke()` method of the corresponding AgentExecutor.
Yields:
The output of querying the Agent with the given input and config.
"""
from langchain.load import dump as langchain_load_dump
if isinstance(input, str):
input = {"input": input}
if not self._runnable:
self.set_up()
for chunk in self._runnable.stream(input=input, config=config, **kwargs):
yield langchain_load_dump.dumpd(chunk)
def get_state_history(
self,
config: Optional["RunnableConfig"] = None,
**kwargs: Any,
) -> Iterable[Any]:
"""Gets the state history of the Agent.
Args:
config (Optional[RunnableConfig]):
Optional. The config for invoking the Agent.
**kwargs:
Optional. Additional keyword arguments for the `.invoke()` method.
Yields:
Dict[str, Any]: The state history of the Agent.
"""
if not self._runnable:
self.set_up()
for state_snapshot in self._runnable.get_state_history(config=config, **kwargs):
yield state_snapshot._asdict()
def get_state(
self,
config: Optional["RunnableConfig"] = None,
**kwargs: Any,
) -> Dict[str, Any]:
"""Gets the current state of the Agent.
Args:
config (Optional[RunnableConfig]):
Optional. The config for invoking the Agent.
**kwargs:
Optional. Additional keyword arguments for the `.invoke()` method.
Returns:
Dict[str, Any]: The current state of the Agent.
"""
if not self._runnable:
self.set_up()
return self._runnable.get_state(config=config, **kwargs)._asdict()
def update_state(
self,
config: Optional["RunnableConfig"] = None,
**kwargs: Any,
) -> Dict[str, Any]:
"""Updates the state of the Agent.
Args:
config (Optional[RunnableConfig]):
Optional. The config for invoking the Agent.
**kwargs:
Optional. Additional keyword arguments for the `.invoke()` method.
Returns:
Dict[str, Any]: The updated state of the Agent.
"""
if not self._runnable:
self.set_up()
return self._runnable.update_state(config=config, **kwargs)
def register_operations(self) -> Mapping[str, Sequence[str]]:
"""Registers the operations of the Agent.
This mapping defines how different operation modes (e.g., "", "stream")
are implemented by specific methods of the Agent. The "default" mode,
represented by the empty string ``, is associated with the `query` API,
while the "stream" mode is associated with the `stream_query` API.
Returns:
Mapping[str, Sequence[str]]: A mapping of operation modes to a list
of method names that implement those operation modes.
"""
return {
"": ["query", "get_state", "update_state"],
"stream": ["stream_query", "get_state_history"],
}

View File

@@ -0,0 +1,553 @@
# -*- coding: utf-8 -*-
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
Mapping,
Optional,
Sequence,
Union,
)
if TYPE_CHECKING:
try:
from llama_index.core.base.query_pipeline import query
from llama_index.core.llms import function_calling
from llama_index.core import query_pipeline
FunctionCallingLLM = function_calling.FunctionCallingLLM
QueryComponent = query.QUERY_COMPONENT_TYPE
QueryPipeline = query_pipeline.QueryPipeline
except ImportError:
FunctionCallingLLM = Any
QueryComponent = Any
QueryPipeline = Any
try:
from opentelemetry.sdk import trace
TracerProvider = trace.TracerProvider
SpanProcessor = trace.SpanProcessor
SynchronousMultiSpanProcessor = trace.SynchronousMultiSpanProcessor
except ImportError:
TracerProvider = Any
SpanProcessor = Any
SynchronousMultiSpanProcessor = Any
def _default_model_builder(
model_name: str,
*,
project: str,
location: str,
model_kwargs: Optional[Mapping[str, Any]] = None,
) -> "FunctionCallingLLM":
"""Creates a default model builder for LlamaIndex."""
import vertexai
from google.cloud.aiplatform import initializer
from llama_index.llms import google_genai
model_kwargs = model_kwargs or {}
model = google_genai.GoogleGenAI(
model=model_name,
vertexai_config={"project": project, "location": location},
**model_kwargs,
)
current_project = initializer.global_config.project
current_location = initializer.global_config.location
vertexai.init(project=current_project, location=current_location)
return model
def _default_runnable_builder(
model: "FunctionCallingLLM",
*,
system_instruction: Optional[str] = None,
prompt: Optional["QueryComponent"] = None,
retriever: Optional["QueryComponent"] = None,
response_synthesizer: Optional["QueryComponent"] = None,
runnable_kwargs: Optional[Mapping[str, Any]] = None,
) -> "QueryPipeline":
"""Creates a default runnable builder for LlamaIndex."""
try:
from llama_index.core.query_pipeline import QueryPipeline
except ImportError:
raise ImportError(
"Please call 'pip install google-cloud-aiplatform[llama_index]'."
)
prompt = prompt or _default_prompt(
system_instruction=system_instruction,
)
pipeline = QueryPipeline(**runnable_kwargs)
pipeline_modules = {
"prompt": prompt,
"model": model,
}
if retriever:
pipeline_modules["retriever"] = retriever
if response_synthesizer:
pipeline_modules["response_synthesizer"] = response_synthesizer
pipeline.add_modules(pipeline_modules)
pipeline.add_link("prompt", "model")
if "retriever" in pipeline_modules:
pipeline.add_link("model", "retriever")
if "response_synthesizer" in pipeline_modules:
pipeline.add_link("model", "response_synthesizer", dest_key="query_str")
if "retriever" in pipeline_modules:
pipeline.add_link("retriever", "response_synthesizer", dest_key="nodes")
return pipeline
def _default_prompt(
system_instruction: Optional[str] = None,
) -> "QueryComponent":
"""Creates a default prompt template for LlamaIndex.
Handles both system instruction and user input.
Args:
system_instruction (str, optional): The system instruction to use.
Returns:
QueryComponent: The LlamaIndex QueryComponent.
"""
try:
from llama_index.core import prompts
from llama_index.core.base.llms import types
except ImportError:
raise ImportError(
"Please call 'pip install google-cloud-aiplatform[llama_index]'."
)
# Define a prompt template
message_templates = []
if system_instruction:
message_templates.append(
types.ChatMessage(role=types.MessageRole.SYSTEM, content=system_instruction)
)
# Add user input message
message_templates.append(
types.ChatMessage(role=types.MessageRole.USER, content="{input}")
)
# Create the prompt template
return prompts.ChatPromptTemplate(message_templates=message_templates)
def _override_active_span_processor(
tracer_provider: "TracerProvider",
active_span_processor: "SynchronousMultiSpanProcessor",
):
"""Overrides the active span processor.
When working with multiple LlamaIndexQueryPipelineAgents in the same
environment, it's crucial to manage trace exports carefully.
Each agent needs its own span processor tied to a unique project ID.
While we add a new span processor for each agent, this can lead to
unexpected behavior.
For instance, with two agents linked to different projects, traces from the
second agent might be sent to both projects.
To prevent this and guarantee traces go to the correct project, we overwrite
the active span processor whenever a new LlamaIndexQueryPipelineAgent is
created.
Args:
tracer_provider (TracerProvider):
The tracer provider to use for the project.
active_span_processor (SynchronousMultiSpanProcessor):
The active span processor overrides the tracer provider's
active span processor.
"""
if tracer_provider._active_span_processor:
tracer_provider._active_span_processor.shutdown()
tracer_provider._active_span_processor = active_span_processor
class LlamaIndexQueryPipelineAgent:
"""A LlamaIndex Query Pipeline Agent.
This agent uses a query pipeline for LLAIndex, including prompt, model,
retrieval and summarization steps. More details can be found in
https://docs.llamaindex.ai/en/stable/module_guides/querying/pipeline/.
"""
def __init__(
self,
model: str,
*,
system_instruction: Optional[str] = None,
prompt: Optional["QueryComponent"] = None,
model_kwargs: Optional[Mapping[str, Any]] = None,
model_builder: Optional[Callable[..., "FunctionCallingLLM"]] = None,
retriever_kwargs: Optional[Mapping[str, Any]] = None,
retriever_builder: Optional[Callable[..., "QueryComponent"]] = None,
response_synthesizer_kwargs: Optional[Mapping[str, Any]] = None,
response_synthesizer_builder: Optional[Callable[..., "QueryComponent"]] = None,
runnable_kwargs: Optional[Mapping[str, Any]] = None,
runnable_builder: Optional[Callable[..., "QueryPipeline"]] = None,
enable_tracing: bool = False,
):
"""Initializes the LlamaIndexQueryPipelineAgent.
Under-the-hood, assuming .set_up() is called, this will correspond to
```python
# model_builder
model = model_builder(model_name, project, location, model_kwargs)
# runnable_builder
runnable = runnable_builder(
prompt=prompt,
model=model,
retriever=retriever_builder(model, retriever_kwargs),
response_synthesizer=response_synthesizer_builder(
model, response_synthesizer_kwargs
),
runnable_kwargs=runnable_kwargs,
)
```
When everything is based on their default values, this corresponds to a
query pipeline `Prompt - Model`:
```python
# Default Model Builder
model = google_genai.GoogleGenAI(
model=model_name,
vertexai_config={
"project": initializer.global_config.project,
"location": initializer.global_config.location,
},
)
# Default Prompt Builder
prompt = prompts.ChatPromptTemplate(
message_templates=[
types.ChatMessage(
role=types.MessageRole.USER,
content="{input}",
),
],
)
# Default Runnable Builder
runnable = QueryPipeline(
modules = {
"prompt": prompt,
"model": model,
},
)
pipeline.add_link("prompt", "model")
```
When `system_instruction` is specified, the prompt will be updated to
include the system instruction.
```python
# Updated Prompt Builder
prompt = prompts.ChatPromptTemplate(
message_templates=[
types.ChatMessage(
role=types.MessageRole.SYSTEM,
content=system_instruction,
),
types.ChatMessage(
role=types.MessageRole.USER,
content="{input}",
),
],
)
```
When all inputs are specified, this corresponds to a query pipeline
`Prompt - Model - Retriever - Summarizer`:
```python
runnable = QueryPipeline(
modules = {
"prompt": prompt,
"model": model,
"retriever": retriever_builder(retriever_kwargs),
"response_synthesizer": response_synthesizer_builder(
response_synthesizer_kwargs
),
},
)
pipeline.add_link("prompt", "model")
pipeline.add_link("model", "retriever")
pipeline.add_link("model", "response_synthesizer", dest_key="query_str")
pipeline.add_link("retriever", "response_synthesizer", dest_key="nodes")
```
Args:
model (str):
The name of the model (e.g. "gemini-1.0-pro").
system_instruction (str):
Optional. The system instruction to use for the agent.
prompt (llama_index.core.base.query_pipeline.query.QUERY_COMPONENT_TYPE):
Optional. The prompt template for the model.
model_kwargs (Mapping[str, Any]):
Optional. Keyword arguments for the model constructor of the
google_genai.GoogleGenAI. An example of a model_kwargs is:
```python
{
# api_key (string): The API key for the GoogleGenAI model.
# The API can also be fetched from the GOOGLE_API_KEY
# environment variable. If `vertexai_config` is provided,
# the API key is ignored.
"api_key": "your_api_key",
# temperature (float): Sampling temperature, it controls the
# degree of randomness in token selection. If not provided,
# the default temperature is 0.1.
"temperature": 0.1,
# context_window (int): The context window of the model.
# If not provided, the default context window is 200000.
"context_window": 200000,
# max_tokens (int): Token limit determines the maximum
# amount of text output from one prompt. If not provided,
# the default max_tokens is 256.
"max_tokens": 256,
# is_function_calling_model (bool): Whether the model is a
# function calling model. If not provided, the default
# is_function_calling_model is True.
"is_function_calling_model": True,
}
```
model_builder (Callable):
Optional. Callable that returns a language model.
retriever_kwargs (Mapping[str, Any]):
Optional. Keyword arguments for the retriever constructor.
retriever_builder (Callable):
Optional. Callable that returns a retriever object.
response_synthesizer_kwargs (Mapping[str, Any]):
Optional. Keyword arguments for the response synthesizer constructor.
response_synthesizer_builder (Callable):
Optional. Callable that returns a response_synthesizer object.
runnable_kwargs (Mapping[str, Any]):
Optional. Keyword arguments for the runnable constructor.
runnable_builder (Callable):
Optional. Callable that returns a runnable (query pipeline).
enable_tracing (bool):
Optional. Whether to enable tracing. Defaults to False.
"""
from google.cloud.aiplatform import initializer
self._project = initializer.global_config.project
self._location = initializer.global_config.location
self._model_name = model
self._system_instruction = system_instruction
self._prompt = prompt
self._model = None
self._model_kwargs = model_kwargs or {}
self._model_builder = model_builder
self._retriever = None
self._retriever_kwargs = retriever_kwargs or {}
self._retriever_builder = retriever_builder
self._response_synthesizer = None
self._response_synthesizer_kwargs = response_synthesizer_kwargs or {}
self._response_synthesizer_builder = response_synthesizer_builder
self._runnable = None
self._runnable_kwargs = runnable_kwargs or {}
self._runnable_builder = runnable_builder
self._instrumentor = None
self._enable_tracing = enable_tracing
def set_up(self):
"""Sets up the agent for execution of queries at runtime.
It initializes the model, connects it with the prompt template,
retriever and response_synthesizer.
This method should not be called for an object that being passed to
the ReasoningEngine service for deployment, as it initializes clients
that can not be serialized.
"""
if self._enable_tracing:
from vertexai.reasoning_engines import _utils
cloud_trace_exporter = _utils._import_cloud_trace_exporter_or_warn()
cloud_trace_v2 = _utils._import_cloud_trace_v2_or_warn()
openinference_llama_index = (
_utils._import_openinference_llama_index_or_warn()
)
opentelemetry = _utils._import_opentelemetry_or_warn()
opentelemetry_sdk_trace = _utils._import_opentelemetry_sdk_trace_or_warn()
if all(
(
cloud_trace_exporter,
cloud_trace_v2,
openinference_llama_index,
opentelemetry,
opentelemetry_sdk_trace,
)
):
import google.auth
credentials, _ = google.auth.default()
span_exporter = cloud_trace_exporter.CloudTraceSpanExporter(
project_id=self._project,
client=cloud_trace_v2.TraceServiceClient(
credentials=credentials.with_quota_project(self._project),
),
)
span_processor: SpanProcessor = (
opentelemetry_sdk_trace.export.SimpleSpanProcessor(
span_exporter=span_exporter,
)
)
tracer_provider: TracerProvider = (
opentelemetry.trace.get_tracer_provider()
)
# Get the appropriate tracer provider:
# 1. If _TRACER_PROVIDER is already set, use that.
# 2. Otherwise, if the OTEL_PYTHON_TRACER_PROVIDER environment
# variable is set, use that.
# 3. As a final fallback, use _PROXY_TRACER_PROVIDER.
# If none of the above is set, we log a warning, and
# create a tracer provider.
if not tracer_provider:
from google.cloud.aiplatform import base
_LOGGER = base.Logger(__name__)
_LOGGER.warning(
"No tracer provider. By default, "
"we should get one of the following providers: "
"OTEL_PYTHON_TRACER_PROVIDER, _TRACER_PROVIDER, "
"or _PROXY_TRACER_PROVIDER."
)
tracer_provider = opentelemetry_sdk_trace.TracerProvider()
opentelemetry.trace.set_tracer_provider(tracer_provider)
# Avoids AttributeError:
# 'ProxyTracerProvider' and 'NoOpTracerProvider' objects has no
# attribute 'add_span_processor'.
if _utils.is_noop_or_proxy_tracer_provider(tracer_provider):
tracer_provider = opentelemetry_sdk_trace.TracerProvider()
opentelemetry.trace.set_tracer_provider(tracer_provider)
# Avoids OpenTelemetry client already exists error.
_override_active_span_processor(
tracer_provider,
opentelemetry_sdk_trace.SynchronousMultiSpanProcessor(),
)
tracer_provider.add_span_processor(span_processor)
# Keep the instrumentation up-to-date.
# When creating multiple LlamaIndexQueryPipelineAgents,
# we need to keep the instrumentation up-to-date.
# We deliberately override the instrument each time,
# so that if different agents end up using different
# instrumentations, we guarantee that the user is always
# working with the most recent agent's instrumentation.
self._instrumentor = openinference_llama_index.LlamaIndexInstrumentor()
if self._instrumentor.is_instrumented_by_opentelemetry:
self._instrumentor.uninstrument()
self._instrumentor.instrument()
else:
from google.cloud.aiplatform import base
_LOGGER = base.Logger(__name__)
_LOGGER.warning(
"enable_tracing=True but proceeding with tracing disabled "
"because not all packages for tracing have been installed"
)
model_builder = self._model_builder or _default_model_builder
self._model = model_builder(
model_name=self._model_name,
model_kwargs=self._model_kwargs,
project=self._project,
location=self._location,
)
if self._retriever_builder:
self._retriever = self._retriever_builder(
model=self._model,
retriever_kwargs=self._retriever_kwargs,
)
if self._response_synthesizer_builder:
self._response_synthesizer = self._response_synthesizer_builder(
model=self._model,
response_synthesizer_kwargs=self._response_synthesizer_kwargs,
)
runnable_builder = self._runnable_builder or _default_runnable_builder
self._runnable = runnable_builder(
prompt=self._prompt,
model=self._model,
system_instruction=self._system_instruction,
retriever=self._retriever,
response_synthesizer=self._response_synthesizer,
runnable_kwargs=self._runnable_kwargs,
)
def clone(self) -> "LlamaIndexQueryPipelineAgent":
"""Returns a clone of the LlamaIndexQueryPipelineAgent."""
import copy
return LlamaIndexQueryPipelineAgent(
model=self._model_name,
system_instruction=self._system_instruction,
prompt=copy.deepcopy(self._prompt),
model_kwargs=copy.deepcopy(self._model_kwargs),
model_builder=self._model_builder,
retriever_kwargs=copy.deepcopy(self._retriever_kwargs),
retriever_builder=self._retriever_builder,
response_synthesizer_kwargs=copy.deepcopy(
self._response_synthesizer_kwargs
),
response_synthesizer_builder=self._response_synthesizer_builder,
runnable_kwargs=copy.deepcopy(self._runnable_kwargs),
runnable_builder=self._runnable_builder,
enable_tracing=self._enable_tracing,
)
def query(
self,
input: Union[str, Mapping[str, Any]],
**kwargs: Any,
) -> Union[str, Dict[str, Any], Sequence[Union[str, Dict[str, Any]]]]:
"""Queries the Agent with the given input and config.
Args:
input (Union[str, Mapping[str, Any]]):
Required. The input to be passed to the Agent.
**kwargs:
Optional. Any additional keyword arguments to be passed to the
`.invoke()` method of the corresponding AgentExecutor.
Returns:
The output of querying the Agent with the given input and config.
"""
from vertexai.reasoning_engines import _utils
if isinstance(input, str):
input = {"input": input}
if not self._runnable:
self.set_up()
if kwargs.get("batch"):
nest_asyncio = _utils._import_nest_asyncio_or_warn()
nest_asyncio.apply()
return _utils.to_json_serializable_llama_index_object(
self._runnable.run(**input, **kwargs)
)