structure saas with tools
This commit is contained in:
@@ -0,0 +1,47 @@
|
||||
# Copyright 2024 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
"""Classes for working with reasoning engines."""
|
||||
|
||||
# We just want to re-export certain classes
|
||||
# pylint: disable=g-multiple-import,g-importing-member
|
||||
from vertexai.reasoning_engines._reasoning_engines import (
|
||||
Queryable,
|
||||
ReasoningEngine,
|
||||
)
|
||||
from vertexai.preview.reasoning_engines.templates.adk import (
|
||||
AdkApp,
|
||||
)
|
||||
from vertexai.preview.reasoning_engines.templates.ag2 import (
|
||||
AG2Agent,
|
||||
)
|
||||
from vertexai.preview.reasoning_engines.templates.langchain import (
|
||||
LangchainAgent,
|
||||
)
|
||||
from vertexai.preview.reasoning_engines.templates.langgraph import (
|
||||
LanggraphAgent,
|
||||
)
|
||||
from vertexai.preview.reasoning_engines.templates.llama_index import (
|
||||
LlamaIndexQueryPipelineAgent,
|
||||
)
|
||||
|
||||
__all__ = (
|
||||
"AdkApp",
|
||||
"AG2Agent",
|
||||
"LangchainAgent",
|
||||
"LanggraphAgent",
|
||||
"LlamaIndexQueryPipelineAgent",
|
||||
"Queryable",
|
||||
"ReasoningEngine",
|
||||
)
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,651 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional
|
||||
|
||||
if TYPE_CHECKING:
|
||||
try:
|
||||
from google.adk.events.event import Event
|
||||
|
||||
Event = Event
|
||||
except (ImportError, AttributeError):
|
||||
Event = Any
|
||||
|
||||
try:
|
||||
from google.adk.agents import BaseAgent
|
||||
|
||||
BaseAgent = BaseAgent
|
||||
except (ImportError, AttributeError):
|
||||
BaseAgent = Any
|
||||
|
||||
try:
|
||||
from google.adk.sessions import BaseSessionService
|
||||
|
||||
BaseSessionService = BaseSessionService
|
||||
except (ImportError, AttributeError):
|
||||
BaseSessionService = Any
|
||||
|
||||
try:
|
||||
from google.adk.artifacts import BaseArtifactService
|
||||
|
||||
BaseArtifactService = BaseArtifactService
|
||||
except (ImportError, AttributeError):
|
||||
BaseArtifactService = Any
|
||||
|
||||
try:
|
||||
from opentelemetry.sdk import trace
|
||||
|
||||
TracerProvider = trace.TracerProvider
|
||||
SpanProcessor = trace.SpanProcessor
|
||||
SynchronousMultiSpanProcessor = trace.SynchronousMultiSpanProcessor
|
||||
except ImportError:
|
||||
TracerProvider = Any
|
||||
SpanProcessor = Any
|
||||
SynchronousMultiSpanProcessor = Any
|
||||
|
||||
|
||||
_DEFAULT_APP_NAME = "default-app-name"
|
||||
_DEFAULT_USER_ID = "default-user-id"
|
||||
|
||||
|
||||
class _ArtifactVersion:
|
||||
def __init__(self, **kwargs):
|
||||
self.version: Optional[str] = kwargs.get("version")
|
||||
self.data = kwargs.get("data")
|
||||
|
||||
def dump(self) -> Dict[str, Any]:
|
||||
result = {}
|
||||
if self.version:
|
||||
result["version"] = self.version
|
||||
if self.data:
|
||||
result["data"] = self.data
|
||||
return result
|
||||
|
||||
|
||||
class _Artifact:
|
||||
def __init__(self, **kwargs):
|
||||
self.file_name: Optional[str] = kwargs.get("file_name")
|
||||
self.versions: List[_ArtifactVersion] = kwargs.get("versions")
|
||||
|
||||
def dump(self) -> Dict[str, Any]:
|
||||
result = {}
|
||||
if self.file_name:
|
||||
result["file_name"] = self.file_name
|
||||
if self.versions:
|
||||
result["versions"] = [version.dump() for version in self.versions]
|
||||
return result
|
||||
|
||||
|
||||
class _Authorization:
|
||||
def __init__(self, **kwargs):
|
||||
self.access_token: Optional[str] = kwargs.get("access_token") or kwargs.get(
|
||||
"accessToken"
|
||||
)
|
||||
|
||||
|
||||
class _StreamRunRequest:
|
||||
"""Request object for `streaming_agent_run_with_events` method."""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
from google.adk.events.event import Event
|
||||
from google.genai import types
|
||||
|
||||
self.message: Optional[types.Content] = kwargs.get("message")
|
||||
# The new message to be processed by the agent.
|
||||
|
||||
self.events: Optional[List[Event]] = kwargs.get("events")
|
||||
# List of preceding events happened in the same session.
|
||||
|
||||
self.artifacts: Optional[List[_Artifact]] = kwargs.get("artifacts")
|
||||
# List of artifacts belonging to the session.
|
||||
|
||||
self.authorizations: Dict[str, _Authorization] = kwargs.get(
|
||||
"authorizations", {}
|
||||
)
|
||||
# The authorizations of the user, keyed by authorization ID.
|
||||
|
||||
self.user_id: Optional[str] = kwargs.get("user_id", _DEFAULT_USER_ID)
|
||||
# The user ID.
|
||||
|
||||
|
||||
class _StreamingRunResponse:
|
||||
"""Response object for `streaming_agent_run_with_events` method.
|
||||
|
||||
It contains the generated events together with the belonging artifacts.
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
self.events: Optional[List["Event"]] = kwargs.get("events")
|
||||
# List of generated events.
|
||||
self.artifacts: Optional[List[_Artifact]] = kwargs.get("artifacts")
|
||||
# List of artifacts belonging to the session.
|
||||
|
||||
def dump(self) -> Dict[str, Any]:
|
||||
result = {}
|
||||
if self.events:
|
||||
result["events"] = []
|
||||
for event in self.events:
|
||||
event_dict = event.model_dump(exclude_none=True)
|
||||
event_dict["invocation_id"] = event_dict.get("invocation_id", "")
|
||||
result["events"].append(event_dict)
|
||||
if self.artifacts:
|
||||
result["artifacts"] = [artifact.dump() for artifact in self.artifacts]
|
||||
return result
|
||||
|
||||
|
||||
def _default_instrumentor_builder(project_id: str):
|
||||
from vertexai.agent_engines import _utils
|
||||
|
||||
cloud_trace_exporter = _utils._import_cloud_trace_exporter_or_warn()
|
||||
cloud_trace_v2 = _utils._import_cloud_trace_v2_or_warn()
|
||||
opentelemetry = _utils._import_opentelemetry_or_warn()
|
||||
opentelemetry_sdk_trace = _utils._import_opentelemetry_sdk_trace_or_warn()
|
||||
if all(
|
||||
(
|
||||
cloud_trace_exporter,
|
||||
cloud_trace_v2,
|
||||
opentelemetry,
|
||||
opentelemetry_sdk_trace,
|
||||
)
|
||||
):
|
||||
import google.auth
|
||||
|
||||
credentials, _ = google.auth.default()
|
||||
span_exporter = cloud_trace_exporter.CloudTraceSpanExporter(
|
||||
project_id=project_id,
|
||||
client=cloud_trace_v2.TraceServiceClient(
|
||||
credentials=credentials.with_quota_project(project_id),
|
||||
),
|
||||
)
|
||||
span_processor = opentelemetry_sdk_trace.export.BatchSpanProcessor(
|
||||
span_exporter=span_exporter,
|
||||
)
|
||||
tracer_provider = opentelemetry.trace.get_tracer_provider()
|
||||
# Get the appropriate tracer provider:
|
||||
# 1. If _TRACER_PROVIDER is already set, use that.
|
||||
# 2. Otherwise, if the OTEL_PYTHON_TRACER_PROVIDER environment
|
||||
# variable is set, use that.
|
||||
# 3. As a final fallback, use _PROXY_TRACER_PROVIDER.
|
||||
# If none of the above is set, we log a warning, and
|
||||
# create a tracer provider.
|
||||
if not tracer_provider:
|
||||
from google.cloud.aiplatform import base
|
||||
|
||||
_LOGGER = base.Logger(__name__)
|
||||
_LOGGER.warning(
|
||||
"No tracer provider. By default, "
|
||||
"we should get one of the following providers: "
|
||||
"OTEL_PYTHON_TRACER_PROVIDER, _TRACER_PROVIDER, "
|
||||
"or _PROXY_TRACER_PROVIDER."
|
||||
)
|
||||
tracer_provider = opentelemetry_sdk_trace.TracerProvider()
|
||||
opentelemetry.trace.set_tracer_provider(tracer_provider)
|
||||
# Avoids AttributeError:
|
||||
# 'ProxyTracerProvider' and 'NoOpTracerProvider' objects has no
|
||||
# attribute 'add_span_processor'.
|
||||
if _utils.is_noop_or_proxy_tracer_provider(tracer_provider):
|
||||
tracer_provider = opentelemetry_sdk_trace.TracerProvider()
|
||||
opentelemetry.trace.set_tracer_provider(tracer_provider)
|
||||
# Avoids OpenTelemetry client already exists error.
|
||||
_override_active_span_processor(
|
||||
tracer_provider,
|
||||
opentelemetry_sdk_trace.SynchronousMultiSpanProcessor(),
|
||||
)
|
||||
tracer_provider.add_span_processor(span_processor)
|
||||
return None
|
||||
else:
|
||||
from google.cloud.aiplatform import base
|
||||
|
||||
_LOGGER = base.Logger(__name__)
|
||||
_LOGGER.warning(
|
||||
"enable_tracing=True but proceeding with tracing disabled "
|
||||
"because not all packages (i.e. `google-cloud-trace`, `opentelemetry-sdk`, "
|
||||
"`opentelemetry-exporter-gcp-trace`) for tracing have been installed"
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
def _override_active_span_processor(
|
||||
tracer_provider: "TracerProvider",
|
||||
active_span_processor: "SynchronousMultiSpanProcessor",
|
||||
):
|
||||
"""Overrides the active span processor.
|
||||
|
||||
When working with multiple LangchainAgents in the same environment,
|
||||
it's crucial to manage trace exports carefully.
|
||||
Each agent needs its own span processor tied to a unique project ID.
|
||||
While we add a new span processor for each agent, this can lead to
|
||||
unexpected behavior.
|
||||
For instance, with two agents linked to different projects, traces from the
|
||||
second agent might be sent to both projects.
|
||||
To prevent this and guarantee traces go to the correct project, we overwrite
|
||||
the active span processor whenever a new LangchainAgent is created.
|
||||
|
||||
Args:
|
||||
tracer_provider (TracerProvider):
|
||||
The tracer provider to use for the project.
|
||||
active_span_processor (SynchronousMultiSpanProcessor):
|
||||
The active span processor overrides the tracer provider's
|
||||
active span processor.
|
||||
"""
|
||||
if tracer_provider._active_span_processor:
|
||||
tracer_provider._active_span_processor.shutdown()
|
||||
tracer_provider._active_span_processor = active_span_processor
|
||||
|
||||
|
||||
class AdkApp:
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
agent: "BaseAgent",
|
||||
enable_tracing: bool = False,
|
||||
session_service_builder: Optional[Callable[..., "BaseSessionService"]] = None,
|
||||
artifact_service_builder: Optional[Callable[..., "BaseArtifactService"]] = None,
|
||||
env_vars: Optional[Dict[str, str]] = None,
|
||||
):
|
||||
"""An ADK Application."""
|
||||
from google.cloud.aiplatform import initializer
|
||||
|
||||
self._tmpl_attrs: Dict[str, Any] = {
|
||||
"project": initializer.global_config.project,
|
||||
"location": initializer.global_config.location,
|
||||
"agent": agent,
|
||||
"enable_tracing": enable_tracing,
|
||||
"session_service_builder": session_service_builder,
|
||||
"artifact_service_builder": artifact_service_builder,
|
||||
"app_name": _DEFAULT_APP_NAME,
|
||||
"env_vars": env_vars or {},
|
||||
}
|
||||
|
||||
def _init_session(
|
||||
self,
|
||||
session_service: "BaseSessionService",
|
||||
artifact_service: "BaseArtifactService",
|
||||
request: _StreamRunRequest,
|
||||
):
|
||||
"""Initializes the session, and returns the session id."""
|
||||
from google.adk.events.event import Event
|
||||
import random
|
||||
|
||||
session_state = None
|
||||
if request.authorizations:
|
||||
session_state = {}
|
||||
for auth_id, auth in request.authorizations.items():
|
||||
auth = _Authorization(**auth)
|
||||
session_state[f"temp:{auth_id}"] = auth.access_token
|
||||
|
||||
session_id = f"temp_session_{random.randbytes(8).hex()}"
|
||||
session = session_service.create_session(
|
||||
app_name=self._tmpl_attrs.get("app_name"),
|
||||
user_id=request.user_id,
|
||||
session_id=session_id,
|
||||
state=session_state,
|
||||
)
|
||||
if not session:
|
||||
raise RuntimeError("Create session failed.")
|
||||
if request.events:
|
||||
for event in request.events:
|
||||
session_service.append_event(session, Event(**event))
|
||||
if request.artifacts:
|
||||
for artifact in request.artifacts:
|
||||
artifact = _Artifact(**artifact)
|
||||
for version_data in sorted(
|
||||
artifact.versions, key=lambda x: x["version"]
|
||||
):
|
||||
version_data = _ArtifactVersion(**version_data)
|
||||
saved_version = artifact_service.save_artifact(
|
||||
app_name=self._tmpl_attrs.get("app_name"),
|
||||
user_id=request.user_id,
|
||||
session_id=session_id,
|
||||
filename=artifact.file_name,
|
||||
artifact=version_data.data,
|
||||
)
|
||||
if saved_version != version_data.version:
|
||||
from google.cloud.aiplatform import base
|
||||
|
||||
_LOGGER = base.Logger(__name__)
|
||||
_LOGGER.debug(
|
||||
"Artifact '%s' saved at version %s instead of %s",
|
||||
artifact.file_name,
|
||||
saved_version,
|
||||
version_data.version,
|
||||
)
|
||||
return session
|
||||
|
||||
def _convert_response_events(
|
||||
self,
|
||||
user_id: str,
|
||||
session_id: str,
|
||||
events: List["Event"],
|
||||
artifact_service: Optional["BaseArtifactService"],
|
||||
) -> _StreamingRunResponse:
|
||||
"""Converts the events to the streaming run response object."""
|
||||
import collections
|
||||
|
||||
result = _StreamingRunResponse(events=events, artifacts=[])
|
||||
|
||||
# Save the generated artifacts into the result object.
|
||||
artifact_versions = collections.defaultdict(list)
|
||||
for event in events:
|
||||
if event.actions and event.actions.artifact_delta:
|
||||
for key, version in event.actions.artifact_delta.items():
|
||||
artifact_versions[key].append(version)
|
||||
|
||||
for key, versions in artifact_versions.items():
|
||||
result.artifacts.append(
|
||||
_Artifact(
|
||||
file_name=key,
|
||||
versions=[
|
||||
_ArtifactVersion(
|
||||
version=version,
|
||||
data=artifact_service.load_artifact(
|
||||
app_name=self._tmpl_attrs.get("app_name"),
|
||||
user_id=user_id,
|
||||
session_id=session_id,
|
||||
filename=key,
|
||||
version=version,
|
||||
),
|
||||
)
|
||||
for version in versions
|
||||
],
|
||||
)
|
||||
)
|
||||
|
||||
return result.dump()
|
||||
|
||||
def clone(self):
|
||||
"""Returns a clone of the ADK application."""
|
||||
import copy
|
||||
|
||||
return AdkApp(
|
||||
agent=copy.deepcopy(self._tmpl_attrs.get("agent")),
|
||||
enable_tracing=self._tmpl_attrs.get("enable_tracing"),
|
||||
session_service_builder=self._tmpl_attrs.get("session_service_builder"),
|
||||
artifact_service_builder=self._tmpl_attrs.get("artifact_service_builder"),
|
||||
env_vars=self._tmpl_attrs.get("env_vars"),
|
||||
)
|
||||
|
||||
def set_up(self):
|
||||
"""Sets up the ADK application."""
|
||||
import os
|
||||
from google.adk.runners import Runner
|
||||
from google.adk.sessions.in_memory_session_service import InMemorySessionService
|
||||
from google.adk.artifacts.in_memory_artifact_service import (
|
||||
InMemoryArtifactService,
|
||||
)
|
||||
|
||||
os.environ["GOOGLE_GENAI_USE_VERTEXAI"] = "1"
|
||||
project = self._tmpl_attrs.get("project")
|
||||
os.environ["GOOGLE_CLOUD_PROJECT"] = project
|
||||
location = self._tmpl_attrs.get("location")
|
||||
os.environ["GOOGLE_CLOUD_LOCATION"] = location
|
||||
if self._tmpl_attrs.get("enable_tracing"):
|
||||
self._tmpl_attrs["instrumentor"] = _default_instrumentor_builder(
|
||||
project_id=project
|
||||
)
|
||||
for key, value in self._tmpl_attrs.get("env_vars").items():
|
||||
os.environ[key] = value
|
||||
|
||||
artifact_service_builder = self._tmpl_attrs.get("artifact_service_builder")
|
||||
if artifact_service_builder:
|
||||
self._tmpl_attrs["artifact_service"] = artifact_service_builder()
|
||||
else:
|
||||
self._tmpl_attrs["artifact_service"] = InMemoryArtifactService()
|
||||
|
||||
session_service_builder = self._tmpl_attrs.get("session_service_builder")
|
||||
if session_service_builder:
|
||||
self._tmpl_attrs["session_service"] = session_service_builder()
|
||||
elif "GOOGLE_CLOUD_AGENT_ENGINE_ID" in os.environ:
|
||||
from google.adk.sessions.vertex_ai_session_service import (
|
||||
VertexAiSessionService,
|
||||
)
|
||||
|
||||
self._tmpl_attrs["session_service"] = VertexAiSessionService(
|
||||
project=project,
|
||||
location=location,
|
||||
)
|
||||
self._tmpl_attrs["app_name"] = os.environ.get(
|
||||
"GOOGLE_CLOUD_AGENT_ENGINE_ID",
|
||||
self._tmpl_attrs.get("app_name"),
|
||||
)
|
||||
else:
|
||||
self._tmpl_attrs["session_service"] = InMemorySessionService()
|
||||
|
||||
self._tmpl_attrs["runner"] = Runner(
|
||||
agent=self._tmpl_attrs.get("agent"),
|
||||
session_service=self._tmpl_attrs.get("session_service"),
|
||||
artifact_service=self._tmpl_attrs.get("artifact_service"),
|
||||
app_name=self._tmpl_attrs.get("app_name"),
|
||||
)
|
||||
self._tmpl_attrs["in_memory_session_service"] = InMemorySessionService()
|
||||
self._tmpl_attrs["in_memory_artifact_service"] = InMemoryArtifactService()
|
||||
self._tmpl_attrs["in_memory_runner"] = Runner(
|
||||
agent=self._tmpl_attrs.get("agent"),
|
||||
session_service=self._tmpl_attrs.get("in_memory_session_service"),
|
||||
artifact_service=self._tmpl_attrs.get("in_memory_artifact_service"),
|
||||
app_name=self._tmpl_attrs.get("app_name"),
|
||||
)
|
||||
|
||||
def stream_query(
|
||||
self,
|
||||
*,
|
||||
message: str,
|
||||
user_id: str,
|
||||
session_id: Optional[str] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""Streams responses from the ADK application in response to a message.
|
||||
|
||||
Args:
|
||||
message (str):
|
||||
Required. The message to stream responses for.
|
||||
user_id (str):
|
||||
Required. The ID of the user.
|
||||
session_id (str):
|
||||
Optional. The ID of the session. If not provided, a new
|
||||
session will be created for the user.
|
||||
**kwargs (dict[str, Any]):
|
||||
Optional. Additional keyword arguments to pass to the
|
||||
runner.
|
||||
|
||||
Yields:
|
||||
The output of querying the ADK application.
|
||||
"""
|
||||
from google.genai import types
|
||||
|
||||
content = types.Content(role="user", parts=[types.Part(text=message)])
|
||||
if not self._tmpl_attrs.get("runner"):
|
||||
self.set_up()
|
||||
if not session_id:
|
||||
session = self.create_session(user_id=user_id)
|
||||
session_id = session.id
|
||||
for event in self._tmpl_attrs.get("runner").run(
|
||||
user_id=user_id, session_id=session_id, new_message=content, **kwargs
|
||||
):
|
||||
yield event.model_dump(exclude_none=True)
|
||||
|
||||
def streaming_agent_run_with_events(self, request_json: str):
|
||||
import json
|
||||
from google.genai import types
|
||||
|
||||
request = _StreamRunRequest(**json.loads(request_json))
|
||||
if not self._tmpl_attrs.get("in_memory_runner"):
|
||||
self.set_up()
|
||||
if not self._tmpl_attrs.get("artifact_service"):
|
||||
self.set_up()
|
||||
# Prepare the in-memory session.
|
||||
if not self._tmpl_attrs.get("in_memory_artifact_service"):
|
||||
self.set_up()
|
||||
if not self._tmpl_attrs.get("in_memory_session_service"):
|
||||
self.set_up()
|
||||
session = self._init_session(
|
||||
session_service=self._tmpl_attrs.get("in_memory_session_service"),
|
||||
artifact_service=self._tmpl_attrs.get("in_memory_artifact_service"),
|
||||
request=request,
|
||||
)
|
||||
if not session:
|
||||
raise RuntimeError("Session initialization failed.")
|
||||
# Run the agent.
|
||||
for event in self._tmpl_attrs.get("in_memory_runner").run(
|
||||
user_id=request.user_id,
|
||||
session_id=session.id,
|
||||
new_message=types.Content(**request.message),
|
||||
):
|
||||
yield self._convert_response_events(
|
||||
user_id=request.user_id,
|
||||
session_id=session.id,
|
||||
events=[event],
|
||||
artifact_service=self._tmpl_attrs.get("in_memory_artifact_service"),
|
||||
)
|
||||
self._tmpl_attrs.get("in_memory_session_service").delete_session(
|
||||
app_name=self._tmpl_attrs.get("app_name"),
|
||||
user_id=request.user_id,
|
||||
session_id=session.id,
|
||||
)
|
||||
|
||||
def get_session(
|
||||
self,
|
||||
*,
|
||||
user_id: str,
|
||||
session_id: str,
|
||||
**kwargs,
|
||||
):
|
||||
"""Get a session for the given user.
|
||||
|
||||
Args:
|
||||
user_id (str):
|
||||
Required. The ID of the user.
|
||||
session_id (str):
|
||||
Required. The ID of the session.
|
||||
**kwargs (dict[str, Any]):
|
||||
Optional. Additional keyword arguments to pass to the
|
||||
session service.
|
||||
|
||||
Returns:
|
||||
Session: The session instance (if any). It returns None if the
|
||||
session is not found.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If the session is not found.
|
||||
"""
|
||||
if not self._tmpl_attrs.get("session_service"):
|
||||
self.set_up()
|
||||
session = self._tmpl_attrs.get("session_service").get_session(
|
||||
app_name=self._tmpl_attrs.get("app_name"),
|
||||
user_id=user_id,
|
||||
session_id=session_id,
|
||||
**kwargs,
|
||||
)
|
||||
if not session:
|
||||
raise RuntimeError(
|
||||
"Session not found. Please create it using .create_session()"
|
||||
)
|
||||
return session
|
||||
|
||||
def list_sessions(self, *, user_id: str, **kwargs):
|
||||
"""List sessions for the given user.
|
||||
|
||||
Args:
|
||||
user_id (str):
|
||||
Required. The ID of the user.
|
||||
**kwargs (dict[str, Any]):
|
||||
Optional. Additional keyword arguments to pass to the
|
||||
session service.
|
||||
|
||||
Returns:
|
||||
ListSessionsResponse: The list of sessions.
|
||||
"""
|
||||
if not self._tmpl_attrs.get("session_service"):
|
||||
self.set_up()
|
||||
return self._tmpl_attrs.get("session_service").list_sessions(
|
||||
app_name=self._tmpl_attrs.get("app_name"),
|
||||
user_id=user_id,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def create_session(
|
||||
self,
|
||||
*,
|
||||
user_id: str,
|
||||
session_id: Optional[str] = None,
|
||||
state: Optional[Dict[str, Any]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""Creates a new session.
|
||||
|
||||
Args:
|
||||
user_id (str):
|
||||
Required. The ID of the user.
|
||||
session_id (str):
|
||||
Optional. The ID of the session. If not provided, an ID
|
||||
will be be generated for the session.
|
||||
state (dict[str, Any]):
|
||||
Optional. The initial state of the session.
|
||||
**kwargs (dict[str, Any]):
|
||||
Optional. Additional keyword arguments to pass to the
|
||||
session service.
|
||||
|
||||
Returns:
|
||||
Session: The newly created session instance.
|
||||
"""
|
||||
if not self._tmpl_attrs.get("session_service"):
|
||||
self.set_up()
|
||||
session = self._tmpl_attrs.get("session_service").create_session(
|
||||
app_name=self._tmpl_attrs.get("app_name"),
|
||||
user_id=user_id,
|
||||
session_id=session_id,
|
||||
state=state,
|
||||
**kwargs,
|
||||
)
|
||||
return session
|
||||
|
||||
def delete_session(
|
||||
self,
|
||||
*,
|
||||
user_id: str,
|
||||
session_id: str,
|
||||
**kwargs,
|
||||
):
|
||||
"""Deletes a session for the given user.
|
||||
|
||||
Args:
|
||||
user_id (str):
|
||||
Required. The ID of the user.
|
||||
session_id (str):
|
||||
Required. The ID of the session.
|
||||
**kwargs (dict[str, Any]):
|
||||
Optional. Additional keyword arguments to pass to the
|
||||
session service.
|
||||
"""
|
||||
if not self._tmpl_attrs.get("session_service"):
|
||||
self.set_up()
|
||||
self._tmpl_attrs.get("session_service").delete_session(
|
||||
app_name=self._tmpl_attrs.get("app_name"),
|
||||
user_id=user_id,
|
||||
session_id=session_id,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def register_operations(self) -> Dict[str, List[str]]:
|
||||
"""Registers the operations of the ADK application."""
|
||||
return {
|
||||
"": [
|
||||
"get_session",
|
||||
"list_sessions",
|
||||
"create_session",
|
||||
"delete_session",
|
||||
],
|
||||
"stream": ["stream_query", "streaming_agent_run_with_events"],
|
||||
}
|
||||
@@ -0,0 +1,474 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
Mapping,
|
||||
Optional,
|
||||
Sequence,
|
||||
Union,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
try:
|
||||
from autogen import agentchat
|
||||
|
||||
ConversableAgent = agentchat.ConversableAgent
|
||||
ChatResult = agentchat.ChatResult
|
||||
except ImportError:
|
||||
ConversableAgent = Any
|
||||
|
||||
try:
|
||||
from opentelemetry.sdk import trace
|
||||
|
||||
TracerProvider = trace.TracerProvider
|
||||
SpanProcessor = trace.SpanProcessor
|
||||
SynchronousMultiSpanProcessor = trace.SynchronousMultiSpanProcessor
|
||||
except ImportError:
|
||||
TracerProvider = Any
|
||||
SpanProcessor = Any
|
||||
SynchronousMultiSpanProcessor = Any
|
||||
|
||||
|
||||
def _prepare_runnable_kwargs(
|
||||
runnable_kwargs: Mapping[str, Any],
|
||||
system_instruction: str,
|
||||
runnable_name: str,
|
||||
llm_config: Mapping[str, Any],
|
||||
) -> Mapping[str, Any]:
|
||||
"""Prepares the configuration for a runnable, applying defaults and enforcing constraints."""
|
||||
if runnable_kwargs is None:
|
||||
runnable_kwargs = {}
|
||||
|
||||
if (
|
||||
"human_input_mode" in runnable_kwargs
|
||||
and runnable_kwargs["human_input_mode"] != "NEVER"
|
||||
):
|
||||
from google.cloud.aiplatform import base
|
||||
|
||||
_LOGGER = base.Logger(__name__)
|
||||
_LOGGER.warning(
|
||||
f"human_input_mode={runnable_kwargs['human_input_mode']}"
|
||||
"is not supported. Will be enforced to 'NEVER'."
|
||||
)
|
||||
runnable_kwargs["human_input_mode"] = "NEVER"
|
||||
|
||||
if "system_message" not in runnable_kwargs and system_instruction:
|
||||
runnable_kwargs["system_message"] = system_instruction
|
||||
|
||||
if "name" not in runnable_kwargs:
|
||||
runnable_kwargs["name"] = runnable_name
|
||||
|
||||
if "llm_config" not in runnable_kwargs:
|
||||
runnable_kwargs["llm_config"] = llm_config
|
||||
|
||||
return runnable_kwargs
|
||||
|
||||
|
||||
def _default_runnable_builder(
|
||||
**runnable_kwargs: Any,
|
||||
) -> "ConversableAgent":
|
||||
from autogen import agentchat
|
||||
|
||||
return agentchat.ConversableAgent(**runnable_kwargs)
|
||||
|
||||
|
||||
def _validate_callable_parameters_are_annotated(callable: Callable):
|
||||
"""Validates that the parameters of the callable have type annotations.
|
||||
|
||||
This ensures that they can be used for constructing AG2 tools that are
|
||||
usable with Gemini function calling.
|
||||
"""
|
||||
import inspect
|
||||
|
||||
parameters = dict(inspect.signature(callable).parameters)
|
||||
for name, parameter in parameters.items():
|
||||
if parameter.annotation == inspect.Parameter.empty:
|
||||
raise TypeError(
|
||||
f"Callable={callable.__name__} has untyped input_arg={name}. "
|
||||
f"Please specify a type when defining it, e.g. `{name}: str`."
|
||||
)
|
||||
|
||||
|
||||
def _validate_tools(tools: Sequence[Callable[..., Any]]):
|
||||
"""Validates that the tools are usable for tool calling."""
|
||||
for tool in tools:
|
||||
if isinstance(tool, Callable):
|
||||
_validate_callable_parameters_are_annotated(tool)
|
||||
|
||||
|
||||
def _override_active_span_processor(
|
||||
tracer_provider: "TracerProvider",
|
||||
active_span_processor: "SynchronousMultiSpanProcessor",
|
||||
):
|
||||
"""Overrides the active span processor.
|
||||
|
||||
When working with multiple AG2Agents in the same environment,
|
||||
it's crucial to manage trace exports carefully.
|
||||
Each agent needs its own span processor tied to a unique project ID.
|
||||
While we add a new span processor for each agent, this can lead to
|
||||
unexpected behavior.
|
||||
For instance, with two agents linked to different projects, traces from the
|
||||
second agent might be sent to both projects.
|
||||
To prevent this and guarantee traces go to the correct project, we overwrite
|
||||
the active span processor whenever a new AG2Agent is created.
|
||||
|
||||
Args:
|
||||
tracer_provider (TracerProvider):
|
||||
The tracer provider to use for the project.
|
||||
active_span_processor (SynchronousMultiSpanProcessor):
|
||||
The active span processor overrides the tracer provider's
|
||||
active span processor.
|
||||
"""
|
||||
if tracer_provider._active_span_processor:
|
||||
tracer_provider._active_span_processor.shutdown()
|
||||
tracer_provider._active_span_processor = active_span_processor
|
||||
|
||||
|
||||
class AG2Agent:
|
||||
"""An AG2 Agent.
|
||||
|
||||
See https://cloud.google.com/vertex-ai/generative-ai/docs/agent-engine/develop/ag2
|
||||
for details.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: str,
|
||||
runnable_name: str,
|
||||
*,
|
||||
api_type: Optional[str] = None,
|
||||
llm_config: Optional[Mapping[str, Any]] = None,
|
||||
system_instruction: Optional[str] = None,
|
||||
runnable_kwargs: Optional[Mapping[str, Any]] = None,
|
||||
runnable_builder: Optional[Callable[..., "ConversableAgent"]] = None,
|
||||
tools: Optional[Sequence[Callable[..., Any]]] = None,
|
||||
enable_tracing: bool = False,
|
||||
):
|
||||
"""Initializes the AG2 Agent.
|
||||
|
||||
Under-the-hood, assuming .set_up() is called, this will correspond to
|
||||
```python
|
||||
# runnable_builder
|
||||
runnable = runnable_builder(
|
||||
llm_config=llm_config,
|
||||
system_message=system_instruction,
|
||||
**runnable_kwargs,
|
||||
)
|
||||
```
|
||||
|
||||
When everything is based on their default values, this corresponds to
|
||||
```python
|
||||
# llm_config
|
||||
llm_config = {
|
||||
"config_list": [{
|
||||
"project_id": initializer.global_config.project,
|
||||
"location": initializer.global_config.location,
|
||||
"model": "gemini-1.0-pro-001",
|
||||
"api_type": "google",
|
||||
}]
|
||||
}
|
||||
|
||||
# runnable_builder
|
||||
runnable = ConversableAgent(
|
||||
llm_config=llm_config,
|
||||
name="Default AG2 Agent"
|
||||
system_message="You are a helpful AI Assistant.",
|
||||
human_input_mode="NEVER",
|
||||
)
|
||||
```
|
||||
|
||||
By default, if `llm_config` is not specified, a default configuration
|
||||
will be created using the provided `model` and `api_type`.
|
||||
|
||||
If `runnable_builder` is not specified, a default runnable builder will
|
||||
be used, configured with the `system_instruction`, `runnable_name` and
|
||||
`runnable_kwargs`.
|
||||
|
||||
Args:
|
||||
model (str):
|
||||
Required. The name of the model (e.g. "gemini-1.0-pro").
|
||||
Used to create a default `llm_config` if one is not provided.
|
||||
This parameter is ignored if `llm_config` is provided.
|
||||
runnable_name (str):
|
||||
Required. The name of the runnable.
|
||||
This name is used as the default `runnable_kwargs["name"]`
|
||||
unless `runnable_kwargs` already contains a "name", in which
|
||||
case the provided `runnable_kwargs["name"]` will be used.
|
||||
api_type (str):
|
||||
Optional. The API type to use for the language model.
|
||||
Used to create a default `llm_config` if one is not provided.
|
||||
This parameter is ignored if `llm_config` is provided.
|
||||
llm_config (Mapping[str, Any]):
|
||||
Optional. Configuration dictionary for the language model.
|
||||
If provided, this configuration will be used directly.
|
||||
Otherwise, a default `llm_config` will be created using `model`
|
||||
and `api_type`. This `llm_config` is used as the default
|
||||
`runnable_kwargs["llm_config"]` unless `runnable_kwargs` already
|
||||
contains a "llm_config", in which case the provided
|
||||
`runnable_kwargs["llm_config"]` will be used.
|
||||
system_instruction (str):
|
||||
Optional. The system instruction for the agent.
|
||||
This instruction is used as the default
|
||||
`runnable_kwargs["system_message"]` unless `runnable_kwargs`
|
||||
already contains a "system_message", in which case the provided
|
||||
`runnable_kwargs["system_message"]` will be used.
|
||||
runnable_kwargs (Mapping[str, Any]):
|
||||
Optional. Additional keyword arguments for the constructor of
|
||||
the runnable. Details of the kwargs can be found in
|
||||
https://docs.ag2.ai/docs/api-reference/autogen/ConversableAgent.
|
||||
`runnable_kwargs` only supports `human_input_mode="NEVER"`.
|
||||
Other `human_input_mode` values will trigger a warning.
|
||||
runnable_builder (Callable[..., "ConversableAgent"]):
|
||||
Optional. Callable that returns a new runnable. This can be used
|
||||
for customizing the orchestration logic of the Agent.
|
||||
If not provided, a default runnable builder will be used.
|
||||
tools (Sequence[Callable[..., Any]]):
|
||||
Optional. The tools for the agent to be able to use. All input
|
||||
callables (e.g. function or class method) will be converted
|
||||
to a AG2 tool . Defaults to None.
|
||||
enable_tracing (bool):
|
||||
Optional. Whether to enable tracing in Cloud Trace. Defaults to
|
||||
False.
|
||||
"""
|
||||
from google.cloud.aiplatform import initializer
|
||||
|
||||
# Set up llm config.
|
||||
self._project = initializer.global_config.project
|
||||
self._location = initializer.global_config.location
|
||||
self._model_name = model or "gemini-1.0-pro-001"
|
||||
self._api_type = api_type or "google"
|
||||
self._llm_config = llm_config or {
|
||||
"config_list": [
|
||||
{
|
||||
"project_id": self._project,
|
||||
"location": self._location,
|
||||
"model": self._model_name,
|
||||
"api_type": self._api_type,
|
||||
}
|
||||
]
|
||||
}
|
||||
self._system_instruction = system_instruction
|
||||
self._runnable_name = runnable_name
|
||||
self._runnable_kwargs = _prepare_runnable_kwargs(
|
||||
runnable_kwargs=runnable_kwargs,
|
||||
llm_config=self._llm_config,
|
||||
system_instruction=self._system_instruction,
|
||||
runnable_name=self._runnable_name,
|
||||
)
|
||||
|
||||
self._tools = []
|
||||
if tools:
|
||||
# We validate tools at initialization for actionable feedback before
|
||||
# they are deployed.
|
||||
_validate_tools(tools)
|
||||
self._tools = tools
|
||||
self._ag2_tool_objects = []
|
||||
self._runnable = None
|
||||
self._runnable_builder = runnable_builder
|
||||
|
||||
self._instrumentor = None
|
||||
self._enable_tracing = enable_tracing
|
||||
|
||||
def set_up(self):
|
||||
"""Sets up the agent for execution of queries at runtime.
|
||||
|
||||
It initializes the runnable, binds the runnable with tools.
|
||||
|
||||
This method should not be called for an object that being passed to
|
||||
the ReasoningEngine service for deployment, as it initializes clients
|
||||
that can not be serialized.
|
||||
"""
|
||||
if self._enable_tracing:
|
||||
from vertexai.reasoning_engines import _utils
|
||||
|
||||
cloud_trace_exporter = _utils._import_cloud_trace_exporter_or_warn()
|
||||
cloud_trace_v2 = _utils._import_cloud_trace_v2_or_warn()
|
||||
openinference_autogen = _utils._import_openinference_autogen_or_warn()
|
||||
opentelemetry = _utils._import_opentelemetry_or_warn()
|
||||
opentelemetry_sdk_trace = _utils._import_opentelemetry_sdk_trace_or_warn()
|
||||
if all(
|
||||
(
|
||||
cloud_trace_exporter,
|
||||
cloud_trace_v2,
|
||||
openinference_autogen,
|
||||
opentelemetry,
|
||||
opentelemetry_sdk_trace,
|
||||
)
|
||||
):
|
||||
import google.auth
|
||||
|
||||
credentials, _ = google.auth.default()
|
||||
span_exporter = cloud_trace_exporter.CloudTraceSpanExporter(
|
||||
project_id=self._project,
|
||||
client=cloud_trace_v2.TraceServiceClient(
|
||||
credentials=credentials.with_quota_project(self._project),
|
||||
),
|
||||
)
|
||||
span_processor: SpanProcessor = (
|
||||
opentelemetry_sdk_trace.export.SimpleSpanProcessor(
|
||||
span_exporter=span_exporter,
|
||||
)
|
||||
)
|
||||
tracer_provider: TracerProvider = (
|
||||
opentelemetry.trace.get_tracer_provider()
|
||||
)
|
||||
# Get the appropriate tracer provider:
|
||||
# 1. If _TRACER_PROVIDER is already set, use that.
|
||||
# 2. Otherwise, if the OTEL_PYTHON_TRACER_PROVIDER environment
|
||||
# variable is set, use that.
|
||||
# 3. As a final fallback, use _PROXY_TRACER_PROVIDER.
|
||||
# If none of the above is set, we log a warning, and
|
||||
# create a tracer provider.
|
||||
if not tracer_provider:
|
||||
from google.cloud.aiplatform import base
|
||||
|
||||
_LOGGER = base.Logger(__name__)
|
||||
_LOGGER.warning(
|
||||
"No tracer provider. By default, "
|
||||
"we should get one of the following providers: "
|
||||
"OTEL_PYTHON_TRACER_PROVIDER, _TRACER_PROVIDER, "
|
||||
"or _PROXY_TRACER_PROVIDER."
|
||||
)
|
||||
tracer_provider = opentelemetry_sdk_trace.TracerProvider()
|
||||
opentelemetry.trace.set_tracer_provider(tracer_provider)
|
||||
# Avoids AttributeError:
|
||||
# 'ProxyTracerProvider' and 'NoOpTracerProvider' objects has no
|
||||
# attribute 'add_span_processor'.
|
||||
if _utils.is_noop_or_proxy_tracer_provider(tracer_provider):
|
||||
tracer_provider = opentelemetry_sdk_trace.TracerProvider()
|
||||
opentelemetry.trace.set_tracer_provider(tracer_provider)
|
||||
# Avoids OpenTelemetry client already exists error.
|
||||
_override_active_span_processor(
|
||||
tracer_provider,
|
||||
opentelemetry_sdk_trace.SynchronousMultiSpanProcessor(),
|
||||
)
|
||||
tracer_provider.add_span_processor(span_processor)
|
||||
# Keep the instrumentation up-to-date.
|
||||
# When creating multiple AG2Agents,
|
||||
# we need to keep the instrumentation up-to-date.
|
||||
# We deliberately override the instrument each time,
|
||||
# so that if different agents end up using different
|
||||
# instrumentations, we guarantee that the user is always
|
||||
# working with the most recent agent's instrumentation.
|
||||
self._instrumentor = openinference_autogen.AutogenInstrumentor()
|
||||
self._instrumentor.uninstrument()
|
||||
self._instrumentor.instrument()
|
||||
else:
|
||||
from google.cloud.aiplatform import base
|
||||
|
||||
_LOGGER = base.Logger(__name__)
|
||||
_LOGGER.warning(
|
||||
"enable_tracing=True but proceeding with tracing disabled "
|
||||
"because not all packages for tracing have been installed"
|
||||
)
|
||||
|
||||
# Set up tools.
|
||||
if self._tools and not self._ag2_tool_objects:
|
||||
from vertexai.reasoning_engines import _utils
|
||||
|
||||
autogen_tools = _utils._import_autogen_tools_or_warn()
|
||||
if autogen_tools:
|
||||
for tool in self._tools:
|
||||
self._ag2_tool_objects.append(autogen_tools.Tool(func_or_tool=tool))
|
||||
|
||||
# Set up runnable.
|
||||
runnable_builder = self._runnable_builder or _default_runnable_builder
|
||||
self._runnable = runnable_builder(
|
||||
**self._runnable_kwargs,
|
||||
)
|
||||
|
||||
def clone(self) -> "AG2Agent":
|
||||
"""Returns a clone of the AG2Agent."""
|
||||
import copy
|
||||
|
||||
return AG2Agent(
|
||||
model=self._model_name,
|
||||
api_type=self._api_type,
|
||||
llm_config=copy.deepcopy(self._llm_config),
|
||||
system_instruction=self._system_instruction,
|
||||
runnable_name=self._runnable_name,
|
||||
tools=copy.deepcopy(self._tools),
|
||||
runnable_kwargs=copy.deepcopy(self._runnable_kwargs),
|
||||
runnable_builder=self._runnable_builder,
|
||||
enable_tracing=self._enable_tracing,
|
||||
)
|
||||
|
||||
def query(
|
||||
self,
|
||||
*,
|
||||
input: Union[str, Mapping[str, Any]],
|
||||
max_turns: Optional[int] = None,
|
||||
**kwargs: Any,
|
||||
) -> Dict[str, Any]:
|
||||
"""Queries the Agent with the given input.
|
||||
|
||||
Args:
|
||||
input (Union[str, Mapping[str, Any]]):
|
||||
Required. The input to be passed to the Agent.
|
||||
max_turns (int):
|
||||
Optional. The maximum number of turns to run the agent for.
|
||||
If not provided, the agent will run indefinitely.
|
||||
If `max_turns` is a `float`, it will be converted to `int`
|
||||
through rounding.
|
||||
**kwargs:
|
||||
Optional. Any additional keyword arguments to be passed to the
|
||||
`.run()` method of the corresponding runnable.
|
||||
Details of the kwargs can be found in
|
||||
https://docs.ag2.ai/docs/api-reference/autogen/ConversableAgent#run.
|
||||
The `user_input` parameter defaults to `False`, and should not
|
||||
be passed through `kwargs`.
|
||||
|
||||
Returns:
|
||||
The output of querying the Agent with the given input.
|
||||
"""
|
||||
if isinstance(input, str):
|
||||
input = {"content": input}
|
||||
|
||||
if max_turns and isinstance(max_turns, float):
|
||||
# Supporting auto-conversion float to int.
|
||||
max_turns = round(max_turns)
|
||||
|
||||
if "user_input" in kwargs:
|
||||
from google.cloud.aiplatform import base
|
||||
|
||||
_LOGGER = base.Logger(__name__)
|
||||
_LOGGER.warning(
|
||||
"The `user_input` parameter should not be passed through"
|
||||
"kwargs. The `user_input` defaults to `False`."
|
||||
)
|
||||
kwargs.pop("user_input")
|
||||
|
||||
if not self._runnable:
|
||||
self.set_up()
|
||||
|
||||
from vertexai.reasoning_engines import _utils
|
||||
|
||||
# `.run()` will return a `ChatResult` object, which is a dataclass.
|
||||
# We need to convert it to a JSON-serializable object.
|
||||
# More details of `ChatResult` can be found in
|
||||
# https://docs.ag2.ai/docs/api-reference/autogen/ChatResult.
|
||||
return _utils.dataclass_to_dict(
|
||||
self._runnable.run(
|
||||
input,
|
||||
user_input=False,
|
||||
tools=self._ag2_tool_objects,
|
||||
max_turns=max_turns,
|
||||
**kwargs,
|
||||
)
|
||||
)
|
||||
@@ -0,0 +1,643 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2024 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
Iterable,
|
||||
Mapping,
|
||||
Optional,
|
||||
Sequence,
|
||||
Union,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
try:
|
||||
from langchain_core import runnables
|
||||
from langchain_core import tools as lc_tools
|
||||
from langchain_core.language_models import base as lc_language_models
|
||||
|
||||
BaseTool = lc_tools.BaseTool
|
||||
BaseLanguageModel = lc_language_models.BaseLanguageModel
|
||||
GetSessionHistoryCallable = runnables.history.GetSessionHistoryCallable
|
||||
RunnableConfig = runnables.RunnableConfig
|
||||
RunnableSerializable = runnables.RunnableSerializable
|
||||
except ImportError:
|
||||
BaseTool = Any
|
||||
BaseLanguageModel = Any
|
||||
GetSessionHistoryCallable = Any
|
||||
RunnableConfig = Any
|
||||
RunnableSerializable = Any
|
||||
|
||||
try:
|
||||
from langchain_google_vertexai.functions_utils import _ToolsType
|
||||
|
||||
_ToolLike = _ToolsType
|
||||
except ImportError:
|
||||
_ToolLike = Any
|
||||
|
||||
try:
|
||||
from opentelemetry.sdk import trace
|
||||
|
||||
TracerProvider = trace.TracerProvider
|
||||
SpanProcessor = trace.SpanProcessor
|
||||
SynchronousMultiSpanProcessor = trace.SynchronousMultiSpanProcessor
|
||||
except ImportError:
|
||||
TracerProvider = Any
|
||||
SpanProcessor = Any
|
||||
SynchronousMultiSpanProcessor = Any
|
||||
|
||||
|
||||
def _default_runnable_kwargs(has_history: bool) -> Mapping[str, Any]:
|
||||
# https://github.com/langchain-ai/langchain/blob/5784dfed001730530637793bea1795d9d5a7c244/libs/core/langchain_core/runnables/history.py#L237-L241
|
||||
runnable_kwargs = {
|
||||
# input_messages_key (str): Must be specified if the underlying
|
||||
# agent accepts a dict as input.
|
||||
"input_messages_key": "input",
|
||||
# output_messages_key (str): Must be specified if the underlying
|
||||
# agent returns a dict as output.
|
||||
"output_messages_key": "output",
|
||||
}
|
||||
if has_history:
|
||||
# history_messages_key (str): Must be specified if the underlying
|
||||
# agent accepts a dict as input and a separate key for historical
|
||||
# messages.
|
||||
runnable_kwargs["history_messages_key"] = "history"
|
||||
return runnable_kwargs
|
||||
|
||||
|
||||
def _default_output_parser():
|
||||
try:
|
||||
from langchain.agents.output_parsers.tools import ToolsAgentOutputParser
|
||||
except (ModuleNotFoundError, ImportError):
|
||||
# Fallback to an older version if needed.
|
||||
from langchain.agents.output_parsers.openai_tools import (
|
||||
OpenAIToolsAgentOutputParser as ToolsAgentOutputParser,
|
||||
)
|
||||
|
||||
return ToolsAgentOutputParser()
|
||||
|
||||
|
||||
def _default_model_builder(
|
||||
model_name: str,
|
||||
*,
|
||||
project: str,
|
||||
location: str,
|
||||
model_kwargs: Optional[Mapping[str, Any]] = None,
|
||||
) -> "BaseLanguageModel":
|
||||
import vertexai
|
||||
from google.cloud.aiplatform import initializer
|
||||
from langchain_google_vertexai import ChatVertexAI
|
||||
|
||||
model_kwargs = model_kwargs or {}
|
||||
current_project = initializer.global_config.project
|
||||
current_location = initializer.global_config.location
|
||||
vertexai.init(project=project, location=location)
|
||||
model = ChatVertexAI(model_name=model_name, **model_kwargs)
|
||||
vertexai.init(project=current_project, location=current_location)
|
||||
return model
|
||||
|
||||
|
||||
def _default_runnable_builder(
|
||||
model: "BaseLanguageModel",
|
||||
*,
|
||||
system_instruction: Optional[str] = None,
|
||||
tools: Optional[Sequence["_ToolLike"]] = None,
|
||||
prompt: Optional["RunnableSerializable"] = None,
|
||||
output_parser: Optional["RunnableSerializable"] = None,
|
||||
chat_history: Optional["GetSessionHistoryCallable"] = None,
|
||||
model_tool_kwargs: Optional[Mapping[str, Any]] = None,
|
||||
agent_executor_kwargs: Optional[Mapping[str, Any]] = None,
|
||||
runnable_kwargs: Optional[Mapping[str, Any]] = None,
|
||||
) -> "RunnableSerializable":
|
||||
from langchain_core import tools as lc_tools
|
||||
from langchain.agents import AgentExecutor
|
||||
from langchain.tools.base import StructuredTool
|
||||
|
||||
# The prompt template and runnable_kwargs needs to be customized depending
|
||||
# on whether the user intends for the agent to have history. The way the
|
||||
# user would reflect that is by setting chat_history (which defaults to
|
||||
# None).
|
||||
has_history: bool = chat_history is not None
|
||||
prompt = prompt or _default_prompt(
|
||||
has_history=has_history,
|
||||
system_instruction=system_instruction,
|
||||
)
|
||||
output_parser = output_parser or _default_output_parser()
|
||||
model_tool_kwargs = model_tool_kwargs or {}
|
||||
agent_executor_kwargs = agent_executor_kwargs or {}
|
||||
runnable_kwargs = runnable_kwargs or _default_runnable_kwargs(has_history)
|
||||
if tools:
|
||||
model = model.bind_tools(tools=tools, **model_tool_kwargs)
|
||||
else:
|
||||
tools = []
|
||||
agent_executor = AgentExecutor(
|
||||
agent=prompt | model | output_parser,
|
||||
tools=[
|
||||
tool
|
||||
if isinstance(tool, lc_tools.BaseTool)
|
||||
else StructuredTool.from_function(tool)
|
||||
for tool in tools
|
||||
if isinstance(tool, (Callable, lc_tools.BaseTool))
|
||||
],
|
||||
**agent_executor_kwargs,
|
||||
)
|
||||
if has_history:
|
||||
from langchain_core.runnables.history import RunnableWithMessageHistory
|
||||
|
||||
return RunnableWithMessageHistory(
|
||||
runnable=agent_executor,
|
||||
get_session_history=chat_history,
|
||||
**runnable_kwargs,
|
||||
)
|
||||
return agent_executor
|
||||
|
||||
|
||||
def _default_prompt(
|
||||
has_history: bool,
|
||||
system_instruction: Optional[str] = None,
|
||||
) -> "RunnableSerializable":
|
||||
from langchain_core import prompts
|
||||
|
||||
try:
|
||||
from langchain.agents.format_scratchpad.tools import format_to_tool_messages
|
||||
except (ModuleNotFoundError, ImportError):
|
||||
# Fallback to an older version if needed.
|
||||
from langchain.agents.format_scratchpad.openai_tools import (
|
||||
format_to_openai_tool_messages as format_to_tool_messages,
|
||||
)
|
||||
|
||||
system_instructions = []
|
||||
if system_instruction:
|
||||
system_instructions = [("system", system_instruction)]
|
||||
|
||||
if has_history:
|
||||
return {
|
||||
"history": lambda x: x["history"],
|
||||
"input": lambda x: x["input"],
|
||||
"agent_scratchpad": (
|
||||
lambda x: format_to_tool_messages(x["intermediate_steps"])
|
||||
),
|
||||
} | prompts.ChatPromptTemplate.from_messages(
|
||||
system_instructions
|
||||
+ [
|
||||
prompts.MessagesPlaceholder(variable_name="history"),
|
||||
("user", "{input}"),
|
||||
prompts.MessagesPlaceholder(variable_name="agent_scratchpad"),
|
||||
]
|
||||
)
|
||||
else:
|
||||
return {
|
||||
"input": lambda x: x["input"],
|
||||
"agent_scratchpad": (
|
||||
lambda x: format_to_tool_messages(x["intermediate_steps"])
|
||||
),
|
||||
} | prompts.ChatPromptTemplate.from_messages(
|
||||
system_instructions
|
||||
+ [
|
||||
("user", "{input}"),
|
||||
prompts.MessagesPlaceholder(variable_name="agent_scratchpad"),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def _validate_callable_parameters_are_annotated(callable: Callable):
|
||||
"""Validates that the parameters of the callable have type annotations.
|
||||
|
||||
This ensures that they can be used for constructing LangChain tools that are
|
||||
usable with Gemini function calling.
|
||||
"""
|
||||
import inspect
|
||||
|
||||
parameters = dict(inspect.signature(callable).parameters)
|
||||
for name, parameter in parameters.items():
|
||||
if parameter.annotation == inspect.Parameter.empty:
|
||||
raise TypeError(
|
||||
f"Callable={callable.__name__} has untyped input_arg={name}. "
|
||||
f"Please specify a type when defining it, e.g. `{name}: str`."
|
||||
)
|
||||
|
||||
|
||||
def _validate_tools(tools: Sequence["_ToolLike"]):
|
||||
"""Validates that the tools are usable for tool calling."""
|
||||
for tool in tools:
|
||||
if isinstance(tool, Callable):
|
||||
_validate_callable_parameters_are_annotated(tool)
|
||||
|
||||
|
||||
def _override_active_span_processor(
|
||||
tracer_provider: "TracerProvider",
|
||||
active_span_processor: "SynchronousMultiSpanProcessor",
|
||||
):
|
||||
"""Overrides the active span processor.
|
||||
|
||||
When working with multiple LangchainAgents in the same environment,
|
||||
it's crucial to manage trace exports carefully.
|
||||
Each agent needs its own span processor tied to a unique project ID.
|
||||
While we add a new span processor for each agent, this can lead to
|
||||
unexpected behavior.
|
||||
For instance, with two agents linked to different projects, traces from the
|
||||
second agent might be sent to both projects.
|
||||
To prevent this and guarantee traces go to the correct project, we overwrite
|
||||
the active span processor whenever a new LangchainAgent is created.
|
||||
|
||||
Args:
|
||||
tracer_provider (TracerProvider):
|
||||
The tracer provider to use for the project.
|
||||
active_span_processor (SynchronousMultiSpanProcessor):
|
||||
The active span processor overrides the tracer provider's
|
||||
active span processor.
|
||||
"""
|
||||
if tracer_provider._active_span_processor:
|
||||
tracer_provider._active_span_processor.shutdown()
|
||||
tracer_provider._active_span_processor = active_span_processor
|
||||
|
||||
|
||||
class LangchainAgent:
|
||||
"""A Langchain Agent.
|
||||
|
||||
See https://cloud.google.com/vertex-ai/generative-ai/docs/agent-engine/develop/langchain
|
||||
for details.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: str,
|
||||
*,
|
||||
system_instruction: Optional[str] = None,
|
||||
prompt: Optional["RunnableSerializable"] = None,
|
||||
tools: Optional[Sequence["_ToolLike"]] = None,
|
||||
output_parser: Optional["RunnableSerializable"] = None,
|
||||
chat_history: Optional["GetSessionHistoryCallable"] = None,
|
||||
model_kwargs: Optional[Mapping[str, Any]] = None,
|
||||
model_tool_kwargs: Optional[Mapping[str, Any]] = None,
|
||||
agent_executor_kwargs: Optional[Mapping[str, Any]] = None,
|
||||
runnable_kwargs: Optional[Mapping[str, Any]] = None,
|
||||
model_builder: Optional[Callable] = None,
|
||||
runnable_builder: Optional[Callable] = None,
|
||||
enable_tracing: bool = False,
|
||||
):
|
||||
"""Initializes the LangchainAgent.
|
||||
|
||||
Under-the-hood, assuming .set_up() is called, this will correspond to
|
||||
|
||||
```
|
||||
model = model_builder(model_name=model, model_kwargs=model_kwargs)
|
||||
runnable = runnable_builder(
|
||||
prompt=prompt,
|
||||
model=model,
|
||||
tools=tools,
|
||||
output_parser=output_parser,
|
||||
chat_history=chat_history,
|
||||
agent_executor_kwargs=agent_executor_kwargs,
|
||||
runnable_kwargs=runnable_kwargs,
|
||||
)
|
||||
```
|
||||
|
||||
When everything is based on their default values, this corresponds to
|
||||
```
|
||||
# model_builder
|
||||
from langchain_google_vertexai import ChatVertexAI
|
||||
llm = ChatVertexAI(model_name=model, **model_kwargs)
|
||||
|
||||
# runnable_builder
|
||||
from langchain import agents
|
||||
from langchain_core.runnables.history import RunnableWithMessageHistory
|
||||
llm_with_tools = llm.bind_tools(tools=tools, **model_tool_kwargs)
|
||||
agent_executor = agents.AgentExecutor(
|
||||
agent=prompt | llm_with_tools | output_parser,
|
||||
tools=tools,
|
||||
**agent_executor_kwargs,
|
||||
)
|
||||
runnable = RunnableWithMessageHistory(
|
||||
runnable=agent_executor,
|
||||
get_session_history=chat_history,
|
||||
**runnable_kwargs,
|
||||
)
|
||||
```
|
||||
|
||||
Args:
|
||||
model (str):
|
||||
Optional. The name of the model (e.g. "gemini-1.0-pro").
|
||||
system_instruction (str):
|
||||
Optional. The system instruction to use for the agent. This
|
||||
argument should not be specified if `prompt` is specified.
|
||||
prompt (langchain_core.runnables.RunnableSerializable):
|
||||
Optional. The prompt template for the model. Defaults to a
|
||||
ChatPromptTemplate.
|
||||
tools (Sequence[langchain_core.tools.BaseTool, Callable]):
|
||||
Optional. The tools for the agent to be able to use. All input
|
||||
callables (e.g. function or class method) will be converted
|
||||
to a langchain.tools.base.StructuredTool. Defaults to None.
|
||||
output_parser (langchain_core.runnables.RunnableSerializable):
|
||||
Optional. The output parser for the model. Defaults to an
|
||||
output parser that works with Gemini function-calling.
|
||||
chat_history (langchain_core.runnables.history.GetSessionHistoryCallable):
|
||||
Optional. Callable that returns a new BaseChatMessageHistory.
|
||||
Defaults to None, i.e. chat_history is not preserved.
|
||||
model_kwargs (Mapping[str, Any]):
|
||||
Optional. Additional keyword arguments for the constructor of
|
||||
chat_models.ChatVertexAI. An example would be
|
||||
```
|
||||
{
|
||||
# temperature (float): Sampling temperature, it controls the
|
||||
# degree of randomness in token selection.
|
||||
"temperature": 0.28,
|
||||
# max_output_tokens (int): Token limit determines the
|
||||
# maximum amount of text output from one prompt.
|
||||
"max_output_tokens": 1000,
|
||||
# top_p (float): Tokens are selected from most probable to
|
||||
# least, until the sum of their probabilities equals the
|
||||
# top_p value.
|
||||
"top_p": 0.95,
|
||||
# top_k (int): How the model selects tokens for output, the
|
||||
# next token is selected from among the top_k most probable
|
||||
# tokens.
|
||||
"top_k": 40,
|
||||
}
|
||||
```
|
||||
model_tool_kwargs (Mapping[str, Any]):
|
||||
Optional. Additional keyword arguments when binding tools to the
|
||||
model using `model.bind_tools()`.
|
||||
agent_executor_kwargs (Mapping[str, Any]):
|
||||
Optional. Additional keyword arguments for the constructor of
|
||||
langchain.agents.AgentExecutor. An example would be
|
||||
```
|
||||
{
|
||||
# Whether to return the agent's trajectory of intermediate
|
||||
# steps at the end in addition to the final output.
|
||||
"return_intermediate_steps": False,
|
||||
# The maximum number of steps to take before ending the
|
||||
# execution loop.
|
||||
"max_iterations": 15,
|
||||
# The method to use for early stopping if the agent never
|
||||
# returns `AgentFinish`. Either 'force' or 'generate'.
|
||||
"early_stopping_method": "force",
|
||||
# How to handle errors raised by the agent's output parser.
|
||||
# Defaults to `False`, which raises the error.
|
||||
"handle_parsing_errors": False,
|
||||
}
|
||||
```
|
||||
runnable_kwargs (Mapping[str, Any]):
|
||||
Optional. Additional keyword arguments for the constructor of
|
||||
langchain.runnables.history.RunnableWithMessageHistory if
|
||||
chat_history is specified. If chat_history is None, this will be
|
||||
ignored.
|
||||
model_builder (Callable):
|
||||
Optional. Callable that returns a new language model. Defaults
|
||||
to a a callable that returns ChatVertexAI based on `model`,
|
||||
`model_kwargs` and the parameters in `vertexai.init`.
|
||||
runnable_builder (Callable):
|
||||
Optional. Callable that returns a new runnable. This can be used
|
||||
for customizing the orchestration logic of the Agent based on
|
||||
the model returned by `model_builder` and the rest of the input
|
||||
arguments.
|
||||
enable_tracing (bool):
|
||||
Optional. Whether to enable tracing in Cloud Trace. Defaults to
|
||||
False.
|
||||
|
||||
Raises:
|
||||
ValueError: If both `prompt` and `system_instruction` are specified.
|
||||
TypeError: If there is an invalid tool (e.g. function with an input
|
||||
that did not specify its type).
|
||||
"""
|
||||
from google.cloud.aiplatform import initializer
|
||||
|
||||
self._project = initializer.global_config.project
|
||||
self._location = initializer.global_config.location
|
||||
self._tools = []
|
||||
if tools:
|
||||
# We validate tools at initialization for actionable feedback before
|
||||
# they are deployed.
|
||||
_validate_tools(tools)
|
||||
self._tools = tools
|
||||
if prompt and system_instruction:
|
||||
raise ValueError(
|
||||
"Only one of `prompt` or `system_instruction` should be specified. "
|
||||
"Consider incorporating the system instruction into the prompt "
|
||||
"rather than passing it separately as an argument."
|
||||
)
|
||||
self._model_name = model
|
||||
self._system_instruction = system_instruction
|
||||
self._prompt = prompt
|
||||
self._output_parser = output_parser
|
||||
self._chat_history = chat_history
|
||||
self._model_kwargs = model_kwargs
|
||||
self._model_tool_kwargs = model_tool_kwargs
|
||||
self._agent_executor_kwargs = agent_executor_kwargs
|
||||
self._runnable_kwargs = runnable_kwargs
|
||||
self._model = None
|
||||
self._model_builder = model_builder
|
||||
self._runnable = None
|
||||
self._runnable_builder = runnable_builder
|
||||
self._instrumentor = None
|
||||
self._enable_tracing = enable_tracing
|
||||
|
||||
def set_up(self):
|
||||
"""Sets up the agent for execution of queries at runtime.
|
||||
|
||||
It initializes the model, binds the model with tools, and connects it
|
||||
with the prompt template and output parser.
|
||||
|
||||
This method should not be called for an object that being passed to
|
||||
the ReasoningEngine service for deployment, as it initializes clients
|
||||
that can not be serialized.
|
||||
"""
|
||||
if self._enable_tracing:
|
||||
from vertexai.reasoning_engines import _utils
|
||||
|
||||
cloud_trace_exporter = _utils._import_cloud_trace_exporter_or_warn()
|
||||
cloud_trace_v2 = _utils._import_cloud_trace_v2_or_warn()
|
||||
openinference_langchain = _utils._import_openinference_langchain_or_warn()
|
||||
opentelemetry = _utils._import_opentelemetry_or_warn()
|
||||
opentelemetry_sdk_trace = _utils._import_opentelemetry_sdk_trace_or_warn()
|
||||
if all(
|
||||
(
|
||||
cloud_trace_exporter,
|
||||
cloud_trace_v2,
|
||||
openinference_langchain,
|
||||
opentelemetry,
|
||||
opentelemetry_sdk_trace,
|
||||
)
|
||||
):
|
||||
import google.auth
|
||||
|
||||
credentials, _ = google.auth.default()
|
||||
span_exporter = cloud_trace_exporter.CloudTraceSpanExporter(
|
||||
project_id=self._project,
|
||||
client=cloud_trace_v2.TraceServiceClient(
|
||||
credentials=credentials.with_quota_project(self._project),
|
||||
),
|
||||
)
|
||||
span_processor: SpanProcessor = (
|
||||
opentelemetry_sdk_trace.export.SimpleSpanProcessor(
|
||||
span_exporter=span_exporter,
|
||||
)
|
||||
)
|
||||
tracer_provider: TracerProvider = (
|
||||
opentelemetry.trace.get_tracer_provider()
|
||||
)
|
||||
# Get the appropriate tracer provider:
|
||||
# 1. If _TRACER_PROVIDER is already set, use that.
|
||||
# 2. Otherwise, if the OTEL_PYTHON_TRACER_PROVIDER environment
|
||||
# variable is set, use that.
|
||||
# 3. As a final fallback, use _PROXY_TRACER_PROVIDER.
|
||||
# If none of the above is set, we log a warning, and
|
||||
# create a tracer provider.
|
||||
if not tracer_provider:
|
||||
from google.cloud.aiplatform import base
|
||||
|
||||
_LOGGER = base.Logger(__name__)
|
||||
_LOGGER.warning(
|
||||
"No tracer provider. By default, "
|
||||
"we should get one of the following providers: "
|
||||
"OTEL_PYTHON_TRACER_PROVIDER, _TRACER_PROVIDER, "
|
||||
"or _PROXY_TRACER_PROVIDER."
|
||||
)
|
||||
tracer_provider = opentelemetry_sdk_trace.TracerProvider()
|
||||
opentelemetry.trace.set_tracer_provider(tracer_provider)
|
||||
# Avoids AttributeError:
|
||||
# 'ProxyTracerProvider' and 'NoOpTracerProvider' objects has no
|
||||
# attribute 'add_span_processor'.
|
||||
if _utils.is_noop_or_proxy_tracer_provider(tracer_provider):
|
||||
tracer_provider = opentelemetry_sdk_trace.TracerProvider()
|
||||
opentelemetry.trace.set_tracer_provider(tracer_provider)
|
||||
# Avoids OpenTelemetry client already exists error.
|
||||
_override_active_span_processor(
|
||||
tracer_provider,
|
||||
opentelemetry_sdk_trace.SynchronousMultiSpanProcessor(),
|
||||
)
|
||||
tracer_provider.add_span_processor(span_processor)
|
||||
# Keep the instrumentation up-to-date.
|
||||
# When creating multiple LangchainAgents,
|
||||
# we need to keep the instrumentation up-to-date.
|
||||
# We deliberately override the instrument each time,
|
||||
# so that if different agents end up using different
|
||||
# instrumentations, we guarantee that the user is always
|
||||
# working with the most recent agent's instrumentation.
|
||||
self._instrumentor = openinference_langchain.LangChainInstrumentor()
|
||||
if self._instrumentor.is_instrumented_by_opentelemetry:
|
||||
self._instrumentor.uninstrument()
|
||||
self._instrumentor.instrument()
|
||||
else:
|
||||
from google.cloud.aiplatform import base
|
||||
|
||||
_LOGGER = base.Logger(__name__)
|
||||
_LOGGER.warning(
|
||||
"enable_tracing=True but proceeding with tracing disabled "
|
||||
"because not all packages for tracing have been installed"
|
||||
)
|
||||
model_builder = self._model_builder or _default_model_builder
|
||||
self._model = model_builder(
|
||||
model_name=self._model_name,
|
||||
model_kwargs=self._model_kwargs,
|
||||
project=self._project,
|
||||
location=self._location,
|
||||
)
|
||||
runnable_builder = self._runnable_builder or _default_runnable_builder
|
||||
self._runnable = runnable_builder(
|
||||
prompt=self._prompt,
|
||||
model=self._model,
|
||||
tools=self._tools,
|
||||
system_instruction=self._system_instruction,
|
||||
output_parser=self._output_parser,
|
||||
chat_history=self._chat_history,
|
||||
model_tool_kwargs=self._model_tool_kwargs,
|
||||
agent_executor_kwargs=self._agent_executor_kwargs,
|
||||
runnable_kwargs=self._runnable_kwargs,
|
||||
)
|
||||
|
||||
def clone(self) -> "LangchainAgent":
|
||||
"""Returns a clone of the LangchainAgent."""
|
||||
import copy
|
||||
|
||||
return LangchainAgent(
|
||||
model=self._model_name,
|
||||
system_instruction=self._system_instruction,
|
||||
prompt=copy.deepcopy(self._prompt),
|
||||
tools=copy.deepcopy(self._tools),
|
||||
output_parser=copy.deepcopy(self._output_parser),
|
||||
chat_history=copy.deepcopy(self._chat_history),
|
||||
model_kwargs=copy.deepcopy(self._model_kwargs),
|
||||
model_tool_kwargs=copy.deepcopy(self._model_tool_kwargs),
|
||||
agent_executor_kwargs=copy.deepcopy(self._agent_executor_kwargs),
|
||||
runnable_kwargs=copy.deepcopy(self._runnable_kwargs),
|
||||
model_builder=self._model_builder,
|
||||
runnable_builder=self._runnable_builder,
|
||||
enable_tracing=self._enable_tracing,
|
||||
)
|
||||
|
||||
def query(
|
||||
self,
|
||||
*,
|
||||
input: Union[str, Mapping[str, Any]],
|
||||
config: Optional["RunnableConfig"] = None,
|
||||
**kwargs: Any,
|
||||
) -> Dict[str, Any]:
|
||||
"""Queries the Agent with the given input and config.
|
||||
|
||||
Args:
|
||||
input (Union[str, Mapping[str, Any]]):
|
||||
Required. The input to be passed to the Agent.
|
||||
config (langchain_core.runnables.RunnableConfig):
|
||||
Optional. The config (if any) to be used for invoking the Agent.
|
||||
**kwargs:
|
||||
Optional. Any additional keyword arguments to be passed to the
|
||||
`.invoke()` method of the corresponding AgentExecutor.
|
||||
|
||||
Returns:
|
||||
The output of querying the Agent with the given input and config.
|
||||
"""
|
||||
from langchain.load import dump as langchain_load_dump
|
||||
|
||||
if isinstance(input, str):
|
||||
input = {"input": input}
|
||||
if not self._runnable:
|
||||
self.set_up()
|
||||
return langchain_load_dump.dumpd(
|
||||
self._runnable.invoke(input=input, config=config, **kwargs)
|
||||
)
|
||||
|
||||
def stream_query(
|
||||
self,
|
||||
*,
|
||||
input: Union[str, Mapping[str, Any]],
|
||||
config: Optional["RunnableConfig"] = None,
|
||||
**kwargs,
|
||||
) -> Iterable[Any]:
|
||||
"""Stream queries the Agent with the given input and config.
|
||||
|
||||
Args:
|
||||
input (Union[str, Mapping[str, Any]]):
|
||||
Required. The input to be passed to the Agent.
|
||||
config (langchain_core.runnables.RunnableConfig):
|
||||
Optional. The config (if any) to be used for invoking the Agent.
|
||||
**kwargs:
|
||||
Optional. Any additional keyword arguments to be passed to the
|
||||
`.invoke()` method of the corresponding AgentExecutor.
|
||||
|
||||
Yields:
|
||||
The output of querying the Agent with the given input and config.
|
||||
"""
|
||||
from langchain.load import dump as langchain_load_dump
|
||||
|
||||
if isinstance(input, str):
|
||||
input = {"input": input}
|
||||
if not self._runnable:
|
||||
self.set_up()
|
||||
for chunk in self._runnable.stream(input=input, config=config, **kwargs):
|
||||
yield langchain_load_dump.dumpd(chunk)
|
||||
@@ -0,0 +1,658 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2024 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
Iterable,
|
||||
Mapping,
|
||||
Optional,
|
||||
Sequence,
|
||||
Union,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
try:
|
||||
from langchain_core import runnables
|
||||
from langchain_core import tools as lc_tools
|
||||
from langchain_core.language_models import base as lc_language_models
|
||||
|
||||
BaseTool = lc_tools.BaseTool
|
||||
BaseLanguageModel = lc_language_models.BaseLanguageModel
|
||||
RunnableConfig = runnables.RunnableConfig
|
||||
RunnableSerializable = runnables.RunnableSerializable
|
||||
except ImportError:
|
||||
BaseTool = Any
|
||||
BaseLanguageModel = Any
|
||||
RunnableConfig = Any
|
||||
RunnableSerializable = Any
|
||||
|
||||
try:
|
||||
from langchain_google_vertexai.functions_utils import _ToolsType
|
||||
|
||||
_ToolLike = _ToolsType
|
||||
except ImportError:
|
||||
_ToolLike = Any
|
||||
|
||||
try:
|
||||
from opentelemetry.sdk import trace
|
||||
|
||||
TracerProvider = trace.TracerProvider
|
||||
SpanProcessor = trace.SpanProcessor
|
||||
SynchronousMultiSpanProcessor = trace.SynchronousMultiSpanProcessor
|
||||
except ImportError:
|
||||
TracerProvider = Any
|
||||
SpanProcessor = Any
|
||||
SynchronousMultiSpanProcessor = Any
|
||||
|
||||
try:
|
||||
from langgraph_checkpoint.checkpoint import base
|
||||
|
||||
BaseCheckpointSaver = base.BaseCheckpointSaver
|
||||
except ImportError:
|
||||
try:
|
||||
from langgraph.checkpoint import base
|
||||
|
||||
BaseCheckpointSaver = base.BaseCheckpointSaver
|
||||
except ImportError:
|
||||
BaseCheckpointSaver = Any
|
||||
|
||||
|
||||
def _default_model_builder(
|
||||
model_name: str,
|
||||
*,
|
||||
project: str,
|
||||
location: str,
|
||||
model_kwargs: Optional[Mapping[str, Any]] = None,
|
||||
) -> "BaseLanguageModel":
|
||||
"""Default callable for building a language model.
|
||||
|
||||
Args:
|
||||
model_name (str):
|
||||
Required. The name of the model (e.g. "gemini-1.0-pro").
|
||||
project (str):
|
||||
Required. The Google Cloud project ID.
|
||||
location (str):
|
||||
Required. The Google Cloud location.
|
||||
model_kwargs (Mapping[str, Any]):
|
||||
Optional. Additional keyword arguments for the constructor of
|
||||
chat_models.ChatVertexAI.
|
||||
|
||||
Returns:
|
||||
BaseLanguageModel: The language model.
|
||||
"""
|
||||
import vertexai
|
||||
from google.cloud.aiplatform import initializer
|
||||
from langchain_google_vertexai import ChatVertexAI
|
||||
|
||||
model_kwargs = model_kwargs or {}
|
||||
current_project = initializer.global_config.project
|
||||
current_location = initializer.global_config.location
|
||||
vertexai.init(project=project, location=location)
|
||||
model = ChatVertexAI(model_name=model_name, **model_kwargs)
|
||||
vertexai.init(project=current_project, location=current_location)
|
||||
return model
|
||||
|
||||
|
||||
def _default_runnable_builder(
|
||||
model: "BaseLanguageModel",
|
||||
*,
|
||||
tools: Optional[Sequence["_ToolLike"]] = None,
|
||||
checkpointer: Optional[Any] = None,
|
||||
model_tool_kwargs: Optional[Mapping[str, Any]] = None,
|
||||
runnable_kwargs: Optional[Mapping[str, Any]] = None,
|
||||
) -> "RunnableSerializable":
|
||||
"""Default callable for building a runnable.
|
||||
|
||||
Args:
|
||||
model (BaseLanguageModel):
|
||||
Required. The language model.
|
||||
tools (Optional[Sequence[_ToolLike]]):
|
||||
Optional. The tools for the agent to be able to use.
|
||||
checkpointer (Optional[Checkpointer]):
|
||||
Optional. The checkpointer for the agent.
|
||||
model_tool_kwargs (Optional[Mapping[str, Any]]):
|
||||
Optional. Additional keyword arguments when binding tools to the model.
|
||||
runnable_kwargs (Optional[Mapping[str, Any]]):
|
||||
Optional. Additional keyword arguments for the runnable.
|
||||
|
||||
Returns:
|
||||
RunnableSerializable: The runnable.
|
||||
"""
|
||||
from langgraph import prebuilt as langgraph_prebuilt
|
||||
|
||||
model_tool_kwargs = model_tool_kwargs or {}
|
||||
runnable_kwargs = runnable_kwargs or {}
|
||||
if tools:
|
||||
model = model.bind_tools(tools=tools, **model_tool_kwargs)
|
||||
else:
|
||||
tools = []
|
||||
if checkpointer:
|
||||
if "checkpointer" in runnable_kwargs:
|
||||
from google.cloud.aiplatform import base
|
||||
|
||||
base.Logger(__name__).warning(
|
||||
"checkpointer is being specified in both checkpointer_builder "
|
||||
"and runnable_kwargs. Please specify it in only one of them. "
|
||||
"Overriding the checkpointer in runnable_kwargs."
|
||||
)
|
||||
runnable_kwargs["checkpointer"] = checkpointer
|
||||
return langgraph_prebuilt.create_react_agent(
|
||||
model,
|
||||
tools=tools,
|
||||
**runnable_kwargs,
|
||||
)
|
||||
|
||||
|
||||
def _validate_callable_parameters_are_annotated(callable: Callable):
|
||||
"""Validates that the parameters of the callable have type annotations.
|
||||
|
||||
This ensures that they can be used for constructing LangChain tools that are
|
||||
usable with Gemini function calling.
|
||||
|
||||
Args:
|
||||
callable (Callable): The callable to validate.
|
||||
|
||||
Raises:
|
||||
TypeError: If any parameter is not annotated.
|
||||
"""
|
||||
import inspect
|
||||
|
||||
parameters = dict(inspect.signature(callable).parameters)
|
||||
for name, parameter in parameters.items():
|
||||
if parameter.annotation == inspect.Parameter.empty:
|
||||
raise TypeError(
|
||||
f"Callable={callable.__name__} has untyped input_arg={name}. "
|
||||
f"Please specify a type when defining it, e.g. `{name}: str`."
|
||||
)
|
||||
|
||||
|
||||
def _validate_tools(tools: Sequence["_ToolLike"]):
|
||||
"""Validates that the tools are usable for tool calling.
|
||||
|
||||
Args:
|
||||
tools (Sequence[_ToolLike]): The tools to validate.
|
||||
|
||||
Raises:
|
||||
TypeError: If any tool is a callable with untyped parameters.
|
||||
"""
|
||||
for tool in tools:
|
||||
if isinstance(tool, Callable):
|
||||
_validate_callable_parameters_are_annotated(tool)
|
||||
|
||||
|
||||
def _override_active_span_processor(
|
||||
tracer_provider: "TracerProvider",
|
||||
active_span_processor: "SynchronousMultiSpanProcessor",
|
||||
):
|
||||
"""Overrides the active span processor.
|
||||
|
||||
When working with multiple LangchainAgents in the same environment,
|
||||
it's crucial to manage trace exports carefully.
|
||||
Each agent needs its own span processor tied to a unique project ID.
|
||||
While we add a new span processor for each agent, this can lead to
|
||||
unexpected behavior.
|
||||
For instance, with two agents linked to different projects, traces from the
|
||||
second agent might be sent to both projects.
|
||||
To prevent this and guarantee traces go to the correct project, we overwrite
|
||||
the active span processor whenever a new LangchainAgent is created.
|
||||
|
||||
Args:
|
||||
tracer_provider (TracerProvider):
|
||||
The tracer provider to use for the project.
|
||||
active_span_processor (SynchronousMultiSpanProcessor):
|
||||
The active span processor overrides the tracer provider's
|
||||
active span processor.
|
||||
"""
|
||||
if tracer_provider._active_span_processor:
|
||||
tracer_provider._active_span_processor.shutdown()
|
||||
tracer_provider._active_span_processor = active_span_processor
|
||||
|
||||
|
||||
class LanggraphAgent:
|
||||
"""A LangGraph Agent.
|
||||
|
||||
See https://cloud.google.com/vertex-ai/generative-ai/docs/agent-engine/develop/langgraph
|
||||
for details.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: str,
|
||||
*,
|
||||
tools: Optional[Sequence["_ToolLike"]] = None,
|
||||
model_kwargs: Optional[Mapping[str, Any]] = None,
|
||||
model_tool_kwargs: Optional[Mapping[str, Any]] = None,
|
||||
model_builder: Optional[Callable[..., "BaseLanguageModel"]] = None,
|
||||
runnable_kwargs: Optional[Mapping[str, Any]] = None,
|
||||
runnable_builder: Optional[Callable[..., "RunnableSerializable"]] = None,
|
||||
checkpointer_kwargs: Optional[Mapping[str, Any]] = None,
|
||||
checkpointer_builder: Optional[Callable[..., "BaseCheckpointSaver"]] = None,
|
||||
enable_tracing: bool = False,
|
||||
):
|
||||
"""Initializes the LangGraph Agent.
|
||||
|
||||
Under-the-hood, assuming .set_up() is called, this will correspond to
|
||||
```python
|
||||
model = model_builder(model_name=model, model_kwargs=model_kwargs)
|
||||
runnable = runnable_builder(
|
||||
model=model,
|
||||
tools=tools,
|
||||
model_tool_kwargs=model_tool_kwargs,
|
||||
runnable_kwargs=runnable_kwargs,
|
||||
)
|
||||
```
|
||||
|
||||
When everything is based on their default values, this corresponds to
|
||||
```python
|
||||
# model_builder
|
||||
from langchain_google_vertexai import ChatVertexAI
|
||||
llm = ChatVertexAI(model_name=model, **model_kwargs)
|
||||
|
||||
# runnable_builder
|
||||
from langgraph.prebuilt import create_react_agent
|
||||
llm_with_tools = llm.bind_tools(tools=tools, **model_tool_kwargs)
|
||||
runnable = create_react_agent(
|
||||
llm_with_tools,
|
||||
tools=tools,
|
||||
**runnable_kwargs,
|
||||
)
|
||||
```
|
||||
|
||||
By default, no checkpointer is used (i.e. there is no state history). To
|
||||
enable checkpointing, provide a `checkpointer_builder` function that
|
||||
returns a checkpointer instance.
|
||||
|
||||
**Example using Spanner:**
|
||||
```python
|
||||
def checkpointer_builder(instance_id, database_id, project_id, **kwargs):
|
||||
from langchain_google_spanner import SpannerCheckpointSaver
|
||||
|
||||
checkpointer = SpannerCheckpointSaver(instance_id, database_id, project_id)
|
||||
with checkpointer.cursor() as cur:
|
||||
cur.execute("DROP TABLE IF EXISTS checkpoints")
|
||||
cur.execute("DROP TABLE IF EXISTS checkpoint_writes")
|
||||
checkpointer.setup()
|
||||
|
||||
return checkpointer
|
||||
```
|
||||
|
||||
**Example using an in-memory checkpointer:**
|
||||
```python
|
||||
def checkpointer_builder(**kwargs):
|
||||
from langgraph.checkpoint.memory import MemorySaver
|
||||
|
||||
return MemorySaver()
|
||||
```
|
||||
|
||||
The `checkpointer_builder` function will be called with any keyword
|
||||
arguments passed to the agent's constructor. Ensure your
|
||||
`checkpointer_builder` function accepts `**kwargs` to handle these
|
||||
arguments, even if unused.
|
||||
|
||||
Args:
|
||||
model (str):
|
||||
Optional. The name of the model (e.g. "gemini-1.0-pro").
|
||||
tools (Sequence[langchain_core.tools.BaseTool, Callable]):
|
||||
Optional. The tools for the agent to be able to use. All input
|
||||
callables (e.g. function or class method) will be converted
|
||||
to a langchain.tools.base.StructuredTool. Defaults to None.
|
||||
model_kwargs (Mapping[str, Any]):
|
||||
Optional. Additional keyword arguments for the constructor of
|
||||
chat_models.ChatVertexAI. An example would be
|
||||
```
|
||||
{
|
||||
# temperature (float): Sampling temperature, it controls the
|
||||
# degree of randomness in token selection.
|
||||
"temperature": 0.28,
|
||||
# max_output_tokens (int): Token limit determines the
|
||||
# maximum amount of text output from one prompt.
|
||||
"max_output_tokens": 1000,
|
||||
# top_p (float): Tokens are selected from most probable to
|
||||
# least, until the sum of their probabilities equals the
|
||||
# top_p value.
|
||||
"top_p": 0.95,
|
||||
# top_k (int): How the model selects tokens for output, the
|
||||
# next token is selected from among the top_k most probable
|
||||
# tokens.
|
||||
"top_k": 40,
|
||||
}
|
||||
```
|
||||
model_tool_kwargs (Mapping[str, Any]):
|
||||
Optional. Additional keyword arguments when binding tools to the
|
||||
model using `model.bind_tools()`.
|
||||
model_builder (Callable[..., "BaseLanguageModel"]):
|
||||
Optional. Callable that returns a new language model. Defaults
|
||||
to a a callable that returns ChatVertexAI based on `model`,
|
||||
`model_kwargs` and the parameters in `vertexai.init`.
|
||||
runnable_kwargs (Mapping[str, Any]):
|
||||
Optional. Additional keyword arguments for the constructor of
|
||||
langchain.runnables.history.RunnableWithMessageHistory if
|
||||
chat_history is specified. If chat_history is None, this will be
|
||||
ignored.
|
||||
runnable_builder (Callable[..., "RunnableSerializable"]):
|
||||
Optional. Callable that returns a new runnable. This can be used
|
||||
for customizing the orchestration logic of the Agent based on
|
||||
the model returned by `model_builder` and the rest of the input
|
||||
arguments.
|
||||
checkpointer_kwargs (Mapping[str, Any]):
|
||||
Optional. Additional keyword arguments for the constructor of
|
||||
the checkpointer returned by `checkpointer_builder`.
|
||||
checkpointer_builder (Callable[..., "BaseCheckpointSaver"]):
|
||||
Optional. Callable that returns a checkpointer. This can be used
|
||||
for defining the checkpointer of the Agent. Defaults to None.
|
||||
enable_tracing (bool):
|
||||
Optional. Whether to enable tracing in Cloud Trace. Defaults to
|
||||
False.
|
||||
|
||||
Raises:
|
||||
TypeError: If there is an invalid tool (e.g. function with an input
|
||||
that did not specify its type).
|
||||
"""
|
||||
from google.cloud.aiplatform import initializer
|
||||
|
||||
self._project = initializer.global_config.project
|
||||
self._location = initializer.global_config.location
|
||||
self._tools = []
|
||||
if tools:
|
||||
# We validate tools at initialization for actionable feedback before
|
||||
# they are deployed.
|
||||
_validate_tools(tools)
|
||||
self._tools = tools
|
||||
self._model_name = model
|
||||
self._model_kwargs = model_kwargs
|
||||
self._model_tool_kwargs = model_tool_kwargs
|
||||
self._runnable_kwargs = runnable_kwargs
|
||||
self._checkpointer_kwargs = checkpointer_kwargs
|
||||
self._model = None
|
||||
self._model_builder = model_builder
|
||||
self._runnable = None
|
||||
self._runnable_builder = runnable_builder
|
||||
self._checkpointer_builder = checkpointer_builder
|
||||
self._instrumentor = None
|
||||
self._enable_tracing = enable_tracing
|
||||
|
||||
def set_up(self):
|
||||
"""Sets up the agent for execution of queries at runtime.
|
||||
|
||||
It initializes the model, binds the model with tools, and connects it
|
||||
with the prompt template and output parser.
|
||||
|
||||
This method should not be called for an object that being passed to
|
||||
the ReasoningEngine service for deployment, as it initializes clients
|
||||
that can not be serialized.
|
||||
"""
|
||||
if self._enable_tracing:
|
||||
from vertexai.reasoning_engines import _utils
|
||||
|
||||
cloud_trace_exporter = _utils._import_cloud_trace_exporter_or_warn()
|
||||
cloud_trace_v2 = _utils._import_cloud_trace_v2_or_warn()
|
||||
openinference_langchain = _utils._import_openinference_langchain_or_warn()
|
||||
opentelemetry = _utils._import_opentelemetry_or_warn()
|
||||
opentelemetry_sdk_trace = _utils._import_opentelemetry_sdk_trace_or_warn()
|
||||
if all(
|
||||
(
|
||||
cloud_trace_exporter,
|
||||
cloud_trace_v2,
|
||||
openinference_langchain,
|
||||
opentelemetry,
|
||||
opentelemetry_sdk_trace,
|
||||
)
|
||||
):
|
||||
import google.auth
|
||||
|
||||
credentials, _ = google.auth.default()
|
||||
span_exporter = cloud_trace_exporter.CloudTraceSpanExporter(
|
||||
project_id=self._project,
|
||||
client=cloud_trace_v2.TraceServiceClient(
|
||||
credentials=credentials.with_quota_project(self._project),
|
||||
),
|
||||
)
|
||||
span_processor: SpanProcessor = (
|
||||
opentelemetry_sdk_trace.export.SimpleSpanProcessor(
|
||||
span_exporter=span_exporter,
|
||||
)
|
||||
)
|
||||
tracer_provider: TracerProvider = (
|
||||
opentelemetry.trace.get_tracer_provider()
|
||||
)
|
||||
# Get the appropriate tracer provider:
|
||||
# 1. If _TRACER_PROVIDER is already set, use that.
|
||||
# 2. Otherwise, if the OTEL_PYTHON_TRACER_PROVIDER environment
|
||||
# variable is set, use that.
|
||||
# 3. As a final fallback, use _PROXY_TRACER_PROVIDER.
|
||||
# If none of the above is set, we log a warning, and
|
||||
# create a tracer provider.
|
||||
if not tracer_provider:
|
||||
from google.cloud.aiplatform import base
|
||||
|
||||
base.Logger(__name__).warning(
|
||||
"No tracer provider. By default, "
|
||||
"we should get one of the following providers: "
|
||||
"OTEL_PYTHON_TRACER_PROVIDER, _TRACER_PROVIDER, "
|
||||
"or _PROXY_TRACER_PROVIDER."
|
||||
)
|
||||
tracer_provider = opentelemetry_sdk_trace.TracerProvider()
|
||||
opentelemetry.trace.set_tracer_provider(tracer_provider)
|
||||
# Avoids AttributeError:
|
||||
# 'ProxyTracerProvider' and 'NoOpTracerProvider' objects has no
|
||||
# attribute 'add_span_processor'.
|
||||
if _utils.is_noop_or_proxy_tracer_provider(tracer_provider):
|
||||
tracer_provider = opentelemetry_sdk_trace.TracerProvider()
|
||||
opentelemetry.trace.set_tracer_provider(tracer_provider)
|
||||
# Avoids OpenTelemetry client already exists error.
|
||||
_override_active_span_processor(
|
||||
tracer_provider,
|
||||
opentelemetry_sdk_trace.SynchronousMultiSpanProcessor(),
|
||||
)
|
||||
tracer_provider.add_span_processor(span_processor)
|
||||
# Keep the instrumentation up-to-date.
|
||||
# When creating multiple LangchainAgents,
|
||||
# we need to keep the instrumentation up-to-date.
|
||||
# We deliberately override the instrument each time,
|
||||
# so that if different agents end up using different
|
||||
# instrumentations, we guarantee that the user is always
|
||||
# working with the most recent agent's instrumentation.
|
||||
self._instrumentor = openinference_langchain.LangChainInstrumentor()
|
||||
if self._instrumentor.is_instrumented_by_opentelemetry:
|
||||
self._instrumentor.uninstrument()
|
||||
self._instrumentor.instrument()
|
||||
else:
|
||||
from google.cloud.aiplatform import base
|
||||
|
||||
_LOGGER = base.Logger(__name__)
|
||||
_LOGGER.warning(
|
||||
"enable_tracing=True but proceeding with tracing disabled "
|
||||
"because not all packages for tracing have been installed"
|
||||
)
|
||||
model_builder = self._model_builder or _default_model_builder
|
||||
self._model = model_builder(
|
||||
model_name=self._model_name,
|
||||
model_kwargs=self._model_kwargs,
|
||||
project=self._project,
|
||||
location=self._location,
|
||||
)
|
||||
self._checkpointer = None
|
||||
if self._checkpointer_builder:
|
||||
checkpointer_kwargs = self._checkpointer_kwargs or {}
|
||||
self._checkpointer = self._checkpointer_builder(
|
||||
**checkpointer_kwargs,
|
||||
)
|
||||
runnable_builder = self._runnable_builder or _default_runnable_builder
|
||||
self._runnable = runnable_builder(
|
||||
model=self._model,
|
||||
tools=self._tools,
|
||||
checkpointer=self._checkpointer,
|
||||
model_tool_kwargs=self._model_tool_kwargs,
|
||||
runnable_kwargs=self._runnable_kwargs,
|
||||
)
|
||||
|
||||
def clone(self) -> "LanggraphAgent":
|
||||
"""Returns a clone of the LanggraphAgent."""
|
||||
import copy
|
||||
|
||||
return LanggraphAgent(
|
||||
model=self._model_name,
|
||||
tools=copy.deepcopy(self._tools),
|
||||
model_kwargs=copy.deepcopy(self._model_kwargs),
|
||||
model_tool_kwargs=copy.deepcopy(self._model_tool_kwargs),
|
||||
runnable_kwargs=copy.deepcopy(self._runnable_kwargs),
|
||||
checkpointer_kwargs=copy.deepcopy(self._checkpointer_kwargs),
|
||||
model_builder=self._model_builder,
|
||||
runnable_builder=self._runnable_builder,
|
||||
checkpointer_builder=self._checkpointer_builder,
|
||||
enable_tracing=self._enable_tracing,
|
||||
)
|
||||
|
||||
def query(
|
||||
self,
|
||||
*,
|
||||
input: Union[str, Mapping[str, Any]],
|
||||
config: Optional["RunnableConfig"] = None,
|
||||
**kwargs: Any,
|
||||
) -> Dict[str, Any]:
|
||||
"""Queries the Agent with the given input and config.
|
||||
|
||||
Args:
|
||||
input (Union[str, Mapping[str, Any]]):
|
||||
Required. The input to be passed to the Agent.
|
||||
config (langchain_core.runnables.RunnableConfig):
|
||||
Optional. The config (if any) to be used for invoking the Agent.
|
||||
**kwargs:
|
||||
Optional. Any additional keyword arguments to be passed to the
|
||||
`.invoke()` method of the corresponding AgentExecutor.
|
||||
|
||||
Returns:
|
||||
The output of querying the Agent with the given input and config.
|
||||
"""
|
||||
from langchain.load import dump as langchain_load_dump
|
||||
|
||||
if isinstance(input, str):
|
||||
input = {"input": input}
|
||||
if not self._runnable:
|
||||
self.set_up()
|
||||
return langchain_load_dump.dumpd(
|
||||
self._runnable.invoke(input=input, config=config, **kwargs)
|
||||
)
|
||||
|
||||
def stream_query(
|
||||
self,
|
||||
*,
|
||||
input: Union[str, Mapping[str, Any]],
|
||||
config: Optional["RunnableConfig"] = None,
|
||||
**kwargs,
|
||||
) -> Iterable[Any]:
|
||||
"""Stream queries the Agent with the given input and config.
|
||||
|
||||
Args:
|
||||
input (Union[str, Mapping[str, Any]]):
|
||||
Required. The input to be passed to the Agent.
|
||||
config (langchain_core.runnables.RunnableConfig):
|
||||
Optional. The config (if any) to be used for invoking the Agent.
|
||||
**kwargs:
|
||||
Optional. Any additional keyword arguments to be passed to the
|
||||
`.invoke()` method of the corresponding AgentExecutor.
|
||||
|
||||
Yields:
|
||||
The output of querying the Agent with the given input and config.
|
||||
"""
|
||||
from langchain.load import dump as langchain_load_dump
|
||||
|
||||
if isinstance(input, str):
|
||||
input = {"input": input}
|
||||
if not self._runnable:
|
||||
self.set_up()
|
||||
for chunk in self._runnable.stream(input=input, config=config, **kwargs):
|
||||
yield langchain_load_dump.dumpd(chunk)
|
||||
|
||||
def get_state_history(
|
||||
self,
|
||||
config: Optional["RunnableConfig"] = None,
|
||||
**kwargs: Any,
|
||||
) -> Iterable[Any]:
|
||||
"""Gets the state history of the Agent.
|
||||
|
||||
Args:
|
||||
config (Optional[RunnableConfig]):
|
||||
Optional. The config for invoking the Agent.
|
||||
**kwargs:
|
||||
Optional. Additional keyword arguments for the `.invoke()` method.
|
||||
|
||||
Yields:
|
||||
Dict[str, Any]: The state history of the Agent.
|
||||
"""
|
||||
if not self._runnable:
|
||||
self.set_up()
|
||||
for state_snapshot in self._runnable.get_state_history(config=config, **kwargs):
|
||||
yield state_snapshot._asdict()
|
||||
|
||||
def get_state(
|
||||
self,
|
||||
config: Optional["RunnableConfig"] = None,
|
||||
**kwargs: Any,
|
||||
) -> Dict[str, Any]:
|
||||
"""Gets the current state of the Agent.
|
||||
|
||||
Args:
|
||||
config (Optional[RunnableConfig]):
|
||||
Optional. The config for invoking the Agent.
|
||||
**kwargs:
|
||||
Optional. Additional keyword arguments for the `.invoke()` method.
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: The current state of the Agent.
|
||||
"""
|
||||
if not self._runnable:
|
||||
self.set_up()
|
||||
return self._runnable.get_state(config=config, **kwargs)._asdict()
|
||||
|
||||
def update_state(
|
||||
self,
|
||||
config: Optional["RunnableConfig"] = None,
|
||||
**kwargs: Any,
|
||||
) -> Dict[str, Any]:
|
||||
"""Updates the state of the Agent.
|
||||
|
||||
Args:
|
||||
config (Optional[RunnableConfig]):
|
||||
Optional. The config for invoking the Agent.
|
||||
**kwargs:
|
||||
Optional. Additional keyword arguments for the `.invoke()` method.
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: The updated state of the Agent.
|
||||
"""
|
||||
if not self._runnable:
|
||||
self.set_up()
|
||||
return self._runnable.update_state(config=config, **kwargs)
|
||||
|
||||
def register_operations(self) -> Mapping[str, Sequence[str]]:
|
||||
"""Registers the operations of the Agent.
|
||||
|
||||
This mapping defines how different operation modes (e.g., "", "stream")
|
||||
are implemented by specific methods of the Agent. The "default" mode,
|
||||
represented by the empty string ``, is associated with the `query` API,
|
||||
while the "stream" mode is associated with the `stream_query` API.
|
||||
|
||||
Returns:
|
||||
Mapping[str, Sequence[str]]: A mapping of operation modes to a list
|
||||
of method names that implement those operation modes.
|
||||
"""
|
||||
return {
|
||||
"": ["query", "get_state", "update_state"],
|
||||
"stream": ["stream_query", "get_state_history"],
|
||||
}
|
||||
@@ -0,0 +1,553 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
Mapping,
|
||||
Optional,
|
||||
Sequence,
|
||||
Union,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
try:
|
||||
from llama_index.core.base.query_pipeline import query
|
||||
from llama_index.core.llms import function_calling
|
||||
from llama_index.core import query_pipeline
|
||||
|
||||
FunctionCallingLLM = function_calling.FunctionCallingLLM
|
||||
QueryComponent = query.QUERY_COMPONENT_TYPE
|
||||
QueryPipeline = query_pipeline.QueryPipeline
|
||||
except ImportError:
|
||||
FunctionCallingLLM = Any
|
||||
QueryComponent = Any
|
||||
QueryPipeline = Any
|
||||
|
||||
try:
|
||||
from opentelemetry.sdk import trace
|
||||
|
||||
TracerProvider = trace.TracerProvider
|
||||
SpanProcessor = trace.SpanProcessor
|
||||
SynchronousMultiSpanProcessor = trace.SynchronousMultiSpanProcessor
|
||||
except ImportError:
|
||||
TracerProvider = Any
|
||||
SpanProcessor = Any
|
||||
SynchronousMultiSpanProcessor = Any
|
||||
|
||||
|
||||
def _default_model_builder(
|
||||
model_name: str,
|
||||
*,
|
||||
project: str,
|
||||
location: str,
|
||||
model_kwargs: Optional[Mapping[str, Any]] = None,
|
||||
) -> "FunctionCallingLLM":
|
||||
"""Creates a default model builder for LlamaIndex."""
|
||||
import vertexai
|
||||
from google.cloud.aiplatform import initializer
|
||||
from llama_index.llms import google_genai
|
||||
|
||||
model_kwargs = model_kwargs or {}
|
||||
model = google_genai.GoogleGenAI(
|
||||
model=model_name,
|
||||
vertexai_config={"project": project, "location": location},
|
||||
**model_kwargs,
|
||||
)
|
||||
current_project = initializer.global_config.project
|
||||
current_location = initializer.global_config.location
|
||||
vertexai.init(project=current_project, location=current_location)
|
||||
return model
|
||||
|
||||
|
||||
def _default_runnable_builder(
|
||||
model: "FunctionCallingLLM",
|
||||
*,
|
||||
system_instruction: Optional[str] = None,
|
||||
prompt: Optional["QueryComponent"] = None,
|
||||
retriever: Optional["QueryComponent"] = None,
|
||||
response_synthesizer: Optional["QueryComponent"] = None,
|
||||
runnable_kwargs: Optional[Mapping[str, Any]] = None,
|
||||
) -> "QueryPipeline":
|
||||
"""Creates a default runnable builder for LlamaIndex."""
|
||||
try:
|
||||
from llama_index.core.query_pipeline import QueryPipeline
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Please call 'pip install google-cloud-aiplatform[llama_index]'."
|
||||
)
|
||||
|
||||
prompt = prompt or _default_prompt(
|
||||
system_instruction=system_instruction,
|
||||
)
|
||||
pipeline = QueryPipeline(**runnable_kwargs)
|
||||
pipeline_modules = {
|
||||
"prompt": prompt,
|
||||
"model": model,
|
||||
}
|
||||
if retriever:
|
||||
pipeline_modules["retriever"] = retriever
|
||||
if response_synthesizer:
|
||||
pipeline_modules["response_synthesizer"] = response_synthesizer
|
||||
|
||||
pipeline.add_modules(pipeline_modules)
|
||||
pipeline.add_link("prompt", "model")
|
||||
if "retriever" in pipeline_modules:
|
||||
pipeline.add_link("model", "retriever")
|
||||
if "response_synthesizer" in pipeline_modules:
|
||||
pipeline.add_link("model", "response_synthesizer", dest_key="query_str")
|
||||
if "retriever" in pipeline_modules:
|
||||
pipeline.add_link("retriever", "response_synthesizer", dest_key="nodes")
|
||||
|
||||
return pipeline
|
||||
|
||||
|
||||
def _default_prompt(
|
||||
system_instruction: Optional[str] = None,
|
||||
) -> "QueryComponent":
|
||||
"""Creates a default prompt template for LlamaIndex.
|
||||
|
||||
Handles both system instruction and user input.
|
||||
|
||||
Args:
|
||||
system_instruction (str, optional): The system instruction to use.
|
||||
|
||||
Returns:
|
||||
QueryComponent: The LlamaIndex QueryComponent.
|
||||
"""
|
||||
try:
|
||||
from llama_index.core import prompts
|
||||
from llama_index.core.base.llms import types
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Please call 'pip install google-cloud-aiplatform[llama_index]'."
|
||||
)
|
||||
|
||||
# Define a prompt template
|
||||
message_templates = []
|
||||
if system_instruction:
|
||||
message_templates.append(
|
||||
types.ChatMessage(role=types.MessageRole.SYSTEM, content=system_instruction)
|
||||
)
|
||||
# Add user input message
|
||||
message_templates.append(
|
||||
types.ChatMessage(role=types.MessageRole.USER, content="{input}")
|
||||
)
|
||||
|
||||
# Create the prompt template
|
||||
return prompts.ChatPromptTemplate(message_templates=message_templates)
|
||||
|
||||
|
||||
def _override_active_span_processor(
|
||||
tracer_provider: "TracerProvider",
|
||||
active_span_processor: "SynchronousMultiSpanProcessor",
|
||||
):
|
||||
"""Overrides the active span processor.
|
||||
|
||||
When working with multiple LlamaIndexQueryPipelineAgents in the same
|
||||
environment, it's crucial to manage trace exports carefully.
|
||||
Each agent needs its own span processor tied to a unique project ID.
|
||||
While we add a new span processor for each agent, this can lead to
|
||||
unexpected behavior.
|
||||
For instance, with two agents linked to different projects, traces from the
|
||||
second agent might be sent to both projects.
|
||||
To prevent this and guarantee traces go to the correct project, we overwrite
|
||||
the active span processor whenever a new LlamaIndexQueryPipelineAgent is
|
||||
created.
|
||||
|
||||
Args:
|
||||
tracer_provider (TracerProvider):
|
||||
The tracer provider to use for the project.
|
||||
active_span_processor (SynchronousMultiSpanProcessor):
|
||||
The active span processor overrides the tracer provider's
|
||||
active span processor.
|
||||
"""
|
||||
if tracer_provider._active_span_processor:
|
||||
tracer_provider._active_span_processor.shutdown()
|
||||
tracer_provider._active_span_processor = active_span_processor
|
||||
|
||||
|
||||
class LlamaIndexQueryPipelineAgent:
|
||||
"""A LlamaIndex Query Pipeline Agent.
|
||||
|
||||
This agent uses a query pipeline for LLAIndex, including prompt, model,
|
||||
retrieval and summarization steps. More details can be found in
|
||||
https://docs.llamaindex.ai/en/stable/module_guides/querying/pipeline/.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: str,
|
||||
*,
|
||||
system_instruction: Optional[str] = None,
|
||||
prompt: Optional["QueryComponent"] = None,
|
||||
model_kwargs: Optional[Mapping[str, Any]] = None,
|
||||
model_builder: Optional[Callable[..., "FunctionCallingLLM"]] = None,
|
||||
retriever_kwargs: Optional[Mapping[str, Any]] = None,
|
||||
retriever_builder: Optional[Callable[..., "QueryComponent"]] = None,
|
||||
response_synthesizer_kwargs: Optional[Mapping[str, Any]] = None,
|
||||
response_synthesizer_builder: Optional[Callable[..., "QueryComponent"]] = None,
|
||||
runnable_kwargs: Optional[Mapping[str, Any]] = None,
|
||||
runnable_builder: Optional[Callable[..., "QueryPipeline"]] = None,
|
||||
enable_tracing: bool = False,
|
||||
):
|
||||
"""Initializes the LlamaIndexQueryPipelineAgent.
|
||||
|
||||
Under-the-hood, assuming .set_up() is called, this will correspond to
|
||||
```python
|
||||
# model_builder
|
||||
model = model_builder(model_name, project, location, model_kwargs)
|
||||
|
||||
# runnable_builder
|
||||
runnable = runnable_builder(
|
||||
prompt=prompt,
|
||||
model=model,
|
||||
retriever=retriever_builder(model, retriever_kwargs),
|
||||
response_synthesizer=response_synthesizer_builder(
|
||||
model, response_synthesizer_kwargs
|
||||
),
|
||||
runnable_kwargs=runnable_kwargs,
|
||||
)
|
||||
```
|
||||
|
||||
When everything is based on their default values, this corresponds to a
|
||||
query pipeline `Prompt - Model`:
|
||||
```python
|
||||
# Default Model Builder
|
||||
model = google_genai.GoogleGenAI(
|
||||
model=model_name,
|
||||
vertexai_config={
|
||||
"project": initializer.global_config.project,
|
||||
"location": initializer.global_config.location,
|
||||
},
|
||||
)
|
||||
|
||||
# Default Prompt Builder
|
||||
prompt = prompts.ChatPromptTemplate(
|
||||
message_templates=[
|
||||
types.ChatMessage(
|
||||
role=types.MessageRole.USER,
|
||||
content="{input}",
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
# Default Runnable Builder
|
||||
runnable = QueryPipeline(
|
||||
modules = {
|
||||
"prompt": prompt,
|
||||
"model": model,
|
||||
},
|
||||
)
|
||||
pipeline.add_link("prompt", "model")
|
||||
```
|
||||
|
||||
When `system_instruction` is specified, the prompt will be updated to
|
||||
include the system instruction.
|
||||
```python
|
||||
# Updated Prompt Builder
|
||||
prompt = prompts.ChatPromptTemplate(
|
||||
message_templates=[
|
||||
types.ChatMessage(
|
||||
role=types.MessageRole.SYSTEM,
|
||||
content=system_instruction,
|
||||
),
|
||||
types.ChatMessage(
|
||||
role=types.MessageRole.USER,
|
||||
content="{input}",
|
||||
),
|
||||
],
|
||||
)
|
||||
```
|
||||
|
||||
When all inputs are specified, this corresponds to a query pipeline
|
||||
`Prompt - Model - Retriever - Summarizer`:
|
||||
```python
|
||||
runnable = QueryPipeline(
|
||||
modules = {
|
||||
"prompt": prompt,
|
||||
"model": model,
|
||||
"retriever": retriever_builder(retriever_kwargs),
|
||||
"response_synthesizer": response_synthesizer_builder(
|
||||
response_synthesizer_kwargs
|
||||
),
|
||||
},
|
||||
)
|
||||
pipeline.add_link("prompt", "model")
|
||||
pipeline.add_link("model", "retriever")
|
||||
pipeline.add_link("model", "response_synthesizer", dest_key="query_str")
|
||||
pipeline.add_link("retriever", "response_synthesizer", dest_key="nodes")
|
||||
```
|
||||
|
||||
Args:
|
||||
model (str):
|
||||
The name of the model (e.g. "gemini-1.0-pro").
|
||||
system_instruction (str):
|
||||
Optional. The system instruction to use for the agent.
|
||||
prompt (llama_index.core.base.query_pipeline.query.QUERY_COMPONENT_TYPE):
|
||||
Optional. The prompt template for the model.
|
||||
model_kwargs (Mapping[str, Any]):
|
||||
Optional. Keyword arguments for the model constructor of the
|
||||
google_genai.GoogleGenAI. An example of a model_kwargs is:
|
||||
```python
|
||||
{
|
||||
# api_key (string): The API key for the GoogleGenAI model.
|
||||
# The API can also be fetched from the GOOGLE_API_KEY
|
||||
# environment variable. If `vertexai_config` is provided,
|
||||
# the API key is ignored.
|
||||
"api_key": "your_api_key",
|
||||
# temperature (float): Sampling temperature, it controls the
|
||||
# degree of randomness in token selection. If not provided,
|
||||
# the default temperature is 0.1.
|
||||
"temperature": 0.1,
|
||||
# context_window (int): The context window of the model.
|
||||
# If not provided, the default context window is 200000.
|
||||
"context_window": 200000,
|
||||
# max_tokens (int): Token limit determines the maximum
|
||||
# amount of text output from one prompt. If not provided,
|
||||
# the default max_tokens is 256.
|
||||
"max_tokens": 256,
|
||||
# is_function_calling_model (bool): Whether the model is a
|
||||
# function calling model. If not provided, the default
|
||||
# is_function_calling_model is True.
|
||||
"is_function_calling_model": True,
|
||||
}
|
||||
```
|
||||
model_builder (Callable):
|
||||
Optional. Callable that returns a language model.
|
||||
retriever_kwargs (Mapping[str, Any]):
|
||||
Optional. Keyword arguments for the retriever constructor.
|
||||
retriever_builder (Callable):
|
||||
Optional. Callable that returns a retriever object.
|
||||
response_synthesizer_kwargs (Mapping[str, Any]):
|
||||
Optional. Keyword arguments for the response synthesizer constructor.
|
||||
response_synthesizer_builder (Callable):
|
||||
Optional. Callable that returns a response_synthesizer object.
|
||||
runnable_kwargs (Mapping[str, Any]):
|
||||
Optional. Keyword arguments for the runnable constructor.
|
||||
runnable_builder (Callable):
|
||||
Optional. Callable that returns a runnable (query pipeline).
|
||||
enable_tracing (bool):
|
||||
Optional. Whether to enable tracing. Defaults to False.
|
||||
"""
|
||||
from google.cloud.aiplatform import initializer
|
||||
|
||||
self._project = initializer.global_config.project
|
||||
self._location = initializer.global_config.location
|
||||
self._model_name = model
|
||||
self._system_instruction = system_instruction
|
||||
self._prompt = prompt
|
||||
|
||||
self._model = None
|
||||
self._model_kwargs = model_kwargs or {}
|
||||
self._model_builder = model_builder
|
||||
|
||||
self._retriever = None
|
||||
self._retriever_kwargs = retriever_kwargs or {}
|
||||
self._retriever_builder = retriever_builder
|
||||
|
||||
self._response_synthesizer = None
|
||||
self._response_synthesizer_kwargs = response_synthesizer_kwargs or {}
|
||||
self._response_synthesizer_builder = response_synthesizer_builder
|
||||
|
||||
self._runnable = None
|
||||
self._runnable_kwargs = runnable_kwargs or {}
|
||||
self._runnable_builder = runnable_builder
|
||||
|
||||
self._instrumentor = None
|
||||
self._enable_tracing = enable_tracing
|
||||
|
||||
def set_up(self):
|
||||
"""Sets up the agent for execution of queries at runtime.
|
||||
|
||||
It initializes the model, connects it with the prompt template,
|
||||
retriever and response_synthesizer.
|
||||
|
||||
This method should not be called for an object that being passed to
|
||||
the ReasoningEngine service for deployment, as it initializes clients
|
||||
that can not be serialized.
|
||||
"""
|
||||
if self._enable_tracing:
|
||||
from vertexai.reasoning_engines import _utils
|
||||
|
||||
cloud_trace_exporter = _utils._import_cloud_trace_exporter_or_warn()
|
||||
cloud_trace_v2 = _utils._import_cloud_trace_v2_or_warn()
|
||||
openinference_llama_index = (
|
||||
_utils._import_openinference_llama_index_or_warn()
|
||||
)
|
||||
opentelemetry = _utils._import_opentelemetry_or_warn()
|
||||
opentelemetry_sdk_trace = _utils._import_opentelemetry_sdk_trace_or_warn()
|
||||
if all(
|
||||
(
|
||||
cloud_trace_exporter,
|
||||
cloud_trace_v2,
|
||||
openinference_llama_index,
|
||||
opentelemetry,
|
||||
opentelemetry_sdk_trace,
|
||||
)
|
||||
):
|
||||
import google.auth
|
||||
|
||||
credentials, _ = google.auth.default()
|
||||
span_exporter = cloud_trace_exporter.CloudTraceSpanExporter(
|
||||
project_id=self._project,
|
||||
client=cloud_trace_v2.TraceServiceClient(
|
||||
credentials=credentials.with_quota_project(self._project),
|
||||
),
|
||||
)
|
||||
span_processor: SpanProcessor = (
|
||||
opentelemetry_sdk_trace.export.SimpleSpanProcessor(
|
||||
span_exporter=span_exporter,
|
||||
)
|
||||
)
|
||||
tracer_provider: TracerProvider = (
|
||||
opentelemetry.trace.get_tracer_provider()
|
||||
)
|
||||
# Get the appropriate tracer provider:
|
||||
# 1. If _TRACER_PROVIDER is already set, use that.
|
||||
# 2. Otherwise, if the OTEL_PYTHON_TRACER_PROVIDER environment
|
||||
# variable is set, use that.
|
||||
# 3. As a final fallback, use _PROXY_TRACER_PROVIDER.
|
||||
# If none of the above is set, we log a warning, and
|
||||
# create a tracer provider.
|
||||
if not tracer_provider:
|
||||
from google.cloud.aiplatform import base
|
||||
|
||||
_LOGGER = base.Logger(__name__)
|
||||
_LOGGER.warning(
|
||||
"No tracer provider. By default, "
|
||||
"we should get one of the following providers: "
|
||||
"OTEL_PYTHON_TRACER_PROVIDER, _TRACER_PROVIDER, "
|
||||
"or _PROXY_TRACER_PROVIDER."
|
||||
)
|
||||
tracer_provider = opentelemetry_sdk_trace.TracerProvider()
|
||||
opentelemetry.trace.set_tracer_provider(tracer_provider)
|
||||
# Avoids AttributeError:
|
||||
# 'ProxyTracerProvider' and 'NoOpTracerProvider' objects has no
|
||||
# attribute 'add_span_processor'.
|
||||
if _utils.is_noop_or_proxy_tracer_provider(tracer_provider):
|
||||
tracer_provider = opentelemetry_sdk_trace.TracerProvider()
|
||||
opentelemetry.trace.set_tracer_provider(tracer_provider)
|
||||
# Avoids OpenTelemetry client already exists error.
|
||||
_override_active_span_processor(
|
||||
tracer_provider,
|
||||
opentelemetry_sdk_trace.SynchronousMultiSpanProcessor(),
|
||||
)
|
||||
tracer_provider.add_span_processor(span_processor)
|
||||
# Keep the instrumentation up-to-date.
|
||||
# When creating multiple LlamaIndexQueryPipelineAgents,
|
||||
# we need to keep the instrumentation up-to-date.
|
||||
# We deliberately override the instrument each time,
|
||||
# so that if different agents end up using different
|
||||
# instrumentations, we guarantee that the user is always
|
||||
# working with the most recent agent's instrumentation.
|
||||
self._instrumentor = openinference_llama_index.LlamaIndexInstrumentor()
|
||||
if self._instrumentor.is_instrumented_by_opentelemetry:
|
||||
self._instrumentor.uninstrument()
|
||||
self._instrumentor.instrument()
|
||||
else:
|
||||
from google.cloud.aiplatform import base
|
||||
|
||||
_LOGGER = base.Logger(__name__)
|
||||
_LOGGER.warning(
|
||||
"enable_tracing=True but proceeding with tracing disabled "
|
||||
"because not all packages for tracing have been installed"
|
||||
)
|
||||
|
||||
model_builder = self._model_builder or _default_model_builder
|
||||
self._model = model_builder(
|
||||
model_name=self._model_name,
|
||||
model_kwargs=self._model_kwargs,
|
||||
project=self._project,
|
||||
location=self._location,
|
||||
)
|
||||
|
||||
if self._retriever_builder:
|
||||
self._retriever = self._retriever_builder(
|
||||
model=self._model,
|
||||
retriever_kwargs=self._retriever_kwargs,
|
||||
)
|
||||
|
||||
if self._response_synthesizer_builder:
|
||||
self._response_synthesizer = self._response_synthesizer_builder(
|
||||
model=self._model,
|
||||
response_synthesizer_kwargs=self._response_synthesizer_kwargs,
|
||||
)
|
||||
|
||||
runnable_builder = self._runnable_builder or _default_runnable_builder
|
||||
self._runnable = runnable_builder(
|
||||
prompt=self._prompt,
|
||||
model=self._model,
|
||||
system_instruction=self._system_instruction,
|
||||
retriever=self._retriever,
|
||||
response_synthesizer=self._response_synthesizer,
|
||||
runnable_kwargs=self._runnable_kwargs,
|
||||
)
|
||||
|
||||
def clone(self) -> "LlamaIndexQueryPipelineAgent":
|
||||
"""Returns a clone of the LlamaIndexQueryPipelineAgent."""
|
||||
import copy
|
||||
|
||||
return LlamaIndexQueryPipelineAgent(
|
||||
model=self._model_name,
|
||||
system_instruction=self._system_instruction,
|
||||
prompt=copy.deepcopy(self._prompt),
|
||||
model_kwargs=copy.deepcopy(self._model_kwargs),
|
||||
model_builder=self._model_builder,
|
||||
retriever_kwargs=copy.deepcopy(self._retriever_kwargs),
|
||||
retriever_builder=self._retriever_builder,
|
||||
response_synthesizer_kwargs=copy.deepcopy(
|
||||
self._response_synthesizer_kwargs
|
||||
),
|
||||
response_synthesizer_builder=self._response_synthesizer_builder,
|
||||
runnable_kwargs=copy.deepcopy(self._runnable_kwargs),
|
||||
runnable_builder=self._runnable_builder,
|
||||
enable_tracing=self._enable_tracing,
|
||||
)
|
||||
|
||||
def query(
|
||||
self,
|
||||
input: Union[str, Mapping[str, Any]],
|
||||
**kwargs: Any,
|
||||
) -> Union[str, Dict[str, Any], Sequence[Union[str, Dict[str, Any]]]]:
|
||||
"""Queries the Agent with the given input and config.
|
||||
|
||||
Args:
|
||||
input (Union[str, Mapping[str, Any]]):
|
||||
Required. The input to be passed to the Agent.
|
||||
**kwargs:
|
||||
Optional. Any additional keyword arguments to be passed to the
|
||||
`.invoke()` method of the corresponding AgentExecutor.
|
||||
|
||||
Returns:
|
||||
The output of querying the Agent with the given input and config.
|
||||
"""
|
||||
from vertexai.reasoning_engines import _utils
|
||||
|
||||
if isinstance(input, str):
|
||||
input = {"input": input}
|
||||
|
||||
if not self._runnable:
|
||||
self.set_up()
|
||||
|
||||
if kwargs.get("batch"):
|
||||
nest_asyncio = _utils._import_nest_asyncio_or_warn()
|
||||
nest_asyncio.apply()
|
||||
|
||||
return _utils.to_json_serializable_llama_index_object(
|
||||
self._runnable.run(**input, **kwargs)
|
||||
)
|
||||
Reference in New Issue
Block a user