structure saas with tools

This commit is contained in:
Davidson Gomes
2025-04-25 15:30:54 -03:00
commit 1aef473937
16434 changed files with 6584257 additions and 0 deletions

View File

@@ -0,0 +1,48 @@
"""
Support for GPT-4o audio Family
OpenAI Doc: https://platform.openai.com/docs/guides/audio/quickstart?audio-generation-quickstart-example=audio-in&lang=python
"""
import litellm
from .gpt_transformation import OpenAIGPTConfig
class OpenAIGPTAudioConfig(OpenAIGPTConfig):
"""
Reference: https://platform.openai.com/docs/guides/audio
"""
@classmethod
def get_config(cls):
return super().get_config()
def get_supported_openai_params(self, model: str) -> list:
"""
Get the supported OpenAI params for the `gpt-audio` models
"""
all_openai_params = super().get_supported_openai_params(model=model)
audio_specific_params = ["audio"]
return all_openai_params + audio_specific_params
def is_model_gpt_audio_model(self, model: str) -> bool:
if model in litellm.open_ai_chat_completion_models and "audio" in model:
return True
return False
def _map_openai_params(
self,
non_default_params: dict,
optional_params: dict,
model: str,
drop_params: bool,
) -> dict:
return super()._map_openai_params(
non_default_params=non_default_params,
optional_params=optional_params,
model=model,
drop_params=drop_params,
)

View File

@@ -0,0 +1,419 @@
"""
Support for gpt model family
"""
from typing import (
TYPE_CHECKING,
Any,
AsyncIterator,
Iterator,
List,
Optional,
Union,
cast,
)
import httpx
import litellm
from litellm.llms.base_llm.base_model_iterator import BaseModelResponseIterator
from litellm.llms.base_llm.base_utils import BaseLLMModelInfo
from litellm.llms.base_llm.chat.transformation import BaseConfig, BaseLLMException
from litellm.secret_managers.main import get_secret_str
from litellm.types.llms.openai import (
AllMessageValues,
ChatCompletionFileObject,
ChatCompletionFileObjectFile,
ChatCompletionImageObject,
ChatCompletionImageUrlObject,
)
from litellm.types.utils import ModelResponse, ModelResponseStream
from litellm.utils import convert_to_model_response_object
from ..common_utils import OpenAIError
if TYPE_CHECKING:
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
LiteLLMLoggingObj = _LiteLLMLoggingObj
else:
LiteLLMLoggingObj = Any
class OpenAIGPTConfig(BaseLLMModelInfo, BaseConfig):
"""
Reference: https://platform.openai.com/docs/api-reference/chat/create
The class `OpenAIConfig` provides configuration for the OpenAI's Chat API interface. Below are the parameters:
- `frequency_penalty` (number or null): Defaults to 0. Allows a value between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far, thereby minimizing repetition.
- `function_call` (string or object): This optional parameter controls how the model calls functions.
- `functions` (array): An optional parameter. It is a list of functions for which the model may generate JSON inputs.
- `logit_bias` (map): This optional parameter modifies the likelihood of specified tokens appearing in the completion.
- `max_tokens` (integer or null): This optional parameter helps to set the maximum number of tokens to generate in the chat completion.
- `n` (integer or null): This optional parameter helps to set how many chat completion choices to generate for each input message.
- `presence_penalty` (number or null): Defaults to 0. It penalizes new tokens based on if they appear in the text so far, hence increasing the model's likelihood to talk about new topics.
- `stop` (string / array / null): Specifies up to 4 sequences where the API will stop generating further tokens.
- `temperature` (number or null): Defines the sampling temperature to use, varying between 0 and 2.
- `top_p` (number or null): An alternative to sampling with temperature, used for nucleus sampling.
"""
frequency_penalty: Optional[int] = None
function_call: Optional[Union[str, dict]] = None
functions: Optional[list] = None
logit_bias: Optional[dict] = None
max_tokens: Optional[int] = None
n: Optional[int] = None
presence_penalty: Optional[int] = None
stop: Optional[Union[str, list]] = None
temperature: Optional[int] = None
top_p: Optional[int] = None
response_format: Optional[dict] = None
def __init__(
self,
frequency_penalty: Optional[int] = None,
function_call: Optional[Union[str, dict]] = None,
functions: Optional[list] = None,
logit_bias: Optional[dict] = None,
max_tokens: Optional[int] = None,
n: Optional[int] = None,
presence_penalty: Optional[int] = None,
stop: Optional[Union[str, list]] = None,
temperature: Optional[int] = None,
top_p: Optional[int] = None,
response_format: Optional[dict] = None,
) -> None:
locals_ = locals().copy()
for key, value in locals_.items():
if key != "self" and value is not None:
setattr(self.__class__, key, value)
@classmethod
def get_config(cls):
return super().get_config()
def get_supported_openai_params(self, model: str) -> list:
base_params = [
"frequency_penalty",
"logit_bias",
"logprobs",
"top_logprobs",
"max_tokens",
"max_completion_tokens",
"modalities",
"prediction",
"n",
"presence_penalty",
"seed",
"stop",
"stream",
"stream_options",
"temperature",
"top_p",
"tools",
"tool_choice",
"function_call",
"functions",
"max_retries",
"extra_headers",
"parallel_tool_calls",
"audio",
] # works across all models
model_specific_params = []
if (
model != "gpt-3.5-turbo-16k" and model != "gpt-4"
): # gpt-4 does not support 'response_format'
model_specific_params.append("response_format")
if (
model in litellm.open_ai_chat_completion_models
) or model in litellm.open_ai_text_completion_models:
model_specific_params.append(
"user"
) # user is not a param supported by all openai-compatible endpoints - e.g. azure ai
return base_params + model_specific_params
def _map_openai_params(
self,
non_default_params: dict,
optional_params: dict,
model: str,
drop_params: bool,
) -> dict:
"""
If any supported_openai_params are in non_default_params, add them to optional_params, so they are use in API call
Args:
non_default_params (dict): Non-default parameters to filter.
optional_params (dict): Optional parameters to update.
model (str): Model name for parameter support check.
Returns:
dict: Updated optional_params with supported non-default parameters.
"""
supported_openai_params = self.get_supported_openai_params(model)
for param, value in non_default_params.items():
if param in supported_openai_params:
optional_params[param] = value
return optional_params
def map_openai_params(
self,
non_default_params: dict,
optional_params: dict,
model: str,
drop_params: bool,
) -> dict:
return self._map_openai_params(
non_default_params=non_default_params,
optional_params=optional_params,
model=model,
drop_params=drop_params,
)
def _transform_messages(
self, messages: List[AllMessageValues], model: str
) -> List[AllMessageValues]:
"""OpenAI no longer supports image_url as a string, so we need to convert it to a dict"""
for message in messages:
message_content = message.get("content")
if message_content and isinstance(message_content, list):
for content_item in message_content:
litellm_specific_params = {"format"}
if content_item.get("type") == "image_url":
content_item = cast(ChatCompletionImageObject, content_item)
if isinstance(content_item["image_url"], str):
content_item["image_url"] = {
"url": content_item["image_url"],
}
elif isinstance(content_item["image_url"], dict):
new_image_url_obj = ChatCompletionImageUrlObject(
**{ # type: ignore
k: v
for k, v in content_item["image_url"].items()
if k not in litellm_specific_params
}
)
content_item["image_url"] = new_image_url_obj
elif content_item.get("type") == "file":
content_item = cast(ChatCompletionFileObject, content_item)
file_obj = content_item["file"]
new_file_obj = ChatCompletionFileObjectFile(
**{ # type: ignore
k: v
for k, v in file_obj.items()
if k not in litellm_specific_params
}
)
content_item["file"] = new_file_obj
return messages
def transform_request(
self,
model: str,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
headers: dict,
) -> dict:
"""
Transform the overall request to be sent to the API.
Returns:
dict: The transformed request. Sent as the body of the API call.
"""
messages = self._transform_messages(messages=messages, model=model)
return {
"model": model,
"messages": messages,
**optional_params,
}
def transform_response(
self,
model: str,
raw_response: httpx.Response,
model_response: ModelResponse,
logging_obj: LiteLLMLoggingObj,
request_data: dict,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
encoding: Any,
api_key: Optional[str] = None,
json_mode: Optional[bool] = None,
) -> ModelResponse:
"""
Transform the response from the API.
Returns:
dict: The transformed response.
"""
## LOGGING
logging_obj.post_call(
input=messages,
api_key=api_key,
original_response=raw_response.text,
additional_args={"complete_input_dict": request_data},
)
## RESPONSE OBJECT
try:
completion_response = raw_response.json()
except Exception as e:
response_headers = getattr(raw_response, "headers", None)
raise OpenAIError(
message="Unable to get json response - {}, Original Response: {}".format(
str(e), raw_response.text
),
status_code=raw_response.status_code,
headers=response_headers,
)
raw_response_headers = dict(raw_response.headers)
final_response_obj = convert_to_model_response_object(
response_object=completion_response,
model_response_object=model_response,
hidden_params={"headers": raw_response_headers},
_response_headers=raw_response_headers,
)
return cast(ModelResponse, final_response_obj)
def get_error_class(
self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers]
) -> BaseLLMException:
return OpenAIError(
status_code=status_code,
message=error_message,
headers=cast(httpx.Headers, headers),
)
def get_complete_url(
self,
api_base: Optional[str],
api_key: Optional[str],
model: str,
optional_params: dict,
litellm_params: dict,
stream: Optional[bool] = None,
) -> str:
"""
Get the complete URL for the API call.
Returns:
str: The complete URL for the API call.
"""
if api_base is None:
api_base = "https://api.openai.com"
endpoint = "chat/completions"
# Remove trailing slash from api_base if present
api_base = api_base.rstrip("/")
# Check if endpoint is already in the api_base
if endpoint in api_base:
return api_base
return f"{api_base}/{endpoint}"
def validate_environment(
self,
headers: dict,
model: str,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
) -> dict:
if api_key is not None:
headers["Authorization"] = f"Bearer {api_key}"
# Ensure Content-Type is set to application/json
if "content-type" not in headers and "Content-Type" not in headers:
headers["Content-Type"] = "application/json"
return headers
def get_models(
self, api_key: Optional[str] = None, api_base: Optional[str] = None
) -> List[str]:
"""
Calls OpenAI's `/v1/models` endpoint and returns the list of models.
"""
if api_base is None:
api_base = "https://api.openai.com"
if api_key is None:
api_key = get_secret_str("OPENAI_API_KEY")
response = litellm.module_level_client.get(
url=f"{api_base}/v1/models",
headers={"Authorization": f"Bearer {api_key}"},
)
if response.status_code != 200:
raise Exception(f"Failed to get models: {response.text}")
models = response.json()["data"]
return [model["id"] for model in models]
@staticmethod
def get_api_key(api_key: Optional[str] = None) -> Optional[str]:
return (
api_key
or litellm.api_key
or litellm.openai_key
or get_secret_str("OPENAI_API_KEY")
)
@staticmethod
def get_api_base(api_base: Optional[str] = None) -> Optional[str]:
return (
api_base
or litellm.api_base
or get_secret_str("OPENAI_API_BASE")
or "https://api.openai.com/v1"
)
@staticmethod
def get_base_model(model: Optional[str] = None) -> Optional[str]:
return model
def get_model_response_iterator(
self,
streaming_response: Union[Iterator[str], AsyncIterator[str], ModelResponse],
sync_stream: bool,
json_mode: Optional[bool] = False,
) -> Any:
return OpenAIChatCompletionStreamingHandler(
streaming_response=streaming_response,
sync_stream=sync_stream,
json_mode=json_mode,
)
class OpenAIChatCompletionStreamingHandler(BaseModelResponseIterator):
def chunk_parser(self, chunk: dict) -> ModelResponseStream:
try:
return ModelResponseStream(
id=chunk["id"],
object="chat.completion.chunk",
created=chunk["created"],
model=chunk["model"],
choices=chunk["choices"],
)
except Exception as e:
raise e

View File

@@ -0,0 +1,3 @@
"""
LLM Calling done in `openai/openai.py`
"""

View File

@@ -0,0 +1,159 @@
"""
Support for o1/o3 model family
https://platform.openai.com/docs/guides/reasoning
Translations handled by LiteLLM:
- modalities: image => drop param (if user opts in to dropping param)
- role: system ==> translate to role 'user'
- streaming => faked by LiteLLM
- Tools, response_format => drop param (if user opts in to dropping param)
- Logprobs => drop param (if user opts in to dropping param)
"""
from typing import List, Optional
import litellm
from litellm import verbose_logger
from litellm.litellm_core_utils.get_llm_provider_logic import get_llm_provider
from litellm.types.llms.openai import AllMessageValues, ChatCompletionUserMessage
from litellm.utils import (
supports_function_calling,
supports_parallel_function_calling,
supports_response_schema,
supports_system_messages,
)
from .gpt_transformation import OpenAIGPTConfig
class OpenAIOSeriesConfig(OpenAIGPTConfig):
"""
Reference: https://platform.openai.com/docs/guides/reasoning
"""
@classmethod
def get_config(cls):
return super().get_config()
def translate_developer_role_to_system_role(
self, messages: List[AllMessageValues]
) -> List[AllMessageValues]:
"""
O-series models support `developer` role.
"""
return messages
def get_supported_openai_params(self, model: str) -> list:
"""
Get the supported OpenAI params for the given model
"""
all_openai_params = super().get_supported_openai_params(model=model)
non_supported_params = [
"logprobs",
"top_p",
"presence_penalty",
"frequency_penalty",
"top_logprobs",
]
o_series_only_param = ["reasoning_effort"]
all_openai_params.extend(o_series_only_param)
try:
model, custom_llm_provider, api_base, api_key = get_llm_provider(
model=model
)
except Exception:
verbose_logger.debug(
f"Unable to infer model provider for model={model}, defaulting to openai for o1 supported param check"
)
custom_llm_provider = "openai"
_supports_function_calling = supports_function_calling(
model, custom_llm_provider
)
_supports_response_schema = supports_response_schema(model, custom_llm_provider)
_supports_parallel_tool_calls = supports_parallel_function_calling(
model, custom_llm_provider
)
if not _supports_function_calling:
non_supported_params.append("tools")
non_supported_params.append("tool_choice")
non_supported_params.append("function_call")
non_supported_params.append("functions")
if not _supports_parallel_tool_calls:
non_supported_params.append("parallel_tool_calls")
if not _supports_response_schema:
non_supported_params.append("response_format")
return [
param for param in all_openai_params if param not in non_supported_params
]
def map_openai_params(
self,
non_default_params: dict,
optional_params: dict,
model: str,
drop_params: bool,
):
if "max_tokens" in non_default_params:
optional_params["max_completion_tokens"] = non_default_params.pop(
"max_tokens"
)
if "temperature" in non_default_params:
temperature_value: Optional[float] = non_default_params.pop("temperature")
if temperature_value is not None:
if temperature_value == 1:
optional_params["temperature"] = temperature_value
else:
## UNSUPPORTED TOOL CHOICE VALUE
if litellm.drop_params is True or drop_params is True:
pass
else:
raise litellm.utils.UnsupportedParamsError(
message="O-series models don't support temperature={}. Only temperature=1 is supported. To drop unsupported openai params from the call, set `litellm.drop_params = True`".format(
temperature_value
),
status_code=400,
)
return super()._map_openai_params(
non_default_params, optional_params, model, drop_params
)
def is_model_o_series_model(self, model: str) -> bool:
if model in litellm.open_ai_chat_completion_models and (
"o1" in model
or "o3" in model
or "o4"
in model # [TODO] make this a more generic check (e.g. using `openai-o-series` as provider like gemini)
):
return True
return False
def _transform_messages(
self, messages: List[AllMessageValues], model: str
) -> List[AllMessageValues]:
"""
Handles limitations of O-1 model family.
- modalities: image => drop param (if user opts in to dropping param)
- role: system ==> translate to role 'user'
"""
_supports_system_messages = supports_system_messages(model, "openai")
for i, message in enumerate(messages):
if message["role"] == "system" and not _supports_system_messages:
new_message = ChatCompletionUserMessage(
content=message["content"], role="user"
)
messages[i] = new_message # Replace the old message with the new one
messages = super()._transform_messages(messages, model)
return messages

View File

@@ -0,0 +1,208 @@
"""
Common helpers / utils across al OpenAI endpoints
"""
import hashlib
import json
from typing import Any, Dict, List, Literal, Optional, Union
import httpx
import openai
from openai import AsyncAzureOpenAI, AsyncOpenAI, AzureOpenAI, OpenAI
import litellm
from litellm.llms.base_llm.chat.transformation import BaseLLMException
from litellm.llms.custom_httpx.http_handler import _DEFAULT_TTL_FOR_HTTPX_CLIENTS
class OpenAIError(BaseLLMException):
def __init__(
self,
status_code: int,
message: str,
request: Optional[httpx.Request] = None,
response: Optional[httpx.Response] = None,
headers: Optional[Union[dict, httpx.Headers]] = None,
body: Optional[dict] = None,
):
self.status_code = status_code
self.message = message
self.headers = headers
if request:
self.request = request
else:
self.request = httpx.Request(method="POST", url="https://api.openai.com/v1")
if response:
self.response = response
else:
self.response = httpx.Response(
status_code=status_code, request=self.request
)
super().__init__(
status_code=status_code,
message=self.message,
headers=self.headers,
request=self.request,
response=self.response,
body=body,
)
####### Error Handling Utils for OpenAI API #######################
###################################################################
def drop_params_from_unprocessable_entity_error(
e: Union[openai.UnprocessableEntityError, httpx.HTTPStatusError],
data: Dict[str, Any],
) -> Dict[str, Any]:
"""
Helper function to read OpenAI UnprocessableEntityError and drop the params that raised an error from the error message.
Args:
e (UnprocessableEntityError): The UnprocessableEntityError exception
data (Dict[str, Any]): The original data dictionary containing all parameters
Returns:
Dict[str, Any]: A new dictionary with invalid parameters removed
"""
invalid_params: List[str] = []
if isinstance(e, httpx.HTTPStatusError):
error_json = e.response.json()
error_message = error_json.get("error", {})
error_body = error_message
else:
error_body = e.body
if (
error_body is not None
and isinstance(error_body, dict)
and error_body.get("message")
):
message = error_body.get("message", {})
if isinstance(message, str):
try:
message = json.loads(message)
except json.JSONDecodeError:
message = {"detail": message}
detail = message.get("detail")
if isinstance(detail, List) and len(detail) > 0 and isinstance(detail[0], dict):
for error_dict in detail:
if (
error_dict.get("loc")
and isinstance(error_dict.get("loc"), list)
and len(error_dict.get("loc")) == 2
):
invalid_params.append(error_dict["loc"][1])
new_data = {k: v for k, v in data.items() if k not in invalid_params}
return new_data
class BaseOpenAILLM:
"""
Base class for OpenAI LLMs for getting their httpx clients and SSL verification settings
"""
@staticmethod
def get_cached_openai_client(
client_initialization_params: dict, client_type: Literal["openai", "azure"]
) -> Optional[Union[OpenAI, AsyncOpenAI, AzureOpenAI, AsyncAzureOpenAI]]:
"""Retrieves the OpenAI client from the in-memory cache based on the client initialization parameters"""
_cache_key = BaseOpenAILLM.get_openai_client_cache_key(
client_initialization_params=client_initialization_params,
client_type=client_type,
)
_cached_client = litellm.in_memory_llm_clients_cache.get_cache(_cache_key)
return _cached_client
@staticmethod
def set_cached_openai_client(
openai_client: Union[OpenAI, AsyncOpenAI, AzureOpenAI, AsyncAzureOpenAI],
client_type: Literal["openai", "azure"],
client_initialization_params: dict,
):
"""Stores the OpenAI client in the in-memory cache for _DEFAULT_TTL_FOR_HTTPX_CLIENTS SECONDS"""
_cache_key = BaseOpenAILLM.get_openai_client_cache_key(
client_initialization_params=client_initialization_params,
client_type=client_type,
)
litellm.in_memory_llm_clients_cache.set_cache(
key=_cache_key,
value=openai_client,
ttl=_DEFAULT_TTL_FOR_HTTPX_CLIENTS,
)
@staticmethod
def get_openai_client_cache_key(
client_initialization_params: dict, client_type: Literal["openai", "azure"]
) -> str:
"""Creates a cache key for the OpenAI client based on the client initialization parameters"""
hashed_api_key = None
if client_initialization_params.get("api_key") is not None:
hash_object = hashlib.sha256(
client_initialization_params.get("api_key", "").encode()
)
# Hexadecimal representation of the hash
hashed_api_key = hash_object.hexdigest()
# Create a more readable cache key using a list of key-value pairs
key_parts = [
f"hashed_api_key={hashed_api_key}",
f"is_async={client_initialization_params.get('is_async')}",
]
LITELLM_CLIENT_SPECIFIC_PARAMS = [
"timeout",
"max_retries",
"organization",
"api_base",
]
openai_client_fields = (
BaseOpenAILLM.get_openai_client_initialization_param_fields(
client_type=client_type
)
+ LITELLM_CLIENT_SPECIFIC_PARAMS
)
for param in openai_client_fields:
key_parts.append(f"{param}={client_initialization_params.get(param)}")
_cache_key = ",".join(key_parts)
return _cache_key
@staticmethod
def get_openai_client_initialization_param_fields(
client_type: Literal["openai", "azure"]
) -> List[str]:
"""Returns a list of fields that are used to initialize the OpenAI client"""
import inspect
from openai import AzureOpenAI, OpenAI
if client_type == "openai":
signature = inspect.signature(OpenAI.__init__)
else:
signature = inspect.signature(AzureOpenAI.__init__)
# Extract parameter names, excluding 'self'
param_names = [param for param in signature.parameters if param != "self"]
return param_names
@staticmethod
def _get_async_http_client() -> Optional[httpx.AsyncClient]:
if litellm.aclient_session is not None:
return litellm.aclient_session
return httpx.AsyncClient(
limits=httpx.Limits(max_connections=1000, max_keepalive_connections=100),
verify=litellm.ssl_verify,
)
@staticmethod
def _get_sync_http_client() -> Optional[httpx.Client]:
if litellm.client_session is not None:
return litellm.client_session
return httpx.Client(
limits=httpx.Limits(max_connections=1000, max_keepalive_connections=100),
verify=litellm.ssl_verify,
)

View File

@@ -0,0 +1,318 @@
import json
from typing import Callable, List, Optional, Union
from openai import AsyncOpenAI, OpenAI
import litellm
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
from litellm.litellm_core_utils.streaming_handler import CustomStreamWrapper
from litellm.llms.base import BaseLLM
from litellm.types.llms.openai import AllMessageValues, OpenAITextCompletionUserMessage
from litellm.types.utils import LlmProviders, ModelResponse, TextCompletionResponse
from litellm.utils import ProviderConfigManager
from ..common_utils import OpenAIError
from .transformation import OpenAITextCompletionConfig
class OpenAITextCompletion(BaseLLM):
openai_text_completion_global_config = OpenAITextCompletionConfig()
def __init__(self) -> None:
super().__init__()
def validate_environment(self, api_key):
headers = {
"content-type": "application/json",
}
if api_key:
headers["Authorization"] = f"Bearer {api_key}"
return headers
def completion(
self,
model_response: ModelResponse,
api_key: str,
model: str,
messages: Union[List[AllMessageValues], List[OpenAITextCompletionUserMessage]],
timeout: float,
custom_llm_provider: str,
logging_obj: LiteLLMLoggingObj,
optional_params: dict,
print_verbose: Optional[Callable] = None,
api_base: Optional[str] = None,
acompletion: bool = False,
litellm_params=None,
logger_fn=None,
client=None,
organization: Optional[str] = None,
headers: Optional[dict] = None,
):
try:
if headers is None:
headers = self.validate_environment(api_key=api_key)
if model is None or messages is None:
raise OpenAIError(status_code=422, message="Missing model or messages")
# don't send max retries to the api, if set
provider_config = ProviderConfigManager.get_provider_text_completion_config(
model=model,
provider=LlmProviders(custom_llm_provider),
)
data = provider_config.transform_text_completion_request(
model=model,
messages=messages,
optional_params=optional_params,
headers=headers,
)
max_retries = data.pop("max_retries", 2)
## LOGGING
logging_obj.pre_call(
input=messages,
api_key=api_key,
additional_args={
"headers": headers,
"api_base": api_base,
"complete_input_dict": data,
},
)
if acompletion is True:
if optional_params.get("stream", False):
return self.async_streaming(
logging_obj=logging_obj,
api_base=api_base,
api_key=api_key,
data=data,
headers=headers,
model_response=model_response,
model=model,
timeout=timeout,
max_retries=max_retries,
client=client,
organization=organization,
)
else:
return self.acompletion(api_base=api_base, data=data, headers=headers, model_response=model_response, api_key=api_key, logging_obj=logging_obj, model=model, timeout=timeout, max_retries=max_retries, organization=organization, client=client) # type: ignore
elif optional_params.get("stream", False):
return self.streaming(
logging_obj=logging_obj,
api_base=api_base,
api_key=api_key,
data=data,
headers=headers,
model_response=model_response,
model=model,
timeout=timeout,
max_retries=max_retries, # type: ignore
client=client,
organization=organization,
)
else:
if client is None:
openai_client = OpenAI(
api_key=api_key,
base_url=api_base,
http_client=litellm.client_session,
timeout=timeout,
max_retries=max_retries, # type: ignore
organization=organization,
)
else:
openai_client = client
raw_response = openai_client.completions.with_raw_response.create(**data) # type: ignore
response = raw_response.parse()
response_json = response.model_dump()
## LOGGING
logging_obj.post_call(
api_key=api_key,
original_response=response_json,
additional_args={
"headers": headers,
"api_base": api_base,
},
)
## RESPONSE OBJECT
return TextCompletionResponse(**response_json)
except Exception as e:
status_code = getattr(e, "status_code", 500)
error_headers = getattr(e, "headers", None)
error_text = getattr(e, "text", str(e))
error_response = getattr(e, "response", None)
if error_headers is None and error_response:
error_headers = getattr(error_response, "headers", None)
raise OpenAIError(
status_code=status_code, message=error_text, headers=error_headers
)
async def acompletion(
self,
logging_obj,
api_base: str,
data: dict,
headers: dict,
model_response: ModelResponse,
api_key: str,
model: str,
timeout: float,
max_retries: int,
organization: Optional[str] = None,
client=None,
):
try:
if client is None:
openai_aclient = AsyncOpenAI(
api_key=api_key,
base_url=api_base,
http_client=litellm.aclient_session,
timeout=timeout,
max_retries=max_retries,
organization=organization,
)
else:
openai_aclient = client
raw_response = await openai_aclient.completions.with_raw_response.create(
**data
)
response = raw_response.parse()
response_json = response.model_dump()
## LOGGING
logging_obj.post_call(
api_key=api_key,
original_response=response,
additional_args={
"headers": headers,
"api_base": api_base,
},
)
## RESPONSE OBJECT
response_obj = TextCompletionResponse(**response_json)
response_obj._hidden_params.original_response = json.dumps(response_json)
return response_obj
except Exception as e:
status_code = getattr(e, "status_code", 500)
error_headers = getattr(e, "headers", None)
error_text = getattr(e, "text", str(e))
error_response = getattr(e, "response", None)
if error_headers is None and error_response:
error_headers = getattr(error_response, "headers", None)
raise OpenAIError(
status_code=status_code, message=error_text, headers=error_headers
)
def streaming(
self,
logging_obj,
api_key: str,
data: dict,
headers: dict,
model_response: ModelResponse,
model: str,
timeout: float,
api_base: Optional[str] = None,
max_retries=None,
client=None,
organization=None,
):
if client is None:
openai_client = OpenAI(
api_key=api_key,
base_url=api_base,
http_client=litellm.client_session,
timeout=timeout,
max_retries=max_retries, # type: ignore
organization=organization,
)
else:
openai_client = client
try:
raw_response = openai_client.completions.with_raw_response.create(**data)
response = raw_response.parse()
except Exception as e:
status_code = getattr(e, "status_code", 500)
error_headers = getattr(e, "headers", None)
error_text = getattr(e, "text", str(e))
error_response = getattr(e, "response", None)
if error_headers is None and error_response:
error_headers = getattr(error_response, "headers", None)
raise OpenAIError(
status_code=status_code, message=error_text, headers=error_headers
)
streamwrapper = CustomStreamWrapper(
completion_stream=response,
model=model,
custom_llm_provider="text-completion-openai",
logging_obj=logging_obj,
stream_options=data.get("stream_options", None),
)
try:
for chunk in streamwrapper:
yield chunk
except Exception as e:
status_code = getattr(e, "status_code", 500)
error_headers = getattr(e, "headers", None)
error_text = getattr(e, "text", str(e))
error_response = getattr(e, "response", None)
if error_headers is None and error_response:
error_headers = getattr(error_response, "headers", None)
raise OpenAIError(
status_code=status_code, message=error_text, headers=error_headers
)
async def async_streaming(
self,
logging_obj,
api_key: str,
data: dict,
headers: dict,
model_response: ModelResponse,
model: str,
timeout: float,
max_retries: int,
api_base: Optional[str] = None,
client=None,
organization=None,
):
if client is None:
openai_client = AsyncOpenAI(
api_key=api_key,
base_url=api_base,
http_client=litellm.aclient_session,
timeout=timeout,
max_retries=max_retries,
organization=organization,
)
else:
openai_client = client
raw_response = await openai_client.completions.with_raw_response.create(**data)
response = raw_response.parse()
streamwrapper = CustomStreamWrapper(
completion_stream=response,
model=model,
custom_llm_provider="text-completion-openai",
logging_obj=logging_obj,
stream_options=data.get("stream_options", None),
)
try:
async for transformed_chunk in streamwrapper:
yield transformed_chunk
except Exception as e:
status_code = getattr(e, "status_code", 500)
error_headers = getattr(e, "headers", None)
error_text = getattr(e, "text", str(e))
error_response = getattr(e, "response", None)
if error_headers is None and error_response:
error_headers = getattr(error_response, "headers", None)
raise OpenAIError(
status_code=status_code, message=error_text, headers=error_headers
)

View File

@@ -0,0 +1,158 @@
"""
Support for gpt model family
"""
from typing import List, Optional, Union
from litellm.llms.base_llm.completion.transformation import BaseTextCompletionConfig
from litellm.types.llms.openai import AllMessageValues, OpenAITextCompletionUserMessage
from litellm.types.utils import Choices, Message, ModelResponse, TextCompletionResponse
from ..chat.gpt_transformation import OpenAIGPTConfig
from .utils import _transform_prompt
class OpenAITextCompletionConfig(BaseTextCompletionConfig, OpenAIGPTConfig):
"""
Reference: https://platform.openai.com/docs/api-reference/completions/create
The class `OpenAITextCompletionConfig` provides configuration for the OpenAI's text completion API interface. Below are the parameters:
- `best_of` (integer or null): This optional parameter generates server-side completions and returns the one with the highest log probability per token.
- `echo` (boolean or null): This optional parameter will echo back the prompt in addition to the completion.
- `frequency_penalty` (number or null): Defaults to 0. It is a numbers from -2.0 to 2.0, where positive values decrease the model's likelihood to repeat the same line.
- `logit_bias` (map): This optional parameter modifies the likelihood of specified tokens appearing in the completion.
- `logprobs` (integer or null): This optional parameter includes the log probabilities on the most likely tokens as well as the chosen tokens.
- `max_tokens` (integer or null): This optional parameter sets the maximum number of tokens to generate in the completion.
- `n` (integer or null): This optional parameter sets how many completions to generate for each prompt.
- `presence_penalty` (number or null): Defaults to 0 and can be between -2.0 and 2.0. Positive values increase the model's likelihood to talk about new topics.
- `stop` (string / array / null): Specifies up to 4 sequences where the API will stop generating further tokens.
- `suffix` (string or null): Defines the suffix that comes after a completion of inserted text.
- `temperature` (number or null): This optional parameter defines the sampling temperature to use.
- `top_p` (number or null): An alternative to sampling with temperature, used for nucleus sampling.
"""
best_of: Optional[int] = None
echo: Optional[bool] = None
frequency_penalty: Optional[int] = None
logit_bias: Optional[dict] = None
logprobs: Optional[int] = None
max_tokens: Optional[int] = None
n: Optional[int] = None
presence_penalty: Optional[int] = None
stop: Optional[Union[str, list]] = None
suffix: Optional[str] = None
def __init__(
self,
best_of: Optional[int] = None,
echo: Optional[bool] = None,
frequency_penalty: Optional[int] = None,
logit_bias: Optional[dict] = None,
logprobs: Optional[int] = None,
max_tokens: Optional[int] = None,
n: Optional[int] = None,
presence_penalty: Optional[int] = None,
stop: Optional[Union[str, list]] = None,
suffix: Optional[str] = None,
temperature: Optional[float] = None,
top_p: Optional[float] = None,
) -> None:
locals_ = locals().copy()
for key, value in locals_.items():
if key != "self" and value is not None:
setattr(self.__class__, key, value)
@classmethod
def get_config(cls):
return super().get_config()
def convert_to_chat_model_response_object(
self,
response_object: Optional[TextCompletionResponse] = None,
model_response_object: Optional[ModelResponse] = None,
):
try:
## RESPONSE OBJECT
if response_object is None or model_response_object is None:
raise ValueError("Error in response object format")
choice_list = []
for idx, choice in enumerate(response_object["choices"]):
message = Message(
content=choice["text"],
role="assistant",
)
choice = Choices(
finish_reason=choice["finish_reason"],
index=idx,
message=message,
logprobs=choice.get("logprobs", None),
)
choice_list.append(choice)
model_response_object.choices = choice_list
if "usage" in response_object:
setattr(model_response_object, "usage", response_object["usage"])
if "id" in response_object:
model_response_object.id = response_object["id"]
if "model" in response_object:
model_response_object.model = response_object["model"]
model_response_object._hidden_params[
"original_response"
] = response_object # track original response, if users make a litellm.text_completion() request, we can return the original response
return model_response_object
except Exception as e:
raise e
def get_supported_openai_params(self, model: str) -> List:
return [
"functions",
"function_call",
"temperature",
"top_p",
"n",
"stream",
"stream_options",
"stop",
"max_tokens",
"presence_penalty",
"frequency_penalty",
"logit_bias",
"user",
"response_format",
"seed",
"tools",
"tool_choice",
"max_retries",
"logprobs",
"top_logprobs",
"extra_headers",
]
def transform_text_completion_request(
self,
model: str,
messages: Union[List[AllMessageValues], List[OpenAITextCompletionUserMessage]],
optional_params: dict,
headers: dict,
) -> dict:
prompt = _transform_prompt(messages)
return {
"model": model,
"prompt": prompt,
**optional_params,
}

View File

@@ -0,0 +1,50 @@
from typing import List, Union, cast
from litellm.litellm_core_utils.prompt_templates.common_utils import (
convert_content_list_to_str,
)
from litellm.types.llms.openai import (
AllMessageValues,
AllPromptValues,
OpenAITextCompletionUserMessage,
)
def is_tokens_or_list_of_tokens(value: List):
# Check if it's a list of integers (tokens)
if isinstance(value, list) and all(isinstance(item, int) for item in value):
return True
# Check if it's a list of lists of integers (list of tokens)
if isinstance(value, list) and all(
isinstance(item, list) and all(isinstance(i, int) for i in item)
for item in value
):
return True
return False
def _transform_prompt(
messages: Union[List[AllMessageValues], List[OpenAITextCompletionUserMessage]],
) -> AllPromptValues:
if len(messages) == 1: # base case
message_content = messages[0].get("content")
if (
message_content
and isinstance(message_content, list)
and is_tokens_or_list_of_tokens(message_content)
):
openai_prompt: AllPromptValues = cast(AllPromptValues, message_content)
else:
openai_prompt = ""
content = convert_content_list_to_str(cast(AllMessageValues, messages[0]))
openai_prompt += content
else:
prompt_str_list: List[str] = []
for m in messages:
try: # expect list of int/list of list of int to be a 1 message array only.
content = convert_content_list_to_str(cast(AllMessageValues, m))
prompt_str_list.append(content)
except Exception as e:
raise e
openai_prompt = prompt_str_list
return openai_prompt

View File

@@ -0,0 +1,122 @@
"""
Helper util for handling openai-specific cost calculation
- e.g.: prompt caching
"""
from typing import Literal, Optional, Tuple
from litellm._logging import verbose_logger
from litellm.litellm_core_utils.llm_cost_calc.utils import generic_cost_per_token
from litellm.types.utils import CallTypes, Usage
from litellm.utils import get_model_info
def cost_router(call_type: CallTypes) -> Literal["cost_per_token", "cost_per_second"]:
if call_type == CallTypes.atranscription or call_type == CallTypes.transcription:
return "cost_per_second"
else:
return "cost_per_token"
def cost_per_token(model: str, usage: Usage) -> Tuple[float, float]:
"""
Calculates the cost per token for a given model, prompt tokens, and completion tokens.
Input:
- model: str, the model name without provider prefix
- usage: LiteLLM Usage block, containing anthropic caching information
Returns:
Tuple[float, float] - prompt_cost_in_usd, completion_cost_in_usd
"""
## CALCULATE INPUT COST
return generic_cost_per_token(
model=model, usage=usage, custom_llm_provider="openai"
)
# ### Non-cached text tokens
# non_cached_text_tokens = usage.prompt_tokens
# cached_tokens: Optional[int] = None
# if usage.prompt_tokens_details and usage.prompt_tokens_details.cached_tokens:
# cached_tokens = usage.prompt_tokens_details.cached_tokens
# non_cached_text_tokens = non_cached_text_tokens - cached_tokens
# prompt_cost: float = non_cached_text_tokens * model_info["input_cost_per_token"]
# ## Prompt Caching cost calculation
# if model_info.get("cache_read_input_token_cost") is not None and cached_tokens:
# # Note: We read ._cache_read_input_tokens from the Usage - since cost_calculator.py standardizes the cache read tokens on usage._cache_read_input_tokens
# prompt_cost += cached_tokens * (
# model_info.get("cache_read_input_token_cost", 0) or 0
# )
# _audio_tokens: Optional[int] = (
# usage.prompt_tokens_details.audio_tokens
# if usage.prompt_tokens_details is not None
# else None
# )
# _audio_cost_per_token: Optional[float] = model_info.get(
# "input_cost_per_audio_token"
# )
# if _audio_tokens is not None and _audio_cost_per_token is not None:
# audio_cost: float = _audio_tokens * _audio_cost_per_token
# prompt_cost += audio_cost
# ## CALCULATE OUTPUT COST
# completion_cost: float = (
# usage["completion_tokens"] * model_info["output_cost_per_token"]
# )
# _output_cost_per_audio_token: Optional[float] = model_info.get(
# "output_cost_per_audio_token"
# )
# _output_audio_tokens: Optional[int] = (
# usage.completion_tokens_details.audio_tokens
# if usage.completion_tokens_details is not None
# else None
# )
# if _output_cost_per_audio_token is not None and _output_audio_tokens is not None:
# audio_cost = _output_audio_tokens * _output_cost_per_audio_token
# completion_cost += audio_cost
# return prompt_cost, completion_cost
def cost_per_second(
model: str, custom_llm_provider: Optional[str], duration: float = 0.0
) -> Tuple[float, float]:
"""
Calculates the cost per second for a given model, prompt tokens, and completion tokens.
Input:
- model: str, the model name without provider prefix
- custom_llm_provider: str, the custom llm provider
- duration: float, the duration of the response in seconds
Returns:
Tuple[float, float] - prompt_cost_in_usd, completion_cost_in_usd
"""
## GET MODEL INFO
model_info = get_model_info(
model=model, custom_llm_provider=custom_llm_provider or "openai"
)
prompt_cost = 0.0
completion_cost = 0.0
## Speech / Audio cost calculation
if (
"output_cost_per_second" in model_info
and model_info["output_cost_per_second"] is not None
):
verbose_logger.debug(
f"For model={model} - output_cost_per_second: {model_info.get('output_cost_per_second')}; duration: {duration}"
)
## COST PER SECOND ##
completion_cost = model_info["output_cost_per_second"] * duration
elif (
"input_cost_per_second" in model_info
and model_info["input_cost_per_second"] is not None
):
verbose_logger.debug(
f"For model={model} - input_cost_per_second: {model_info.get('input_cost_per_second')}; duration: {duration}"
)
## COST PER SECOND ##
prompt_cost = model_info["input_cost_per_second"] * duration
completion_cost = 0.0
return prompt_cost, completion_cost

View File

@@ -0,0 +1,275 @@
from typing import Any, Coroutine, Optional, Union
import httpx
from openai import AsyncAzureOpenAI, AsyncOpenAI, AzureOpenAI, OpenAI
from openai.types.fine_tuning import FineTuningJob
from litellm._logging import verbose_logger
class OpenAIFineTuningAPI:
"""
OpenAI methods to support for batches
"""
def __init__(self) -> None:
super().__init__()
def get_openai_client(
self,
api_key: Optional[str],
api_base: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
organization: Optional[str],
client: Optional[
Union[OpenAI, AsyncOpenAI, AzureOpenAI, AsyncAzureOpenAI]
] = None,
_is_async: bool = False,
api_version: Optional[str] = None,
litellm_params: Optional[dict] = None,
) -> Optional[Union[OpenAI, AsyncOpenAI, AzureOpenAI, AsyncAzureOpenAI,]]:
received_args = locals()
openai_client: Optional[
Union[OpenAI, AsyncOpenAI, AzureOpenAI, AsyncAzureOpenAI]
] = None
if client is None:
data = {}
for k, v in received_args.items():
if k == "self" or k == "client" or k == "_is_async":
pass
elif k == "api_base" and v is not None:
data["base_url"] = v
elif v is not None:
data[k] = v
if _is_async is True:
openai_client = AsyncOpenAI(**data)
else:
openai_client = OpenAI(**data) # type: ignore
else:
openai_client = client
return openai_client
async def acreate_fine_tuning_job(
self,
create_fine_tuning_job_data: dict,
openai_client: Union[AsyncOpenAI, AsyncAzureOpenAI],
) -> FineTuningJob:
response = await openai_client.fine_tuning.jobs.create(
**create_fine_tuning_job_data
)
return response
def create_fine_tuning_job(
self,
_is_async: bool,
create_fine_tuning_job_data: dict,
api_key: Optional[str],
api_base: Optional[str],
api_version: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
organization: Optional[str],
client: Optional[
Union[OpenAI, AsyncOpenAI, AzureOpenAI, AsyncAzureOpenAI]
] = None,
) -> Union[FineTuningJob, Coroutine[Any, Any, FineTuningJob]]:
openai_client: Optional[
Union[OpenAI, AsyncOpenAI, AzureOpenAI, AsyncAzureOpenAI]
] = self.get_openai_client(
api_key=api_key,
api_base=api_base,
timeout=timeout,
max_retries=max_retries,
organization=organization,
client=client,
_is_async=_is_async,
api_version=api_version,
)
if openai_client is None:
raise ValueError(
"OpenAI client is not initialized. Make sure api_key is passed or OPENAI_API_KEY is set in the environment."
)
if _is_async is True:
if not isinstance(openai_client, (AsyncOpenAI, AsyncAzureOpenAI)):
raise ValueError(
"OpenAI client is not an instance of AsyncOpenAI. Make sure you passed an AsyncOpenAI client."
)
return self.acreate_fine_tuning_job( # type: ignore
create_fine_tuning_job_data=create_fine_tuning_job_data,
openai_client=openai_client,
)
verbose_logger.debug(
"creating fine tuning job, args= %s", create_fine_tuning_job_data
)
response = openai_client.fine_tuning.jobs.create(**create_fine_tuning_job_data)
return response
async def acancel_fine_tuning_job(
self,
fine_tuning_job_id: str,
openai_client: Union[AsyncOpenAI, AsyncAzureOpenAI],
) -> FineTuningJob:
response = await openai_client.fine_tuning.jobs.cancel(
fine_tuning_job_id=fine_tuning_job_id
)
return response
def cancel_fine_tuning_job(
self,
_is_async: bool,
fine_tuning_job_id: str,
api_key: Optional[str],
api_base: Optional[str],
api_version: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
organization: Optional[str],
client: Optional[
Union[OpenAI, AsyncOpenAI, AzureOpenAI, AsyncAzureOpenAI]
] = None,
):
openai_client: Optional[
Union[OpenAI, AsyncOpenAI, AzureOpenAI, AsyncAzureOpenAI]
] = self.get_openai_client(
api_key=api_key,
api_base=api_base,
timeout=timeout,
max_retries=max_retries,
organization=organization,
client=client,
_is_async=_is_async,
api_version=api_version,
)
if openai_client is None:
raise ValueError(
"OpenAI client is not initialized. Make sure api_key is passed or OPENAI_API_KEY is set in the environment."
)
if _is_async is True:
if not isinstance(openai_client, (AsyncOpenAI, AsyncAzureOpenAI)):
raise ValueError(
"OpenAI client is not an instance of AsyncOpenAI. Make sure you passed an AsyncOpenAI client."
)
return self.acancel_fine_tuning_job( # type: ignore
fine_tuning_job_id=fine_tuning_job_id,
openai_client=openai_client,
)
verbose_logger.debug("canceling fine tuning job, args= %s", fine_tuning_job_id)
response = openai_client.fine_tuning.jobs.cancel(
fine_tuning_job_id=fine_tuning_job_id
)
return response
async def alist_fine_tuning_jobs(
self,
openai_client: Union[AsyncOpenAI, AsyncAzureOpenAI],
after: Optional[str] = None,
limit: Optional[int] = None,
):
response = await openai_client.fine_tuning.jobs.list(after=after, limit=limit) # type: ignore
return response
def list_fine_tuning_jobs(
self,
_is_async: bool,
api_key: Optional[str],
api_base: Optional[str],
api_version: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
organization: Optional[str],
client: Optional[
Union[OpenAI, AsyncOpenAI, AzureOpenAI, AsyncAzureOpenAI]
] = None,
after: Optional[str] = None,
limit: Optional[int] = None,
):
openai_client: Optional[
Union[OpenAI, AsyncOpenAI, AzureOpenAI, AsyncAzureOpenAI]
] = self.get_openai_client(
api_key=api_key,
api_base=api_base,
timeout=timeout,
max_retries=max_retries,
organization=organization,
client=client,
_is_async=_is_async,
api_version=api_version,
)
if openai_client is None:
raise ValueError(
"OpenAI client is not initialized. Make sure api_key is passed or OPENAI_API_KEY is set in the environment."
)
if _is_async is True:
if not isinstance(openai_client, (AsyncOpenAI, AsyncAzureOpenAI)):
raise ValueError(
"OpenAI client is not an instance of AsyncOpenAI. Make sure you passed an AsyncOpenAI client."
)
return self.alist_fine_tuning_jobs( # type: ignore
after=after,
limit=limit,
openai_client=openai_client,
)
verbose_logger.debug("list fine tuning job, after= %s, limit= %s", after, limit)
response = openai_client.fine_tuning.jobs.list(after=after, limit=limit) # type: ignore
return response
async def aretrieve_fine_tuning_job(
self,
fine_tuning_job_id: str,
openai_client: Union[AsyncOpenAI, AsyncAzureOpenAI],
) -> FineTuningJob:
response = await openai_client.fine_tuning.jobs.retrieve(
fine_tuning_job_id=fine_tuning_job_id
)
return response
def retrieve_fine_tuning_job(
self,
_is_async: bool,
fine_tuning_job_id: str,
api_key: Optional[str],
api_base: Optional[str],
api_version: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
organization: Optional[str],
client: Optional[
Union[OpenAI, AsyncOpenAI, AzureOpenAI, AsyncAzureOpenAI]
] = None,
):
openai_client: Optional[
Union[OpenAI, AsyncOpenAI, AzureOpenAI, AsyncAzureOpenAI]
] = self.get_openai_client(
api_key=api_key,
api_base=api_base,
timeout=timeout,
max_retries=max_retries,
organization=organization,
client=client,
_is_async=_is_async,
api_version=api_version,
)
if openai_client is None:
raise ValueError(
"OpenAI client is not initialized. Make sure api_key is passed or OPENAI_API_KEY is set in the environment."
)
if _is_async is True:
if not isinstance(openai_client, AsyncOpenAI):
raise ValueError(
"OpenAI client is not an instance of AsyncOpenAI. Make sure you passed an AsyncOpenAI client."
)
return self.aretrieve_fine_tuning_job( # type: ignore
fine_tuning_job_id=fine_tuning_job_id,
openai_client=openai_client,
)
verbose_logger.debug("retrieving fine tuning job, id= %s", fine_tuning_job_id)
response = openai_client.fine_tuning.jobs.retrieve(
fine_tuning_job_id=fine_tuning_job_id
)
return response

View File

@@ -0,0 +1,244 @@
"""
OpenAI Image Variations Handler
"""
from typing import Callable, Optional
import httpx
from openai import AsyncOpenAI, OpenAI
import litellm
from litellm.types.utils import FileTypes, ImageResponse, LlmProviders
from litellm.utils import ProviderConfigManager
from ...base_llm.image_variations.transformation import BaseImageVariationConfig
from ...custom_httpx.llm_http_handler import LiteLLMLoggingObj
from ..common_utils import OpenAIError
class OpenAIImageVariationsHandler:
def get_sync_client(
self,
client: Optional[OpenAI],
init_client_params: dict,
):
if client is None:
openai_client = OpenAI(
**init_client_params,
)
else:
openai_client = client
return openai_client
def get_async_client(
self, client: Optional[AsyncOpenAI], init_client_params: dict
) -> AsyncOpenAI:
if client is None:
openai_client = AsyncOpenAI(
**init_client_params,
)
else:
openai_client = client
return openai_client
async def async_image_variations(
self,
api_key: str,
api_base: str,
organization: Optional[str],
client: Optional[AsyncOpenAI],
data: dict,
headers: dict,
model: Optional[str],
timeout: float,
max_retries: int,
logging_obj: LiteLLMLoggingObj,
model_response: ImageResponse,
optional_params: dict,
litellm_params: dict,
image: FileTypes,
provider_config: BaseImageVariationConfig,
) -> ImageResponse:
try:
init_client_params = {
"api_key": api_key,
"base_url": api_base,
"http_client": litellm.client_session,
"timeout": timeout,
"max_retries": max_retries, # type: ignore
"organization": organization,
}
client = self.get_async_client(
client=client, init_client_params=init_client_params
)
raw_response = await client.images.with_raw_response.create_variation(**data) # type: ignore
response = raw_response.parse()
response_json = response.model_dump()
## LOGGING
logging_obj.post_call(
api_key=api_key,
original_response=response_json,
additional_args={
"headers": headers,
"api_base": api_base,
},
)
## RESPONSE OBJECT
return provider_config.transform_response_image_variation(
model=model,
model_response=ImageResponse(**response_json),
raw_response=httpx.Response(
status_code=200,
request=httpx.Request(
method="GET", url="https://litellm.ai"
), # mock request object
),
logging_obj=logging_obj,
request_data=data,
image=image,
optional_params=optional_params,
litellm_params=litellm_params,
encoding=None,
api_key=api_key,
)
except Exception as e:
status_code = getattr(e, "status_code", 500)
error_headers = getattr(e, "headers", None)
error_text = getattr(e, "text", str(e))
error_response = getattr(e, "response", None)
if error_headers is None and error_response:
error_headers = getattr(error_response, "headers", None)
raise OpenAIError(
status_code=status_code, message=error_text, headers=error_headers
)
def image_variations(
self,
model_response: ImageResponse,
api_key: str,
api_base: str,
model: Optional[str],
image: FileTypes,
timeout: float,
custom_llm_provider: str,
logging_obj: LiteLLMLoggingObj,
optional_params: dict,
litellm_params: dict,
print_verbose: Optional[Callable] = None,
logger_fn=None,
client=None,
organization: Optional[str] = None,
headers: Optional[dict] = None,
) -> ImageResponse:
try:
provider_config = ProviderConfigManager.get_provider_image_variation_config(
model=model or "", # openai defaults to dall-e-2
provider=LlmProviders.OPENAI,
)
if provider_config is None:
raise ValueError(
f"image variation provider not found: {custom_llm_provider}."
)
max_retries = optional_params.pop("max_retries", 2)
data = provider_config.transform_request_image_variation(
model=model,
image=image,
optional_params=optional_params,
headers=headers or {},
)
json_data = data.get("data")
if not json_data:
raise ValueError(
f"data field is required, for openai image variations. Got={data}"
)
## LOGGING
logging_obj.pre_call(
input="",
api_key=api_key,
additional_args={
"headers": headers,
"api_base": api_base,
"complete_input_dict": data,
},
)
if litellm_params.get("async_call", False):
return self.async_image_variations(
api_base=api_base,
data=json_data,
headers=headers or {},
model_response=model_response,
api_key=api_key,
logging_obj=logging_obj,
model=model,
timeout=timeout,
max_retries=max_retries,
organization=organization,
client=client,
provider_config=provider_config,
image=image,
optional_params=optional_params,
litellm_params=litellm_params,
) # type: ignore
init_client_params = {
"api_key": api_key,
"base_url": api_base,
"http_client": litellm.client_session,
"timeout": timeout,
"max_retries": max_retries, # type: ignore
"organization": organization,
}
client = self.get_sync_client(
client=client, init_client_params=init_client_params
)
raw_response = client.images.with_raw_response.create_variation(**json_data) # type: ignore
response = raw_response.parse()
response_json = response.model_dump()
## LOGGING
logging_obj.post_call(
api_key=api_key,
original_response=response_json,
additional_args={
"headers": headers,
"api_base": api_base,
},
)
## RESPONSE OBJECT
return provider_config.transform_response_image_variation(
model=model,
model_response=ImageResponse(**response_json),
raw_response=httpx.Response(
status_code=200,
request=httpx.Request(
method="GET", url="https://litellm.ai"
), # mock request object
),
logging_obj=logging_obj,
request_data=json_data,
image=image,
optional_params=optional_params,
litellm_params=litellm_params,
encoding=None,
api_key=api_key,
)
except Exception as e:
status_code = getattr(e, "status_code", 500)
error_headers = getattr(e, "headers", None)
error_text = getattr(e, "text", str(e))
error_response = getattr(e, "response", None)
if error_headers is None and error_response:
error_headers = getattr(error_response, "headers", None)
raise OpenAIError(
status_code=status_code, message=error_text, headers=error_headers
)

View File

@@ -0,0 +1,82 @@
from typing import Any, List, Optional, Union
from aiohttp import ClientResponse
from httpx import Headers, Response
from litellm.llms.base_llm.chat.transformation import BaseLLMException
from litellm.llms.base_llm.image_variations.transformation import LiteLLMLoggingObj
from litellm.types.llms.openai import OpenAIImageVariationOptionalParams
from litellm.types.utils import FileTypes, HttpHandlerRequestFields, ImageResponse
from ...base_llm.image_variations.transformation import BaseImageVariationConfig
from ..common_utils import OpenAIError
class OpenAIImageVariationConfig(BaseImageVariationConfig):
def get_supported_openai_params(
self, model: str
) -> List[OpenAIImageVariationOptionalParams]:
return ["n", "size", "response_format", "user"]
def map_openai_params(
self,
non_default_params: dict,
optional_params: dict,
model: str,
drop_params: bool,
) -> dict:
optional_params.update(non_default_params)
return optional_params
def transform_request_image_variation(
self,
model: Optional[str],
image: FileTypes,
optional_params: dict,
headers: dict,
) -> HttpHandlerRequestFields:
return {
"data": {
"image": image,
**optional_params,
}
}
async def async_transform_response_image_variation(
self,
model: Optional[str],
raw_response: ClientResponse,
model_response: ImageResponse,
logging_obj: LiteLLMLoggingObj,
request_data: dict,
image: FileTypes,
optional_params: dict,
litellm_params: dict,
encoding: Any,
api_key: Optional[str] = None,
) -> ImageResponse:
return model_response
def transform_response_image_variation(
self,
model: Optional[str],
raw_response: Response,
model_response: ImageResponse,
logging_obj: LiteLLMLoggingObj,
request_data: dict,
image: FileTypes,
optional_params: dict,
litellm_params: dict,
encoding: Any,
api_key: Optional[str] = None,
) -> ImageResponse:
return model_response
def get_error_class(
self, error_message: str, status_code: int, headers: Union[dict, Headers]
) -> BaseLLMException:
return OpenAIError(
status_code=status_code,
message=error_message,
headers=headers,
)

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,73 @@
"""
This file contains the calling Azure OpenAI's `/openai/realtime` endpoint.
This requires websockets, and is currently only supported on LiteLLM Proxy.
"""
from typing import Any, Optional
from ....litellm_core_utils.litellm_logging import Logging as LiteLLMLogging
from ....litellm_core_utils.realtime_streaming import RealTimeStreaming
from ..openai import OpenAIChatCompletion
class OpenAIRealtime(OpenAIChatCompletion):
def _construct_url(self, api_base: str, model: str) -> str:
"""
Example output:
"BACKEND_WS_URL = "wss://localhost:8080/v1/realtime?model=gpt-4o-realtime-preview-2024-10-01"";
"""
api_base = api_base.replace("https://", "wss://")
api_base = api_base.replace("http://", "ws://")
return f"{api_base}/v1/realtime?model={model}"
async def async_realtime(
self,
model: str,
websocket: Any,
logging_obj: LiteLLMLogging,
api_base: Optional[str] = None,
api_key: Optional[str] = None,
client: Optional[Any] = None,
timeout: Optional[float] = None,
):
import websockets
if api_base is None:
raise ValueError("api_base is required for Azure OpenAI calls")
if api_key is None:
raise ValueError("api_key is required for Azure OpenAI calls")
url = self._construct_url(api_base, model)
try:
async with websockets.connect( # type: ignore
url,
extra_headers={
"Authorization": f"Bearer {api_key}", # type: ignore
"OpenAI-Beta": "realtime=v1",
},
) as backend_ws:
realtime_streaming = RealTimeStreaming(
websocket, backend_ws, logging_obj
)
await realtime_streaming.bidirectional_forward()
except websockets.exceptions.InvalidStatusCode as e: # type: ignore
await websocket.close(code=e.status_code, reason=str(e))
except Exception as e:
try:
await websocket.close(
code=1011, reason=f"Internal server error: {str(e)}"
)
except RuntimeError as close_error:
if "already completed" in str(close_error) or "websocket.close" in str(
close_error
):
# The WebSocket is already closed or the response is completed, so we can ignore this error
pass
else:
# If it's a different RuntimeError, we might want to log it or handle it differently
raise Exception(
f"Unexpected error while closing WebSocket: {close_error}"
)

View File

@@ -0,0 +1,252 @@
from typing import TYPE_CHECKING, Any, Dict, Optional, Union, cast
import httpx
import litellm
from litellm._logging import verbose_logger
from litellm.llms.base_llm.responses.transformation import BaseResponsesAPIConfig
from litellm.secret_managers.main import get_secret_str
from litellm.types.llms.openai import *
from litellm.types.responses.main import *
from litellm.types.router import GenericLiteLLMParams
from ..common_utils import OpenAIError
if TYPE_CHECKING:
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
LiteLLMLoggingObj = _LiteLLMLoggingObj
else:
LiteLLMLoggingObj = Any
class OpenAIResponsesAPIConfig(BaseResponsesAPIConfig):
def get_supported_openai_params(self, model: str) -> list:
"""
All OpenAI Responses API params are supported
"""
return [
"input",
"model",
"include",
"instructions",
"max_output_tokens",
"metadata",
"parallel_tool_calls",
"previous_response_id",
"reasoning",
"store",
"stream",
"temperature",
"text",
"tool_choice",
"tools",
"top_p",
"truncation",
"user",
"extra_headers",
"extra_query",
"extra_body",
"timeout",
]
def map_openai_params(
self,
response_api_optional_params: ResponsesAPIOptionalRequestParams,
model: str,
drop_params: bool,
) -> Dict:
"""No mapping applied since inputs are in OpenAI spec already"""
return dict(response_api_optional_params)
def transform_responses_api_request(
self,
model: str,
input: Union[str, ResponseInputParam],
response_api_optional_request_params: Dict,
litellm_params: GenericLiteLLMParams,
headers: dict,
) -> Dict:
"""No transform applied since inputs are in OpenAI spec already"""
return dict(
ResponsesAPIRequestParams(
model=model, input=input, **response_api_optional_request_params
)
)
def transform_response_api_response(
self,
model: str,
raw_response: httpx.Response,
logging_obj: LiteLLMLoggingObj,
) -> ResponsesAPIResponse:
"""No transform applied since outputs are in OpenAI spec already"""
try:
raw_response_json = raw_response.json()
except Exception:
raise OpenAIError(
message=raw_response.text, status_code=raw_response.status_code
)
return ResponsesAPIResponse(**raw_response_json)
def validate_environment(
self,
headers: dict,
model: str,
api_key: Optional[str] = None,
) -> dict:
api_key = (
api_key
or litellm.api_key
or litellm.openai_key
or get_secret_str("OPENAI_API_KEY")
)
headers.update(
{
"Authorization": f"Bearer {api_key}",
}
)
return headers
def get_complete_url(
self,
api_base: Optional[str],
litellm_params: dict,
) -> str:
"""
Get the endpoint for OpenAI responses API
"""
api_base = (
api_base
or litellm.api_base
or get_secret_str("OPENAI_API_BASE")
or "https://api.openai.com/v1"
)
# Remove trailing slashes
api_base = api_base.rstrip("/")
return f"{api_base}/responses"
def transform_streaming_response(
self,
model: str,
parsed_chunk: dict,
logging_obj: LiteLLMLoggingObj,
) -> ResponsesAPIStreamingResponse:
"""
Transform a parsed streaming response chunk into a ResponsesAPIStreamingResponse
"""
# Convert the dictionary to a properly typed ResponsesAPIStreamingResponse
verbose_logger.debug("Raw OpenAI Chunk=%s", parsed_chunk)
event_type = str(parsed_chunk.get("type"))
event_pydantic_model = OpenAIResponsesAPIConfig.get_event_model_class(
event_type=event_type
)
return event_pydantic_model(**parsed_chunk)
@staticmethod
def get_event_model_class(event_type: str) -> Any:
"""
Returns the appropriate event model class based on the event type.
Args:
event_type (str): The type of event from the response chunk
Returns:
Any: The corresponding event model class
Raises:
ValueError: If the event type is unknown
"""
event_models = {
ResponsesAPIStreamEvents.RESPONSE_CREATED: ResponseCreatedEvent,
ResponsesAPIStreamEvents.RESPONSE_IN_PROGRESS: ResponseInProgressEvent,
ResponsesAPIStreamEvents.RESPONSE_COMPLETED: ResponseCompletedEvent,
ResponsesAPIStreamEvents.RESPONSE_FAILED: ResponseFailedEvent,
ResponsesAPIStreamEvents.RESPONSE_INCOMPLETE: ResponseIncompleteEvent,
ResponsesAPIStreamEvents.OUTPUT_ITEM_ADDED: OutputItemAddedEvent,
ResponsesAPIStreamEvents.OUTPUT_ITEM_DONE: OutputItemDoneEvent,
ResponsesAPIStreamEvents.CONTENT_PART_ADDED: ContentPartAddedEvent,
ResponsesAPIStreamEvents.CONTENT_PART_DONE: ContentPartDoneEvent,
ResponsesAPIStreamEvents.OUTPUT_TEXT_DELTA: OutputTextDeltaEvent,
ResponsesAPIStreamEvents.OUTPUT_TEXT_ANNOTATION_ADDED: OutputTextAnnotationAddedEvent,
ResponsesAPIStreamEvents.OUTPUT_TEXT_DONE: OutputTextDoneEvent,
ResponsesAPIStreamEvents.REFUSAL_DELTA: RefusalDeltaEvent,
ResponsesAPIStreamEvents.REFUSAL_DONE: RefusalDoneEvent,
ResponsesAPIStreamEvents.FUNCTION_CALL_ARGUMENTS_DELTA: FunctionCallArgumentsDeltaEvent,
ResponsesAPIStreamEvents.FUNCTION_CALL_ARGUMENTS_DONE: FunctionCallArgumentsDoneEvent,
ResponsesAPIStreamEvents.FILE_SEARCH_CALL_IN_PROGRESS: FileSearchCallInProgressEvent,
ResponsesAPIStreamEvents.FILE_SEARCH_CALL_SEARCHING: FileSearchCallSearchingEvent,
ResponsesAPIStreamEvents.FILE_SEARCH_CALL_COMPLETED: FileSearchCallCompletedEvent,
ResponsesAPIStreamEvents.WEB_SEARCH_CALL_IN_PROGRESS: WebSearchCallInProgressEvent,
ResponsesAPIStreamEvents.WEB_SEARCH_CALL_SEARCHING: WebSearchCallSearchingEvent,
ResponsesAPIStreamEvents.WEB_SEARCH_CALL_COMPLETED: WebSearchCallCompletedEvent,
ResponsesAPIStreamEvents.ERROR: ErrorEvent,
}
model_class = event_models.get(cast(ResponsesAPIStreamEvents, event_type))
if not model_class:
return GenericEvent
return model_class
def should_fake_stream(
self,
model: Optional[str],
stream: Optional[bool],
custom_llm_provider: Optional[str] = None,
) -> bool:
if stream is not True:
return False
if model is not None:
try:
if (
litellm.utils.supports_native_streaming(
model=model,
custom_llm_provider=custom_llm_provider,
)
is False
):
return True
except Exception as e:
verbose_logger.debug(
f"Error getting model info in OpenAIResponsesAPIConfig: {e}"
)
return False
#########################################################
########## DELETE RESPONSE API TRANSFORMATION ##############
#########################################################
def transform_delete_response_api_request(
self,
response_id: str,
api_base: str,
litellm_params: GenericLiteLLMParams,
headers: dict,
) -> Tuple[str, Dict]:
"""
Transform the delete response API request into a URL and data
OpenAI API expects the following request
- DELETE /v1/responses/{response_id}
"""
url = f"{api_base}/{response_id}"
data: Dict = {}
return url, data
def transform_delete_response_api_response(
self,
raw_response: httpx.Response,
logging_obj: LiteLLMLoggingObj,
) -> DeleteResponseResult:
"""
Transform the delete response API response into a DeleteResponseResult
"""
try:
raw_response_json = raw_response.json()
except Exception:
raise OpenAIError(
message=raw_response.text, status_code=raw_response.status_code
)
return DeleteResponseResult(**raw_response_json)

View File

@@ -0,0 +1,34 @@
from typing import List
from litellm.types.llms.openai import OpenAIAudioTranscriptionOptionalParams
from litellm.types.utils import FileTypes
from .whisper_transformation import OpenAIWhisperAudioTranscriptionConfig
class OpenAIGPTAudioTranscriptionConfig(OpenAIWhisperAudioTranscriptionConfig):
def get_supported_openai_params(
self, model: str
) -> List[OpenAIAudioTranscriptionOptionalParams]:
"""
Get the supported OpenAI params for the `gpt-4o-transcribe` models
"""
return [
"language",
"prompt",
"response_format",
"temperature",
"include",
]
def transform_audio_transcription_request(
self,
model: str,
audio_file: FileTypes,
optional_params: dict,
litellm_params: dict,
) -> dict:
"""
Transform the audio transcription request
"""
return {"model": model, "file": audio_file, **optional_params}

View File

@@ -0,0 +1,222 @@
from typing import Optional, Union
import httpx
from openai import AsyncOpenAI, OpenAI
from pydantic import BaseModel
import litellm
from litellm.litellm_core_utils.audio_utils.utils import get_audio_file_name
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
from litellm.llms.base_llm.audio_transcription.transformation import (
BaseAudioTranscriptionConfig,
)
from litellm.types.utils import FileTypes
from litellm.utils import (
TranscriptionResponse,
convert_to_model_response_object,
extract_duration_from_srt_or_vtt,
)
from ..openai import OpenAIChatCompletion
class OpenAIAudioTranscription(OpenAIChatCompletion):
# Audio Transcriptions
async def make_openai_audio_transcriptions_request(
self,
openai_aclient: AsyncOpenAI,
data: dict,
timeout: Union[float, httpx.Timeout],
):
"""
Helper to:
- call openai_aclient.audio.transcriptions.with_raw_response when litellm.return_response_headers is True
- call openai_aclient.audio.transcriptions.create by default
"""
try:
raw_response = (
await openai_aclient.audio.transcriptions.with_raw_response.create(
**data, timeout=timeout
)
) # type: ignore
headers = dict(raw_response.headers)
response = raw_response.parse()
return headers, response
except Exception as e:
raise e
def make_sync_openai_audio_transcriptions_request(
self,
openai_client: OpenAI,
data: dict,
timeout: Union[float, httpx.Timeout],
):
"""
Helper to:
- call openai_aclient.audio.transcriptions.with_raw_response when litellm.return_response_headers is True
- call openai_aclient.audio.transcriptions.create by default
"""
try:
if litellm.return_response_headers is True:
raw_response = (
openai_client.audio.transcriptions.with_raw_response.create(
**data, timeout=timeout
)
) # type: ignore
headers = dict(raw_response.headers)
response = raw_response.parse()
return headers, response
else:
response = openai_client.audio.transcriptions.create(**data, timeout=timeout) # type: ignore
return None, response
except Exception as e:
raise e
def audio_transcriptions(
self,
model: str,
audio_file: FileTypes,
optional_params: dict,
litellm_params: dict,
model_response: TranscriptionResponse,
timeout: float,
max_retries: int,
logging_obj: LiteLLMLoggingObj,
api_key: Optional[str],
api_base: Optional[str],
client=None,
atranscription: bool = False,
provider_config: Optional[BaseAudioTranscriptionConfig] = None,
) -> TranscriptionResponse:
"""
Handle audio transcription request
"""
if provider_config is not None:
data = provider_config.transform_audio_transcription_request(
model=model,
audio_file=audio_file,
optional_params=optional_params,
litellm_params=litellm_params,
)
if isinstance(data, bytes):
raise ValueError("OpenAI transformation route requires a dict")
else:
data = {"model": model, "file": audio_file, **optional_params}
if atranscription is True:
return self.async_audio_transcriptions( # type: ignore
audio_file=audio_file,
data=data,
model_response=model_response,
timeout=timeout,
api_key=api_key,
api_base=api_base,
client=client,
max_retries=max_retries,
logging_obj=logging_obj,
)
openai_client: OpenAI = self._get_openai_client( # type: ignore
is_async=False,
api_key=api_key,
api_base=api_base,
timeout=timeout,
max_retries=max_retries,
client=client,
)
## LOGGING
logging_obj.pre_call(
input=None,
api_key=openai_client.api_key,
additional_args={
"api_base": openai_client._base_url._uri_reference,
"atranscription": True,
"complete_input_dict": data,
},
)
_, response = self.make_sync_openai_audio_transcriptions_request(
openai_client=openai_client,
data=data,
timeout=timeout,
)
if isinstance(response, BaseModel):
stringified_response = response.model_dump()
else:
stringified_response = TranscriptionResponse(text=response).model_dump()
## LOGGING
logging_obj.post_call(
input=get_audio_file_name(audio_file),
api_key=api_key,
additional_args={"complete_input_dict": data},
original_response=stringified_response,
)
hidden_params = {"model": "whisper-1", "custom_llm_provider": "openai"}
final_response: TranscriptionResponse = convert_to_model_response_object(response_object=stringified_response, model_response_object=model_response, hidden_params=hidden_params, response_type="audio_transcription") # type: ignore
return final_response
async def async_audio_transcriptions(
self,
audio_file: FileTypes,
data: dict,
model_response: TranscriptionResponse,
timeout: float,
logging_obj: LiteLLMLoggingObj,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
client=None,
max_retries=None,
):
try:
openai_aclient: AsyncOpenAI = self._get_openai_client( # type: ignore
is_async=True,
api_key=api_key,
api_base=api_base,
timeout=timeout,
max_retries=max_retries,
client=client,
)
## LOGGING
logging_obj.pre_call(
input=None,
api_key=openai_aclient.api_key,
additional_args={
"api_base": openai_aclient._base_url._uri_reference,
"atranscription": True,
"complete_input_dict": data,
},
)
headers, response = await self.make_openai_audio_transcriptions_request(
openai_aclient=openai_aclient,
data=data,
timeout=timeout,
)
logging_obj.model_call_details["response_headers"] = headers
if isinstance(response, BaseModel):
stringified_response = response.model_dump()
else:
duration = extract_duration_from_srt_or_vtt(response)
stringified_response = TranscriptionResponse(text=response).model_dump()
stringified_response["duration"] = duration
## LOGGING
logging_obj.post_call(
input=get_audio_file_name(audio_file),
api_key=api_key,
additional_args={"complete_input_dict": data},
original_response=stringified_response,
)
hidden_params = {"model": "whisper-1", "custom_llm_provider": "openai"}
return convert_to_model_response_object(response_object=stringified_response, model_response_object=model_response, hidden_params=hidden_params, response_type="audio_transcription") # type: ignore
except Exception as e:
## LOGGING
logging_obj.post_call(
input=input,
api_key=api_key,
original_response=str(e),
)
raise e

View File

@@ -0,0 +1,98 @@
from typing import List, Optional, Union
from httpx import Headers
from litellm.llms.base_llm.audio_transcription.transformation import (
BaseAudioTranscriptionConfig,
)
from litellm.llms.base_llm.chat.transformation import BaseLLMException
from litellm.secret_managers.main import get_secret_str
from litellm.types.llms.openai import (
AllMessageValues,
OpenAIAudioTranscriptionOptionalParams,
)
from litellm.types.utils import FileTypes
from ..common_utils import OpenAIError
class OpenAIWhisperAudioTranscriptionConfig(BaseAudioTranscriptionConfig):
def get_supported_openai_params(
self, model: str
) -> List[OpenAIAudioTranscriptionOptionalParams]:
"""
Get the supported OpenAI params for the `whisper-1` models
"""
return [
"language",
"prompt",
"response_format",
"temperature",
"timestamp_granularities",
]
def map_openai_params(
self,
non_default_params: dict,
optional_params: dict,
model: str,
drop_params: bool,
) -> dict:
"""
Map the OpenAI params to the Whisper params
"""
supported_params = self.get_supported_openai_params(model)
for k, v in non_default_params.items():
if k in supported_params:
optional_params[k] = v
return optional_params
def validate_environment(
self,
headers: dict,
model: str,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
) -> dict:
api_key = api_key or get_secret_str("OPENAI_API_KEY")
auth_header = {
"Authorization": f"Bearer {api_key}",
}
headers.update(auth_header)
return headers
def transform_audio_transcription_request(
self,
model: str,
audio_file: FileTypes,
optional_params: dict,
litellm_params: dict,
) -> dict:
"""
Transform the audio transcription request
"""
data = {"model": model, "file": audio_file, **optional_params}
if "response_format" not in data or (
data["response_format"] == "text" or data["response_format"] == "json"
):
data[
"response_format"
] = "verbose_json" # ensures 'duration' is received - used for cost calculation
return data
def get_error_class(
self, error_message: str, status_code: int, headers: Union[dict, Headers]
) -> BaseLLMException:
return OpenAIError(
status_code=status_code,
message=error_message,
headers=headers,
)