structure saas with tools
This commit is contained in:
@@ -0,0 +1,487 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2023 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import dataclasses
|
||||
import inspect
|
||||
import json
|
||||
import types
|
||||
import typing
|
||||
from typing import Any, Callable, Dict, Iterable, Mapping, Optional, Sequence, Union
|
||||
|
||||
import proto
|
||||
|
||||
from google.cloud.aiplatform import base
|
||||
from google.api import httpbody_pb2
|
||||
from google.protobuf import struct_pb2
|
||||
from google.protobuf import json_format
|
||||
|
||||
try:
|
||||
# For LangChain templates, they might not import langchain_core and get
|
||||
# PydanticUserError: `query` is not fully defined; you should define
|
||||
# `RunnableConfig`, then call `query.model_rebuild()`.
|
||||
import langchain_core.runnables.config
|
||||
|
||||
RunnableConfig = langchain_core.runnables.config.RunnableConfig
|
||||
except ImportError:
|
||||
RunnableConfig = Any
|
||||
|
||||
try:
|
||||
from llama_index.core.base.response import schema as llama_index_schema
|
||||
from llama_index.core.base.llms import types as llama_index_types
|
||||
|
||||
LlamaIndexResponse = llama_index_schema.Response
|
||||
LlamaIndexBaseModel = llama_index_schema.BaseModel
|
||||
LlamaIndexChatResponse = llama_index_types.ChatResponse
|
||||
except ImportError:
|
||||
LlamaIndexResponse = Any
|
||||
LlamaIndexBaseModel = Any
|
||||
LlamaIndexChatResponse = Any
|
||||
|
||||
JsonDict = Dict[str, Any]
|
||||
|
||||
_LOGGER = base.Logger(__name__)
|
||||
|
||||
|
||||
def to_proto(
|
||||
obj: Union[JsonDict, proto.Message],
|
||||
message: Optional[proto.Message] = None,
|
||||
) -> proto.Message:
|
||||
"""Parses a JSON-like object into a message.
|
||||
|
||||
If the object is already a message, this will return the object as-is. If
|
||||
the object is a JSON Dict, this will parse and merge the object into the
|
||||
message.
|
||||
|
||||
Args:
|
||||
obj (Union[dict[str, Any], proto.Message]):
|
||||
Required. The object to convert to a proto message.
|
||||
message (proto.Message):
|
||||
Optional. A protocol buffer message to merge the obj into. It
|
||||
defaults to Struct() if unspecified.
|
||||
|
||||
Returns:
|
||||
proto.Message: The same message passed as argument.
|
||||
"""
|
||||
if message is None:
|
||||
message = struct_pb2.Struct()
|
||||
if isinstance(obj, (proto.Message, struct_pb2.Struct)):
|
||||
return obj
|
||||
try:
|
||||
json_format.ParseDict(obj, message._pb)
|
||||
except AttributeError:
|
||||
json_format.ParseDict(obj, message)
|
||||
return message
|
||||
|
||||
|
||||
def to_dict(message: proto.Message) -> JsonDict:
|
||||
"""Converts the contents of the protobuf message to JSON format.
|
||||
|
||||
Args:
|
||||
message (proto.Message):
|
||||
Required. The proto message to be converted to a JSON dictionary.
|
||||
|
||||
Returns:
|
||||
dict[str, Any]: A dictionary containing the contents of the proto.
|
||||
"""
|
||||
try:
|
||||
# Best effort attempt to convert the message into a JSON dictionary.
|
||||
result: JsonDict = json.loads(json_format.MessageToJson(message._pb))
|
||||
except AttributeError:
|
||||
result: JsonDict = json.loads(json_format.MessageToJson(message))
|
||||
return result
|
||||
|
||||
|
||||
def dataclass_to_dict(obj: dataclasses.dataclass) -> JsonDict:
|
||||
"""Converts a dataclass to a JSON dictionary.
|
||||
|
||||
Args:
|
||||
obj (dataclasses.dataclass):
|
||||
Required. The dataclass to be converted to a JSON dictionary.
|
||||
|
||||
Returns:
|
||||
dict[str, Any]: A dictionary containing the contents of the dataclass.
|
||||
"""
|
||||
return json.loads(json.dumps(dataclasses.asdict(obj)))
|
||||
|
||||
|
||||
def _llama_index_response_to_dict(obj: LlamaIndexResponse) -> Dict[str, Any]:
|
||||
response = {}
|
||||
if hasattr(obj, "response"):
|
||||
response["response"] = obj.response
|
||||
if hasattr(obj, "source_nodes"):
|
||||
response["source_nodes"] = [node.model_dump_json() for node in obj.source_nodes]
|
||||
if hasattr(obj, "metadata"):
|
||||
response["metadata"] = obj.metadata
|
||||
|
||||
return json.loads(json.dumps(response))
|
||||
|
||||
|
||||
def _llama_index_chat_response_to_dict(
|
||||
obj: LlamaIndexChatResponse,
|
||||
) -> Dict[str, Any]:
|
||||
return json.loads(obj.message.model_dump_json())
|
||||
|
||||
|
||||
def _llama_index_base_model_to_dict(
|
||||
obj: LlamaIndexBaseModel,
|
||||
) -> Dict[str, Any]:
|
||||
return json.loads(obj.model_dump_json())
|
||||
|
||||
|
||||
def to_json_serializable_llama_index_object(
|
||||
obj: Union[
|
||||
LlamaIndexResponse,
|
||||
LlamaIndexBaseModel,
|
||||
LlamaIndexChatResponse,
|
||||
Sequence[LlamaIndexBaseModel],
|
||||
]
|
||||
) -> Union[str, Dict[str, Any], Sequence[Union[str, Dict[str, Any]]]]:
|
||||
"""Converts a LlamaIndexResponse to a JSON serializable object."""
|
||||
if isinstance(obj, LlamaIndexResponse):
|
||||
return _llama_index_response_to_dict(obj)
|
||||
if isinstance(obj, LlamaIndexChatResponse):
|
||||
return _llama_index_chat_response_to_dict(obj)
|
||||
if isinstance(obj, Sequence):
|
||||
seq_result = []
|
||||
for item in obj:
|
||||
if isinstance(item, LlamaIndexBaseModel):
|
||||
seq_result.append(_llama_index_base_model_to_dict(item))
|
||||
continue
|
||||
seq_result.append(str(item))
|
||||
return seq_result
|
||||
if isinstance(obj, LlamaIndexBaseModel):
|
||||
return _llama_index_base_model_to_dict(obj)
|
||||
return str(obj)
|
||||
|
||||
|
||||
def yield_parsed_json(body: httpbody_pb2.HttpBody) -> Iterable[Any]:
|
||||
"""Converts the contents of the httpbody message to JSON format.
|
||||
|
||||
Args:
|
||||
body (httpbody_pb2.HttpBody):
|
||||
Required. The httpbody body to be converted to a JSON.
|
||||
|
||||
Yields:
|
||||
Any: A JSON object or the original body if it is not JSON or None.
|
||||
"""
|
||||
content_type = getattr(body, "content_type", None)
|
||||
data = getattr(body, "data", None)
|
||||
|
||||
if content_type is None or data is None or "application/json" not in content_type:
|
||||
yield body
|
||||
return
|
||||
|
||||
try:
|
||||
utf8_data = data.decode("utf-8")
|
||||
except Exception as e:
|
||||
_LOGGER.warning(f"Failed to decode data: {data}. Exception: {e}")
|
||||
yield body
|
||||
return
|
||||
|
||||
if not utf8_data:
|
||||
yield None
|
||||
return
|
||||
|
||||
# Handle the case of multiple dictionaries delimited by newlines.
|
||||
for line in utf8_data.split("\n"):
|
||||
if line:
|
||||
try:
|
||||
line = json.loads(line)
|
||||
except Exception as e:
|
||||
_LOGGER.warning(f"failed to parse json: {line}. Exception: {e}")
|
||||
yield line
|
||||
|
||||
|
||||
def generate_schema(
|
||||
f: Callable[..., Any],
|
||||
*,
|
||||
schema_name: Optional[str] = None,
|
||||
descriptions: Mapping[str, str] = {},
|
||||
required: Sequence[str] = [],
|
||||
) -> JsonDict:
|
||||
"""Generates the OpenAPI Schema for a callable object.
|
||||
|
||||
Only positional and keyword arguments of the function `f` will be supported
|
||||
in the OpenAPI Schema that is generated. I.e. `*args` and `**kwargs` will
|
||||
not be present in the OpenAPI schema returned from this function. For those
|
||||
cases, you can either include it in the docstring for `f`, or modify the
|
||||
OpenAPI schema returned from this function to include additional arguments.
|
||||
|
||||
Args:
|
||||
f (Callable):
|
||||
Required. The function to generate an OpenAPI Schema for.
|
||||
schema_name (str):
|
||||
Optional. The name for the OpenAPI schema. If unspecified, the name
|
||||
of the Callable will be used.
|
||||
descriptions (Mapping[str, str]):
|
||||
Optional. A `{name: description}` mapping for annotating input
|
||||
arguments of the function with user-provided descriptions. It
|
||||
defaults to an empty dictionary (i.e. there will not be any
|
||||
description for any of the inputs).
|
||||
required (Sequence[str]):
|
||||
Optional. For the user to specify the set of required arguments in
|
||||
function calls to `f`. If specified, it will be automatically
|
||||
inferred from `f`.
|
||||
|
||||
Returns:
|
||||
dict[str, Any]: The OpenAPI Schema for the function `f` in JSON format.
|
||||
"""
|
||||
pydantic = _import_pydantic_or_raise()
|
||||
defaults = dict(inspect.signature(f).parameters)
|
||||
fields_dict = {
|
||||
name: (
|
||||
# 1. We infer the argument type here: use Any rather than None so
|
||||
# it will not try to auto-infer the type based on the default value.
|
||||
(param.annotation if param.annotation != inspect.Parameter.empty else Any),
|
||||
pydantic.Field(
|
||||
# 2. We do not support default values for now.
|
||||
# default=(
|
||||
# param.default if param.default != inspect.Parameter.empty
|
||||
# else None
|
||||
# ),
|
||||
# 3. We support user-provided descriptions.
|
||||
description=descriptions.get(name, None),
|
||||
),
|
||||
)
|
||||
for name, param in defaults.items()
|
||||
# We do not support *args or **kwargs
|
||||
if param.kind
|
||||
in (
|
||||
inspect.Parameter.POSITIONAL_OR_KEYWORD,
|
||||
inspect.Parameter.KEYWORD_ONLY,
|
||||
inspect.Parameter.POSITIONAL_ONLY,
|
||||
)
|
||||
}
|
||||
parameters = pydantic.create_model(f.__name__, **fields_dict).schema()
|
||||
# Postprocessing
|
||||
# 4. Suppress unnecessary title generation:
|
||||
# * https://github.com/pydantic/pydantic/issues/1051
|
||||
# * http://cl/586221780
|
||||
parameters.pop("title", "")
|
||||
for name, function_arg in parameters.get("properties", {}).items():
|
||||
function_arg.pop("title", "")
|
||||
annotation = defaults[name].annotation
|
||||
# 5. Nullable fields:
|
||||
# * https://github.com/pydantic/pydantic/issues/1270
|
||||
# * https://stackoverflow.com/a/58841311
|
||||
# * https://github.com/pydantic/pydantic/discussions/4872
|
||||
if typing.get_origin(annotation) is Union and type(None) in typing.get_args(
|
||||
annotation
|
||||
):
|
||||
# for "typing.Optional" arguments, function_arg might be a
|
||||
# dictionary like
|
||||
#
|
||||
# {'anyOf': [{'type': 'integer'}, {'type': 'null'}]
|
||||
for schema in function_arg.pop("anyOf", []):
|
||||
schema_type = schema.get("type")
|
||||
if schema_type and schema_type != "null":
|
||||
function_arg["type"] = schema_type
|
||||
break
|
||||
function_arg["nullable"] = True
|
||||
# 6. Annotate required fields.
|
||||
if required:
|
||||
# We use the user-provided "required" fields if specified.
|
||||
parameters["required"] = required
|
||||
else:
|
||||
# Otherwise we infer it from the function signature.
|
||||
parameters["required"] = [
|
||||
k
|
||||
for k in defaults
|
||||
if (
|
||||
defaults[k].default == inspect.Parameter.empty
|
||||
and defaults[k].kind
|
||||
in (
|
||||
inspect.Parameter.POSITIONAL_OR_KEYWORD,
|
||||
inspect.Parameter.KEYWORD_ONLY,
|
||||
inspect.Parameter.POSITIONAL_ONLY,
|
||||
)
|
||||
)
|
||||
]
|
||||
schema = dict(name=f.__name__, description=f.__doc__, parameters=parameters)
|
||||
if schema_name:
|
||||
schema["name"] = schema_name
|
||||
return schema
|
||||
|
||||
|
||||
def is_noop_or_proxy_tracer_provider(tracer_provider) -> bool:
|
||||
"""Returns True if the tracer_provider is Proxy or NoOp."""
|
||||
opentelemetry = _import_opentelemetry_or_warn()
|
||||
ProxyTracerProvider = opentelemetry.trace.ProxyTracerProvider
|
||||
NoOpTracerProvider = opentelemetry.trace.NoOpTracerProvider
|
||||
return isinstance(tracer_provider, (NoOpTracerProvider, ProxyTracerProvider))
|
||||
|
||||
|
||||
def _import_cloud_storage_or_raise() -> types.ModuleType:
|
||||
"""Tries to import the Cloud Storage module."""
|
||||
try:
|
||||
from google.cloud import storage
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"Cloud Storage is not installed. Please call "
|
||||
"'pip install google-cloud-aiplatform[reasoningengine]'."
|
||||
) from e
|
||||
return storage
|
||||
|
||||
|
||||
def _import_cloudpickle_or_raise() -> types.ModuleType:
|
||||
"""Tries to import the cloudpickle module."""
|
||||
try:
|
||||
import cloudpickle # noqa:F401
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"cloudpickle is not installed. Please call "
|
||||
"'pip install google-cloud-aiplatform[reasoningengine]'."
|
||||
) from e
|
||||
return cloudpickle
|
||||
|
||||
|
||||
def _import_pydantic_or_raise() -> types.ModuleType:
|
||||
"""Tries to import the pydantic module."""
|
||||
try:
|
||||
import pydantic
|
||||
|
||||
_ = pydantic.Field
|
||||
except AttributeError:
|
||||
from pydantic import v1 as pydantic
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"pydantic is not installed. Please call "
|
||||
"'pip install google-cloud-aiplatform[reasoningengine]'."
|
||||
) from e
|
||||
return pydantic
|
||||
|
||||
|
||||
def _import_opentelemetry_or_warn() -> Optional[types.ModuleType]:
|
||||
"""Tries to import the opentelemetry module."""
|
||||
try:
|
||||
import opentelemetry # noqa:F401
|
||||
|
||||
return opentelemetry
|
||||
except ImportError:
|
||||
_LOGGER.warning(
|
||||
"opentelemetry-sdk is not installed. Please call "
|
||||
"'pip install google-cloud-aiplatform[reasoningengine]'."
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
def _import_opentelemetry_sdk_trace_or_warn() -> Optional[types.ModuleType]:
|
||||
"""Tries to import the opentelemetry.sdk.trace module."""
|
||||
try:
|
||||
import opentelemetry.sdk.trace # noqa:F401
|
||||
|
||||
return opentelemetry.sdk.trace
|
||||
except ImportError:
|
||||
_LOGGER.warning(
|
||||
"opentelemetry-sdk is not installed. Please call "
|
||||
"'pip install google-cloud-aiplatform[reasoningengine]'."
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
def _import_cloud_trace_v2_or_warn() -> Optional[types.ModuleType]:
|
||||
"""Tries to import the google.cloud.trace_v2 module."""
|
||||
try:
|
||||
import google.cloud.trace_v2
|
||||
|
||||
return google.cloud.trace_v2
|
||||
except ImportError:
|
||||
_LOGGER.warning(
|
||||
"google-cloud-trace is not installed. Please call "
|
||||
"'pip install google-cloud-aiplatform[reasoningengine]'."
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
def _import_cloud_trace_exporter_or_warn() -> Optional[types.ModuleType]:
|
||||
"""Tries to import the opentelemetry.exporter.cloud_trace module."""
|
||||
try:
|
||||
import opentelemetry.exporter.cloud_trace # noqa:F401
|
||||
|
||||
return opentelemetry.exporter.cloud_trace
|
||||
except ImportError:
|
||||
_LOGGER.warning(
|
||||
"opentelemetry-exporter-gcp-trace is not installed. Please "
|
||||
"call 'pip install google-cloud-aiplatform[langchain]'."
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
def _import_openinference_langchain_or_warn() -> Optional[types.ModuleType]:
|
||||
"""Tries to import the openinference.instrumentation.langchain module."""
|
||||
try:
|
||||
import openinference.instrumentation.langchain # noqa:F401
|
||||
|
||||
return openinference.instrumentation.langchain
|
||||
except ImportError:
|
||||
_LOGGER.warning(
|
||||
"openinference-instrumentation-langchain is not installed. Please "
|
||||
"call 'pip install google-cloud-aiplatform[langchain]'."
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
def _import_openinference_autogen_or_warn() -> Optional[types.ModuleType]:
|
||||
"""Tries to import the openinference.instrumentation.autogen module."""
|
||||
try:
|
||||
import openinference.instrumentation.autogen # noqa:F401
|
||||
|
||||
return openinference.instrumentation.autogen
|
||||
except ImportError:
|
||||
_LOGGER.warning(
|
||||
"openinference-instrumentation-autogen is not installed. Please "
|
||||
"call 'pip install openinference-instrumentation-autogen'."
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
def _import_openinference_llama_index_or_warn() -> Optional[types.ModuleType]:
|
||||
"""Tries to import the openinference.instrumentation.llama_index module."""
|
||||
try:
|
||||
import openinference.instrumentation.llama_index # noqa:F401
|
||||
|
||||
return openinference.instrumentation.llama_index
|
||||
except ImportError:
|
||||
_LOGGER.warning(
|
||||
"openinference-instrumentation-llama_index is not installed. Please "
|
||||
"call 'pip install google-cloud-aiplatform[llama_index]'."
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
def _import_autogen_tools_or_warn() -> Optional[types.ModuleType]:
|
||||
"""Tries to import the autogen.tools module."""
|
||||
try:
|
||||
from autogen import tools
|
||||
|
||||
return tools
|
||||
except ImportError:
|
||||
_LOGGER.warning(
|
||||
"autogen.tools is not installed. Please call: `pip install ag2[tools]`"
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
def _import_nest_asyncio_or_warn() -> Optional[types.ModuleType]:
|
||||
"""Tries to import the nest_asyncio module."""
|
||||
try:
|
||||
import nest_asyncio
|
||||
|
||||
return nest_asyncio
|
||||
except ImportError:
|
||||
_LOGGER.warning(
|
||||
"nest_asyncio is not installed. Please call: `pip install nest-asyncio`"
|
||||
)
|
||||
return None
|
||||
Reference in New Issue
Block a user