structure saas with tools
This commit is contained in:
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,48 @@
|
||||
"""
|
||||
Support for GPT-4o audio Family
|
||||
|
||||
OpenAI Doc: https://platform.openai.com/docs/guides/audio/quickstart?audio-generation-quickstart-example=audio-in&lang=python
|
||||
"""
|
||||
|
||||
import litellm
|
||||
|
||||
from .gpt_transformation import OpenAIGPTConfig
|
||||
|
||||
|
||||
class OpenAIGPTAudioConfig(OpenAIGPTConfig):
|
||||
"""
|
||||
Reference: https://platform.openai.com/docs/guides/audio
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def get_config(cls):
|
||||
return super().get_config()
|
||||
|
||||
def get_supported_openai_params(self, model: str) -> list:
|
||||
"""
|
||||
Get the supported OpenAI params for the `gpt-audio` models
|
||||
|
||||
"""
|
||||
|
||||
all_openai_params = super().get_supported_openai_params(model=model)
|
||||
audio_specific_params = ["audio"]
|
||||
return all_openai_params + audio_specific_params
|
||||
|
||||
def is_model_gpt_audio_model(self, model: str) -> bool:
|
||||
if model in litellm.open_ai_chat_completion_models and "audio" in model:
|
||||
return True
|
||||
return False
|
||||
|
||||
def _map_openai_params(
|
||||
self,
|
||||
non_default_params: dict,
|
||||
optional_params: dict,
|
||||
model: str,
|
||||
drop_params: bool,
|
||||
) -> dict:
|
||||
return super()._map_openai_params(
|
||||
non_default_params=non_default_params,
|
||||
optional_params=optional_params,
|
||||
model=model,
|
||||
drop_params=drop_params,
|
||||
)
|
||||
@@ -0,0 +1,419 @@
|
||||
"""
|
||||
Support for gpt model family
|
||||
"""
|
||||
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
AsyncIterator,
|
||||
Iterator,
|
||||
List,
|
||||
Optional,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
|
||||
import httpx
|
||||
|
||||
import litellm
|
||||
from litellm.llms.base_llm.base_model_iterator import BaseModelResponseIterator
|
||||
from litellm.llms.base_llm.base_utils import BaseLLMModelInfo
|
||||
from litellm.llms.base_llm.chat.transformation import BaseConfig, BaseLLMException
|
||||
from litellm.secret_managers.main import get_secret_str
|
||||
from litellm.types.llms.openai import (
|
||||
AllMessageValues,
|
||||
ChatCompletionFileObject,
|
||||
ChatCompletionFileObjectFile,
|
||||
ChatCompletionImageObject,
|
||||
ChatCompletionImageUrlObject,
|
||||
)
|
||||
from litellm.types.utils import ModelResponse, ModelResponseStream
|
||||
from litellm.utils import convert_to_model_response_object
|
||||
|
||||
from ..common_utils import OpenAIError
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
|
||||
|
||||
LiteLLMLoggingObj = _LiteLLMLoggingObj
|
||||
else:
|
||||
LiteLLMLoggingObj = Any
|
||||
|
||||
|
||||
class OpenAIGPTConfig(BaseLLMModelInfo, BaseConfig):
|
||||
"""
|
||||
Reference: https://platform.openai.com/docs/api-reference/chat/create
|
||||
|
||||
The class `OpenAIConfig` provides configuration for the OpenAI's Chat API interface. Below are the parameters:
|
||||
|
||||
- `frequency_penalty` (number or null): Defaults to 0. Allows a value between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far, thereby minimizing repetition.
|
||||
|
||||
- `function_call` (string or object): This optional parameter controls how the model calls functions.
|
||||
|
||||
- `functions` (array): An optional parameter. It is a list of functions for which the model may generate JSON inputs.
|
||||
|
||||
- `logit_bias` (map): This optional parameter modifies the likelihood of specified tokens appearing in the completion.
|
||||
|
||||
- `max_tokens` (integer or null): This optional parameter helps to set the maximum number of tokens to generate in the chat completion.
|
||||
|
||||
- `n` (integer or null): This optional parameter helps to set how many chat completion choices to generate for each input message.
|
||||
|
||||
- `presence_penalty` (number or null): Defaults to 0. It penalizes new tokens based on if they appear in the text so far, hence increasing the model's likelihood to talk about new topics.
|
||||
|
||||
- `stop` (string / array / null): Specifies up to 4 sequences where the API will stop generating further tokens.
|
||||
|
||||
- `temperature` (number or null): Defines the sampling temperature to use, varying between 0 and 2.
|
||||
|
||||
- `top_p` (number or null): An alternative to sampling with temperature, used for nucleus sampling.
|
||||
"""
|
||||
|
||||
frequency_penalty: Optional[int] = None
|
||||
function_call: Optional[Union[str, dict]] = None
|
||||
functions: Optional[list] = None
|
||||
logit_bias: Optional[dict] = None
|
||||
max_tokens: Optional[int] = None
|
||||
n: Optional[int] = None
|
||||
presence_penalty: Optional[int] = None
|
||||
stop: Optional[Union[str, list]] = None
|
||||
temperature: Optional[int] = None
|
||||
top_p: Optional[int] = None
|
||||
response_format: Optional[dict] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
frequency_penalty: Optional[int] = None,
|
||||
function_call: Optional[Union[str, dict]] = None,
|
||||
functions: Optional[list] = None,
|
||||
logit_bias: Optional[dict] = None,
|
||||
max_tokens: Optional[int] = None,
|
||||
n: Optional[int] = None,
|
||||
presence_penalty: Optional[int] = None,
|
||||
stop: Optional[Union[str, list]] = None,
|
||||
temperature: Optional[int] = None,
|
||||
top_p: Optional[int] = None,
|
||||
response_format: Optional[dict] = None,
|
||||
) -> None:
|
||||
locals_ = locals().copy()
|
||||
for key, value in locals_.items():
|
||||
if key != "self" and value is not None:
|
||||
setattr(self.__class__, key, value)
|
||||
|
||||
@classmethod
|
||||
def get_config(cls):
|
||||
return super().get_config()
|
||||
|
||||
def get_supported_openai_params(self, model: str) -> list:
|
||||
base_params = [
|
||||
"frequency_penalty",
|
||||
"logit_bias",
|
||||
"logprobs",
|
||||
"top_logprobs",
|
||||
"max_tokens",
|
||||
"max_completion_tokens",
|
||||
"modalities",
|
||||
"prediction",
|
||||
"n",
|
||||
"presence_penalty",
|
||||
"seed",
|
||||
"stop",
|
||||
"stream",
|
||||
"stream_options",
|
||||
"temperature",
|
||||
"top_p",
|
||||
"tools",
|
||||
"tool_choice",
|
||||
"function_call",
|
||||
"functions",
|
||||
"max_retries",
|
||||
"extra_headers",
|
||||
"parallel_tool_calls",
|
||||
"audio",
|
||||
] # works across all models
|
||||
|
||||
model_specific_params = []
|
||||
if (
|
||||
model != "gpt-3.5-turbo-16k" and model != "gpt-4"
|
||||
): # gpt-4 does not support 'response_format'
|
||||
model_specific_params.append("response_format")
|
||||
|
||||
if (
|
||||
model in litellm.open_ai_chat_completion_models
|
||||
) or model in litellm.open_ai_text_completion_models:
|
||||
model_specific_params.append(
|
||||
"user"
|
||||
) # user is not a param supported by all openai-compatible endpoints - e.g. azure ai
|
||||
return base_params + model_specific_params
|
||||
|
||||
def _map_openai_params(
|
||||
self,
|
||||
non_default_params: dict,
|
||||
optional_params: dict,
|
||||
model: str,
|
||||
drop_params: bool,
|
||||
) -> dict:
|
||||
"""
|
||||
If any supported_openai_params are in non_default_params, add them to optional_params, so they are use in API call
|
||||
|
||||
Args:
|
||||
non_default_params (dict): Non-default parameters to filter.
|
||||
optional_params (dict): Optional parameters to update.
|
||||
model (str): Model name for parameter support check.
|
||||
|
||||
Returns:
|
||||
dict: Updated optional_params with supported non-default parameters.
|
||||
"""
|
||||
supported_openai_params = self.get_supported_openai_params(model)
|
||||
for param, value in non_default_params.items():
|
||||
if param in supported_openai_params:
|
||||
optional_params[param] = value
|
||||
return optional_params
|
||||
|
||||
def map_openai_params(
|
||||
self,
|
||||
non_default_params: dict,
|
||||
optional_params: dict,
|
||||
model: str,
|
||||
drop_params: bool,
|
||||
) -> dict:
|
||||
return self._map_openai_params(
|
||||
non_default_params=non_default_params,
|
||||
optional_params=optional_params,
|
||||
model=model,
|
||||
drop_params=drop_params,
|
||||
)
|
||||
|
||||
def _transform_messages(
|
||||
self, messages: List[AllMessageValues], model: str
|
||||
) -> List[AllMessageValues]:
|
||||
"""OpenAI no longer supports image_url as a string, so we need to convert it to a dict"""
|
||||
for message in messages:
|
||||
message_content = message.get("content")
|
||||
if message_content and isinstance(message_content, list):
|
||||
for content_item in message_content:
|
||||
litellm_specific_params = {"format"}
|
||||
if content_item.get("type") == "image_url":
|
||||
content_item = cast(ChatCompletionImageObject, content_item)
|
||||
if isinstance(content_item["image_url"], str):
|
||||
content_item["image_url"] = {
|
||||
"url": content_item["image_url"],
|
||||
}
|
||||
elif isinstance(content_item["image_url"], dict):
|
||||
new_image_url_obj = ChatCompletionImageUrlObject(
|
||||
**{ # type: ignore
|
||||
k: v
|
||||
for k, v in content_item["image_url"].items()
|
||||
if k not in litellm_specific_params
|
||||
}
|
||||
)
|
||||
content_item["image_url"] = new_image_url_obj
|
||||
elif content_item.get("type") == "file":
|
||||
content_item = cast(ChatCompletionFileObject, content_item)
|
||||
file_obj = content_item["file"]
|
||||
new_file_obj = ChatCompletionFileObjectFile(
|
||||
**{ # type: ignore
|
||||
k: v
|
||||
for k, v in file_obj.items()
|
||||
if k not in litellm_specific_params
|
||||
}
|
||||
)
|
||||
content_item["file"] = new_file_obj
|
||||
return messages
|
||||
|
||||
def transform_request(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
headers: dict,
|
||||
) -> dict:
|
||||
"""
|
||||
Transform the overall request to be sent to the API.
|
||||
|
||||
Returns:
|
||||
dict: The transformed request. Sent as the body of the API call.
|
||||
"""
|
||||
messages = self._transform_messages(messages=messages, model=model)
|
||||
return {
|
||||
"model": model,
|
||||
"messages": messages,
|
||||
**optional_params,
|
||||
}
|
||||
|
||||
def transform_response(
|
||||
self,
|
||||
model: str,
|
||||
raw_response: httpx.Response,
|
||||
model_response: ModelResponse,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
request_data: dict,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
encoding: Any,
|
||||
api_key: Optional[str] = None,
|
||||
json_mode: Optional[bool] = None,
|
||||
) -> ModelResponse:
|
||||
"""
|
||||
Transform the response from the API.
|
||||
|
||||
Returns:
|
||||
dict: The transformed response.
|
||||
"""
|
||||
|
||||
## LOGGING
|
||||
logging_obj.post_call(
|
||||
input=messages,
|
||||
api_key=api_key,
|
||||
original_response=raw_response.text,
|
||||
additional_args={"complete_input_dict": request_data},
|
||||
)
|
||||
|
||||
## RESPONSE OBJECT
|
||||
try:
|
||||
completion_response = raw_response.json()
|
||||
except Exception as e:
|
||||
response_headers = getattr(raw_response, "headers", None)
|
||||
raise OpenAIError(
|
||||
message="Unable to get json response - {}, Original Response: {}".format(
|
||||
str(e), raw_response.text
|
||||
),
|
||||
status_code=raw_response.status_code,
|
||||
headers=response_headers,
|
||||
)
|
||||
raw_response_headers = dict(raw_response.headers)
|
||||
final_response_obj = convert_to_model_response_object(
|
||||
response_object=completion_response,
|
||||
model_response_object=model_response,
|
||||
hidden_params={"headers": raw_response_headers},
|
||||
_response_headers=raw_response_headers,
|
||||
)
|
||||
|
||||
return cast(ModelResponse, final_response_obj)
|
||||
|
||||
def get_error_class(
|
||||
self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers]
|
||||
) -> BaseLLMException:
|
||||
return OpenAIError(
|
||||
status_code=status_code,
|
||||
message=error_message,
|
||||
headers=cast(httpx.Headers, headers),
|
||||
)
|
||||
|
||||
def get_complete_url(
|
||||
self,
|
||||
api_base: Optional[str],
|
||||
api_key: Optional[str],
|
||||
model: str,
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
stream: Optional[bool] = None,
|
||||
) -> str:
|
||||
"""
|
||||
Get the complete URL for the API call.
|
||||
|
||||
Returns:
|
||||
str: The complete URL for the API call.
|
||||
"""
|
||||
if api_base is None:
|
||||
api_base = "https://api.openai.com"
|
||||
endpoint = "chat/completions"
|
||||
|
||||
# Remove trailing slash from api_base if present
|
||||
api_base = api_base.rstrip("/")
|
||||
|
||||
# Check if endpoint is already in the api_base
|
||||
if endpoint in api_base:
|
||||
return api_base
|
||||
|
||||
return f"{api_base}/{endpoint}"
|
||||
|
||||
def validate_environment(
|
||||
self,
|
||||
headers: dict,
|
||||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
) -> dict:
|
||||
if api_key is not None:
|
||||
headers["Authorization"] = f"Bearer {api_key}"
|
||||
|
||||
# Ensure Content-Type is set to application/json
|
||||
if "content-type" not in headers and "Content-Type" not in headers:
|
||||
headers["Content-Type"] = "application/json"
|
||||
|
||||
return headers
|
||||
|
||||
def get_models(
|
||||
self, api_key: Optional[str] = None, api_base: Optional[str] = None
|
||||
) -> List[str]:
|
||||
"""
|
||||
Calls OpenAI's `/v1/models` endpoint and returns the list of models.
|
||||
"""
|
||||
|
||||
if api_base is None:
|
||||
api_base = "https://api.openai.com"
|
||||
if api_key is None:
|
||||
api_key = get_secret_str("OPENAI_API_KEY")
|
||||
|
||||
response = litellm.module_level_client.get(
|
||||
url=f"{api_base}/v1/models",
|
||||
headers={"Authorization": f"Bearer {api_key}"},
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
raise Exception(f"Failed to get models: {response.text}")
|
||||
|
||||
models = response.json()["data"]
|
||||
return [model["id"] for model in models]
|
||||
|
||||
@staticmethod
|
||||
def get_api_key(api_key: Optional[str] = None) -> Optional[str]:
|
||||
return (
|
||||
api_key
|
||||
or litellm.api_key
|
||||
or litellm.openai_key
|
||||
or get_secret_str("OPENAI_API_KEY")
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_api_base(api_base: Optional[str] = None) -> Optional[str]:
|
||||
return (
|
||||
api_base
|
||||
or litellm.api_base
|
||||
or get_secret_str("OPENAI_API_BASE")
|
||||
or "https://api.openai.com/v1"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_base_model(model: Optional[str] = None) -> Optional[str]:
|
||||
return model
|
||||
|
||||
def get_model_response_iterator(
|
||||
self,
|
||||
streaming_response: Union[Iterator[str], AsyncIterator[str], ModelResponse],
|
||||
sync_stream: bool,
|
||||
json_mode: Optional[bool] = False,
|
||||
) -> Any:
|
||||
return OpenAIChatCompletionStreamingHandler(
|
||||
streaming_response=streaming_response,
|
||||
sync_stream=sync_stream,
|
||||
json_mode=json_mode,
|
||||
)
|
||||
|
||||
|
||||
class OpenAIChatCompletionStreamingHandler(BaseModelResponseIterator):
|
||||
def chunk_parser(self, chunk: dict) -> ModelResponseStream:
|
||||
try:
|
||||
return ModelResponseStream(
|
||||
id=chunk["id"],
|
||||
object="chat.completion.chunk",
|
||||
created=chunk["created"],
|
||||
model=chunk["model"],
|
||||
choices=chunk["choices"],
|
||||
)
|
||||
except Exception as e:
|
||||
raise e
|
||||
@@ -0,0 +1,3 @@
|
||||
"""
|
||||
LLM Calling done in `openai/openai.py`
|
||||
"""
|
||||
@@ -0,0 +1,159 @@
|
||||
"""
|
||||
Support for o1/o3 model family
|
||||
|
||||
https://platform.openai.com/docs/guides/reasoning
|
||||
|
||||
Translations handled by LiteLLM:
|
||||
- modalities: image => drop param (if user opts in to dropping param)
|
||||
- role: system ==> translate to role 'user'
|
||||
- streaming => faked by LiteLLM
|
||||
- Tools, response_format => drop param (if user opts in to dropping param)
|
||||
- Logprobs => drop param (if user opts in to dropping param)
|
||||
"""
|
||||
|
||||
from typing import List, Optional
|
||||
|
||||
import litellm
|
||||
from litellm import verbose_logger
|
||||
from litellm.litellm_core_utils.get_llm_provider_logic import get_llm_provider
|
||||
from litellm.types.llms.openai import AllMessageValues, ChatCompletionUserMessage
|
||||
from litellm.utils import (
|
||||
supports_function_calling,
|
||||
supports_parallel_function_calling,
|
||||
supports_response_schema,
|
||||
supports_system_messages,
|
||||
)
|
||||
|
||||
from .gpt_transformation import OpenAIGPTConfig
|
||||
|
||||
|
||||
class OpenAIOSeriesConfig(OpenAIGPTConfig):
|
||||
"""
|
||||
Reference: https://platform.openai.com/docs/guides/reasoning
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def get_config(cls):
|
||||
return super().get_config()
|
||||
|
||||
def translate_developer_role_to_system_role(
|
||||
self, messages: List[AllMessageValues]
|
||||
) -> List[AllMessageValues]:
|
||||
"""
|
||||
O-series models support `developer` role.
|
||||
"""
|
||||
return messages
|
||||
|
||||
def get_supported_openai_params(self, model: str) -> list:
|
||||
"""
|
||||
Get the supported OpenAI params for the given model
|
||||
|
||||
"""
|
||||
|
||||
all_openai_params = super().get_supported_openai_params(model=model)
|
||||
non_supported_params = [
|
||||
"logprobs",
|
||||
"top_p",
|
||||
"presence_penalty",
|
||||
"frequency_penalty",
|
||||
"top_logprobs",
|
||||
]
|
||||
|
||||
o_series_only_param = ["reasoning_effort"]
|
||||
|
||||
all_openai_params.extend(o_series_only_param)
|
||||
|
||||
try:
|
||||
model, custom_llm_provider, api_base, api_key = get_llm_provider(
|
||||
model=model
|
||||
)
|
||||
except Exception:
|
||||
verbose_logger.debug(
|
||||
f"Unable to infer model provider for model={model}, defaulting to openai for o1 supported param check"
|
||||
)
|
||||
custom_llm_provider = "openai"
|
||||
|
||||
_supports_function_calling = supports_function_calling(
|
||||
model, custom_llm_provider
|
||||
)
|
||||
_supports_response_schema = supports_response_schema(model, custom_llm_provider)
|
||||
_supports_parallel_tool_calls = supports_parallel_function_calling(
|
||||
model, custom_llm_provider
|
||||
)
|
||||
|
||||
if not _supports_function_calling:
|
||||
non_supported_params.append("tools")
|
||||
non_supported_params.append("tool_choice")
|
||||
non_supported_params.append("function_call")
|
||||
non_supported_params.append("functions")
|
||||
|
||||
if not _supports_parallel_tool_calls:
|
||||
non_supported_params.append("parallel_tool_calls")
|
||||
|
||||
if not _supports_response_schema:
|
||||
non_supported_params.append("response_format")
|
||||
|
||||
return [
|
||||
param for param in all_openai_params if param not in non_supported_params
|
||||
]
|
||||
|
||||
def map_openai_params(
|
||||
self,
|
||||
non_default_params: dict,
|
||||
optional_params: dict,
|
||||
model: str,
|
||||
drop_params: bool,
|
||||
):
|
||||
if "max_tokens" in non_default_params:
|
||||
optional_params["max_completion_tokens"] = non_default_params.pop(
|
||||
"max_tokens"
|
||||
)
|
||||
if "temperature" in non_default_params:
|
||||
temperature_value: Optional[float] = non_default_params.pop("temperature")
|
||||
if temperature_value is not None:
|
||||
if temperature_value == 1:
|
||||
optional_params["temperature"] = temperature_value
|
||||
else:
|
||||
## UNSUPPORTED TOOL CHOICE VALUE
|
||||
if litellm.drop_params is True or drop_params is True:
|
||||
pass
|
||||
else:
|
||||
raise litellm.utils.UnsupportedParamsError(
|
||||
message="O-series models don't support temperature={}. Only temperature=1 is supported. To drop unsupported openai params from the call, set `litellm.drop_params = True`".format(
|
||||
temperature_value
|
||||
),
|
||||
status_code=400,
|
||||
)
|
||||
|
||||
return super()._map_openai_params(
|
||||
non_default_params, optional_params, model, drop_params
|
||||
)
|
||||
|
||||
def is_model_o_series_model(self, model: str) -> bool:
|
||||
if model in litellm.open_ai_chat_completion_models and (
|
||||
"o1" in model
|
||||
or "o3" in model
|
||||
or "o4"
|
||||
in model # [TODO] make this a more generic check (e.g. using `openai-o-series` as provider like gemini)
|
||||
):
|
||||
return True
|
||||
return False
|
||||
|
||||
def _transform_messages(
|
||||
self, messages: List[AllMessageValues], model: str
|
||||
) -> List[AllMessageValues]:
|
||||
"""
|
||||
Handles limitations of O-1 model family.
|
||||
- modalities: image => drop param (if user opts in to dropping param)
|
||||
- role: system ==> translate to role 'user'
|
||||
"""
|
||||
_supports_system_messages = supports_system_messages(model, "openai")
|
||||
for i, message in enumerate(messages):
|
||||
if message["role"] == "system" and not _supports_system_messages:
|
||||
new_message = ChatCompletionUserMessage(
|
||||
content=message["content"], role="user"
|
||||
)
|
||||
messages[i] = new_message # Replace the old message with the new one
|
||||
|
||||
messages = super()._transform_messages(messages, model)
|
||||
return messages
|
||||
@@ -0,0 +1,208 @@
|
||||
"""
|
||||
Common helpers / utils across al OpenAI endpoints
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import json
|
||||
from typing import Any, Dict, List, Literal, Optional, Union
|
||||
|
||||
import httpx
|
||||
import openai
|
||||
from openai import AsyncAzureOpenAI, AsyncOpenAI, AzureOpenAI, OpenAI
|
||||
|
||||
import litellm
|
||||
from litellm.llms.base_llm.chat.transformation import BaseLLMException
|
||||
from litellm.llms.custom_httpx.http_handler import _DEFAULT_TTL_FOR_HTTPX_CLIENTS
|
||||
|
||||
|
||||
class OpenAIError(BaseLLMException):
|
||||
def __init__(
|
||||
self,
|
||||
status_code: int,
|
||||
message: str,
|
||||
request: Optional[httpx.Request] = None,
|
||||
response: Optional[httpx.Response] = None,
|
||||
headers: Optional[Union[dict, httpx.Headers]] = None,
|
||||
body: Optional[dict] = None,
|
||||
):
|
||||
self.status_code = status_code
|
||||
self.message = message
|
||||
self.headers = headers
|
||||
if request:
|
||||
self.request = request
|
||||
else:
|
||||
self.request = httpx.Request(method="POST", url="https://api.openai.com/v1")
|
||||
if response:
|
||||
self.response = response
|
||||
else:
|
||||
self.response = httpx.Response(
|
||||
status_code=status_code, request=self.request
|
||||
)
|
||||
super().__init__(
|
||||
status_code=status_code,
|
||||
message=self.message,
|
||||
headers=self.headers,
|
||||
request=self.request,
|
||||
response=self.response,
|
||||
body=body,
|
||||
)
|
||||
|
||||
|
||||
####### Error Handling Utils for OpenAI API #######################
|
||||
###################################################################
|
||||
def drop_params_from_unprocessable_entity_error(
|
||||
e: Union[openai.UnprocessableEntityError, httpx.HTTPStatusError],
|
||||
data: Dict[str, Any],
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Helper function to read OpenAI UnprocessableEntityError and drop the params that raised an error from the error message.
|
||||
|
||||
Args:
|
||||
e (UnprocessableEntityError): The UnprocessableEntityError exception
|
||||
data (Dict[str, Any]): The original data dictionary containing all parameters
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: A new dictionary with invalid parameters removed
|
||||
"""
|
||||
invalid_params: List[str] = []
|
||||
if isinstance(e, httpx.HTTPStatusError):
|
||||
error_json = e.response.json()
|
||||
error_message = error_json.get("error", {})
|
||||
error_body = error_message
|
||||
else:
|
||||
error_body = e.body
|
||||
if (
|
||||
error_body is not None
|
||||
and isinstance(error_body, dict)
|
||||
and error_body.get("message")
|
||||
):
|
||||
message = error_body.get("message", {})
|
||||
if isinstance(message, str):
|
||||
try:
|
||||
message = json.loads(message)
|
||||
except json.JSONDecodeError:
|
||||
message = {"detail": message}
|
||||
detail = message.get("detail")
|
||||
|
||||
if isinstance(detail, List) and len(detail) > 0 and isinstance(detail[0], dict):
|
||||
for error_dict in detail:
|
||||
if (
|
||||
error_dict.get("loc")
|
||||
and isinstance(error_dict.get("loc"), list)
|
||||
and len(error_dict.get("loc")) == 2
|
||||
):
|
||||
invalid_params.append(error_dict["loc"][1])
|
||||
|
||||
new_data = {k: v for k, v in data.items() if k not in invalid_params}
|
||||
|
||||
return new_data
|
||||
|
||||
|
||||
class BaseOpenAILLM:
|
||||
"""
|
||||
Base class for OpenAI LLMs for getting their httpx clients and SSL verification settings
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def get_cached_openai_client(
|
||||
client_initialization_params: dict, client_type: Literal["openai", "azure"]
|
||||
) -> Optional[Union[OpenAI, AsyncOpenAI, AzureOpenAI, AsyncAzureOpenAI]]:
|
||||
"""Retrieves the OpenAI client from the in-memory cache based on the client initialization parameters"""
|
||||
_cache_key = BaseOpenAILLM.get_openai_client_cache_key(
|
||||
client_initialization_params=client_initialization_params,
|
||||
client_type=client_type,
|
||||
)
|
||||
_cached_client = litellm.in_memory_llm_clients_cache.get_cache(_cache_key)
|
||||
return _cached_client
|
||||
|
||||
@staticmethod
|
||||
def set_cached_openai_client(
|
||||
openai_client: Union[OpenAI, AsyncOpenAI, AzureOpenAI, AsyncAzureOpenAI],
|
||||
client_type: Literal["openai", "azure"],
|
||||
client_initialization_params: dict,
|
||||
):
|
||||
"""Stores the OpenAI client in the in-memory cache for _DEFAULT_TTL_FOR_HTTPX_CLIENTS SECONDS"""
|
||||
_cache_key = BaseOpenAILLM.get_openai_client_cache_key(
|
||||
client_initialization_params=client_initialization_params,
|
||||
client_type=client_type,
|
||||
)
|
||||
litellm.in_memory_llm_clients_cache.set_cache(
|
||||
key=_cache_key,
|
||||
value=openai_client,
|
||||
ttl=_DEFAULT_TTL_FOR_HTTPX_CLIENTS,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_openai_client_cache_key(
|
||||
client_initialization_params: dict, client_type: Literal["openai", "azure"]
|
||||
) -> str:
|
||||
"""Creates a cache key for the OpenAI client based on the client initialization parameters"""
|
||||
hashed_api_key = None
|
||||
if client_initialization_params.get("api_key") is not None:
|
||||
hash_object = hashlib.sha256(
|
||||
client_initialization_params.get("api_key", "").encode()
|
||||
)
|
||||
# Hexadecimal representation of the hash
|
||||
hashed_api_key = hash_object.hexdigest()
|
||||
|
||||
# Create a more readable cache key using a list of key-value pairs
|
||||
key_parts = [
|
||||
f"hashed_api_key={hashed_api_key}",
|
||||
f"is_async={client_initialization_params.get('is_async')}",
|
||||
]
|
||||
|
||||
LITELLM_CLIENT_SPECIFIC_PARAMS = [
|
||||
"timeout",
|
||||
"max_retries",
|
||||
"organization",
|
||||
"api_base",
|
||||
]
|
||||
openai_client_fields = (
|
||||
BaseOpenAILLM.get_openai_client_initialization_param_fields(
|
||||
client_type=client_type
|
||||
)
|
||||
+ LITELLM_CLIENT_SPECIFIC_PARAMS
|
||||
)
|
||||
|
||||
for param in openai_client_fields:
|
||||
key_parts.append(f"{param}={client_initialization_params.get(param)}")
|
||||
|
||||
_cache_key = ",".join(key_parts)
|
||||
return _cache_key
|
||||
|
||||
@staticmethod
|
||||
def get_openai_client_initialization_param_fields(
|
||||
client_type: Literal["openai", "azure"]
|
||||
) -> List[str]:
|
||||
"""Returns a list of fields that are used to initialize the OpenAI client"""
|
||||
import inspect
|
||||
|
||||
from openai import AzureOpenAI, OpenAI
|
||||
|
||||
if client_type == "openai":
|
||||
signature = inspect.signature(OpenAI.__init__)
|
||||
else:
|
||||
signature = inspect.signature(AzureOpenAI.__init__)
|
||||
|
||||
# Extract parameter names, excluding 'self'
|
||||
param_names = [param for param in signature.parameters if param != "self"]
|
||||
return param_names
|
||||
|
||||
@staticmethod
|
||||
def _get_async_http_client() -> Optional[httpx.AsyncClient]:
|
||||
if litellm.aclient_session is not None:
|
||||
return litellm.aclient_session
|
||||
|
||||
return httpx.AsyncClient(
|
||||
limits=httpx.Limits(max_connections=1000, max_keepalive_connections=100),
|
||||
verify=litellm.ssl_verify,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _get_sync_http_client() -> Optional[httpx.Client]:
|
||||
if litellm.client_session is not None:
|
||||
return litellm.client_session
|
||||
return httpx.Client(
|
||||
limits=httpx.Limits(max_connections=1000, max_keepalive_connections=100),
|
||||
verify=litellm.ssl_verify,
|
||||
)
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,318 @@
|
||||
import json
|
||||
from typing import Callable, List, Optional, Union
|
||||
|
||||
from openai import AsyncOpenAI, OpenAI
|
||||
|
||||
import litellm
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
|
||||
from litellm.litellm_core_utils.streaming_handler import CustomStreamWrapper
|
||||
from litellm.llms.base import BaseLLM
|
||||
from litellm.types.llms.openai import AllMessageValues, OpenAITextCompletionUserMessage
|
||||
from litellm.types.utils import LlmProviders, ModelResponse, TextCompletionResponse
|
||||
from litellm.utils import ProviderConfigManager
|
||||
|
||||
from ..common_utils import OpenAIError
|
||||
from .transformation import OpenAITextCompletionConfig
|
||||
|
||||
|
||||
class OpenAITextCompletion(BaseLLM):
|
||||
openai_text_completion_global_config = OpenAITextCompletionConfig()
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def validate_environment(self, api_key):
|
||||
headers = {
|
||||
"content-type": "application/json",
|
||||
}
|
||||
if api_key:
|
||||
headers["Authorization"] = f"Bearer {api_key}"
|
||||
return headers
|
||||
|
||||
def completion(
|
||||
self,
|
||||
model_response: ModelResponse,
|
||||
api_key: str,
|
||||
model: str,
|
||||
messages: Union[List[AllMessageValues], List[OpenAITextCompletionUserMessage]],
|
||||
timeout: float,
|
||||
custom_llm_provider: str,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
optional_params: dict,
|
||||
print_verbose: Optional[Callable] = None,
|
||||
api_base: Optional[str] = None,
|
||||
acompletion: bool = False,
|
||||
litellm_params=None,
|
||||
logger_fn=None,
|
||||
client=None,
|
||||
organization: Optional[str] = None,
|
||||
headers: Optional[dict] = None,
|
||||
):
|
||||
try:
|
||||
if headers is None:
|
||||
headers = self.validate_environment(api_key=api_key)
|
||||
if model is None or messages is None:
|
||||
raise OpenAIError(status_code=422, message="Missing model or messages")
|
||||
|
||||
# don't send max retries to the api, if set
|
||||
|
||||
provider_config = ProviderConfigManager.get_provider_text_completion_config(
|
||||
model=model,
|
||||
provider=LlmProviders(custom_llm_provider),
|
||||
)
|
||||
|
||||
data = provider_config.transform_text_completion_request(
|
||||
model=model,
|
||||
messages=messages,
|
||||
optional_params=optional_params,
|
||||
headers=headers,
|
||||
)
|
||||
max_retries = data.pop("max_retries", 2)
|
||||
## LOGGING
|
||||
logging_obj.pre_call(
|
||||
input=messages,
|
||||
api_key=api_key,
|
||||
additional_args={
|
||||
"headers": headers,
|
||||
"api_base": api_base,
|
||||
"complete_input_dict": data,
|
||||
},
|
||||
)
|
||||
if acompletion is True:
|
||||
if optional_params.get("stream", False):
|
||||
return self.async_streaming(
|
||||
logging_obj=logging_obj,
|
||||
api_base=api_base,
|
||||
api_key=api_key,
|
||||
data=data,
|
||||
headers=headers,
|
||||
model_response=model_response,
|
||||
model=model,
|
||||
timeout=timeout,
|
||||
max_retries=max_retries,
|
||||
client=client,
|
||||
organization=organization,
|
||||
)
|
||||
else:
|
||||
return self.acompletion(api_base=api_base, data=data, headers=headers, model_response=model_response, api_key=api_key, logging_obj=logging_obj, model=model, timeout=timeout, max_retries=max_retries, organization=organization, client=client) # type: ignore
|
||||
elif optional_params.get("stream", False):
|
||||
return self.streaming(
|
||||
logging_obj=logging_obj,
|
||||
api_base=api_base,
|
||||
api_key=api_key,
|
||||
data=data,
|
||||
headers=headers,
|
||||
model_response=model_response,
|
||||
model=model,
|
||||
timeout=timeout,
|
||||
max_retries=max_retries, # type: ignore
|
||||
client=client,
|
||||
organization=organization,
|
||||
)
|
||||
else:
|
||||
if client is None:
|
||||
openai_client = OpenAI(
|
||||
api_key=api_key,
|
||||
base_url=api_base,
|
||||
http_client=litellm.client_session,
|
||||
timeout=timeout,
|
||||
max_retries=max_retries, # type: ignore
|
||||
organization=organization,
|
||||
)
|
||||
else:
|
||||
openai_client = client
|
||||
|
||||
raw_response = openai_client.completions.with_raw_response.create(**data) # type: ignore
|
||||
response = raw_response.parse()
|
||||
response_json = response.model_dump()
|
||||
|
||||
## LOGGING
|
||||
logging_obj.post_call(
|
||||
api_key=api_key,
|
||||
original_response=response_json,
|
||||
additional_args={
|
||||
"headers": headers,
|
||||
"api_base": api_base,
|
||||
},
|
||||
)
|
||||
|
||||
## RESPONSE OBJECT
|
||||
return TextCompletionResponse(**response_json)
|
||||
except Exception as e:
|
||||
status_code = getattr(e, "status_code", 500)
|
||||
error_headers = getattr(e, "headers", None)
|
||||
error_text = getattr(e, "text", str(e))
|
||||
error_response = getattr(e, "response", None)
|
||||
if error_headers is None and error_response:
|
||||
error_headers = getattr(error_response, "headers", None)
|
||||
raise OpenAIError(
|
||||
status_code=status_code, message=error_text, headers=error_headers
|
||||
)
|
||||
|
||||
async def acompletion(
|
||||
self,
|
||||
logging_obj,
|
||||
api_base: str,
|
||||
data: dict,
|
||||
headers: dict,
|
||||
model_response: ModelResponse,
|
||||
api_key: str,
|
||||
model: str,
|
||||
timeout: float,
|
||||
max_retries: int,
|
||||
organization: Optional[str] = None,
|
||||
client=None,
|
||||
):
|
||||
try:
|
||||
if client is None:
|
||||
openai_aclient = AsyncOpenAI(
|
||||
api_key=api_key,
|
||||
base_url=api_base,
|
||||
http_client=litellm.aclient_session,
|
||||
timeout=timeout,
|
||||
max_retries=max_retries,
|
||||
organization=organization,
|
||||
)
|
||||
else:
|
||||
openai_aclient = client
|
||||
|
||||
raw_response = await openai_aclient.completions.with_raw_response.create(
|
||||
**data
|
||||
)
|
||||
response = raw_response.parse()
|
||||
response_json = response.model_dump()
|
||||
|
||||
## LOGGING
|
||||
logging_obj.post_call(
|
||||
api_key=api_key,
|
||||
original_response=response,
|
||||
additional_args={
|
||||
"headers": headers,
|
||||
"api_base": api_base,
|
||||
},
|
||||
)
|
||||
## RESPONSE OBJECT
|
||||
response_obj = TextCompletionResponse(**response_json)
|
||||
response_obj._hidden_params.original_response = json.dumps(response_json)
|
||||
return response_obj
|
||||
except Exception as e:
|
||||
status_code = getattr(e, "status_code", 500)
|
||||
error_headers = getattr(e, "headers", None)
|
||||
error_text = getattr(e, "text", str(e))
|
||||
error_response = getattr(e, "response", None)
|
||||
if error_headers is None and error_response:
|
||||
error_headers = getattr(error_response, "headers", None)
|
||||
raise OpenAIError(
|
||||
status_code=status_code, message=error_text, headers=error_headers
|
||||
)
|
||||
|
||||
def streaming(
|
||||
self,
|
||||
logging_obj,
|
||||
api_key: str,
|
||||
data: dict,
|
||||
headers: dict,
|
||||
model_response: ModelResponse,
|
||||
model: str,
|
||||
timeout: float,
|
||||
api_base: Optional[str] = None,
|
||||
max_retries=None,
|
||||
client=None,
|
||||
organization=None,
|
||||
):
|
||||
if client is None:
|
||||
openai_client = OpenAI(
|
||||
api_key=api_key,
|
||||
base_url=api_base,
|
||||
http_client=litellm.client_session,
|
||||
timeout=timeout,
|
||||
max_retries=max_retries, # type: ignore
|
||||
organization=organization,
|
||||
)
|
||||
else:
|
||||
openai_client = client
|
||||
|
||||
try:
|
||||
raw_response = openai_client.completions.with_raw_response.create(**data)
|
||||
response = raw_response.parse()
|
||||
except Exception as e:
|
||||
status_code = getattr(e, "status_code", 500)
|
||||
error_headers = getattr(e, "headers", None)
|
||||
error_text = getattr(e, "text", str(e))
|
||||
error_response = getattr(e, "response", None)
|
||||
if error_headers is None and error_response:
|
||||
error_headers = getattr(error_response, "headers", None)
|
||||
raise OpenAIError(
|
||||
status_code=status_code, message=error_text, headers=error_headers
|
||||
)
|
||||
streamwrapper = CustomStreamWrapper(
|
||||
completion_stream=response,
|
||||
model=model,
|
||||
custom_llm_provider="text-completion-openai",
|
||||
logging_obj=logging_obj,
|
||||
stream_options=data.get("stream_options", None),
|
||||
)
|
||||
|
||||
try:
|
||||
for chunk in streamwrapper:
|
||||
yield chunk
|
||||
except Exception as e:
|
||||
status_code = getattr(e, "status_code", 500)
|
||||
error_headers = getattr(e, "headers", None)
|
||||
error_text = getattr(e, "text", str(e))
|
||||
error_response = getattr(e, "response", None)
|
||||
if error_headers is None and error_response:
|
||||
error_headers = getattr(error_response, "headers", None)
|
||||
raise OpenAIError(
|
||||
status_code=status_code, message=error_text, headers=error_headers
|
||||
)
|
||||
|
||||
async def async_streaming(
|
||||
self,
|
||||
logging_obj,
|
||||
api_key: str,
|
||||
data: dict,
|
||||
headers: dict,
|
||||
model_response: ModelResponse,
|
||||
model: str,
|
||||
timeout: float,
|
||||
max_retries: int,
|
||||
api_base: Optional[str] = None,
|
||||
client=None,
|
||||
organization=None,
|
||||
):
|
||||
if client is None:
|
||||
openai_client = AsyncOpenAI(
|
||||
api_key=api_key,
|
||||
base_url=api_base,
|
||||
http_client=litellm.aclient_session,
|
||||
timeout=timeout,
|
||||
max_retries=max_retries,
|
||||
organization=organization,
|
||||
)
|
||||
else:
|
||||
openai_client = client
|
||||
|
||||
raw_response = await openai_client.completions.with_raw_response.create(**data)
|
||||
response = raw_response.parse()
|
||||
streamwrapper = CustomStreamWrapper(
|
||||
completion_stream=response,
|
||||
model=model,
|
||||
custom_llm_provider="text-completion-openai",
|
||||
logging_obj=logging_obj,
|
||||
stream_options=data.get("stream_options", None),
|
||||
)
|
||||
|
||||
try:
|
||||
async for transformed_chunk in streamwrapper:
|
||||
yield transformed_chunk
|
||||
except Exception as e:
|
||||
status_code = getattr(e, "status_code", 500)
|
||||
error_headers = getattr(e, "headers", None)
|
||||
error_text = getattr(e, "text", str(e))
|
||||
error_response = getattr(e, "response", None)
|
||||
if error_headers is None and error_response:
|
||||
error_headers = getattr(error_response, "headers", None)
|
||||
raise OpenAIError(
|
||||
status_code=status_code, message=error_text, headers=error_headers
|
||||
)
|
||||
@@ -0,0 +1,158 @@
|
||||
"""
|
||||
Support for gpt model family
|
||||
"""
|
||||
|
||||
from typing import List, Optional, Union
|
||||
|
||||
from litellm.llms.base_llm.completion.transformation import BaseTextCompletionConfig
|
||||
from litellm.types.llms.openai import AllMessageValues, OpenAITextCompletionUserMessage
|
||||
from litellm.types.utils import Choices, Message, ModelResponse, TextCompletionResponse
|
||||
|
||||
from ..chat.gpt_transformation import OpenAIGPTConfig
|
||||
from .utils import _transform_prompt
|
||||
|
||||
|
||||
class OpenAITextCompletionConfig(BaseTextCompletionConfig, OpenAIGPTConfig):
|
||||
"""
|
||||
Reference: https://platform.openai.com/docs/api-reference/completions/create
|
||||
|
||||
The class `OpenAITextCompletionConfig` provides configuration for the OpenAI's text completion API interface. Below are the parameters:
|
||||
|
||||
- `best_of` (integer or null): This optional parameter generates server-side completions and returns the one with the highest log probability per token.
|
||||
|
||||
- `echo` (boolean or null): This optional parameter will echo back the prompt in addition to the completion.
|
||||
|
||||
- `frequency_penalty` (number or null): Defaults to 0. It is a numbers from -2.0 to 2.0, where positive values decrease the model's likelihood to repeat the same line.
|
||||
|
||||
- `logit_bias` (map): This optional parameter modifies the likelihood of specified tokens appearing in the completion.
|
||||
|
||||
- `logprobs` (integer or null): This optional parameter includes the log probabilities on the most likely tokens as well as the chosen tokens.
|
||||
|
||||
- `max_tokens` (integer or null): This optional parameter sets the maximum number of tokens to generate in the completion.
|
||||
|
||||
- `n` (integer or null): This optional parameter sets how many completions to generate for each prompt.
|
||||
|
||||
- `presence_penalty` (number or null): Defaults to 0 and can be between -2.0 and 2.0. Positive values increase the model's likelihood to talk about new topics.
|
||||
|
||||
- `stop` (string / array / null): Specifies up to 4 sequences where the API will stop generating further tokens.
|
||||
|
||||
- `suffix` (string or null): Defines the suffix that comes after a completion of inserted text.
|
||||
|
||||
- `temperature` (number or null): This optional parameter defines the sampling temperature to use.
|
||||
|
||||
- `top_p` (number or null): An alternative to sampling with temperature, used for nucleus sampling.
|
||||
"""
|
||||
|
||||
best_of: Optional[int] = None
|
||||
echo: Optional[bool] = None
|
||||
frequency_penalty: Optional[int] = None
|
||||
logit_bias: Optional[dict] = None
|
||||
logprobs: Optional[int] = None
|
||||
max_tokens: Optional[int] = None
|
||||
n: Optional[int] = None
|
||||
presence_penalty: Optional[int] = None
|
||||
stop: Optional[Union[str, list]] = None
|
||||
suffix: Optional[str] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
best_of: Optional[int] = None,
|
||||
echo: Optional[bool] = None,
|
||||
frequency_penalty: Optional[int] = None,
|
||||
logit_bias: Optional[dict] = None,
|
||||
logprobs: Optional[int] = None,
|
||||
max_tokens: Optional[int] = None,
|
||||
n: Optional[int] = None,
|
||||
presence_penalty: Optional[int] = None,
|
||||
stop: Optional[Union[str, list]] = None,
|
||||
suffix: Optional[str] = None,
|
||||
temperature: Optional[float] = None,
|
||||
top_p: Optional[float] = None,
|
||||
) -> None:
|
||||
locals_ = locals().copy()
|
||||
for key, value in locals_.items():
|
||||
if key != "self" and value is not None:
|
||||
setattr(self.__class__, key, value)
|
||||
|
||||
@classmethod
|
||||
def get_config(cls):
|
||||
return super().get_config()
|
||||
|
||||
def convert_to_chat_model_response_object(
|
||||
self,
|
||||
response_object: Optional[TextCompletionResponse] = None,
|
||||
model_response_object: Optional[ModelResponse] = None,
|
||||
):
|
||||
try:
|
||||
## RESPONSE OBJECT
|
||||
if response_object is None or model_response_object is None:
|
||||
raise ValueError("Error in response object format")
|
||||
choice_list = []
|
||||
for idx, choice in enumerate(response_object["choices"]):
|
||||
message = Message(
|
||||
content=choice["text"],
|
||||
role="assistant",
|
||||
)
|
||||
choice = Choices(
|
||||
finish_reason=choice["finish_reason"],
|
||||
index=idx,
|
||||
message=message,
|
||||
logprobs=choice.get("logprobs", None),
|
||||
)
|
||||
choice_list.append(choice)
|
||||
model_response_object.choices = choice_list
|
||||
|
||||
if "usage" in response_object:
|
||||
setattr(model_response_object, "usage", response_object["usage"])
|
||||
|
||||
if "id" in response_object:
|
||||
model_response_object.id = response_object["id"]
|
||||
|
||||
if "model" in response_object:
|
||||
model_response_object.model = response_object["model"]
|
||||
|
||||
model_response_object._hidden_params[
|
||||
"original_response"
|
||||
] = response_object # track original response, if users make a litellm.text_completion() request, we can return the original response
|
||||
return model_response_object
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
def get_supported_openai_params(self, model: str) -> List:
|
||||
return [
|
||||
"functions",
|
||||
"function_call",
|
||||
"temperature",
|
||||
"top_p",
|
||||
"n",
|
||||
"stream",
|
||||
"stream_options",
|
||||
"stop",
|
||||
"max_tokens",
|
||||
"presence_penalty",
|
||||
"frequency_penalty",
|
||||
"logit_bias",
|
||||
"user",
|
||||
"response_format",
|
||||
"seed",
|
||||
"tools",
|
||||
"tool_choice",
|
||||
"max_retries",
|
||||
"logprobs",
|
||||
"top_logprobs",
|
||||
"extra_headers",
|
||||
]
|
||||
|
||||
def transform_text_completion_request(
|
||||
self,
|
||||
model: str,
|
||||
messages: Union[List[AllMessageValues], List[OpenAITextCompletionUserMessage]],
|
||||
optional_params: dict,
|
||||
headers: dict,
|
||||
) -> dict:
|
||||
prompt = _transform_prompt(messages)
|
||||
return {
|
||||
"model": model,
|
||||
"prompt": prompt,
|
||||
**optional_params,
|
||||
}
|
||||
@@ -0,0 +1,50 @@
|
||||
from typing import List, Union, cast
|
||||
|
||||
from litellm.litellm_core_utils.prompt_templates.common_utils import (
|
||||
convert_content_list_to_str,
|
||||
)
|
||||
from litellm.types.llms.openai import (
|
||||
AllMessageValues,
|
||||
AllPromptValues,
|
||||
OpenAITextCompletionUserMessage,
|
||||
)
|
||||
|
||||
|
||||
def is_tokens_or_list_of_tokens(value: List):
|
||||
# Check if it's a list of integers (tokens)
|
||||
if isinstance(value, list) and all(isinstance(item, int) for item in value):
|
||||
return True
|
||||
# Check if it's a list of lists of integers (list of tokens)
|
||||
if isinstance(value, list) and all(
|
||||
isinstance(item, list) and all(isinstance(i, int) for i in item)
|
||||
for item in value
|
||||
):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _transform_prompt(
|
||||
messages: Union[List[AllMessageValues], List[OpenAITextCompletionUserMessage]],
|
||||
) -> AllPromptValues:
|
||||
if len(messages) == 1: # base case
|
||||
message_content = messages[0].get("content")
|
||||
if (
|
||||
message_content
|
||||
and isinstance(message_content, list)
|
||||
and is_tokens_or_list_of_tokens(message_content)
|
||||
):
|
||||
openai_prompt: AllPromptValues = cast(AllPromptValues, message_content)
|
||||
else:
|
||||
openai_prompt = ""
|
||||
content = convert_content_list_to_str(cast(AllMessageValues, messages[0]))
|
||||
openai_prompt += content
|
||||
else:
|
||||
prompt_str_list: List[str] = []
|
||||
for m in messages:
|
||||
try: # expect list of int/list of list of int to be a 1 message array only.
|
||||
content = convert_content_list_to_str(cast(AllMessageValues, m))
|
||||
prompt_str_list.append(content)
|
||||
except Exception as e:
|
||||
raise e
|
||||
openai_prompt = prompt_str_list
|
||||
return openai_prompt
|
||||
@@ -0,0 +1,122 @@
|
||||
"""
|
||||
Helper util for handling openai-specific cost calculation
|
||||
- e.g.: prompt caching
|
||||
"""
|
||||
|
||||
from typing import Literal, Optional, Tuple
|
||||
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.litellm_core_utils.llm_cost_calc.utils import generic_cost_per_token
|
||||
from litellm.types.utils import CallTypes, Usage
|
||||
from litellm.utils import get_model_info
|
||||
|
||||
|
||||
def cost_router(call_type: CallTypes) -> Literal["cost_per_token", "cost_per_second"]:
|
||||
if call_type == CallTypes.atranscription or call_type == CallTypes.transcription:
|
||||
return "cost_per_second"
|
||||
else:
|
||||
return "cost_per_token"
|
||||
|
||||
|
||||
def cost_per_token(model: str, usage: Usage) -> Tuple[float, float]:
|
||||
"""
|
||||
Calculates the cost per token for a given model, prompt tokens, and completion tokens.
|
||||
|
||||
Input:
|
||||
- model: str, the model name without provider prefix
|
||||
- usage: LiteLLM Usage block, containing anthropic caching information
|
||||
|
||||
Returns:
|
||||
Tuple[float, float] - prompt_cost_in_usd, completion_cost_in_usd
|
||||
"""
|
||||
## CALCULATE INPUT COST
|
||||
return generic_cost_per_token(
|
||||
model=model, usage=usage, custom_llm_provider="openai"
|
||||
)
|
||||
# ### Non-cached text tokens
|
||||
# non_cached_text_tokens = usage.prompt_tokens
|
||||
# cached_tokens: Optional[int] = None
|
||||
# if usage.prompt_tokens_details and usage.prompt_tokens_details.cached_tokens:
|
||||
# cached_tokens = usage.prompt_tokens_details.cached_tokens
|
||||
# non_cached_text_tokens = non_cached_text_tokens - cached_tokens
|
||||
# prompt_cost: float = non_cached_text_tokens * model_info["input_cost_per_token"]
|
||||
# ## Prompt Caching cost calculation
|
||||
# if model_info.get("cache_read_input_token_cost") is not None and cached_tokens:
|
||||
# # Note: We read ._cache_read_input_tokens from the Usage - since cost_calculator.py standardizes the cache read tokens on usage._cache_read_input_tokens
|
||||
# prompt_cost += cached_tokens * (
|
||||
# model_info.get("cache_read_input_token_cost", 0) or 0
|
||||
# )
|
||||
|
||||
# _audio_tokens: Optional[int] = (
|
||||
# usage.prompt_tokens_details.audio_tokens
|
||||
# if usage.prompt_tokens_details is not None
|
||||
# else None
|
||||
# )
|
||||
# _audio_cost_per_token: Optional[float] = model_info.get(
|
||||
# "input_cost_per_audio_token"
|
||||
# )
|
||||
# if _audio_tokens is not None and _audio_cost_per_token is not None:
|
||||
# audio_cost: float = _audio_tokens * _audio_cost_per_token
|
||||
# prompt_cost += audio_cost
|
||||
|
||||
# ## CALCULATE OUTPUT COST
|
||||
# completion_cost: float = (
|
||||
# usage["completion_tokens"] * model_info["output_cost_per_token"]
|
||||
# )
|
||||
# _output_cost_per_audio_token: Optional[float] = model_info.get(
|
||||
# "output_cost_per_audio_token"
|
||||
# )
|
||||
# _output_audio_tokens: Optional[int] = (
|
||||
# usage.completion_tokens_details.audio_tokens
|
||||
# if usage.completion_tokens_details is not None
|
||||
# else None
|
||||
# )
|
||||
# if _output_cost_per_audio_token is not None and _output_audio_tokens is not None:
|
||||
# audio_cost = _output_audio_tokens * _output_cost_per_audio_token
|
||||
# completion_cost += audio_cost
|
||||
|
||||
# return prompt_cost, completion_cost
|
||||
|
||||
|
||||
def cost_per_second(
|
||||
model: str, custom_llm_provider: Optional[str], duration: float = 0.0
|
||||
) -> Tuple[float, float]:
|
||||
"""
|
||||
Calculates the cost per second for a given model, prompt tokens, and completion tokens.
|
||||
|
||||
Input:
|
||||
- model: str, the model name without provider prefix
|
||||
- custom_llm_provider: str, the custom llm provider
|
||||
- duration: float, the duration of the response in seconds
|
||||
|
||||
Returns:
|
||||
Tuple[float, float] - prompt_cost_in_usd, completion_cost_in_usd
|
||||
"""
|
||||
## GET MODEL INFO
|
||||
model_info = get_model_info(
|
||||
model=model, custom_llm_provider=custom_llm_provider or "openai"
|
||||
)
|
||||
prompt_cost = 0.0
|
||||
completion_cost = 0.0
|
||||
## Speech / Audio cost calculation
|
||||
if (
|
||||
"output_cost_per_second" in model_info
|
||||
and model_info["output_cost_per_second"] is not None
|
||||
):
|
||||
verbose_logger.debug(
|
||||
f"For model={model} - output_cost_per_second: {model_info.get('output_cost_per_second')}; duration: {duration}"
|
||||
)
|
||||
## COST PER SECOND ##
|
||||
completion_cost = model_info["output_cost_per_second"] * duration
|
||||
elif (
|
||||
"input_cost_per_second" in model_info
|
||||
and model_info["input_cost_per_second"] is not None
|
||||
):
|
||||
verbose_logger.debug(
|
||||
f"For model={model} - input_cost_per_second: {model_info.get('input_cost_per_second')}; duration: {duration}"
|
||||
)
|
||||
## COST PER SECOND ##
|
||||
prompt_cost = model_info["input_cost_per_second"] * duration
|
||||
completion_cost = 0.0
|
||||
|
||||
return prompt_cost, completion_cost
|
||||
Binary file not shown.
@@ -0,0 +1,275 @@
|
||||
from typing import Any, Coroutine, Optional, Union
|
||||
|
||||
import httpx
|
||||
from openai import AsyncAzureOpenAI, AsyncOpenAI, AzureOpenAI, OpenAI
|
||||
from openai.types.fine_tuning import FineTuningJob
|
||||
|
||||
from litellm._logging import verbose_logger
|
||||
|
||||
|
||||
class OpenAIFineTuningAPI:
|
||||
"""
|
||||
OpenAI methods to support for batches
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def get_openai_client(
|
||||
self,
|
||||
api_key: Optional[str],
|
||||
api_base: Optional[str],
|
||||
timeout: Union[float, httpx.Timeout],
|
||||
max_retries: Optional[int],
|
||||
organization: Optional[str],
|
||||
client: Optional[
|
||||
Union[OpenAI, AsyncOpenAI, AzureOpenAI, AsyncAzureOpenAI]
|
||||
] = None,
|
||||
_is_async: bool = False,
|
||||
api_version: Optional[str] = None,
|
||||
litellm_params: Optional[dict] = None,
|
||||
) -> Optional[Union[OpenAI, AsyncOpenAI, AzureOpenAI, AsyncAzureOpenAI,]]:
|
||||
received_args = locals()
|
||||
openai_client: Optional[
|
||||
Union[OpenAI, AsyncOpenAI, AzureOpenAI, AsyncAzureOpenAI]
|
||||
] = None
|
||||
if client is None:
|
||||
data = {}
|
||||
for k, v in received_args.items():
|
||||
if k == "self" or k == "client" or k == "_is_async":
|
||||
pass
|
||||
elif k == "api_base" and v is not None:
|
||||
data["base_url"] = v
|
||||
elif v is not None:
|
||||
data[k] = v
|
||||
if _is_async is True:
|
||||
openai_client = AsyncOpenAI(**data)
|
||||
else:
|
||||
openai_client = OpenAI(**data) # type: ignore
|
||||
else:
|
||||
openai_client = client
|
||||
|
||||
return openai_client
|
||||
|
||||
async def acreate_fine_tuning_job(
|
||||
self,
|
||||
create_fine_tuning_job_data: dict,
|
||||
openai_client: Union[AsyncOpenAI, AsyncAzureOpenAI],
|
||||
) -> FineTuningJob:
|
||||
response = await openai_client.fine_tuning.jobs.create(
|
||||
**create_fine_tuning_job_data
|
||||
)
|
||||
return response
|
||||
|
||||
def create_fine_tuning_job(
|
||||
self,
|
||||
_is_async: bool,
|
||||
create_fine_tuning_job_data: dict,
|
||||
api_key: Optional[str],
|
||||
api_base: Optional[str],
|
||||
api_version: Optional[str],
|
||||
timeout: Union[float, httpx.Timeout],
|
||||
max_retries: Optional[int],
|
||||
organization: Optional[str],
|
||||
client: Optional[
|
||||
Union[OpenAI, AsyncOpenAI, AzureOpenAI, AsyncAzureOpenAI]
|
||||
] = None,
|
||||
) -> Union[FineTuningJob, Coroutine[Any, Any, FineTuningJob]]:
|
||||
openai_client: Optional[
|
||||
Union[OpenAI, AsyncOpenAI, AzureOpenAI, AsyncAzureOpenAI]
|
||||
] = self.get_openai_client(
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
timeout=timeout,
|
||||
max_retries=max_retries,
|
||||
organization=organization,
|
||||
client=client,
|
||||
_is_async=_is_async,
|
||||
api_version=api_version,
|
||||
)
|
||||
if openai_client is None:
|
||||
raise ValueError(
|
||||
"OpenAI client is not initialized. Make sure api_key is passed or OPENAI_API_KEY is set in the environment."
|
||||
)
|
||||
|
||||
if _is_async is True:
|
||||
if not isinstance(openai_client, (AsyncOpenAI, AsyncAzureOpenAI)):
|
||||
raise ValueError(
|
||||
"OpenAI client is not an instance of AsyncOpenAI. Make sure you passed an AsyncOpenAI client."
|
||||
)
|
||||
return self.acreate_fine_tuning_job( # type: ignore
|
||||
create_fine_tuning_job_data=create_fine_tuning_job_data,
|
||||
openai_client=openai_client,
|
||||
)
|
||||
verbose_logger.debug(
|
||||
"creating fine tuning job, args= %s", create_fine_tuning_job_data
|
||||
)
|
||||
response = openai_client.fine_tuning.jobs.create(**create_fine_tuning_job_data)
|
||||
return response
|
||||
|
||||
async def acancel_fine_tuning_job(
|
||||
self,
|
||||
fine_tuning_job_id: str,
|
||||
openai_client: Union[AsyncOpenAI, AsyncAzureOpenAI],
|
||||
) -> FineTuningJob:
|
||||
response = await openai_client.fine_tuning.jobs.cancel(
|
||||
fine_tuning_job_id=fine_tuning_job_id
|
||||
)
|
||||
return response
|
||||
|
||||
def cancel_fine_tuning_job(
|
||||
self,
|
||||
_is_async: bool,
|
||||
fine_tuning_job_id: str,
|
||||
api_key: Optional[str],
|
||||
api_base: Optional[str],
|
||||
api_version: Optional[str],
|
||||
timeout: Union[float, httpx.Timeout],
|
||||
max_retries: Optional[int],
|
||||
organization: Optional[str],
|
||||
client: Optional[
|
||||
Union[OpenAI, AsyncOpenAI, AzureOpenAI, AsyncAzureOpenAI]
|
||||
] = None,
|
||||
):
|
||||
openai_client: Optional[
|
||||
Union[OpenAI, AsyncOpenAI, AzureOpenAI, AsyncAzureOpenAI]
|
||||
] = self.get_openai_client(
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
timeout=timeout,
|
||||
max_retries=max_retries,
|
||||
organization=organization,
|
||||
client=client,
|
||||
_is_async=_is_async,
|
||||
api_version=api_version,
|
||||
)
|
||||
if openai_client is None:
|
||||
raise ValueError(
|
||||
"OpenAI client is not initialized. Make sure api_key is passed or OPENAI_API_KEY is set in the environment."
|
||||
)
|
||||
|
||||
if _is_async is True:
|
||||
if not isinstance(openai_client, (AsyncOpenAI, AsyncAzureOpenAI)):
|
||||
raise ValueError(
|
||||
"OpenAI client is not an instance of AsyncOpenAI. Make sure you passed an AsyncOpenAI client."
|
||||
)
|
||||
return self.acancel_fine_tuning_job( # type: ignore
|
||||
fine_tuning_job_id=fine_tuning_job_id,
|
||||
openai_client=openai_client,
|
||||
)
|
||||
verbose_logger.debug("canceling fine tuning job, args= %s", fine_tuning_job_id)
|
||||
response = openai_client.fine_tuning.jobs.cancel(
|
||||
fine_tuning_job_id=fine_tuning_job_id
|
||||
)
|
||||
return response
|
||||
|
||||
async def alist_fine_tuning_jobs(
|
||||
self,
|
||||
openai_client: Union[AsyncOpenAI, AsyncAzureOpenAI],
|
||||
after: Optional[str] = None,
|
||||
limit: Optional[int] = None,
|
||||
):
|
||||
response = await openai_client.fine_tuning.jobs.list(after=after, limit=limit) # type: ignore
|
||||
return response
|
||||
|
||||
def list_fine_tuning_jobs(
|
||||
self,
|
||||
_is_async: bool,
|
||||
api_key: Optional[str],
|
||||
api_base: Optional[str],
|
||||
api_version: Optional[str],
|
||||
timeout: Union[float, httpx.Timeout],
|
||||
max_retries: Optional[int],
|
||||
organization: Optional[str],
|
||||
client: Optional[
|
||||
Union[OpenAI, AsyncOpenAI, AzureOpenAI, AsyncAzureOpenAI]
|
||||
] = None,
|
||||
after: Optional[str] = None,
|
||||
limit: Optional[int] = None,
|
||||
):
|
||||
openai_client: Optional[
|
||||
Union[OpenAI, AsyncOpenAI, AzureOpenAI, AsyncAzureOpenAI]
|
||||
] = self.get_openai_client(
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
timeout=timeout,
|
||||
max_retries=max_retries,
|
||||
organization=organization,
|
||||
client=client,
|
||||
_is_async=_is_async,
|
||||
api_version=api_version,
|
||||
)
|
||||
if openai_client is None:
|
||||
raise ValueError(
|
||||
"OpenAI client is not initialized. Make sure api_key is passed or OPENAI_API_KEY is set in the environment."
|
||||
)
|
||||
|
||||
if _is_async is True:
|
||||
if not isinstance(openai_client, (AsyncOpenAI, AsyncAzureOpenAI)):
|
||||
raise ValueError(
|
||||
"OpenAI client is not an instance of AsyncOpenAI. Make sure you passed an AsyncOpenAI client."
|
||||
)
|
||||
return self.alist_fine_tuning_jobs( # type: ignore
|
||||
after=after,
|
||||
limit=limit,
|
||||
openai_client=openai_client,
|
||||
)
|
||||
verbose_logger.debug("list fine tuning job, after= %s, limit= %s", after, limit)
|
||||
response = openai_client.fine_tuning.jobs.list(after=after, limit=limit) # type: ignore
|
||||
return response
|
||||
|
||||
async def aretrieve_fine_tuning_job(
|
||||
self,
|
||||
fine_tuning_job_id: str,
|
||||
openai_client: Union[AsyncOpenAI, AsyncAzureOpenAI],
|
||||
) -> FineTuningJob:
|
||||
response = await openai_client.fine_tuning.jobs.retrieve(
|
||||
fine_tuning_job_id=fine_tuning_job_id
|
||||
)
|
||||
return response
|
||||
|
||||
def retrieve_fine_tuning_job(
|
||||
self,
|
||||
_is_async: bool,
|
||||
fine_tuning_job_id: str,
|
||||
api_key: Optional[str],
|
||||
api_base: Optional[str],
|
||||
api_version: Optional[str],
|
||||
timeout: Union[float, httpx.Timeout],
|
||||
max_retries: Optional[int],
|
||||
organization: Optional[str],
|
||||
client: Optional[
|
||||
Union[OpenAI, AsyncOpenAI, AzureOpenAI, AsyncAzureOpenAI]
|
||||
] = None,
|
||||
):
|
||||
openai_client: Optional[
|
||||
Union[OpenAI, AsyncOpenAI, AzureOpenAI, AsyncAzureOpenAI]
|
||||
] = self.get_openai_client(
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
timeout=timeout,
|
||||
max_retries=max_retries,
|
||||
organization=organization,
|
||||
client=client,
|
||||
_is_async=_is_async,
|
||||
api_version=api_version,
|
||||
)
|
||||
if openai_client is None:
|
||||
raise ValueError(
|
||||
"OpenAI client is not initialized. Make sure api_key is passed or OPENAI_API_KEY is set in the environment."
|
||||
)
|
||||
|
||||
if _is_async is True:
|
||||
if not isinstance(openai_client, AsyncOpenAI):
|
||||
raise ValueError(
|
||||
"OpenAI client is not an instance of AsyncOpenAI. Make sure you passed an AsyncOpenAI client."
|
||||
)
|
||||
return self.aretrieve_fine_tuning_job( # type: ignore
|
||||
fine_tuning_job_id=fine_tuning_job_id,
|
||||
openai_client=openai_client,
|
||||
)
|
||||
verbose_logger.debug("retrieving fine tuning job, id= %s", fine_tuning_job_id)
|
||||
response = openai_client.fine_tuning.jobs.retrieve(
|
||||
fine_tuning_job_id=fine_tuning_job_id
|
||||
)
|
||||
return response
|
||||
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,244 @@
|
||||
"""
|
||||
OpenAI Image Variations Handler
|
||||
"""
|
||||
|
||||
from typing import Callable, Optional
|
||||
|
||||
import httpx
|
||||
from openai import AsyncOpenAI, OpenAI
|
||||
|
||||
import litellm
|
||||
from litellm.types.utils import FileTypes, ImageResponse, LlmProviders
|
||||
from litellm.utils import ProviderConfigManager
|
||||
|
||||
from ...base_llm.image_variations.transformation import BaseImageVariationConfig
|
||||
from ...custom_httpx.llm_http_handler import LiteLLMLoggingObj
|
||||
from ..common_utils import OpenAIError
|
||||
|
||||
|
||||
class OpenAIImageVariationsHandler:
|
||||
def get_sync_client(
|
||||
self,
|
||||
client: Optional[OpenAI],
|
||||
init_client_params: dict,
|
||||
):
|
||||
if client is None:
|
||||
openai_client = OpenAI(
|
||||
**init_client_params,
|
||||
)
|
||||
else:
|
||||
openai_client = client
|
||||
return openai_client
|
||||
|
||||
def get_async_client(
|
||||
self, client: Optional[AsyncOpenAI], init_client_params: dict
|
||||
) -> AsyncOpenAI:
|
||||
if client is None:
|
||||
openai_client = AsyncOpenAI(
|
||||
**init_client_params,
|
||||
)
|
||||
else:
|
||||
openai_client = client
|
||||
return openai_client
|
||||
|
||||
async def async_image_variations(
|
||||
self,
|
||||
api_key: str,
|
||||
api_base: str,
|
||||
organization: Optional[str],
|
||||
client: Optional[AsyncOpenAI],
|
||||
data: dict,
|
||||
headers: dict,
|
||||
model: Optional[str],
|
||||
timeout: float,
|
||||
max_retries: int,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
model_response: ImageResponse,
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
image: FileTypes,
|
||||
provider_config: BaseImageVariationConfig,
|
||||
) -> ImageResponse:
|
||||
try:
|
||||
init_client_params = {
|
||||
"api_key": api_key,
|
||||
"base_url": api_base,
|
||||
"http_client": litellm.client_session,
|
||||
"timeout": timeout,
|
||||
"max_retries": max_retries, # type: ignore
|
||||
"organization": organization,
|
||||
}
|
||||
|
||||
client = self.get_async_client(
|
||||
client=client, init_client_params=init_client_params
|
||||
)
|
||||
|
||||
raw_response = await client.images.with_raw_response.create_variation(**data) # type: ignore
|
||||
response = raw_response.parse()
|
||||
response_json = response.model_dump()
|
||||
|
||||
## LOGGING
|
||||
logging_obj.post_call(
|
||||
api_key=api_key,
|
||||
original_response=response_json,
|
||||
additional_args={
|
||||
"headers": headers,
|
||||
"api_base": api_base,
|
||||
},
|
||||
)
|
||||
|
||||
## RESPONSE OBJECT
|
||||
return provider_config.transform_response_image_variation(
|
||||
model=model,
|
||||
model_response=ImageResponse(**response_json),
|
||||
raw_response=httpx.Response(
|
||||
status_code=200,
|
||||
request=httpx.Request(
|
||||
method="GET", url="https://litellm.ai"
|
||||
), # mock request object
|
||||
),
|
||||
logging_obj=logging_obj,
|
||||
request_data=data,
|
||||
image=image,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
encoding=None,
|
||||
api_key=api_key,
|
||||
)
|
||||
except Exception as e:
|
||||
status_code = getattr(e, "status_code", 500)
|
||||
error_headers = getattr(e, "headers", None)
|
||||
error_text = getattr(e, "text", str(e))
|
||||
error_response = getattr(e, "response", None)
|
||||
if error_headers is None and error_response:
|
||||
error_headers = getattr(error_response, "headers", None)
|
||||
raise OpenAIError(
|
||||
status_code=status_code, message=error_text, headers=error_headers
|
||||
)
|
||||
|
||||
def image_variations(
|
||||
self,
|
||||
model_response: ImageResponse,
|
||||
api_key: str,
|
||||
api_base: str,
|
||||
model: Optional[str],
|
||||
image: FileTypes,
|
||||
timeout: float,
|
||||
custom_llm_provider: str,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
print_verbose: Optional[Callable] = None,
|
||||
logger_fn=None,
|
||||
client=None,
|
||||
organization: Optional[str] = None,
|
||||
headers: Optional[dict] = None,
|
||||
) -> ImageResponse:
|
||||
try:
|
||||
provider_config = ProviderConfigManager.get_provider_image_variation_config(
|
||||
model=model or "", # openai defaults to dall-e-2
|
||||
provider=LlmProviders.OPENAI,
|
||||
)
|
||||
|
||||
if provider_config is None:
|
||||
raise ValueError(
|
||||
f"image variation provider not found: {custom_llm_provider}."
|
||||
)
|
||||
|
||||
max_retries = optional_params.pop("max_retries", 2)
|
||||
|
||||
data = provider_config.transform_request_image_variation(
|
||||
model=model,
|
||||
image=image,
|
||||
optional_params=optional_params,
|
||||
headers=headers or {},
|
||||
)
|
||||
json_data = data.get("data")
|
||||
if not json_data:
|
||||
raise ValueError(
|
||||
f"data field is required, for openai image variations. Got={data}"
|
||||
)
|
||||
## LOGGING
|
||||
logging_obj.pre_call(
|
||||
input="",
|
||||
api_key=api_key,
|
||||
additional_args={
|
||||
"headers": headers,
|
||||
"api_base": api_base,
|
||||
"complete_input_dict": data,
|
||||
},
|
||||
)
|
||||
if litellm_params.get("async_call", False):
|
||||
return self.async_image_variations(
|
||||
api_base=api_base,
|
||||
data=json_data,
|
||||
headers=headers or {},
|
||||
model_response=model_response,
|
||||
api_key=api_key,
|
||||
logging_obj=logging_obj,
|
||||
model=model,
|
||||
timeout=timeout,
|
||||
max_retries=max_retries,
|
||||
organization=organization,
|
||||
client=client,
|
||||
provider_config=provider_config,
|
||||
image=image,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
) # type: ignore
|
||||
|
||||
init_client_params = {
|
||||
"api_key": api_key,
|
||||
"base_url": api_base,
|
||||
"http_client": litellm.client_session,
|
||||
"timeout": timeout,
|
||||
"max_retries": max_retries, # type: ignore
|
||||
"organization": organization,
|
||||
}
|
||||
|
||||
client = self.get_sync_client(
|
||||
client=client, init_client_params=init_client_params
|
||||
)
|
||||
|
||||
raw_response = client.images.with_raw_response.create_variation(**json_data) # type: ignore
|
||||
response = raw_response.parse()
|
||||
response_json = response.model_dump()
|
||||
|
||||
## LOGGING
|
||||
logging_obj.post_call(
|
||||
api_key=api_key,
|
||||
original_response=response_json,
|
||||
additional_args={
|
||||
"headers": headers,
|
||||
"api_base": api_base,
|
||||
},
|
||||
)
|
||||
|
||||
## RESPONSE OBJECT
|
||||
return provider_config.transform_response_image_variation(
|
||||
model=model,
|
||||
model_response=ImageResponse(**response_json),
|
||||
raw_response=httpx.Response(
|
||||
status_code=200,
|
||||
request=httpx.Request(
|
||||
method="GET", url="https://litellm.ai"
|
||||
), # mock request object
|
||||
),
|
||||
logging_obj=logging_obj,
|
||||
request_data=json_data,
|
||||
image=image,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
encoding=None,
|
||||
api_key=api_key,
|
||||
)
|
||||
except Exception as e:
|
||||
status_code = getattr(e, "status_code", 500)
|
||||
error_headers = getattr(e, "headers", None)
|
||||
error_text = getattr(e, "text", str(e))
|
||||
error_response = getattr(e, "response", None)
|
||||
if error_headers is None and error_response:
|
||||
error_headers = getattr(error_response, "headers", None)
|
||||
raise OpenAIError(
|
||||
status_code=status_code, message=error_text, headers=error_headers
|
||||
)
|
||||
@@ -0,0 +1,82 @@
|
||||
from typing import Any, List, Optional, Union
|
||||
|
||||
from aiohttp import ClientResponse
|
||||
from httpx import Headers, Response
|
||||
|
||||
from litellm.llms.base_llm.chat.transformation import BaseLLMException
|
||||
from litellm.llms.base_llm.image_variations.transformation import LiteLLMLoggingObj
|
||||
from litellm.types.llms.openai import OpenAIImageVariationOptionalParams
|
||||
from litellm.types.utils import FileTypes, HttpHandlerRequestFields, ImageResponse
|
||||
|
||||
from ...base_llm.image_variations.transformation import BaseImageVariationConfig
|
||||
from ..common_utils import OpenAIError
|
||||
|
||||
|
||||
class OpenAIImageVariationConfig(BaseImageVariationConfig):
|
||||
def get_supported_openai_params(
|
||||
self, model: str
|
||||
) -> List[OpenAIImageVariationOptionalParams]:
|
||||
return ["n", "size", "response_format", "user"]
|
||||
|
||||
def map_openai_params(
|
||||
self,
|
||||
non_default_params: dict,
|
||||
optional_params: dict,
|
||||
model: str,
|
||||
drop_params: bool,
|
||||
) -> dict:
|
||||
optional_params.update(non_default_params)
|
||||
return optional_params
|
||||
|
||||
def transform_request_image_variation(
|
||||
self,
|
||||
model: Optional[str],
|
||||
image: FileTypes,
|
||||
optional_params: dict,
|
||||
headers: dict,
|
||||
) -> HttpHandlerRequestFields:
|
||||
return {
|
||||
"data": {
|
||||
"image": image,
|
||||
**optional_params,
|
||||
}
|
||||
}
|
||||
|
||||
async def async_transform_response_image_variation(
|
||||
self,
|
||||
model: Optional[str],
|
||||
raw_response: ClientResponse,
|
||||
model_response: ImageResponse,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
request_data: dict,
|
||||
image: FileTypes,
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
encoding: Any,
|
||||
api_key: Optional[str] = None,
|
||||
) -> ImageResponse:
|
||||
return model_response
|
||||
|
||||
def transform_response_image_variation(
|
||||
self,
|
||||
model: Optional[str],
|
||||
raw_response: Response,
|
||||
model_response: ImageResponse,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
request_data: dict,
|
||||
image: FileTypes,
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
encoding: Any,
|
||||
api_key: Optional[str] = None,
|
||||
) -> ImageResponse:
|
||||
return model_response
|
||||
|
||||
def get_error_class(
|
||||
self, error_message: str, status_code: int, headers: Union[dict, Headers]
|
||||
) -> BaseLLMException:
|
||||
return OpenAIError(
|
||||
status_code=status_code,
|
||||
message=error_message,
|
||||
headers=headers,
|
||||
)
|
||||
2862
.venv/lib/python3.10/site-packages/litellm/llms/openai/openai.py
Normal file
2862
.venv/lib/python3.10/site-packages/litellm/llms/openai/openai.py
Normal file
File diff suppressed because it is too large
Load Diff
Binary file not shown.
@@ -0,0 +1,73 @@
|
||||
"""
|
||||
This file contains the calling Azure OpenAI's `/openai/realtime` endpoint.
|
||||
|
||||
This requires websockets, and is currently only supported on LiteLLM Proxy.
|
||||
"""
|
||||
|
||||
from typing import Any, Optional
|
||||
|
||||
from ....litellm_core_utils.litellm_logging import Logging as LiteLLMLogging
|
||||
from ....litellm_core_utils.realtime_streaming import RealTimeStreaming
|
||||
from ..openai import OpenAIChatCompletion
|
||||
|
||||
|
||||
class OpenAIRealtime(OpenAIChatCompletion):
|
||||
def _construct_url(self, api_base: str, model: str) -> str:
|
||||
"""
|
||||
Example output:
|
||||
"BACKEND_WS_URL = "wss://localhost:8080/v1/realtime?model=gpt-4o-realtime-preview-2024-10-01"";
|
||||
"""
|
||||
api_base = api_base.replace("https://", "wss://")
|
||||
api_base = api_base.replace("http://", "ws://")
|
||||
return f"{api_base}/v1/realtime?model={model}"
|
||||
|
||||
async def async_realtime(
|
||||
self,
|
||||
model: str,
|
||||
websocket: Any,
|
||||
logging_obj: LiteLLMLogging,
|
||||
api_base: Optional[str] = None,
|
||||
api_key: Optional[str] = None,
|
||||
client: Optional[Any] = None,
|
||||
timeout: Optional[float] = None,
|
||||
):
|
||||
import websockets
|
||||
|
||||
if api_base is None:
|
||||
raise ValueError("api_base is required for Azure OpenAI calls")
|
||||
if api_key is None:
|
||||
raise ValueError("api_key is required for Azure OpenAI calls")
|
||||
|
||||
url = self._construct_url(api_base, model)
|
||||
|
||||
try:
|
||||
async with websockets.connect( # type: ignore
|
||||
url,
|
||||
extra_headers={
|
||||
"Authorization": f"Bearer {api_key}", # type: ignore
|
||||
"OpenAI-Beta": "realtime=v1",
|
||||
},
|
||||
) as backend_ws:
|
||||
realtime_streaming = RealTimeStreaming(
|
||||
websocket, backend_ws, logging_obj
|
||||
)
|
||||
await realtime_streaming.bidirectional_forward()
|
||||
|
||||
except websockets.exceptions.InvalidStatusCode as e: # type: ignore
|
||||
await websocket.close(code=e.status_code, reason=str(e))
|
||||
except Exception as e:
|
||||
try:
|
||||
await websocket.close(
|
||||
code=1011, reason=f"Internal server error: {str(e)}"
|
||||
)
|
||||
except RuntimeError as close_error:
|
||||
if "already completed" in str(close_error) or "websocket.close" in str(
|
||||
close_error
|
||||
):
|
||||
# The WebSocket is already closed or the response is completed, so we can ignore this error
|
||||
pass
|
||||
else:
|
||||
# If it's a different RuntimeError, we might want to log it or handle it differently
|
||||
raise Exception(
|
||||
f"Unexpected error while closing WebSocket: {close_error}"
|
||||
)
|
||||
Binary file not shown.
@@ -0,0 +1,252 @@
|
||||
from typing import TYPE_CHECKING, Any, Dict, Optional, Union, cast
|
||||
|
||||
import httpx
|
||||
|
||||
import litellm
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.llms.base_llm.responses.transformation import BaseResponsesAPIConfig
|
||||
from litellm.secret_managers.main import get_secret_str
|
||||
from litellm.types.llms.openai import *
|
||||
from litellm.types.responses.main import *
|
||||
from litellm.types.router import GenericLiteLLMParams
|
||||
|
||||
from ..common_utils import OpenAIError
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
|
||||
|
||||
LiteLLMLoggingObj = _LiteLLMLoggingObj
|
||||
else:
|
||||
LiteLLMLoggingObj = Any
|
||||
|
||||
|
||||
class OpenAIResponsesAPIConfig(BaseResponsesAPIConfig):
|
||||
def get_supported_openai_params(self, model: str) -> list:
|
||||
"""
|
||||
All OpenAI Responses API params are supported
|
||||
"""
|
||||
return [
|
||||
"input",
|
||||
"model",
|
||||
"include",
|
||||
"instructions",
|
||||
"max_output_tokens",
|
||||
"metadata",
|
||||
"parallel_tool_calls",
|
||||
"previous_response_id",
|
||||
"reasoning",
|
||||
"store",
|
||||
"stream",
|
||||
"temperature",
|
||||
"text",
|
||||
"tool_choice",
|
||||
"tools",
|
||||
"top_p",
|
||||
"truncation",
|
||||
"user",
|
||||
"extra_headers",
|
||||
"extra_query",
|
||||
"extra_body",
|
||||
"timeout",
|
||||
]
|
||||
|
||||
def map_openai_params(
|
||||
self,
|
||||
response_api_optional_params: ResponsesAPIOptionalRequestParams,
|
||||
model: str,
|
||||
drop_params: bool,
|
||||
) -> Dict:
|
||||
"""No mapping applied since inputs are in OpenAI spec already"""
|
||||
return dict(response_api_optional_params)
|
||||
|
||||
def transform_responses_api_request(
|
||||
self,
|
||||
model: str,
|
||||
input: Union[str, ResponseInputParam],
|
||||
response_api_optional_request_params: Dict,
|
||||
litellm_params: GenericLiteLLMParams,
|
||||
headers: dict,
|
||||
) -> Dict:
|
||||
"""No transform applied since inputs are in OpenAI spec already"""
|
||||
return dict(
|
||||
ResponsesAPIRequestParams(
|
||||
model=model, input=input, **response_api_optional_request_params
|
||||
)
|
||||
)
|
||||
|
||||
def transform_response_api_response(
|
||||
self,
|
||||
model: str,
|
||||
raw_response: httpx.Response,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
) -> ResponsesAPIResponse:
|
||||
"""No transform applied since outputs are in OpenAI spec already"""
|
||||
try:
|
||||
raw_response_json = raw_response.json()
|
||||
except Exception:
|
||||
raise OpenAIError(
|
||||
message=raw_response.text, status_code=raw_response.status_code
|
||||
)
|
||||
return ResponsesAPIResponse(**raw_response_json)
|
||||
|
||||
def validate_environment(
|
||||
self,
|
||||
headers: dict,
|
||||
model: str,
|
||||
api_key: Optional[str] = None,
|
||||
) -> dict:
|
||||
api_key = (
|
||||
api_key
|
||||
or litellm.api_key
|
||||
or litellm.openai_key
|
||||
or get_secret_str("OPENAI_API_KEY")
|
||||
)
|
||||
headers.update(
|
||||
{
|
||||
"Authorization": f"Bearer {api_key}",
|
||||
}
|
||||
)
|
||||
return headers
|
||||
|
||||
def get_complete_url(
|
||||
self,
|
||||
api_base: Optional[str],
|
||||
litellm_params: dict,
|
||||
) -> str:
|
||||
"""
|
||||
Get the endpoint for OpenAI responses API
|
||||
"""
|
||||
api_base = (
|
||||
api_base
|
||||
or litellm.api_base
|
||||
or get_secret_str("OPENAI_API_BASE")
|
||||
or "https://api.openai.com/v1"
|
||||
)
|
||||
|
||||
# Remove trailing slashes
|
||||
api_base = api_base.rstrip("/")
|
||||
|
||||
return f"{api_base}/responses"
|
||||
|
||||
def transform_streaming_response(
|
||||
self,
|
||||
model: str,
|
||||
parsed_chunk: dict,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
) -> ResponsesAPIStreamingResponse:
|
||||
"""
|
||||
Transform a parsed streaming response chunk into a ResponsesAPIStreamingResponse
|
||||
"""
|
||||
# Convert the dictionary to a properly typed ResponsesAPIStreamingResponse
|
||||
verbose_logger.debug("Raw OpenAI Chunk=%s", parsed_chunk)
|
||||
event_type = str(parsed_chunk.get("type"))
|
||||
event_pydantic_model = OpenAIResponsesAPIConfig.get_event_model_class(
|
||||
event_type=event_type
|
||||
)
|
||||
return event_pydantic_model(**parsed_chunk)
|
||||
|
||||
@staticmethod
|
||||
def get_event_model_class(event_type: str) -> Any:
|
||||
"""
|
||||
Returns the appropriate event model class based on the event type.
|
||||
|
||||
Args:
|
||||
event_type (str): The type of event from the response chunk
|
||||
|
||||
Returns:
|
||||
Any: The corresponding event model class
|
||||
|
||||
Raises:
|
||||
ValueError: If the event type is unknown
|
||||
"""
|
||||
event_models = {
|
||||
ResponsesAPIStreamEvents.RESPONSE_CREATED: ResponseCreatedEvent,
|
||||
ResponsesAPIStreamEvents.RESPONSE_IN_PROGRESS: ResponseInProgressEvent,
|
||||
ResponsesAPIStreamEvents.RESPONSE_COMPLETED: ResponseCompletedEvent,
|
||||
ResponsesAPIStreamEvents.RESPONSE_FAILED: ResponseFailedEvent,
|
||||
ResponsesAPIStreamEvents.RESPONSE_INCOMPLETE: ResponseIncompleteEvent,
|
||||
ResponsesAPIStreamEvents.OUTPUT_ITEM_ADDED: OutputItemAddedEvent,
|
||||
ResponsesAPIStreamEvents.OUTPUT_ITEM_DONE: OutputItemDoneEvent,
|
||||
ResponsesAPIStreamEvents.CONTENT_PART_ADDED: ContentPartAddedEvent,
|
||||
ResponsesAPIStreamEvents.CONTENT_PART_DONE: ContentPartDoneEvent,
|
||||
ResponsesAPIStreamEvents.OUTPUT_TEXT_DELTA: OutputTextDeltaEvent,
|
||||
ResponsesAPIStreamEvents.OUTPUT_TEXT_ANNOTATION_ADDED: OutputTextAnnotationAddedEvent,
|
||||
ResponsesAPIStreamEvents.OUTPUT_TEXT_DONE: OutputTextDoneEvent,
|
||||
ResponsesAPIStreamEvents.REFUSAL_DELTA: RefusalDeltaEvent,
|
||||
ResponsesAPIStreamEvents.REFUSAL_DONE: RefusalDoneEvent,
|
||||
ResponsesAPIStreamEvents.FUNCTION_CALL_ARGUMENTS_DELTA: FunctionCallArgumentsDeltaEvent,
|
||||
ResponsesAPIStreamEvents.FUNCTION_CALL_ARGUMENTS_DONE: FunctionCallArgumentsDoneEvent,
|
||||
ResponsesAPIStreamEvents.FILE_SEARCH_CALL_IN_PROGRESS: FileSearchCallInProgressEvent,
|
||||
ResponsesAPIStreamEvents.FILE_SEARCH_CALL_SEARCHING: FileSearchCallSearchingEvent,
|
||||
ResponsesAPIStreamEvents.FILE_SEARCH_CALL_COMPLETED: FileSearchCallCompletedEvent,
|
||||
ResponsesAPIStreamEvents.WEB_SEARCH_CALL_IN_PROGRESS: WebSearchCallInProgressEvent,
|
||||
ResponsesAPIStreamEvents.WEB_SEARCH_CALL_SEARCHING: WebSearchCallSearchingEvent,
|
||||
ResponsesAPIStreamEvents.WEB_SEARCH_CALL_COMPLETED: WebSearchCallCompletedEvent,
|
||||
ResponsesAPIStreamEvents.ERROR: ErrorEvent,
|
||||
}
|
||||
|
||||
model_class = event_models.get(cast(ResponsesAPIStreamEvents, event_type))
|
||||
if not model_class:
|
||||
return GenericEvent
|
||||
|
||||
return model_class
|
||||
|
||||
def should_fake_stream(
|
||||
self,
|
||||
model: Optional[str],
|
||||
stream: Optional[bool],
|
||||
custom_llm_provider: Optional[str] = None,
|
||||
) -> bool:
|
||||
if stream is not True:
|
||||
return False
|
||||
if model is not None:
|
||||
try:
|
||||
if (
|
||||
litellm.utils.supports_native_streaming(
|
||||
model=model,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
)
|
||||
is False
|
||||
):
|
||||
return True
|
||||
except Exception as e:
|
||||
verbose_logger.debug(
|
||||
f"Error getting model info in OpenAIResponsesAPIConfig: {e}"
|
||||
)
|
||||
return False
|
||||
|
||||
#########################################################
|
||||
########## DELETE RESPONSE API TRANSFORMATION ##############
|
||||
#########################################################
|
||||
def transform_delete_response_api_request(
|
||||
self,
|
||||
response_id: str,
|
||||
api_base: str,
|
||||
litellm_params: GenericLiteLLMParams,
|
||||
headers: dict,
|
||||
) -> Tuple[str, Dict]:
|
||||
"""
|
||||
Transform the delete response API request into a URL and data
|
||||
|
||||
OpenAI API expects the following request
|
||||
- DELETE /v1/responses/{response_id}
|
||||
"""
|
||||
url = f"{api_base}/{response_id}"
|
||||
data: Dict = {}
|
||||
return url, data
|
||||
|
||||
def transform_delete_response_api_response(
|
||||
self,
|
||||
raw_response: httpx.Response,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
) -> DeleteResponseResult:
|
||||
"""
|
||||
Transform the delete response API response into a DeleteResponseResult
|
||||
"""
|
||||
try:
|
||||
raw_response_json = raw_response.json()
|
||||
except Exception:
|
||||
raise OpenAIError(
|
||||
message=raw_response.text, status_code=raw_response.status_code
|
||||
)
|
||||
return DeleteResponseResult(**raw_response_json)
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,34 @@
|
||||
from typing import List
|
||||
|
||||
from litellm.types.llms.openai import OpenAIAudioTranscriptionOptionalParams
|
||||
from litellm.types.utils import FileTypes
|
||||
|
||||
from .whisper_transformation import OpenAIWhisperAudioTranscriptionConfig
|
||||
|
||||
|
||||
class OpenAIGPTAudioTranscriptionConfig(OpenAIWhisperAudioTranscriptionConfig):
|
||||
def get_supported_openai_params(
|
||||
self, model: str
|
||||
) -> List[OpenAIAudioTranscriptionOptionalParams]:
|
||||
"""
|
||||
Get the supported OpenAI params for the `gpt-4o-transcribe` models
|
||||
"""
|
||||
return [
|
||||
"language",
|
||||
"prompt",
|
||||
"response_format",
|
||||
"temperature",
|
||||
"include",
|
||||
]
|
||||
|
||||
def transform_audio_transcription_request(
|
||||
self,
|
||||
model: str,
|
||||
audio_file: FileTypes,
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
) -> dict:
|
||||
"""
|
||||
Transform the audio transcription request
|
||||
"""
|
||||
return {"model": model, "file": audio_file, **optional_params}
|
||||
@@ -0,0 +1,222 @@
|
||||
from typing import Optional, Union
|
||||
|
||||
import httpx
|
||||
from openai import AsyncOpenAI, OpenAI
|
||||
from pydantic import BaseModel
|
||||
|
||||
import litellm
|
||||
from litellm.litellm_core_utils.audio_utils.utils import get_audio_file_name
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
|
||||
from litellm.llms.base_llm.audio_transcription.transformation import (
|
||||
BaseAudioTranscriptionConfig,
|
||||
)
|
||||
from litellm.types.utils import FileTypes
|
||||
from litellm.utils import (
|
||||
TranscriptionResponse,
|
||||
convert_to_model_response_object,
|
||||
extract_duration_from_srt_or_vtt,
|
||||
)
|
||||
|
||||
from ..openai import OpenAIChatCompletion
|
||||
|
||||
|
||||
class OpenAIAudioTranscription(OpenAIChatCompletion):
|
||||
# Audio Transcriptions
|
||||
async def make_openai_audio_transcriptions_request(
|
||||
self,
|
||||
openai_aclient: AsyncOpenAI,
|
||||
data: dict,
|
||||
timeout: Union[float, httpx.Timeout],
|
||||
):
|
||||
"""
|
||||
Helper to:
|
||||
- call openai_aclient.audio.transcriptions.with_raw_response when litellm.return_response_headers is True
|
||||
- call openai_aclient.audio.transcriptions.create by default
|
||||
"""
|
||||
try:
|
||||
raw_response = (
|
||||
await openai_aclient.audio.transcriptions.with_raw_response.create(
|
||||
**data, timeout=timeout
|
||||
)
|
||||
) # type: ignore
|
||||
headers = dict(raw_response.headers)
|
||||
response = raw_response.parse()
|
||||
|
||||
return headers, response
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
def make_sync_openai_audio_transcriptions_request(
|
||||
self,
|
||||
openai_client: OpenAI,
|
||||
data: dict,
|
||||
timeout: Union[float, httpx.Timeout],
|
||||
):
|
||||
"""
|
||||
Helper to:
|
||||
- call openai_aclient.audio.transcriptions.with_raw_response when litellm.return_response_headers is True
|
||||
- call openai_aclient.audio.transcriptions.create by default
|
||||
"""
|
||||
try:
|
||||
if litellm.return_response_headers is True:
|
||||
raw_response = (
|
||||
openai_client.audio.transcriptions.with_raw_response.create(
|
||||
**data, timeout=timeout
|
||||
)
|
||||
) # type: ignore
|
||||
headers = dict(raw_response.headers)
|
||||
response = raw_response.parse()
|
||||
return headers, response
|
||||
else:
|
||||
response = openai_client.audio.transcriptions.create(**data, timeout=timeout) # type: ignore
|
||||
return None, response
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
def audio_transcriptions(
|
||||
self,
|
||||
model: str,
|
||||
audio_file: FileTypes,
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
model_response: TranscriptionResponse,
|
||||
timeout: float,
|
||||
max_retries: int,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
api_key: Optional[str],
|
||||
api_base: Optional[str],
|
||||
client=None,
|
||||
atranscription: bool = False,
|
||||
provider_config: Optional[BaseAudioTranscriptionConfig] = None,
|
||||
) -> TranscriptionResponse:
|
||||
"""
|
||||
Handle audio transcription request
|
||||
"""
|
||||
if provider_config is not None:
|
||||
data = provider_config.transform_audio_transcription_request(
|
||||
model=model,
|
||||
audio_file=audio_file,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
)
|
||||
|
||||
if isinstance(data, bytes):
|
||||
raise ValueError("OpenAI transformation route requires a dict")
|
||||
else:
|
||||
data = {"model": model, "file": audio_file, **optional_params}
|
||||
|
||||
if atranscription is True:
|
||||
return self.async_audio_transcriptions( # type: ignore
|
||||
audio_file=audio_file,
|
||||
data=data,
|
||||
model_response=model_response,
|
||||
timeout=timeout,
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
client=client,
|
||||
max_retries=max_retries,
|
||||
logging_obj=logging_obj,
|
||||
)
|
||||
|
||||
openai_client: OpenAI = self._get_openai_client( # type: ignore
|
||||
is_async=False,
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
timeout=timeout,
|
||||
max_retries=max_retries,
|
||||
client=client,
|
||||
)
|
||||
|
||||
## LOGGING
|
||||
logging_obj.pre_call(
|
||||
input=None,
|
||||
api_key=openai_client.api_key,
|
||||
additional_args={
|
||||
"api_base": openai_client._base_url._uri_reference,
|
||||
"atranscription": True,
|
||||
"complete_input_dict": data,
|
||||
},
|
||||
)
|
||||
_, response = self.make_sync_openai_audio_transcriptions_request(
|
||||
openai_client=openai_client,
|
||||
data=data,
|
||||
timeout=timeout,
|
||||
)
|
||||
|
||||
if isinstance(response, BaseModel):
|
||||
stringified_response = response.model_dump()
|
||||
else:
|
||||
stringified_response = TranscriptionResponse(text=response).model_dump()
|
||||
|
||||
## LOGGING
|
||||
logging_obj.post_call(
|
||||
input=get_audio_file_name(audio_file),
|
||||
api_key=api_key,
|
||||
additional_args={"complete_input_dict": data},
|
||||
original_response=stringified_response,
|
||||
)
|
||||
hidden_params = {"model": "whisper-1", "custom_llm_provider": "openai"}
|
||||
final_response: TranscriptionResponse = convert_to_model_response_object(response_object=stringified_response, model_response_object=model_response, hidden_params=hidden_params, response_type="audio_transcription") # type: ignore
|
||||
return final_response
|
||||
|
||||
async def async_audio_transcriptions(
|
||||
self,
|
||||
audio_file: FileTypes,
|
||||
data: dict,
|
||||
model_response: TranscriptionResponse,
|
||||
timeout: float,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
client=None,
|
||||
max_retries=None,
|
||||
):
|
||||
try:
|
||||
openai_aclient: AsyncOpenAI = self._get_openai_client( # type: ignore
|
||||
is_async=True,
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
timeout=timeout,
|
||||
max_retries=max_retries,
|
||||
client=client,
|
||||
)
|
||||
|
||||
## LOGGING
|
||||
logging_obj.pre_call(
|
||||
input=None,
|
||||
api_key=openai_aclient.api_key,
|
||||
additional_args={
|
||||
"api_base": openai_aclient._base_url._uri_reference,
|
||||
"atranscription": True,
|
||||
"complete_input_dict": data,
|
||||
},
|
||||
)
|
||||
headers, response = await self.make_openai_audio_transcriptions_request(
|
||||
openai_aclient=openai_aclient,
|
||||
data=data,
|
||||
timeout=timeout,
|
||||
)
|
||||
logging_obj.model_call_details["response_headers"] = headers
|
||||
if isinstance(response, BaseModel):
|
||||
stringified_response = response.model_dump()
|
||||
else:
|
||||
duration = extract_duration_from_srt_or_vtt(response)
|
||||
stringified_response = TranscriptionResponse(text=response).model_dump()
|
||||
stringified_response["duration"] = duration
|
||||
## LOGGING
|
||||
logging_obj.post_call(
|
||||
input=get_audio_file_name(audio_file),
|
||||
api_key=api_key,
|
||||
additional_args={"complete_input_dict": data},
|
||||
original_response=stringified_response,
|
||||
)
|
||||
hidden_params = {"model": "whisper-1", "custom_llm_provider": "openai"}
|
||||
return convert_to_model_response_object(response_object=stringified_response, model_response_object=model_response, hidden_params=hidden_params, response_type="audio_transcription") # type: ignore
|
||||
except Exception as e:
|
||||
## LOGGING
|
||||
logging_obj.post_call(
|
||||
input=input,
|
||||
api_key=api_key,
|
||||
original_response=str(e),
|
||||
)
|
||||
raise e
|
||||
@@ -0,0 +1,98 @@
|
||||
from typing import List, Optional, Union
|
||||
|
||||
from httpx import Headers
|
||||
|
||||
from litellm.llms.base_llm.audio_transcription.transformation import (
|
||||
BaseAudioTranscriptionConfig,
|
||||
)
|
||||
from litellm.llms.base_llm.chat.transformation import BaseLLMException
|
||||
from litellm.secret_managers.main import get_secret_str
|
||||
from litellm.types.llms.openai import (
|
||||
AllMessageValues,
|
||||
OpenAIAudioTranscriptionOptionalParams,
|
||||
)
|
||||
from litellm.types.utils import FileTypes
|
||||
|
||||
from ..common_utils import OpenAIError
|
||||
|
||||
|
||||
class OpenAIWhisperAudioTranscriptionConfig(BaseAudioTranscriptionConfig):
|
||||
def get_supported_openai_params(
|
||||
self, model: str
|
||||
) -> List[OpenAIAudioTranscriptionOptionalParams]:
|
||||
"""
|
||||
Get the supported OpenAI params for the `whisper-1` models
|
||||
"""
|
||||
return [
|
||||
"language",
|
||||
"prompt",
|
||||
"response_format",
|
||||
"temperature",
|
||||
"timestamp_granularities",
|
||||
]
|
||||
|
||||
def map_openai_params(
|
||||
self,
|
||||
non_default_params: dict,
|
||||
optional_params: dict,
|
||||
model: str,
|
||||
drop_params: bool,
|
||||
) -> dict:
|
||||
"""
|
||||
Map the OpenAI params to the Whisper params
|
||||
"""
|
||||
supported_params = self.get_supported_openai_params(model)
|
||||
for k, v in non_default_params.items():
|
||||
if k in supported_params:
|
||||
optional_params[k] = v
|
||||
return optional_params
|
||||
|
||||
def validate_environment(
|
||||
self,
|
||||
headers: dict,
|
||||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
) -> dict:
|
||||
api_key = api_key or get_secret_str("OPENAI_API_KEY")
|
||||
|
||||
auth_header = {
|
||||
"Authorization": f"Bearer {api_key}",
|
||||
}
|
||||
|
||||
headers.update(auth_header)
|
||||
return headers
|
||||
|
||||
def transform_audio_transcription_request(
|
||||
self,
|
||||
model: str,
|
||||
audio_file: FileTypes,
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
) -> dict:
|
||||
"""
|
||||
Transform the audio transcription request
|
||||
"""
|
||||
|
||||
data = {"model": model, "file": audio_file, **optional_params}
|
||||
|
||||
if "response_format" not in data or (
|
||||
data["response_format"] == "text" or data["response_format"] == "json"
|
||||
):
|
||||
data[
|
||||
"response_format"
|
||||
] = "verbose_json" # ensures 'duration' is received - used for cost calculation
|
||||
|
||||
return data
|
||||
|
||||
def get_error_class(
|
||||
self, error_message: str, status_code: int, headers: Union[dict, Headers]
|
||||
) -> BaseLLMException:
|
||||
return OpenAIError(
|
||||
status_code=status_code,
|
||||
message=error_message,
|
||||
headers=headers,
|
||||
)
|
||||
Reference in New Issue
Block a user