structure saas with tools
This commit is contained in:
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,92 @@
|
||||
from typing import Callable, Optional, Union
|
||||
|
||||
import httpx
|
||||
|
||||
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
|
||||
from litellm.types.utils import CustomStreamingDecoder, ModelResponse
|
||||
|
||||
from ...openai_like.chat.handler import OpenAILikeChatHandler
|
||||
from ..common_utils import _get_api_params
|
||||
from .transformation import IBMWatsonXChatConfig
|
||||
|
||||
watsonx_chat_transformation = IBMWatsonXChatConfig()
|
||||
|
||||
|
||||
class WatsonXChatHandler(OpenAILikeChatHandler):
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def completion(
|
||||
self,
|
||||
*,
|
||||
model: str,
|
||||
messages: list,
|
||||
api_base: str,
|
||||
custom_llm_provider: str,
|
||||
custom_prompt_dict: dict,
|
||||
model_response: ModelResponse,
|
||||
print_verbose: Callable,
|
||||
encoding,
|
||||
api_key: Optional[str],
|
||||
logging_obj,
|
||||
optional_params: dict,
|
||||
acompletion=None,
|
||||
litellm_params: dict = {},
|
||||
headers: Optional[dict] = None,
|
||||
logger_fn=None,
|
||||
timeout: Optional[Union[float, httpx.Timeout]] = None,
|
||||
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
|
||||
custom_endpoint: Optional[bool] = None,
|
||||
streaming_decoder: Optional[CustomStreamingDecoder] = None,
|
||||
fake_stream: bool = False,
|
||||
):
|
||||
api_params = _get_api_params(params=optional_params)
|
||||
|
||||
## UPDATE HEADERS
|
||||
headers = watsonx_chat_transformation.validate_environment(
|
||||
headers=headers or {},
|
||||
model=model,
|
||||
messages=messages,
|
||||
optional_params=optional_params,
|
||||
api_key=api_key,
|
||||
litellm_params=litellm_params,
|
||||
)
|
||||
|
||||
## UPDATE PAYLOAD (optional params)
|
||||
watsonx_auth_payload = watsonx_chat_transformation._prepare_payload(
|
||||
model=model,
|
||||
api_params=api_params,
|
||||
)
|
||||
optional_params.update(watsonx_auth_payload)
|
||||
|
||||
## GET API URL
|
||||
api_base = watsonx_chat_transformation.get_complete_url(
|
||||
api_base=api_base,
|
||||
api_key=api_key,
|
||||
model=model,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
stream=optional_params.get("stream", False),
|
||||
)
|
||||
|
||||
return super().completion(
|
||||
model=model,
|
||||
messages=messages,
|
||||
api_base=api_base,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
custom_prompt_dict=custom_prompt_dict,
|
||||
model_response=model_response,
|
||||
print_verbose=print_verbose,
|
||||
encoding=encoding,
|
||||
api_key=api_key,
|
||||
logging_obj=logging_obj,
|
||||
optional_params=optional_params,
|
||||
acompletion=acompletion,
|
||||
litellm_params=litellm_params,
|
||||
logger_fn=logger_fn,
|
||||
headers=headers,
|
||||
timeout=timeout,
|
||||
client=client,
|
||||
custom_endpoint=True,
|
||||
streaming_decoder=streaming_decoder,
|
||||
)
|
||||
@@ -0,0 +1,110 @@
|
||||
"""
|
||||
Translation from OpenAI's `/chat/completions` endpoint to IBM WatsonX's `/text/chat` endpoint.
|
||||
|
||||
Docs: https://cloud.ibm.com/apidocs/watsonx-ai#text-chat
|
||||
"""
|
||||
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
from litellm.secret_managers.main import get_secret_str
|
||||
from litellm.types.llms.watsonx import WatsonXAIEndpoint
|
||||
|
||||
from ....utils import _remove_additional_properties, _remove_strict_from_schema
|
||||
from ...openai.chat.gpt_transformation import OpenAIGPTConfig
|
||||
from ..common_utils import IBMWatsonXMixin
|
||||
|
||||
|
||||
class IBMWatsonXChatConfig(IBMWatsonXMixin, OpenAIGPTConfig):
|
||||
def get_supported_openai_params(self, model: str) -> List:
|
||||
return [
|
||||
"temperature", # equivalent to temperature
|
||||
"max_tokens", # equivalent to max_new_tokens
|
||||
"top_p", # equivalent to top_p
|
||||
"frequency_penalty", # equivalent to repetition_penalty
|
||||
"stop", # equivalent to stop_sequences
|
||||
"seed", # equivalent to random_seed
|
||||
"stream", # equivalent to stream
|
||||
"tools",
|
||||
"tool_choice", # equivalent to tool_choice + tool_choice_options
|
||||
"logprobs",
|
||||
"top_logprobs",
|
||||
"n",
|
||||
"presence_penalty",
|
||||
"response_format",
|
||||
]
|
||||
|
||||
def is_tool_choice_option(self, tool_choice: Optional[Union[str, dict]]) -> bool:
|
||||
if tool_choice is None:
|
||||
return False
|
||||
if isinstance(tool_choice, str):
|
||||
return tool_choice in ["auto", "none", "required"]
|
||||
return False
|
||||
|
||||
def map_openai_params(
|
||||
self,
|
||||
non_default_params: dict,
|
||||
optional_params: dict,
|
||||
model: str,
|
||||
drop_params: bool,
|
||||
) -> dict:
|
||||
## TOOLS ##
|
||||
_tools = non_default_params.pop("tools", None)
|
||||
if _tools is not None:
|
||||
# remove 'additionalProperties' from tools
|
||||
_tools = _remove_additional_properties(_tools)
|
||||
# remove 'strict' from tools
|
||||
_tools = _remove_strict_from_schema(_tools)
|
||||
if _tools is not None:
|
||||
non_default_params["tools"] = _tools
|
||||
|
||||
## TOOL CHOICE ##
|
||||
|
||||
_tool_choice = non_default_params.pop("tool_choice", None)
|
||||
if self.is_tool_choice_option(_tool_choice):
|
||||
optional_params["tool_choice_options"] = _tool_choice
|
||||
elif _tool_choice is not None:
|
||||
optional_params["tool_choice"] = _tool_choice
|
||||
return super().map_openai_params(
|
||||
non_default_params, optional_params, model, drop_params
|
||||
)
|
||||
|
||||
def _get_openai_compatible_provider_info(
|
||||
self, api_base: Optional[str], api_key: Optional[str]
|
||||
) -> Tuple[Optional[str], Optional[str]]:
|
||||
api_base = api_base or get_secret_str("HOSTED_VLLM_API_BASE") # type: ignore
|
||||
dynamic_api_key = (
|
||||
api_key or get_secret_str("HOSTED_VLLM_API_KEY") or ""
|
||||
) # vllm does not require an api key
|
||||
return api_base, dynamic_api_key
|
||||
|
||||
def get_complete_url(
|
||||
self,
|
||||
api_base: Optional[str],
|
||||
api_key: Optional[str],
|
||||
model: str,
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
stream: Optional[bool] = None,
|
||||
) -> str:
|
||||
url = self._get_base_url(api_base=api_base)
|
||||
if model.startswith("deployment/"):
|
||||
deployment_id = "/".join(model.split("/")[1:])
|
||||
endpoint = (
|
||||
WatsonXAIEndpoint.DEPLOYMENT_CHAT_STREAM.value
|
||||
if stream
|
||||
else WatsonXAIEndpoint.DEPLOYMENT_CHAT.value
|
||||
)
|
||||
endpoint = endpoint.format(deployment_id=deployment_id)
|
||||
else:
|
||||
endpoint = (
|
||||
WatsonXAIEndpoint.CHAT_STREAM.value
|
||||
if stream
|
||||
else WatsonXAIEndpoint.CHAT.value
|
||||
)
|
||||
url = url.rstrip("/") + endpoint
|
||||
|
||||
## add api version
|
||||
url = self._add_api_version_to_url(
|
||||
url=url, api_version=optional_params.pop("api_version", None)
|
||||
)
|
||||
return url
|
||||
@@ -0,0 +1,292 @@
|
||||
from typing import Dict, List, Optional, Union, cast
|
||||
|
||||
import httpx
|
||||
|
||||
import litellm
|
||||
from litellm import verbose_logger
|
||||
from litellm.caching import InMemoryCache
|
||||
from litellm.litellm_core_utils.prompt_templates import factory as ptf
|
||||
from litellm.llms.base_llm.chat.transformation import BaseLLMException
|
||||
from litellm.secret_managers.main import get_secret_str
|
||||
from litellm.types.llms.openai import AllMessageValues
|
||||
from litellm.types.llms.watsonx import WatsonXAPIParams, WatsonXCredentials
|
||||
|
||||
|
||||
class WatsonXAIError(BaseLLMException):
|
||||
def __init__(
|
||||
self,
|
||||
status_code: int,
|
||||
message: str,
|
||||
headers: Optional[Union[Dict, httpx.Headers]] = None,
|
||||
):
|
||||
super().__init__(status_code=status_code, message=message, headers=headers)
|
||||
|
||||
|
||||
iam_token_cache = InMemoryCache()
|
||||
|
||||
|
||||
def get_watsonx_iam_url():
|
||||
return (
|
||||
get_secret_str("WATSONX_IAM_URL") or "https://iam.cloud.ibm.com/identity/token"
|
||||
)
|
||||
|
||||
|
||||
def generate_iam_token(api_key=None, **params) -> str:
|
||||
result: Optional[str] = iam_token_cache.get_cache(api_key) # type: ignore
|
||||
|
||||
if result is None:
|
||||
headers = {}
|
||||
headers["Content-Type"] = "application/x-www-form-urlencoded"
|
||||
if api_key is None:
|
||||
api_key = get_secret_str("WX_API_KEY") or get_secret_str("WATSONX_API_KEY") or get_secret_str("WATSONX_APIKEY")
|
||||
if api_key is None:
|
||||
raise ValueError("API key is required")
|
||||
headers["Accept"] = "application/json"
|
||||
data = {
|
||||
"grant_type": "urn:ibm:params:oauth:grant-type:apikey",
|
||||
"apikey": api_key,
|
||||
}
|
||||
iam_token_url = get_watsonx_iam_url()
|
||||
verbose_logger.debug(
|
||||
"calling ibm `/identity/token` to retrieve IAM token.\nURL=%s\nheaders=%s\ndata=%s",
|
||||
iam_token_url,
|
||||
headers,
|
||||
data,
|
||||
)
|
||||
response = litellm.module_level_client.post(
|
||||
url=iam_token_url, data=data, headers=headers
|
||||
)
|
||||
response.raise_for_status()
|
||||
json_data = response.json()
|
||||
|
||||
result = json_data["access_token"]
|
||||
iam_token_cache.set_cache(
|
||||
key=api_key,
|
||||
value=result,
|
||||
ttl=json_data["expires_in"] - 10, # leave some buffer
|
||||
)
|
||||
|
||||
return cast(str, result)
|
||||
|
||||
|
||||
def _generate_watsonx_token(api_key: Optional[str], token: Optional[str]) -> str:
|
||||
if token is not None:
|
||||
return token
|
||||
token = generate_iam_token(api_key)
|
||||
return token
|
||||
|
||||
|
||||
def _get_api_params(
|
||||
params: dict,
|
||||
) -> WatsonXAPIParams:
|
||||
"""
|
||||
Find watsonx.ai credentials in the params or environment variables and return the headers for authentication.
|
||||
"""
|
||||
# Load auth variables from params
|
||||
project_id = params.pop(
|
||||
"project_id", params.pop("watsonx_project", None)
|
||||
) # watsonx.ai project_id - allow 'watsonx_project' to be consistent with how vertex project implementation works -> reduce provider-specific params
|
||||
space_id = params.pop("space_id", None) # watsonx.ai deployment space_id
|
||||
region_name = params.pop("region_name", params.pop("region", None))
|
||||
if region_name is None:
|
||||
region_name = params.pop(
|
||||
"watsonx_region_name", params.pop("watsonx_region", None)
|
||||
) # consistent with how vertex ai + aws regions are accepted
|
||||
|
||||
# Load auth variables from environment variables
|
||||
if project_id is None:
|
||||
project_id = (
|
||||
get_secret_str("WATSONX_PROJECT_ID")
|
||||
or get_secret_str("WX_PROJECT_ID")
|
||||
or get_secret_str("PROJECT_ID")
|
||||
)
|
||||
if region_name is None:
|
||||
region_name = (
|
||||
get_secret_str("WATSONX_REGION")
|
||||
or get_secret_str("WX_REGION")
|
||||
or get_secret_str("REGION")
|
||||
)
|
||||
if space_id is None:
|
||||
space_id = (
|
||||
get_secret_str("WATSONX_DEPLOYMENT_SPACE_ID")
|
||||
or get_secret_str("WATSONX_SPACE_ID")
|
||||
or get_secret_str("WX_SPACE_ID")
|
||||
or get_secret_str("SPACE_ID")
|
||||
)
|
||||
|
||||
if project_id is None:
|
||||
raise WatsonXAIError(
|
||||
status_code=401,
|
||||
message="Error: Watsonx project_id not set. Set WX_PROJECT_ID in environment variables or pass in as a parameter.",
|
||||
)
|
||||
|
||||
return WatsonXAPIParams(
|
||||
project_id=project_id,
|
||||
space_id=space_id,
|
||||
region_name=region_name,
|
||||
)
|
||||
|
||||
|
||||
def convert_watsonx_messages_to_prompt(
|
||||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
provider: str,
|
||||
custom_prompt_dict: Dict,
|
||||
) -> str:
|
||||
# handle anthropic prompts and amazon titan prompts
|
||||
if model in custom_prompt_dict:
|
||||
# check if the model has a registered custom prompt
|
||||
model_prompt_dict = custom_prompt_dict[model]
|
||||
prompt = ptf.custom_prompt(
|
||||
messages=messages,
|
||||
role_dict=model_prompt_dict.get(
|
||||
"role_dict", model_prompt_dict.get("roles")
|
||||
),
|
||||
initial_prompt_value=model_prompt_dict.get("initial_prompt_value", ""),
|
||||
final_prompt_value=model_prompt_dict.get("final_prompt_value", ""),
|
||||
bos_token=model_prompt_dict.get("bos_token", ""),
|
||||
eos_token=model_prompt_dict.get("eos_token", ""),
|
||||
)
|
||||
return prompt
|
||||
elif provider == "ibm-mistralai":
|
||||
prompt = ptf.mistral_instruct_pt(messages=messages)
|
||||
else:
|
||||
prompt: str = ptf.prompt_factory( # type: ignore
|
||||
model=model, messages=messages, custom_llm_provider="watsonx"
|
||||
)
|
||||
return prompt
|
||||
|
||||
|
||||
# Mixin class for shared IBM Watson X functionality
|
||||
class IBMWatsonXMixin:
|
||||
def validate_environment(
|
||||
self,
|
||||
headers: Dict,
|
||||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: Dict,
|
||||
litellm_params: dict,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
) -> Dict:
|
||||
default_headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Accept": "application/json",
|
||||
}
|
||||
|
||||
if "Authorization" in headers:
|
||||
return {**default_headers, **headers}
|
||||
token = cast(
|
||||
Optional[str],
|
||||
optional_params.get("token") or get_secret_str("WATSONX_TOKEN"),
|
||||
)
|
||||
if token:
|
||||
headers["Authorization"] = f"Bearer {token}"
|
||||
elif zen_api_key := get_secret_str("WATSONX_ZENAPIKEY"):
|
||||
headers["Authorization"] = f"ZenApiKey {zen_api_key}"
|
||||
else:
|
||||
token = _generate_watsonx_token(api_key=api_key, token=token)
|
||||
# build auth headers
|
||||
headers["Authorization"] = f"Bearer {token}"
|
||||
return {**default_headers, **headers}
|
||||
|
||||
def _get_base_url(self, api_base: Optional[str]) -> str:
|
||||
url = (
|
||||
api_base
|
||||
or get_secret_str("WATSONX_API_BASE") # consistent with 'AZURE_API_BASE'
|
||||
or get_secret_str("WATSONX_URL")
|
||||
or get_secret_str("WX_URL")
|
||||
or get_secret_str("WML_URL")
|
||||
)
|
||||
|
||||
if url is None:
|
||||
raise WatsonXAIError(
|
||||
status_code=401,
|
||||
message="Error: Watsonx URL not set. Set WATSONX_API_BASE in environment variables or pass in as parameter - 'api_base='.",
|
||||
)
|
||||
return url
|
||||
|
||||
def _add_api_version_to_url(self, url: str, api_version: Optional[str]) -> str:
|
||||
api_version = api_version or litellm.WATSONX_DEFAULT_API_VERSION
|
||||
url = url + f"?version={api_version}"
|
||||
|
||||
return url
|
||||
|
||||
def get_error_class(
|
||||
self, error_message: str, status_code: int, headers: Union[Dict, httpx.Headers]
|
||||
) -> BaseLLMException:
|
||||
return WatsonXAIError(
|
||||
status_code=status_code, message=error_message, headers=headers
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_watsonx_credentials(
|
||||
optional_params: dict, api_key: Optional[str], api_base: Optional[str]
|
||||
) -> WatsonXCredentials:
|
||||
api_key = (
|
||||
api_key
|
||||
or optional_params.pop("apikey", None)
|
||||
or get_secret_str("WATSONX_APIKEY")
|
||||
or get_secret_str("WATSONX_API_KEY")
|
||||
or get_secret_str("WX_API_KEY")
|
||||
)
|
||||
|
||||
api_base = (
|
||||
api_base
|
||||
or optional_params.pop(
|
||||
"url",
|
||||
optional_params.pop("api_base", optional_params.pop("base_url", None)),
|
||||
)
|
||||
or get_secret_str("WATSONX_API_BASE")
|
||||
or get_secret_str("WATSONX_URL")
|
||||
or get_secret_str("WX_URL")
|
||||
or get_secret_str("WML_URL")
|
||||
)
|
||||
|
||||
wx_credentials = optional_params.pop(
|
||||
"wx_credentials",
|
||||
optional_params.pop(
|
||||
"watsonx_credentials", None
|
||||
), # follow {provider}_credentials, same as vertex ai
|
||||
)
|
||||
|
||||
token: Optional[str] = None
|
||||
|
||||
if wx_credentials is not None:
|
||||
api_base = wx_credentials.get("url", api_base)
|
||||
api_key = wx_credentials.get(
|
||||
"apikey", wx_credentials.get("api_key", api_key)
|
||||
)
|
||||
token = wx_credentials.get(
|
||||
"token",
|
||||
wx_credentials.get(
|
||||
"watsonx_token", None
|
||||
), # follow format of {provider}_token, same as azure - e.g. 'azure_ad_token=..'
|
||||
)
|
||||
if api_key is None or not isinstance(api_key, str):
|
||||
raise WatsonXAIError(
|
||||
status_code=401,
|
||||
message="Error: Watsonx API key not set. Set WATSONX_API_KEY in environment variables or pass in as parameter - 'api_key='.",
|
||||
)
|
||||
if api_base is None or not isinstance(api_base, str):
|
||||
raise WatsonXAIError(
|
||||
status_code=401,
|
||||
message="Error: Watsonx API base not set. Set WATSONX_API_BASE in environment variables or pass in as parameter - 'api_base='.",
|
||||
)
|
||||
return WatsonXCredentials(
|
||||
api_key=api_key, api_base=api_base, token=cast(Optional[str], token)
|
||||
)
|
||||
|
||||
def _prepare_payload(self, model: str, api_params: WatsonXAPIParams) -> dict:
|
||||
payload: dict = {}
|
||||
if model.startswith("deployment/"):
|
||||
if api_params["space_id"] is None:
|
||||
raise WatsonXAIError(
|
||||
status_code=401,
|
||||
message="Error: space_id is required for models called using the 'deployment/' endpoint. Pass in the space_id as a parameter or set it in the WX_SPACE_ID environment variable.",
|
||||
)
|
||||
payload["space_id"] = api_params["space_id"]
|
||||
return payload
|
||||
payload["model_id"] = model
|
||||
payload["project_id"] = api_params["project_id"]
|
||||
return payload
|
||||
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,3 @@
|
||||
"""
|
||||
Watsonx uses the llm_http_handler.py to handle the requests.
|
||||
"""
|
||||
@@ -0,0 +1,392 @@
|
||||
import time
|
||||
from datetime import datetime
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
AsyncIterator,
|
||||
Dict,
|
||||
Iterator,
|
||||
List,
|
||||
Optional,
|
||||
Union,
|
||||
)
|
||||
|
||||
import httpx
|
||||
|
||||
from litellm.llms.base_llm.base_model_iterator import BaseModelResponseIterator
|
||||
from litellm.types.llms.openai import AllMessageValues, ChatCompletionUsageBlock
|
||||
from litellm.types.llms.watsonx import WatsonXAIEndpoint
|
||||
from litellm.types.utils import GenericStreamingChunk, ModelResponse, Usage
|
||||
from litellm.utils import map_finish_reason
|
||||
|
||||
from ...base_llm.chat.transformation import BaseConfig
|
||||
from ..common_utils import (
|
||||
IBMWatsonXMixin,
|
||||
WatsonXAIError,
|
||||
_get_api_params,
|
||||
convert_watsonx_messages_to_prompt,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
|
||||
|
||||
LiteLLMLoggingObj = _LiteLLMLoggingObj
|
||||
else:
|
||||
LiteLLMLoggingObj = Any
|
||||
|
||||
|
||||
class IBMWatsonXAIConfig(IBMWatsonXMixin, BaseConfig):
|
||||
"""
|
||||
Reference: https://cloud.ibm.com/apidocs/watsonx-ai#text-generation
|
||||
(See ibm_watsonx_ai.metanames.GenTextParamsMetaNames for a list of all available params)
|
||||
|
||||
Supported params for all available watsonx.ai foundational models.
|
||||
|
||||
- `decoding_method` (str): One of "greedy" or "sample"
|
||||
|
||||
- `temperature` (float): Sets the model temperature for sampling - not available when decoding_method='greedy'.
|
||||
|
||||
- `max_new_tokens` (integer): Maximum length of the generated tokens.
|
||||
|
||||
- `min_new_tokens` (integer): Maximum length of input tokens. Any more than this will be truncated.
|
||||
|
||||
- `length_penalty` (dict): A dictionary with keys "decay_factor" and "start_index".
|
||||
|
||||
- `stop_sequences` (string[]): list of strings to use as stop sequences.
|
||||
|
||||
- `top_k` (integer): top k for sampling - not available when decoding_method='greedy'.
|
||||
|
||||
- `top_p` (integer): top p for sampling - not available when decoding_method='greedy'.
|
||||
|
||||
- `repetition_penalty` (float): token repetition penalty during text generation.
|
||||
|
||||
- `truncate_input_tokens` (integer): Truncate input tokens to this length.
|
||||
|
||||
- `include_stop_sequences` (bool): If True, the stop sequence will be included at the end of the generated text in the case of a match.
|
||||
|
||||
- `return_options` (dict): A dictionary of options to return. Options include "input_text", "generated_tokens", "input_tokens", "token_ranks". Values are boolean.
|
||||
|
||||
- `random_seed` (integer): Random seed for text generation.
|
||||
|
||||
- `moderations` (dict): Dictionary of properties that control the moderations, for usages such as Hate and profanity (HAP) and PII filtering.
|
||||
|
||||
- `stream` (bool): If True, the model will return a stream of responses.
|
||||
"""
|
||||
|
||||
decoding_method: Optional[str] = "sample"
|
||||
temperature: Optional[float] = None
|
||||
max_new_tokens: Optional[int] = None # litellm.max_tokens
|
||||
min_new_tokens: Optional[int] = None
|
||||
length_penalty: Optional[dict] = None # e.g {"decay_factor": 2.5, "start_index": 5}
|
||||
stop_sequences: Optional[List[str]] = None # e.g ["}", ")", "."]
|
||||
top_k: Optional[int] = None
|
||||
top_p: Optional[float] = None
|
||||
repetition_penalty: Optional[float] = None
|
||||
truncate_input_tokens: Optional[int] = None
|
||||
include_stop_sequences: Optional[bool] = False
|
||||
return_options: Optional[Dict[str, bool]] = None
|
||||
random_seed: Optional[int] = None # e.g 42
|
||||
moderations: Optional[dict] = None
|
||||
stream: Optional[bool] = False
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
decoding_method: Optional[str] = None,
|
||||
temperature: Optional[float] = None,
|
||||
max_new_tokens: Optional[int] = None,
|
||||
min_new_tokens: Optional[int] = None,
|
||||
length_penalty: Optional[dict] = None,
|
||||
stop_sequences: Optional[List[str]] = None,
|
||||
top_k: Optional[int] = None,
|
||||
top_p: Optional[float] = None,
|
||||
repetition_penalty: Optional[float] = None,
|
||||
truncate_input_tokens: Optional[int] = None,
|
||||
include_stop_sequences: Optional[bool] = None,
|
||||
return_options: Optional[dict] = None,
|
||||
random_seed: Optional[int] = None,
|
||||
moderations: Optional[dict] = None,
|
||||
stream: Optional[bool] = None,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
locals_ = locals().copy()
|
||||
for key, value in locals_.items():
|
||||
if key != "self" and value is not None:
|
||||
setattr(self.__class__, key, value)
|
||||
|
||||
@classmethod
|
||||
def get_config(cls):
|
||||
return super().get_config()
|
||||
|
||||
def is_watsonx_text_param(self, param: str) -> bool:
|
||||
"""
|
||||
Determine if user passed in a watsonx.ai text generation param
|
||||
"""
|
||||
text_generation_params = [
|
||||
"decoding_method",
|
||||
"max_new_tokens",
|
||||
"min_new_tokens",
|
||||
"length_penalty",
|
||||
"stop_sequences",
|
||||
"top_k",
|
||||
"repetition_penalty",
|
||||
"truncate_input_tokens",
|
||||
"include_stop_sequences",
|
||||
"return_options",
|
||||
"random_seed",
|
||||
"moderations",
|
||||
"decoding_method",
|
||||
"min_tokens",
|
||||
]
|
||||
|
||||
return param in text_generation_params
|
||||
|
||||
def get_supported_openai_params(self, model: str):
|
||||
return [
|
||||
"temperature", # equivalent to temperature
|
||||
"max_tokens", # equivalent to max_new_tokens
|
||||
"top_p", # equivalent to top_p
|
||||
"frequency_penalty", # equivalent to repetition_penalty
|
||||
"stop", # equivalent to stop_sequences
|
||||
"seed", # equivalent to random_seed
|
||||
"stream", # equivalent to stream
|
||||
]
|
||||
|
||||
def map_openai_params(
|
||||
self,
|
||||
non_default_params: Dict,
|
||||
optional_params: Dict,
|
||||
model: str,
|
||||
drop_params: bool,
|
||||
) -> Dict:
|
||||
extra_body = {}
|
||||
for k, v in non_default_params.items():
|
||||
if k == "max_tokens":
|
||||
optional_params["max_new_tokens"] = v
|
||||
elif k == "stream":
|
||||
optional_params["stream"] = v
|
||||
elif k == "temperature":
|
||||
optional_params["temperature"] = v
|
||||
elif k == "top_p":
|
||||
optional_params["top_p"] = v
|
||||
elif k == "frequency_penalty":
|
||||
optional_params["repetition_penalty"] = v
|
||||
elif k == "seed":
|
||||
optional_params["random_seed"] = v
|
||||
elif k == "stop":
|
||||
optional_params["stop_sequences"] = v
|
||||
elif k == "decoding_method":
|
||||
extra_body["decoding_method"] = v
|
||||
elif k == "min_tokens":
|
||||
extra_body["min_new_tokens"] = v
|
||||
elif k == "top_k":
|
||||
extra_body["top_k"] = v
|
||||
elif k == "truncate_input_tokens":
|
||||
extra_body["truncate_input_tokens"] = v
|
||||
elif k == "length_penalty":
|
||||
extra_body["length_penalty"] = v
|
||||
elif k == "time_limit":
|
||||
extra_body["time_limit"] = v
|
||||
elif k == "return_options":
|
||||
extra_body["return_options"] = v
|
||||
|
||||
if extra_body:
|
||||
optional_params["extra_body"] = extra_body
|
||||
return optional_params
|
||||
|
||||
def get_mapped_special_auth_params(self) -> dict:
|
||||
"""
|
||||
Common auth params across bedrock/vertex_ai/azure/watsonx
|
||||
"""
|
||||
return {
|
||||
"project": "watsonx_project",
|
||||
"region_name": "watsonx_region_name",
|
||||
"token": "watsonx_token",
|
||||
}
|
||||
|
||||
def map_special_auth_params(self, non_default_params: dict, optional_params: dict):
|
||||
mapped_params = self.get_mapped_special_auth_params()
|
||||
|
||||
for param, value in non_default_params.items():
|
||||
if param in mapped_params:
|
||||
optional_params[mapped_params[param]] = value
|
||||
return optional_params
|
||||
|
||||
def get_eu_regions(self) -> List[str]:
|
||||
"""
|
||||
Source: https://www.ibm.com/docs/en/watsonx/saas?topic=integrations-regional-availability
|
||||
"""
|
||||
return [
|
||||
"eu-de",
|
||||
"eu-gb",
|
||||
]
|
||||
|
||||
def get_us_regions(self) -> List[str]:
|
||||
"""
|
||||
Source: https://www.ibm.com/docs/en/watsonx/saas?topic=integrations-regional-availability
|
||||
"""
|
||||
return [
|
||||
"us-south",
|
||||
]
|
||||
|
||||
def transform_request(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: Dict,
|
||||
litellm_params: Dict,
|
||||
headers: Dict,
|
||||
) -> Dict:
|
||||
provider = model.split("/")[0]
|
||||
prompt = convert_watsonx_messages_to_prompt(
|
||||
model=model,
|
||||
messages=messages,
|
||||
provider=provider,
|
||||
custom_prompt_dict={},
|
||||
)
|
||||
extra_body_params = optional_params.pop("extra_body", {})
|
||||
optional_params.update(extra_body_params)
|
||||
watsonx_api_params = _get_api_params(params=optional_params)
|
||||
|
||||
watsonx_auth_payload = self._prepare_payload(
|
||||
model=model,
|
||||
api_params=watsonx_api_params,
|
||||
)
|
||||
|
||||
# init the payload to the text generation call
|
||||
payload = {
|
||||
"input": prompt,
|
||||
"moderations": optional_params.pop("moderations", {}),
|
||||
"parameters": optional_params,
|
||||
**watsonx_auth_payload,
|
||||
}
|
||||
|
||||
return payload
|
||||
|
||||
def transform_response(
|
||||
self,
|
||||
model: str,
|
||||
raw_response: httpx.Response,
|
||||
model_response: ModelResponse,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
request_data: Dict,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: Dict,
|
||||
litellm_params: Dict,
|
||||
encoding: str,
|
||||
api_key: Optional[str] = None,
|
||||
json_mode: Optional[bool] = None,
|
||||
) -> ModelResponse:
|
||||
## LOGGING
|
||||
logging_obj.post_call(
|
||||
input=messages,
|
||||
api_key="",
|
||||
original_response=raw_response.text,
|
||||
)
|
||||
|
||||
json_resp = raw_response.json()
|
||||
|
||||
if "results" not in json_resp:
|
||||
raise WatsonXAIError(
|
||||
status_code=500,
|
||||
message=f"Error: Invalid response from Watsonx.ai API: {json_resp}",
|
||||
)
|
||||
if model_response is None:
|
||||
model_response = ModelResponse(model=json_resp.get("model_id", None))
|
||||
generated_text = json_resp["results"][0]["generated_text"]
|
||||
prompt_tokens = json_resp["results"][0]["input_token_count"]
|
||||
completion_tokens = json_resp["results"][0]["generated_token_count"]
|
||||
model_response.choices[0].message.content = generated_text # type: ignore
|
||||
model_response.choices[0].finish_reason = map_finish_reason(
|
||||
json_resp["results"][0]["stop_reason"]
|
||||
)
|
||||
if json_resp.get("created_at"):
|
||||
model_response.created = int(
|
||||
datetime.fromisoformat(json_resp["created_at"]).timestamp()
|
||||
)
|
||||
else:
|
||||
model_response.created = int(time.time())
|
||||
usage = Usage(
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=prompt_tokens + completion_tokens,
|
||||
)
|
||||
setattr(model_response, "usage", usage)
|
||||
return model_response
|
||||
|
||||
def get_complete_url(
|
||||
self,
|
||||
api_base: Optional[str],
|
||||
api_key: Optional[str],
|
||||
model: str,
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
stream: Optional[bool] = None,
|
||||
) -> str:
|
||||
url = self._get_base_url(api_base=api_base)
|
||||
if model.startswith("deployment/"):
|
||||
# deployment models are passed in as 'deployment/<deployment_id>'
|
||||
deployment_id = "/".join(model.split("/")[1:])
|
||||
endpoint = (
|
||||
WatsonXAIEndpoint.DEPLOYMENT_TEXT_GENERATION_STREAM.value
|
||||
if stream
|
||||
else WatsonXAIEndpoint.DEPLOYMENT_TEXT_GENERATION.value
|
||||
)
|
||||
endpoint = endpoint.format(deployment_id=deployment_id)
|
||||
else:
|
||||
endpoint = (
|
||||
WatsonXAIEndpoint.TEXT_GENERATION_STREAM
|
||||
if stream
|
||||
else WatsonXAIEndpoint.TEXT_GENERATION
|
||||
)
|
||||
url = url.rstrip("/") + endpoint
|
||||
|
||||
## add api version
|
||||
url = self._add_api_version_to_url(
|
||||
url=url, api_version=optional_params.pop("api_version", None)
|
||||
)
|
||||
return url
|
||||
|
||||
def get_model_response_iterator(
|
||||
self,
|
||||
streaming_response: Union[Iterator[str], AsyncIterator[str], ModelResponse],
|
||||
sync_stream: bool,
|
||||
json_mode: Optional[bool] = False,
|
||||
):
|
||||
return WatsonxTextCompletionResponseIterator(
|
||||
streaming_response=streaming_response,
|
||||
sync_stream=sync_stream,
|
||||
json_mode=json_mode,
|
||||
)
|
||||
|
||||
|
||||
class WatsonxTextCompletionResponseIterator(BaseModelResponseIterator):
|
||||
# def _handle_string_chunk(self, str_line: str) -> GenericStreamingChunk:
|
||||
# return self.chunk_parser(json.loads(str_line))
|
||||
|
||||
def chunk_parser(self, chunk: dict) -> GenericStreamingChunk:
|
||||
try:
|
||||
results = chunk.get("results", [])
|
||||
if len(results) > 0:
|
||||
text = results[0].get("generated_text", "")
|
||||
finish_reason = results[0].get("stop_reason")
|
||||
is_finished = finish_reason != "not_finished"
|
||||
|
||||
return GenericStreamingChunk(
|
||||
text=text,
|
||||
is_finished=is_finished,
|
||||
finish_reason=finish_reason,
|
||||
usage=ChatCompletionUsageBlock(
|
||||
prompt_tokens=results[0].get("input_token_count", 0),
|
||||
completion_tokens=results[0].get("generated_token_count", 0),
|
||||
total_tokens=results[0].get("input_token_count", 0)
|
||||
+ results[0].get("generated_token_count", 0),
|
||||
),
|
||||
)
|
||||
return GenericStreamingChunk(
|
||||
text="",
|
||||
is_finished=False,
|
||||
finish_reason="stop",
|
||||
usage=None,
|
||||
)
|
||||
except Exception as e:
|
||||
raise e
|
||||
Binary file not shown.
@@ -0,0 +1,113 @@
|
||||
"""
|
||||
Translates from OpenAI's `/v1/embeddings` to IBM's `/text/embeddings` route.
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import httpx
|
||||
|
||||
from litellm.llms.base_llm.embedding.transformation import (
|
||||
BaseEmbeddingConfig,
|
||||
LiteLLMLoggingObj,
|
||||
)
|
||||
from litellm.types.llms.openai import AllEmbeddingInputValues
|
||||
from litellm.types.llms.watsonx import WatsonXAIEndpoint
|
||||
from litellm.types.utils import EmbeddingResponse, Usage
|
||||
|
||||
from ..common_utils import IBMWatsonXMixin, _get_api_params
|
||||
|
||||
|
||||
class IBMWatsonXEmbeddingConfig(IBMWatsonXMixin, BaseEmbeddingConfig):
|
||||
def get_supported_openai_params(self, model: str) -> list:
|
||||
return []
|
||||
|
||||
def map_openai_params(
|
||||
self,
|
||||
non_default_params: dict,
|
||||
optional_params: dict,
|
||||
model: str,
|
||||
drop_params: bool,
|
||||
) -> dict:
|
||||
return optional_params
|
||||
|
||||
def transform_embedding_request(
|
||||
self,
|
||||
model: str,
|
||||
input: AllEmbeddingInputValues,
|
||||
optional_params: dict,
|
||||
headers: dict,
|
||||
) -> dict:
|
||||
watsonx_api_params = _get_api_params(params=optional_params)
|
||||
watsonx_auth_payload = self._prepare_payload(
|
||||
model=model,
|
||||
api_params=watsonx_api_params,
|
||||
)
|
||||
|
||||
return {
|
||||
"inputs": input,
|
||||
"parameters": optional_params,
|
||||
**watsonx_auth_payload,
|
||||
}
|
||||
|
||||
def get_complete_url(
|
||||
self,
|
||||
api_base: Optional[str],
|
||||
api_key: Optional[str],
|
||||
model: str,
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
stream: Optional[bool] = None,
|
||||
) -> str:
|
||||
url = self._get_base_url(api_base=api_base)
|
||||
endpoint = WatsonXAIEndpoint.EMBEDDINGS.value
|
||||
if model.startswith("deployment/"):
|
||||
deployment_id = "/".join(model.split("/")[1:])
|
||||
endpoint = endpoint.format(deployment_id=deployment_id)
|
||||
url = url.rstrip("/") + endpoint
|
||||
|
||||
## add api version
|
||||
url = self._add_api_version_to_url(
|
||||
url=url, api_version=optional_params.pop("api_version", None)
|
||||
)
|
||||
return url
|
||||
|
||||
def transform_embedding_response(
|
||||
self,
|
||||
model: str,
|
||||
raw_response: httpx.Response,
|
||||
model_response: EmbeddingResponse,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
api_key: Optional[str],
|
||||
request_data: dict,
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
) -> EmbeddingResponse:
|
||||
logging_obj.post_call(
|
||||
original_response=raw_response.text,
|
||||
)
|
||||
json_resp = raw_response.json()
|
||||
if model_response is None:
|
||||
model_response = EmbeddingResponse(model=json_resp.get("model_id", None))
|
||||
results = json_resp.get("results", [])
|
||||
embedding_response = []
|
||||
for idx, result in enumerate(results):
|
||||
embedding_response.append(
|
||||
{
|
||||
"object": "embedding",
|
||||
"index": idx,
|
||||
"embedding": result["embedding"],
|
||||
}
|
||||
)
|
||||
model_response.object = "list"
|
||||
model_response.data = embedding_response
|
||||
input_tokens = json_resp.get("input_token_count", 0)
|
||||
setattr(
|
||||
model_response,
|
||||
"usage",
|
||||
Usage(
|
||||
prompt_tokens=input_tokens,
|
||||
completion_tokens=0,
|
||||
total_tokens=input_tokens,
|
||||
),
|
||||
)
|
||||
return model_response
|
||||
Reference in New Issue
Block a user