structure saas with tools
This commit is contained in:
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,92 @@
|
||||
from typing import Callable, Optional, Union
|
||||
|
||||
import httpx
|
||||
|
||||
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
|
||||
from litellm.types.utils import CustomStreamingDecoder, ModelResponse
|
||||
|
||||
from ...openai_like.chat.handler import OpenAILikeChatHandler
|
||||
from ..common_utils import _get_api_params
|
||||
from .transformation import IBMWatsonXChatConfig
|
||||
|
||||
watsonx_chat_transformation = IBMWatsonXChatConfig()
|
||||
|
||||
|
||||
class WatsonXChatHandler(OpenAILikeChatHandler):
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def completion(
|
||||
self,
|
||||
*,
|
||||
model: str,
|
||||
messages: list,
|
||||
api_base: str,
|
||||
custom_llm_provider: str,
|
||||
custom_prompt_dict: dict,
|
||||
model_response: ModelResponse,
|
||||
print_verbose: Callable,
|
||||
encoding,
|
||||
api_key: Optional[str],
|
||||
logging_obj,
|
||||
optional_params: dict,
|
||||
acompletion=None,
|
||||
litellm_params: dict = {},
|
||||
headers: Optional[dict] = None,
|
||||
logger_fn=None,
|
||||
timeout: Optional[Union[float, httpx.Timeout]] = None,
|
||||
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
|
||||
custom_endpoint: Optional[bool] = None,
|
||||
streaming_decoder: Optional[CustomStreamingDecoder] = None,
|
||||
fake_stream: bool = False,
|
||||
):
|
||||
api_params = _get_api_params(params=optional_params)
|
||||
|
||||
## UPDATE HEADERS
|
||||
headers = watsonx_chat_transformation.validate_environment(
|
||||
headers=headers or {},
|
||||
model=model,
|
||||
messages=messages,
|
||||
optional_params=optional_params,
|
||||
api_key=api_key,
|
||||
litellm_params=litellm_params,
|
||||
)
|
||||
|
||||
## UPDATE PAYLOAD (optional params)
|
||||
watsonx_auth_payload = watsonx_chat_transformation._prepare_payload(
|
||||
model=model,
|
||||
api_params=api_params,
|
||||
)
|
||||
optional_params.update(watsonx_auth_payload)
|
||||
|
||||
## GET API URL
|
||||
api_base = watsonx_chat_transformation.get_complete_url(
|
||||
api_base=api_base,
|
||||
api_key=api_key,
|
||||
model=model,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
stream=optional_params.get("stream", False),
|
||||
)
|
||||
|
||||
return super().completion(
|
||||
model=model,
|
||||
messages=messages,
|
||||
api_base=api_base,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
custom_prompt_dict=custom_prompt_dict,
|
||||
model_response=model_response,
|
||||
print_verbose=print_verbose,
|
||||
encoding=encoding,
|
||||
api_key=api_key,
|
||||
logging_obj=logging_obj,
|
||||
optional_params=optional_params,
|
||||
acompletion=acompletion,
|
||||
litellm_params=litellm_params,
|
||||
logger_fn=logger_fn,
|
||||
headers=headers,
|
||||
timeout=timeout,
|
||||
client=client,
|
||||
custom_endpoint=True,
|
||||
streaming_decoder=streaming_decoder,
|
||||
)
|
||||
@@ -0,0 +1,110 @@
|
||||
"""
|
||||
Translation from OpenAI's `/chat/completions` endpoint to IBM WatsonX's `/text/chat` endpoint.
|
||||
|
||||
Docs: https://cloud.ibm.com/apidocs/watsonx-ai#text-chat
|
||||
"""
|
||||
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
from litellm.secret_managers.main import get_secret_str
|
||||
from litellm.types.llms.watsonx import WatsonXAIEndpoint
|
||||
|
||||
from ....utils import _remove_additional_properties, _remove_strict_from_schema
|
||||
from ...openai.chat.gpt_transformation import OpenAIGPTConfig
|
||||
from ..common_utils import IBMWatsonXMixin
|
||||
|
||||
|
||||
class IBMWatsonXChatConfig(IBMWatsonXMixin, OpenAIGPTConfig):
|
||||
def get_supported_openai_params(self, model: str) -> List:
|
||||
return [
|
||||
"temperature", # equivalent to temperature
|
||||
"max_tokens", # equivalent to max_new_tokens
|
||||
"top_p", # equivalent to top_p
|
||||
"frequency_penalty", # equivalent to repetition_penalty
|
||||
"stop", # equivalent to stop_sequences
|
||||
"seed", # equivalent to random_seed
|
||||
"stream", # equivalent to stream
|
||||
"tools",
|
||||
"tool_choice", # equivalent to tool_choice + tool_choice_options
|
||||
"logprobs",
|
||||
"top_logprobs",
|
||||
"n",
|
||||
"presence_penalty",
|
||||
"response_format",
|
||||
]
|
||||
|
||||
def is_tool_choice_option(self, tool_choice: Optional[Union[str, dict]]) -> bool:
|
||||
if tool_choice is None:
|
||||
return False
|
||||
if isinstance(tool_choice, str):
|
||||
return tool_choice in ["auto", "none", "required"]
|
||||
return False
|
||||
|
||||
def map_openai_params(
|
||||
self,
|
||||
non_default_params: dict,
|
||||
optional_params: dict,
|
||||
model: str,
|
||||
drop_params: bool,
|
||||
) -> dict:
|
||||
## TOOLS ##
|
||||
_tools = non_default_params.pop("tools", None)
|
||||
if _tools is not None:
|
||||
# remove 'additionalProperties' from tools
|
||||
_tools = _remove_additional_properties(_tools)
|
||||
# remove 'strict' from tools
|
||||
_tools = _remove_strict_from_schema(_tools)
|
||||
if _tools is not None:
|
||||
non_default_params["tools"] = _tools
|
||||
|
||||
## TOOL CHOICE ##
|
||||
|
||||
_tool_choice = non_default_params.pop("tool_choice", None)
|
||||
if self.is_tool_choice_option(_tool_choice):
|
||||
optional_params["tool_choice_options"] = _tool_choice
|
||||
elif _tool_choice is not None:
|
||||
optional_params["tool_choice"] = _tool_choice
|
||||
return super().map_openai_params(
|
||||
non_default_params, optional_params, model, drop_params
|
||||
)
|
||||
|
||||
def _get_openai_compatible_provider_info(
|
||||
self, api_base: Optional[str], api_key: Optional[str]
|
||||
) -> Tuple[Optional[str], Optional[str]]:
|
||||
api_base = api_base or get_secret_str("HOSTED_VLLM_API_BASE") # type: ignore
|
||||
dynamic_api_key = (
|
||||
api_key or get_secret_str("HOSTED_VLLM_API_KEY") or ""
|
||||
) # vllm does not require an api key
|
||||
return api_base, dynamic_api_key
|
||||
|
||||
def get_complete_url(
|
||||
self,
|
||||
api_base: Optional[str],
|
||||
api_key: Optional[str],
|
||||
model: str,
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
stream: Optional[bool] = None,
|
||||
) -> str:
|
||||
url = self._get_base_url(api_base=api_base)
|
||||
if model.startswith("deployment/"):
|
||||
deployment_id = "/".join(model.split("/")[1:])
|
||||
endpoint = (
|
||||
WatsonXAIEndpoint.DEPLOYMENT_CHAT_STREAM.value
|
||||
if stream
|
||||
else WatsonXAIEndpoint.DEPLOYMENT_CHAT.value
|
||||
)
|
||||
endpoint = endpoint.format(deployment_id=deployment_id)
|
||||
else:
|
||||
endpoint = (
|
||||
WatsonXAIEndpoint.CHAT_STREAM.value
|
||||
if stream
|
||||
else WatsonXAIEndpoint.CHAT.value
|
||||
)
|
||||
url = url.rstrip("/") + endpoint
|
||||
|
||||
## add api version
|
||||
url = self._add_api_version_to_url(
|
||||
url=url, api_version=optional_params.pop("api_version", None)
|
||||
)
|
||||
return url
|
||||
Reference in New Issue
Block a user