structure saas with tools
This commit is contained in:
12
.venv/lib/python3.10/site-packages/litellm/llms/README.md
Normal file
12
.venv/lib/python3.10/site-packages/litellm/llms/README.md
Normal file
@@ -0,0 +1,12 @@
|
||||
## File Structure
|
||||
|
||||
### August 27th, 2024
|
||||
|
||||
To make it easy to see how calls are transformed for each model/provider:
|
||||
|
||||
we are working on moving all supported litellm providers to a folder structure, where folder name is the supported litellm provider name.
|
||||
|
||||
Each folder will contain a `*_transformation.py` file, which has all the request/response transformation logic, making it easy to see how calls are modified.
|
||||
|
||||
E.g. `cohere/`, `bedrock/`.
|
||||
|
||||
@@ -0,0 +1 @@
|
||||
from . import *
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,70 @@
|
||||
"""
|
||||
AI21 Chat Completions API
|
||||
|
||||
this is OpenAI compatible - no translation needed / occurs
|
||||
"""
|
||||
|
||||
from typing import Optional, Union
|
||||
|
||||
from ...openai_like.chat.transformation import OpenAILikeChatConfig
|
||||
|
||||
|
||||
class AI21ChatConfig(OpenAILikeChatConfig):
|
||||
"""
|
||||
Reference: https://docs.ai21.com/reference/jamba-15-api-ref#request-parameters
|
||||
|
||||
Below are the parameters:
|
||||
"""
|
||||
|
||||
tools: Optional[list] = None
|
||||
response_format: Optional[dict] = None
|
||||
documents: Optional[list] = None
|
||||
max_tokens: Optional[int] = None
|
||||
stop: Optional[Union[str, list]] = None
|
||||
n: Optional[int] = None
|
||||
stream: Optional[bool] = None
|
||||
seed: Optional[int] = None
|
||||
tool_choice: Optional[str] = None
|
||||
user: Optional[str] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tools: Optional[list] = None,
|
||||
response_format: Optional[dict] = None,
|
||||
max_tokens: Optional[int] = None,
|
||||
temperature: Optional[float] = None,
|
||||
top_p: Optional[float] = None,
|
||||
stop: Optional[Union[str, list]] = None,
|
||||
n: Optional[int] = None,
|
||||
stream: Optional[bool] = None,
|
||||
seed: Optional[int] = None,
|
||||
tool_choice: Optional[str] = None,
|
||||
user: Optional[str] = None,
|
||||
) -> None:
|
||||
locals_ = locals().copy()
|
||||
for key, value in locals_.items():
|
||||
if key != "self" and value is not None:
|
||||
setattr(self.__class__, key, value)
|
||||
|
||||
@classmethod
|
||||
def get_config(cls):
|
||||
return super().get_config()
|
||||
|
||||
def get_supported_openai_params(self, model: str) -> list:
|
||||
"""
|
||||
Get the supported OpenAI params for the given model
|
||||
|
||||
"""
|
||||
|
||||
return [
|
||||
"tools",
|
||||
"response_format",
|
||||
"max_tokens",
|
||||
"max_completion_tokens",
|
||||
"temperature",
|
||||
"stop",
|
||||
"n",
|
||||
"stream",
|
||||
"seed",
|
||||
"tool_choice",
|
||||
]
|
||||
Binary file not shown.
@@ -0,0 +1,82 @@
|
||||
"""
|
||||
*New config* for using aiohttp to make the request to the custom OpenAI-like provider
|
||||
|
||||
This leads to 10x higher RPS than httpx
|
||||
https://github.com/BerriAI/litellm/issues/6592
|
||||
|
||||
New config to ensure we introduce this without causing breaking changes for users
|
||||
"""
|
||||
|
||||
from typing import TYPE_CHECKING, Any, List, Optional
|
||||
|
||||
from aiohttp import ClientResponse
|
||||
|
||||
from litellm.llms.openai_like.chat.transformation import OpenAILikeChatConfig
|
||||
from litellm.types.llms.openai import AllMessageValues
|
||||
from litellm.types.utils import Choices, ModelResponse
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
|
||||
|
||||
LiteLLMLoggingObj = _LiteLLMLoggingObj
|
||||
else:
|
||||
LiteLLMLoggingObj = Any
|
||||
|
||||
|
||||
class AiohttpOpenAIChatConfig(OpenAILikeChatConfig):
|
||||
def get_complete_url(
|
||||
self,
|
||||
api_base: Optional[str],
|
||||
api_key: Optional[str],
|
||||
model: str,
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
stream: Optional[bool] = None,
|
||||
) -> str:
|
||||
"""
|
||||
Ensure - /v1/chat/completions is at the end of the url
|
||||
|
||||
"""
|
||||
if api_base is None:
|
||||
api_base = "https://api.openai.com"
|
||||
|
||||
if not api_base.endswith("/chat/completions"):
|
||||
api_base += "/chat/completions"
|
||||
return api_base
|
||||
|
||||
def validate_environment(
|
||||
self,
|
||||
headers: dict,
|
||||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
) -> dict:
|
||||
return {"Authorization": f"Bearer {api_key}"}
|
||||
|
||||
async def transform_response( # type: ignore
|
||||
self,
|
||||
model: str,
|
||||
raw_response: ClientResponse,
|
||||
model_response: ModelResponse,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
request_data: dict,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
encoding: Any,
|
||||
api_key: Optional[str] = None,
|
||||
json_mode: Optional[bool] = None,
|
||||
) -> ModelResponse:
|
||||
_json_response = await raw_response.json()
|
||||
model_response.id = _json_response.get("id")
|
||||
model_response.choices = [
|
||||
Choices(**choice) for choice in _json_response.get("choices")
|
||||
]
|
||||
model_response.created = _json_response.get("created")
|
||||
model_response.model = _json_response.get("model")
|
||||
model_response.object = _json_response.get("object")
|
||||
model_response.system_fingerprint = _json_response.get("system_fingerprint")
|
||||
return model_response
|
||||
Binary file not shown.
Binary file not shown.
@@ -0,0 +1 @@
|
||||
from .handler import AnthropicChatCompletion, ModelResponseIterator
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,846 @@
|
||||
"""
|
||||
Calling + translation logic for anthropic's `/v1/messages` endpoint
|
||||
"""
|
||||
|
||||
import copy
|
||||
import json
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union, cast
|
||||
|
||||
import httpx # type: ignore
|
||||
|
||||
import litellm
|
||||
import litellm.litellm_core_utils
|
||||
import litellm.types
|
||||
import litellm.types.utils
|
||||
from litellm import LlmProviders
|
||||
from litellm.litellm_core_utils.core_helpers import map_finish_reason
|
||||
from litellm.llms.base_llm.chat.transformation import BaseConfig
|
||||
from litellm.llms.custom_httpx.http_handler import (
|
||||
AsyncHTTPHandler,
|
||||
HTTPHandler,
|
||||
get_async_httpx_client,
|
||||
)
|
||||
from litellm.types.llms.anthropic import (
|
||||
ContentBlockDelta,
|
||||
ContentBlockStart,
|
||||
ContentBlockStop,
|
||||
MessageBlockDelta,
|
||||
MessageStartBlock,
|
||||
UsageDelta,
|
||||
)
|
||||
from litellm.types.llms.openai import (
|
||||
ChatCompletionRedactedThinkingBlock,
|
||||
ChatCompletionThinkingBlock,
|
||||
ChatCompletionToolCallChunk,
|
||||
)
|
||||
from litellm.types.utils import (
|
||||
Delta,
|
||||
GenericStreamingChunk,
|
||||
ModelResponseStream,
|
||||
StreamingChoices,
|
||||
Usage,
|
||||
)
|
||||
from litellm.utils import CustomStreamWrapper, ModelResponse, ProviderConfigManager
|
||||
|
||||
from ...base import BaseLLM
|
||||
from ..common_utils import AnthropicError, process_anthropic_headers
|
||||
from .transformation import AnthropicConfig
|
||||
|
||||
|
||||
async def make_call(
|
||||
client: Optional[AsyncHTTPHandler],
|
||||
api_base: str,
|
||||
headers: dict,
|
||||
data: str,
|
||||
model: str,
|
||||
messages: list,
|
||||
logging_obj,
|
||||
timeout: Optional[Union[float, httpx.Timeout]],
|
||||
json_mode: bool,
|
||||
) -> Tuple[Any, httpx.Headers]:
|
||||
if client is None:
|
||||
client = litellm.module_level_aclient
|
||||
|
||||
try:
|
||||
response = await client.post(
|
||||
api_base, headers=headers, data=data, stream=True, timeout=timeout
|
||||
)
|
||||
except httpx.HTTPStatusError as e:
|
||||
error_headers = getattr(e, "headers", None)
|
||||
error_response = getattr(e, "response", None)
|
||||
if error_headers is None and error_response:
|
||||
error_headers = getattr(error_response, "headers", None)
|
||||
raise AnthropicError(
|
||||
status_code=e.response.status_code,
|
||||
message=await e.response.aread(),
|
||||
headers=error_headers,
|
||||
)
|
||||
except Exception as e:
|
||||
for exception in litellm.LITELLM_EXCEPTION_TYPES:
|
||||
if isinstance(e, exception):
|
||||
raise e
|
||||
raise AnthropicError(status_code=500, message=str(e))
|
||||
|
||||
completion_stream = ModelResponseIterator(
|
||||
streaming_response=response.aiter_lines(),
|
||||
sync_stream=False,
|
||||
json_mode=json_mode,
|
||||
)
|
||||
|
||||
# LOGGING
|
||||
logging_obj.post_call(
|
||||
input=messages,
|
||||
api_key="",
|
||||
original_response=completion_stream, # Pass the completion stream for logging
|
||||
additional_args={"complete_input_dict": data},
|
||||
)
|
||||
|
||||
return completion_stream, response.headers
|
||||
|
||||
|
||||
def make_sync_call(
|
||||
client: Optional[HTTPHandler],
|
||||
api_base: str,
|
||||
headers: dict,
|
||||
data: str,
|
||||
model: str,
|
||||
messages: list,
|
||||
logging_obj,
|
||||
timeout: Optional[Union[float, httpx.Timeout]],
|
||||
json_mode: bool,
|
||||
) -> Tuple[Any, httpx.Headers]:
|
||||
if client is None:
|
||||
client = litellm.module_level_client # re-use a module level client
|
||||
|
||||
try:
|
||||
response = client.post(
|
||||
api_base, headers=headers, data=data, stream=True, timeout=timeout
|
||||
)
|
||||
except httpx.HTTPStatusError as e:
|
||||
error_headers = getattr(e, "headers", None)
|
||||
error_response = getattr(e, "response", None)
|
||||
if error_headers is None and error_response:
|
||||
error_headers = getattr(error_response, "headers", None)
|
||||
raise AnthropicError(
|
||||
status_code=e.response.status_code,
|
||||
message=e.response.read(),
|
||||
headers=error_headers,
|
||||
)
|
||||
except Exception as e:
|
||||
for exception in litellm.LITELLM_EXCEPTION_TYPES:
|
||||
if isinstance(e, exception):
|
||||
raise e
|
||||
raise AnthropicError(status_code=500, message=str(e))
|
||||
|
||||
if response.status_code != 200:
|
||||
response_headers = getattr(response, "headers", None)
|
||||
raise AnthropicError(
|
||||
status_code=response.status_code,
|
||||
message=response.read(),
|
||||
headers=response_headers,
|
||||
)
|
||||
|
||||
completion_stream = ModelResponseIterator(
|
||||
streaming_response=response.iter_lines(), sync_stream=True, json_mode=json_mode
|
||||
)
|
||||
|
||||
# LOGGING
|
||||
logging_obj.post_call(
|
||||
input=messages,
|
||||
api_key="",
|
||||
original_response="first stream response received",
|
||||
additional_args={"complete_input_dict": data},
|
||||
)
|
||||
|
||||
return completion_stream, response.headers
|
||||
|
||||
|
||||
class AnthropicChatCompletion(BaseLLM):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
async def acompletion_stream_function(
|
||||
self,
|
||||
model: str,
|
||||
messages: list,
|
||||
api_base: str,
|
||||
custom_prompt_dict: dict,
|
||||
model_response: ModelResponse,
|
||||
print_verbose: Callable,
|
||||
timeout: Union[float, httpx.Timeout],
|
||||
client: Optional[AsyncHTTPHandler],
|
||||
encoding,
|
||||
api_key,
|
||||
logging_obj,
|
||||
stream,
|
||||
_is_function_call,
|
||||
data: dict,
|
||||
json_mode: bool,
|
||||
optional_params=None,
|
||||
litellm_params=None,
|
||||
logger_fn=None,
|
||||
headers={},
|
||||
):
|
||||
data["stream"] = True
|
||||
|
||||
completion_stream, headers = await make_call(
|
||||
client=client,
|
||||
api_base=api_base,
|
||||
headers=headers,
|
||||
data=json.dumps(data),
|
||||
model=model,
|
||||
messages=messages,
|
||||
logging_obj=logging_obj,
|
||||
timeout=timeout,
|
||||
json_mode=json_mode,
|
||||
)
|
||||
streamwrapper = CustomStreamWrapper(
|
||||
completion_stream=completion_stream,
|
||||
model=model,
|
||||
custom_llm_provider="anthropic",
|
||||
logging_obj=logging_obj,
|
||||
_response_headers=process_anthropic_headers(headers),
|
||||
)
|
||||
return streamwrapper
|
||||
|
||||
async def acompletion_function(
|
||||
self,
|
||||
model: str,
|
||||
messages: list,
|
||||
api_base: str,
|
||||
custom_prompt_dict: dict,
|
||||
model_response: ModelResponse,
|
||||
print_verbose: Callable,
|
||||
timeout: Union[float, httpx.Timeout],
|
||||
encoding,
|
||||
api_key,
|
||||
logging_obj,
|
||||
stream,
|
||||
_is_function_call,
|
||||
data: dict,
|
||||
optional_params: dict,
|
||||
json_mode: bool,
|
||||
litellm_params: dict,
|
||||
provider_config: BaseConfig,
|
||||
logger_fn=None,
|
||||
headers={},
|
||||
client: Optional[AsyncHTTPHandler] = None,
|
||||
) -> Union[ModelResponse, CustomStreamWrapper]:
|
||||
async_handler = client or get_async_httpx_client(
|
||||
llm_provider=litellm.LlmProviders.ANTHROPIC
|
||||
)
|
||||
|
||||
try:
|
||||
response = await async_handler.post(
|
||||
api_base, headers=headers, json=data, timeout=timeout
|
||||
)
|
||||
except Exception as e:
|
||||
## LOGGING
|
||||
logging_obj.post_call(
|
||||
input=messages,
|
||||
api_key=api_key,
|
||||
original_response=str(e),
|
||||
additional_args={"complete_input_dict": data},
|
||||
)
|
||||
status_code = getattr(e, "status_code", 500)
|
||||
error_headers = getattr(e, "headers", None)
|
||||
error_text = getattr(e, "text", str(e))
|
||||
error_response = getattr(e, "response", None)
|
||||
if error_headers is None and error_response:
|
||||
error_headers = getattr(error_response, "headers", None)
|
||||
if error_response and hasattr(error_response, "text"):
|
||||
error_text = getattr(error_response, "text", error_text)
|
||||
raise AnthropicError(
|
||||
message=error_text,
|
||||
status_code=status_code,
|
||||
headers=error_headers,
|
||||
)
|
||||
|
||||
return provider_config.transform_response(
|
||||
model=model,
|
||||
raw_response=response,
|
||||
model_response=model_response,
|
||||
logging_obj=logging_obj,
|
||||
api_key=api_key,
|
||||
request_data=data,
|
||||
messages=messages,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
encoding=encoding,
|
||||
json_mode=json_mode,
|
||||
)
|
||||
|
||||
def completion(
|
||||
self,
|
||||
model: str,
|
||||
messages: list,
|
||||
api_base: str,
|
||||
custom_llm_provider: str,
|
||||
custom_prompt_dict: dict,
|
||||
model_response: ModelResponse,
|
||||
print_verbose: Callable,
|
||||
encoding,
|
||||
api_key,
|
||||
logging_obj,
|
||||
optional_params: dict,
|
||||
timeout: Union[float, httpx.Timeout],
|
||||
litellm_params: dict,
|
||||
acompletion=None,
|
||||
logger_fn=None,
|
||||
headers={},
|
||||
client=None,
|
||||
):
|
||||
optional_params = copy.deepcopy(optional_params)
|
||||
stream = optional_params.pop("stream", None)
|
||||
json_mode: bool = optional_params.pop("json_mode", False)
|
||||
is_vertex_request: bool = optional_params.pop("is_vertex_request", False)
|
||||
_is_function_call = False
|
||||
messages = copy.deepcopy(messages)
|
||||
headers = AnthropicConfig().validate_environment(
|
||||
api_key=api_key,
|
||||
headers=headers,
|
||||
model=model,
|
||||
messages=messages,
|
||||
optional_params={**optional_params, "is_vertex_request": is_vertex_request},
|
||||
litellm_params=litellm_params,
|
||||
)
|
||||
|
||||
config = ProviderConfigManager.get_provider_chat_config(
|
||||
model=model,
|
||||
provider=LlmProviders(custom_llm_provider),
|
||||
)
|
||||
if config is None:
|
||||
raise ValueError(
|
||||
f"Provider config not found for model: {model} and provider: {custom_llm_provider}"
|
||||
)
|
||||
|
||||
data = config.transform_request(
|
||||
model=model,
|
||||
messages=messages,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
## LOGGING
|
||||
logging_obj.pre_call(
|
||||
input=messages,
|
||||
api_key=api_key,
|
||||
additional_args={
|
||||
"complete_input_dict": data,
|
||||
"api_base": api_base,
|
||||
"headers": headers,
|
||||
},
|
||||
)
|
||||
print_verbose(f"_is_function_call: {_is_function_call}")
|
||||
if acompletion is True:
|
||||
if (
|
||||
stream is True
|
||||
): # if function call - fake the streaming (need complete blocks for output parsing in openai format)
|
||||
print_verbose("makes async anthropic streaming POST request")
|
||||
data["stream"] = stream
|
||||
return self.acompletion_stream_function(
|
||||
model=model,
|
||||
messages=messages,
|
||||
data=data,
|
||||
api_base=api_base,
|
||||
custom_prompt_dict=custom_prompt_dict,
|
||||
model_response=model_response,
|
||||
print_verbose=print_verbose,
|
||||
encoding=encoding,
|
||||
api_key=api_key,
|
||||
logging_obj=logging_obj,
|
||||
optional_params=optional_params,
|
||||
stream=stream,
|
||||
_is_function_call=_is_function_call,
|
||||
json_mode=json_mode,
|
||||
litellm_params=litellm_params,
|
||||
logger_fn=logger_fn,
|
||||
headers=headers,
|
||||
timeout=timeout,
|
||||
client=(
|
||||
client
|
||||
if client is not None and isinstance(client, AsyncHTTPHandler)
|
||||
else None
|
||||
),
|
||||
)
|
||||
else:
|
||||
return self.acompletion_function(
|
||||
model=model,
|
||||
messages=messages,
|
||||
data=data,
|
||||
api_base=api_base,
|
||||
custom_prompt_dict=custom_prompt_dict,
|
||||
model_response=model_response,
|
||||
print_verbose=print_verbose,
|
||||
encoding=encoding,
|
||||
api_key=api_key,
|
||||
provider_config=config,
|
||||
logging_obj=logging_obj,
|
||||
optional_params=optional_params,
|
||||
stream=stream,
|
||||
_is_function_call=_is_function_call,
|
||||
litellm_params=litellm_params,
|
||||
logger_fn=logger_fn,
|
||||
headers=headers,
|
||||
client=client,
|
||||
json_mode=json_mode,
|
||||
timeout=timeout,
|
||||
)
|
||||
else:
|
||||
## COMPLETION CALL
|
||||
if (
|
||||
stream is True
|
||||
): # if function call - fake the streaming (need complete blocks for output parsing in openai format)
|
||||
data["stream"] = stream
|
||||
completion_stream, headers = make_sync_call(
|
||||
client=client,
|
||||
api_base=api_base,
|
||||
headers=headers, # type: ignore
|
||||
data=json.dumps(data),
|
||||
model=model,
|
||||
messages=messages,
|
||||
logging_obj=logging_obj,
|
||||
timeout=timeout,
|
||||
json_mode=json_mode,
|
||||
)
|
||||
return CustomStreamWrapper(
|
||||
completion_stream=completion_stream,
|
||||
model=model,
|
||||
custom_llm_provider="anthropic",
|
||||
logging_obj=logging_obj,
|
||||
_response_headers=process_anthropic_headers(headers),
|
||||
)
|
||||
|
||||
else:
|
||||
if client is None or not isinstance(client, HTTPHandler):
|
||||
client = HTTPHandler(timeout=timeout) # type: ignore
|
||||
else:
|
||||
client = client
|
||||
|
||||
try:
|
||||
response = client.post(
|
||||
api_base,
|
||||
headers=headers,
|
||||
data=json.dumps(data),
|
||||
timeout=timeout,
|
||||
)
|
||||
except Exception as e:
|
||||
status_code = getattr(e, "status_code", 500)
|
||||
error_headers = getattr(e, "headers", None)
|
||||
error_text = getattr(e, "text", str(e))
|
||||
error_response = getattr(e, "response", None)
|
||||
if error_headers is None and error_response:
|
||||
error_headers = getattr(error_response, "headers", None)
|
||||
if error_response and hasattr(error_response, "text"):
|
||||
error_text = getattr(error_response, "text", error_text)
|
||||
raise AnthropicError(
|
||||
message=error_text,
|
||||
status_code=status_code,
|
||||
headers=error_headers,
|
||||
)
|
||||
|
||||
return config.transform_response(
|
||||
model=model,
|
||||
raw_response=response,
|
||||
model_response=model_response,
|
||||
logging_obj=logging_obj,
|
||||
api_key=api_key,
|
||||
request_data=data,
|
||||
messages=messages,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
encoding=encoding,
|
||||
json_mode=json_mode,
|
||||
)
|
||||
|
||||
def embedding(self):
|
||||
# logic for parsing in - calling - parsing out model embedding calls
|
||||
pass
|
||||
|
||||
|
||||
class ModelResponseIterator:
|
||||
def __init__(
|
||||
self, streaming_response, sync_stream: bool, json_mode: Optional[bool] = False
|
||||
):
|
||||
self.streaming_response = streaming_response
|
||||
self.response_iterator = self.streaming_response
|
||||
self.content_blocks: List[ContentBlockDelta] = []
|
||||
self.tool_index = -1
|
||||
self.json_mode = json_mode
|
||||
|
||||
def check_empty_tool_call_args(self) -> bool:
|
||||
"""
|
||||
Check if the tool call block so far has been an empty string
|
||||
"""
|
||||
args = ""
|
||||
# if text content block -> skip
|
||||
if len(self.content_blocks) == 0:
|
||||
return False
|
||||
|
||||
if (
|
||||
self.content_blocks[0]["delta"]["type"] == "text_delta"
|
||||
or self.content_blocks[0]["delta"]["type"] == "thinking_delta"
|
||||
):
|
||||
return False
|
||||
|
||||
for block in self.content_blocks:
|
||||
if block["delta"]["type"] == "input_json_delta":
|
||||
args += block["delta"].get("partial_json", "") # type: ignore
|
||||
|
||||
if len(args) == 0:
|
||||
return True
|
||||
return False
|
||||
|
||||
def _handle_usage(self, anthropic_usage_chunk: Union[dict, UsageDelta]) -> Usage:
|
||||
return AnthropicConfig().calculate_usage(
|
||||
usage_object=cast(dict, anthropic_usage_chunk), reasoning_content=None
|
||||
)
|
||||
|
||||
def _content_block_delta_helper(
|
||||
self, chunk: dict
|
||||
) -> Tuple[
|
||||
str,
|
||||
Optional[ChatCompletionToolCallChunk],
|
||||
List[Union[ChatCompletionThinkingBlock, ChatCompletionRedactedThinkingBlock]],
|
||||
Dict[str, Any],
|
||||
]:
|
||||
"""
|
||||
Helper function to handle the content block delta
|
||||
"""
|
||||
text = ""
|
||||
tool_use: Optional[ChatCompletionToolCallChunk] = None
|
||||
provider_specific_fields = {}
|
||||
content_block = ContentBlockDelta(**chunk) # type: ignore
|
||||
thinking_blocks: List[
|
||||
Union[ChatCompletionThinkingBlock, ChatCompletionRedactedThinkingBlock]
|
||||
] = []
|
||||
|
||||
self.content_blocks.append(content_block)
|
||||
if "text" in content_block["delta"]:
|
||||
text = content_block["delta"]["text"]
|
||||
elif "partial_json" in content_block["delta"]:
|
||||
tool_use = {
|
||||
"id": None,
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": None,
|
||||
"arguments": content_block["delta"]["partial_json"],
|
||||
},
|
||||
"index": self.tool_index,
|
||||
}
|
||||
elif "citation" in content_block["delta"]:
|
||||
provider_specific_fields["citation"] = content_block["delta"]["citation"]
|
||||
elif (
|
||||
"thinking" in content_block["delta"]
|
||||
or "signature" in content_block["delta"]
|
||||
):
|
||||
thinking_blocks = [
|
||||
ChatCompletionThinkingBlock(
|
||||
type="thinking",
|
||||
thinking=content_block["delta"].get("thinking") or "",
|
||||
signature=content_block["delta"].get("signature"),
|
||||
)
|
||||
]
|
||||
provider_specific_fields["thinking_blocks"] = thinking_blocks
|
||||
|
||||
return text, tool_use, thinking_blocks, provider_specific_fields
|
||||
|
||||
def _handle_reasoning_content(
|
||||
self,
|
||||
thinking_blocks: List[
|
||||
Union[ChatCompletionThinkingBlock, ChatCompletionRedactedThinkingBlock]
|
||||
],
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
Handle the reasoning content
|
||||
"""
|
||||
reasoning_content = None
|
||||
for block in thinking_blocks:
|
||||
thinking_content = cast(Optional[str], block.get("thinking"))
|
||||
if reasoning_content is None:
|
||||
reasoning_content = ""
|
||||
if thinking_content is not None:
|
||||
reasoning_content += thinking_content
|
||||
return reasoning_content
|
||||
|
||||
def chunk_parser(self, chunk: dict) -> ModelResponseStream:
|
||||
try:
|
||||
type_chunk = chunk.get("type", "") or ""
|
||||
|
||||
text = ""
|
||||
tool_use: Optional[ChatCompletionToolCallChunk] = None
|
||||
finish_reason = ""
|
||||
usage: Optional[Usage] = None
|
||||
provider_specific_fields: Dict[str, Any] = {}
|
||||
reasoning_content: Optional[str] = None
|
||||
thinking_blocks: Optional[
|
||||
List[
|
||||
Union[
|
||||
ChatCompletionThinkingBlock, ChatCompletionRedactedThinkingBlock
|
||||
]
|
||||
]
|
||||
] = None
|
||||
|
||||
index = int(chunk.get("index", 0))
|
||||
if type_chunk == "content_block_delta":
|
||||
"""
|
||||
Anthropic content chunk
|
||||
chunk = {'type': 'content_block_delta', 'index': 0, 'delta': {'type': 'text_delta', 'text': 'Hello'}}
|
||||
"""
|
||||
(
|
||||
text,
|
||||
tool_use,
|
||||
thinking_blocks,
|
||||
provider_specific_fields,
|
||||
) = self._content_block_delta_helper(chunk=chunk)
|
||||
if thinking_blocks:
|
||||
reasoning_content = self._handle_reasoning_content(
|
||||
thinking_blocks=thinking_blocks
|
||||
)
|
||||
elif type_chunk == "content_block_start":
|
||||
"""
|
||||
event: content_block_start
|
||||
data: {"type":"content_block_start","index":1,"content_block":{"type":"tool_use","id":"toolu_01T1x1fJ34qAmk2tNTrN7Up6","name":"get_weather","input":{}}}
|
||||
"""
|
||||
content_block_start = ContentBlockStart(**chunk) # type: ignore
|
||||
self.content_blocks = [] # reset content blocks when new block starts
|
||||
if content_block_start["content_block"]["type"] == "text":
|
||||
text = content_block_start["content_block"]["text"]
|
||||
elif content_block_start["content_block"]["type"] == "tool_use":
|
||||
self.tool_index += 1
|
||||
tool_use = {
|
||||
"id": content_block_start["content_block"]["id"],
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": content_block_start["content_block"]["name"],
|
||||
"arguments": "",
|
||||
},
|
||||
"index": self.tool_index,
|
||||
}
|
||||
elif (
|
||||
content_block_start["content_block"]["type"] == "redacted_thinking"
|
||||
):
|
||||
thinking_blocks = [
|
||||
ChatCompletionRedactedThinkingBlock(
|
||||
type="redacted_thinking",
|
||||
data=content_block_start["content_block"]["data"],
|
||||
)
|
||||
]
|
||||
elif type_chunk == "content_block_stop":
|
||||
ContentBlockStop(**chunk) # type: ignore
|
||||
# check if tool call content block
|
||||
is_empty = self.check_empty_tool_call_args()
|
||||
|
||||
if is_empty:
|
||||
tool_use = {
|
||||
"id": None,
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": None,
|
||||
"arguments": "{}",
|
||||
},
|
||||
"index": self.tool_index,
|
||||
}
|
||||
elif type_chunk == "message_delta":
|
||||
"""
|
||||
Anthropic
|
||||
chunk = {'type': 'message_delta', 'delta': {'stop_reason': 'max_tokens', 'stop_sequence': None}, 'usage': {'output_tokens': 10}}
|
||||
"""
|
||||
# TODO - get usage from this chunk, set in response
|
||||
message_delta = MessageBlockDelta(**chunk) # type: ignore
|
||||
finish_reason = map_finish_reason(
|
||||
finish_reason=message_delta["delta"].get("stop_reason", "stop")
|
||||
or "stop"
|
||||
)
|
||||
usage = self._handle_usage(anthropic_usage_chunk=message_delta["usage"])
|
||||
elif type_chunk == "message_start":
|
||||
"""
|
||||
Anthropic
|
||||
chunk = {
|
||||
"type": "message_start",
|
||||
"message": {
|
||||
"id": "msg_vrtx_011PqREFEMzd3REdCoUFAmdG",
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"model": "claude-3-sonnet-20240229",
|
||||
"content": [],
|
||||
"stop_reason": null,
|
||||
"stop_sequence": null,
|
||||
"usage": {
|
||||
"input_tokens": 270,
|
||||
"output_tokens": 1
|
||||
}
|
||||
}
|
||||
}
|
||||
"""
|
||||
message_start_block = MessageStartBlock(**chunk) # type: ignore
|
||||
if "usage" in message_start_block["message"]:
|
||||
usage = self._handle_usage(
|
||||
anthropic_usage_chunk=message_start_block["message"]["usage"]
|
||||
)
|
||||
elif type_chunk == "error":
|
||||
"""
|
||||
{"type":"error","error":{"details":null,"type":"api_error","message":"Internal server error"} }
|
||||
"""
|
||||
_error_dict = chunk.get("error", {}) or {}
|
||||
message = _error_dict.get("message", None) or str(chunk)
|
||||
raise AnthropicError(
|
||||
message=message,
|
||||
status_code=500, # it looks like Anthropic API does not return a status code in the chunk error - default to 500
|
||||
)
|
||||
|
||||
text, tool_use = self._handle_json_mode_chunk(text=text, tool_use=tool_use)
|
||||
|
||||
returned_chunk = ModelResponseStream(
|
||||
choices=[
|
||||
StreamingChoices(
|
||||
index=index,
|
||||
delta=Delta(
|
||||
content=text,
|
||||
tool_calls=[tool_use] if tool_use is not None else None,
|
||||
provider_specific_fields=(
|
||||
provider_specific_fields
|
||||
if provider_specific_fields
|
||||
else None
|
||||
),
|
||||
thinking_blocks=(
|
||||
thinking_blocks if thinking_blocks else None
|
||||
),
|
||||
reasoning_content=reasoning_content,
|
||||
),
|
||||
finish_reason=finish_reason,
|
||||
)
|
||||
],
|
||||
usage=usage,
|
||||
)
|
||||
|
||||
return returned_chunk
|
||||
|
||||
except json.JSONDecodeError:
|
||||
raise ValueError(f"Failed to decode JSON from chunk: {chunk}")
|
||||
|
||||
def _handle_json_mode_chunk(
|
||||
self, text: str, tool_use: Optional[ChatCompletionToolCallChunk]
|
||||
) -> Tuple[str, Optional[ChatCompletionToolCallChunk]]:
|
||||
"""
|
||||
If JSON mode is enabled, convert the tool call to a message.
|
||||
|
||||
Anthropic returns the JSON schema as part of the tool call
|
||||
OpenAI returns the JSON schema as part of the content, this handles placing it in the content
|
||||
|
||||
Args:
|
||||
text: str
|
||||
tool_use: Optional[ChatCompletionToolCallChunk]
|
||||
Returns:
|
||||
Tuple[str, Optional[ChatCompletionToolCallChunk]]
|
||||
|
||||
text: The text to use in the content
|
||||
tool_use: The ChatCompletionToolCallChunk to use in the chunk response
|
||||
"""
|
||||
if self.json_mode is True and tool_use is not None:
|
||||
message = AnthropicConfig._convert_tool_response_to_message(
|
||||
tool_calls=[tool_use]
|
||||
)
|
||||
if message is not None:
|
||||
text = message.content or ""
|
||||
tool_use = None
|
||||
|
||||
return text, tool_use
|
||||
|
||||
# Sync iterator
|
||||
def __iter__(self):
|
||||
return self
|
||||
|
||||
def __next__(self):
|
||||
try:
|
||||
chunk = self.response_iterator.__next__()
|
||||
except StopIteration:
|
||||
raise StopIteration
|
||||
except ValueError as e:
|
||||
raise RuntimeError(f"Error receiving chunk from stream: {e}")
|
||||
|
||||
try:
|
||||
str_line = chunk
|
||||
if isinstance(chunk, bytes): # Handle binary data
|
||||
str_line = chunk.decode("utf-8") # Convert bytes to string
|
||||
index = str_line.find("data:")
|
||||
if index != -1:
|
||||
str_line = str_line[index:]
|
||||
|
||||
if str_line.startswith("data:"):
|
||||
data_json = json.loads(str_line[5:])
|
||||
return self.chunk_parser(chunk=data_json)
|
||||
else:
|
||||
return GenericStreamingChunk(
|
||||
text="",
|
||||
is_finished=False,
|
||||
finish_reason="",
|
||||
usage=None,
|
||||
index=0,
|
||||
tool_use=None,
|
||||
)
|
||||
except StopIteration:
|
||||
raise StopIteration
|
||||
except ValueError as e:
|
||||
raise RuntimeError(f"Error parsing chunk: {e},\nReceived chunk: {chunk}")
|
||||
|
||||
# Async iterator
|
||||
def __aiter__(self):
|
||||
self.async_response_iterator = self.streaming_response.__aiter__()
|
||||
return self
|
||||
|
||||
async def __anext__(self):
|
||||
try:
|
||||
chunk = await self.async_response_iterator.__anext__()
|
||||
except StopAsyncIteration:
|
||||
raise StopAsyncIteration
|
||||
except ValueError as e:
|
||||
raise RuntimeError(f"Error receiving chunk from stream: {e}")
|
||||
|
||||
try:
|
||||
str_line = chunk
|
||||
if isinstance(chunk, bytes): # Handle binary data
|
||||
str_line = chunk.decode("utf-8") # Convert bytes to string
|
||||
index = str_line.find("data:")
|
||||
if index != -1:
|
||||
str_line = str_line[index:]
|
||||
|
||||
if str_line.startswith("data:"):
|
||||
data_json = json.loads(str_line[5:])
|
||||
return self.chunk_parser(chunk=data_json)
|
||||
else:
|
||||
return GenericStreamingChunk(
|
||||
text="",
|
||||
is_finished=False,
|
||||
finish_reason="",
|
||||
usage=None,
|
||||
index=0,
|
||||
tool_use=None,
|
||||
)
|
||||
except StopAsyncIteration:
|
||||
raise StopAsyncIteration
|
||||
except ValueError as e:
|
||||
raise RuntimeError(f"Error parsing chunk: {e},\nReceived chunk: {chunk}")
|
||||
|
||||
def convert_str_chunk_to_generic_chunk(self, chunk: str) -> ModelResponseStream:
|
||||
"""
|
||||
Convert a string chunk to a GenericStreamingChunk
|
||||
|
||||
Note: This is used for Anthropic pass through streaming logging
|
||||
|
||||
We can move __anext__, and __next__ to use this function since it's common logic.
|
||||
Did not migrate them to minmize changes made in 1 PR.
|
||||
"""
|
||||
str_line = chunk
|
||||
if isinstance(chunk, bytes): # Handle binary data
|
||||
str_line = chunk.decode("utf-8") # Convert bytes to string
|
||||
index = str_line.find("data:")
|
||||
if index != -1:
|
||||
str_line = str_line[index:]
|
||||
|
||||
if str_line.startswith("data:"):
|
||||
data_json = json.loads(str_line[5:])
|
||||
return self.chunk_parser(chunk=data_json)
|
||||
else:
|
||||
return ModelResponseStream()
|
||||
@@ -0,0 +1,823 @@
|
||||
import json
|
||||
import time
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union, cast
|
||||
|
||||
import httpx
|
||||
|
||||
import litellm
|
||||
from litellm.constants import (
|
||||
DEFAULT_ANTHROPIC_CHAT_MAX_TOKENS,
|
||||
DEFAULT_REASONING_EFFORT_HIGH_THINKING_BUDGET,
|
||||
DEFAULT_REASONING_EFFORT_LOW_THINKING_BUDGET,
|
||||
DEFAULT_REASONING_EFFORT_MEDIUM_THINKING_BUDGET,
|
||||
RESPONSE_FORMAT_TOOL_NAME,
|
||||
)
|
||||
from litellm.litellm_core_utils.core_helpers import map_finish_reason
|
||||
from litellm.litellm_core_utils.prompt_templates.factory import anthropic_messages_pt
|
||||
from litellm.llms.base_llm.base_utils import type_to_response_format_param
|
||||
from litellm.llms.base_llm.chat.transformation import BaseConfig, BaseLLMException
|
||||
from litellm.types.llms.anthropic import (
|
||||
AllAnthropicToolsValues,
|
||||
AnthropicComputerTool,
|
||||
AnthropicHostedTools,
|
||||
AnthropicInputSchema,
|
||||
AnthropicMessagesTool,
|
||||
AnthropicMessagesToolChoice,
|
||||
AnthropicSystemMessageContent,
|
||||
AnthropicThinkingParam,
|
||||
)
|
||||
from litellm.types.llms.openai import (
|
||||
REASONING_EFFORT,
|
||||
AllMessageValues,
|
||||
ChatCompletionCachedContent,
|
||||
ChatCompletionRedactedThinkingBlock,
|
||||
ChatCompletionSystemMessage,
|
||||
ChatCompletionThinkingBlock,
|
||||
ChatCompletionToolCallChunk,
|
||||
ChatCompletionToolCallFunctionChunk,
|
||||
ChatCompletionToolParam,
|
||||
)
|
||||
from litellm.types.utils import CompletionTokensDetailsWrapper
|
||||
from litellm.types.utils import Message as LitellmMessage
|
||||
from litellm.types.utils import PromptTokensDetailsWrapper
|
||||
from litellm.utils import (
|
||||
ModelResponse,
|
||||
Usage,
|
||||
add_dummy_tool,
|
||||
has_tool_call_blocks,
|
||||
token_counter,
|
||||
)
|
||||
|
||||
from ..common_utils import AnthropicError, AnthropicModelInfo, process_anthropic_headers
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
|
||||
|
||||
LoggingClass = LiteLLMLoggingObj
|
||||
else:
|
||||
LoggingClass = Any
|
||||
|
||||
|
||||
class AnthropicConfig(AnthropicModelInfo, BaseConfig):
|
||||
"""
|
||||
Reference: https://docs.anthropic.com/claude/reference/messages_post
|
||||
|
||||
to pass metadata to anthropic, it's {"user_id": "any-relevant-information"}
|
||||
"""
|
||||
|
||||
max_tokens: Optional[
|
||||
int
|
||||
] = DEFAULT_ANTHROPIC_CHAT_MAX_TOKENS # anthropic requires a default value (Opus, Sonnet, and Haiku have the same default)
|
||||
stop_sequences: Optional[list] = None
|
||||
temperature: Optional[int] = None
|
||||
top_p: Optional[int] = None
|
||||
top_k: Optional[int] = None
|
||||
metadata: Optional[dict] = None
|
||||
system: Optional[str] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_tokens: Optional[
|
||||
int
|
||||
] = DEFAULT_ANTHROPIC_CHAT_MAX_TOKENS, # You can pass in a value yourself or use the default value 4096
|
||||
stop_sequences: Optional[list] = None,
|
||||
temperature: Optional[int] = None,
|
||||
top_p: Optional[int] = None,
|
||||
top_k: Optional[int] = None,
|
||||
metadata: Optional[dict] = None,
|
||||
system: Optional[str] = None,
|
||||
) -> None:
|
||||
locals_ = locals().copy()
|
||||
for key, value in locals_.items():
|
||||
if key != "self" and value is not None:
|
||||
setattr(self.__class__, key, value)
|
||||
|
||||
@classmethod
|
||||
def get_config(cls):
|
||||
return super().get_config()
|
||||
|
||||
def get_supported_openai_params(self, model: str):
|
||||
params = [
|
||||
"stream",
|
||||
"stop",
|
||||
"temperature",
|
||||
"top_p",
|
||||
"max_tokens",
|
||||
"max_completion_tokens",
|
||||
"tools",
|
||||
"tool_choice",
|
||||
"extra_headers",
|
||||
"parallel_tool_calls",
|
||||
"response_format",
|
||||
"user",
|
||||
"reasoning_effort",
|
||||
]
|
||||
|
||||
if "claude-3-7-sonnet" in model:
|
||||
params.append("thinking")
|
||||
|
||||
return params
|
||||
|
||||
def get_json_schema_from_pydantic_object(
|
||||
self, response_format: Union[Any, Dict, None]
|
||||
) -> Optional[dict]:
|
||||
return type_to_response_format_param(
|
||||
response_format, ref_template="/$defs/{model}"
|
||||
) # Relevant issue: https://github.com/BerriAI/litellm/issues/7755
|
||||
|
||||
def get_cache_control_headers(self) -> dict:
|
||||
return {
|
||||
"anthropic-version": "2023-06-01",
|
||||
"anthropic-beta": "prompt-caching-2024-07-31",
|
||||
}
|
||||
|
||||
def _map_tool_choice(
|
||||
self, tool_choice: Optional[str], parallel_tool_use: Optional[bool]
|
||||
) -> Optional[AnthropicMessagesToolChoice]:
|
||||
_tool_choice: Optional[AnthropicMessagesToolChoice] = None
|
||||
if tool_choice == "auto":
|
||||
_tool_choice = AnthropicMessagesToolChoice(
|
||||
type="auto",
|
||||
)
|
||||
elif tool_choice == "required":
|
||||
_tool_choice = AnthropicMessagesToolChoice(type="any")
|
||||
elif isinstance(tool_choice, dict):
|
||||
_tool_name = tool_choice.get("function", {}).get("name")
|
||||
_tool_choice = AnthropicMessagesToolChoice(type="tool")
|
||||
if _tool_name is not None:
|
||||
_tool_choice["name"] = _tool_name
|
||||
|
||||
if parallel_tool_use is not None:
|
||||
# Anthropic uses 'disable_parallel_tool_use' flag to determine if parallel tool use is allowed
|
||||
# this is the inverse of the openai flag.
|
||||
if _tool_choice is not None:
|
||||
_tool_choice["disable_parallel_tool_use"] = not parallel_tool_use
|
||||
else: # use anthropic defaults and make sure to send the disable_parallel_tool_use flag
|
||||
_tool_choice = AnthropicMessagesToolChoice(
|
||||
type="auto",
|
||||
disable_parallel_tool_use=not parallel_tool_use,
|
||||
)
|
||||
return _tool_choice
|
||||
|
||||
def _map_tool_helper(
|
||||
self, tool: ChatCompletionToolParam
|
||||
) -> AllAnthropicToolsValues:
|
||||
returned_tool: Optional[AllAnthropicToolsValues] = None
|
||||
|
||||
if tool["type"] == "function" or tool["type"] == "custom":
|
||||
_input_schema: dict = tool["function"].get(
|
||||
"parameters",
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {},
|
||||
},
|
||||
)
|
||||
input_schema: AnthropicInputSchema = AnthropicInputSchema(**_input_schema)
|
||||
_tool = AnthropicMessagesTool(
|
||||
name=tool["function"]["name"],
|
||||
input_schema=input_schema,
|
||||
)
|
||||
|
||||
_description = tool["function"].get("description")
|
||||
if _description is not None:
|
||||
_tool["description"] = _description
|
||||
|
||||
returned_tool = _tool
|
||||
|
||||
elif tool["type"].startswith("computer_"):
|
||||
## check if all required 'display_' params are given
|
||||
if "parameters" not in tool["function"]:
|
||||
raise ValueError("Missing required parameter: parameters")
|
||||
|
||||
_display_width_px: Optional[int] = tool["function"]["parameters"].get(
|
||||
"display_width_px"
|
||||
)
|
||||
_display_height_px: Optional[int] = tool["function"]["parameters"].get(
|
||||
"display_height_px"
|
||||
)
|
||||
if _display_width_px is None or _display_height_px is None:
|
||||
raise ValueError(
|
||||
"Missing required parameter: display_width_px or display_height_px"
|
||||
)
|
||||
|
||||
_computer_tool = AnthropicComputerTool(
|
||||
type=tool["type"],
|
||||
name=tool["function"].get("name", "computer"),
|
||||
display_width_px=_display_width_px,
|
||||
display_height_px=_display_height_px,
|
||||
)
|
||||
|
||||
_display_number = tool["function"]["parameters"].get("display_number")
|
||||
if _display_number is not None:
|
||||
_computer_tool["display_number"] = _display_number
|
||||
|
||||
returned_tool = _computer_tool
|
||||
elif tool["type"].startswith("bash_") or tool["type"].startswith(
|
||||
"text_editor_"
|
||||
):
|
||||
function_name = tool["function"].get("name")
|
||||
if function_name is None:
|
||||
raise ValueError("Missing required parameter: name")
|
||||
|
||||
returned_tool = AnthropicHostedTools(
|
||||
type=tool["type"],
|
||||
name=function_name,
|
||||
)
|
||||
if returned_tool is None:
|
||||
raise ValueError(f"Unsupported tool type: {tool['type']}")
|
||||
|
||||
## check if cache_control is set in the tool
|
||||
_cache_control = tool.get("cache_control", None)
|
||||
_cache_control_function = tool.get("function", {}).get("cache_control", None)
|
||||
if _cache_control is not None:
|
||||
returned_tool["cache_control"] = _cache_control
|
||||
elif _cache_control_function is not None and isinstance(
|
||||
_cache_control_function, dict
|
||||
):
|
||||
returned_tool["cache_control"] = ChatCompletionCachedContent(
|
||||
**_cache_control_function # type: ignore
|
||||
)
|
||||
|
||||
return returned_tool
|
||||
|
||||
def _map_tools(self, tools: List) -> List[AllAnthropicToolsValues]:
|
||||
anthropic_tools = []
|
||||
for tool in tools:
|
||||
if "input_schema" in tool: # assume in anthropic format
|
||||
anthropic_tools.append(tool)
|
||||
else: # assume openai tool call
|
||||
new_tool = self._map_tool_helper(tool)
|
||||
|
||||
anthropic_tools.append(new_tool)
|
||||
return anthropic_tools
|
||||
|
||||
def _map_stop_sequences(
|
||||
self, stop: Optional[Union[str, List[str]]]
|
||||
) -> Optional[List[str]]:
|
||||
new_stop: Optional[List[str]] = None
|
||||
if isinstance(stop, str):
|
||||
if (
|
||||
stop.isspace() and litellm.drop_params is True
|
||||
): # anthropic doesn't allow whitespace characters as stop-sequences
|
||||
return new_stop
|
||||
new_stop = [stop]
|
||||
elif isinstance(stop, list):
|
||||
new_v = []
|
||||
for v in stop:
|
||||
if (
|
||||
v.isspace() and litellm.drop_params is True
|
||||
): # anthropic doesn't allow whitespace characters as stop-sequences
|
||||
continue
|
||||
new_v.append(v)
|
||||
if len(new_v) > 0:
|
||||
new_stop = new_v
|
||||
return new_stop
|
||||
|
||||
@staticmethod
|
||||
def _map_reasoning_effort(
|
||||
reasoning_effort: Optional[Union[REASONING_EFFORT, str]]
|
||||
) -> Optional[AnthropicThinkingParam]:
|
||||
if reasoning_effort is None:
|
||||
return None
|
||||
elif reasoning_effort == "low":
|
||||
return AnthropicThinkingParam(
|
||||
type="enabled",
|
||||
budget_tokens=DEFAULT_REASONING_EFFORT_LOW_THINKING_BUDGET,
|
||||
)
|
||||
elif reasoning_effort == "medium":
|
||||
return AnthropicThinkingParam(
|
||||
type="enabled",
|
||||
budget_tokens=DEFAULT_REASONING_EFFORT_MEDIUM_THINKING_BUDGET,
|
||||
)
|
||||
elif reasoning_effort == "high":
|
||||
return AnthropicThinkingParam(
|
||||
type="enabled",
|
||||
budget_tokens=DEFAULT_REASONING_EFFORT_HIGH_THINKING_BUDGET,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unmapped reasoning effort: {reasoning_effort}")
|
||||
|
||||
def map_response_format_to_anthropic_tool(
|
||||
self, value: Optional[dict], optional_params: dict, is_thinking_enabled: bool
|
||||
) -> Optional[AnthropicMessagesTool]:
|
||||
ignore_response_format_types = ["text"]
|
||||
if (
|
||||
value is None or value["type"] in ignore_response_format_types
|
||||
): # value is a no-op
|
||||
return None
|
||||
|
||||
json_schema: Optional[dict] = None
|
||||
if "response_schema" in value:
|
||||
json_schema = value["response_schema"]
|
||||
elif "json_schema" in value:
|
||||
json_schema = value["json_schema"]["schema"]
|
||||
"""
|
||||
When using tools in this way: - https://docs.anthropic.com/en/docs/build-with-claude/tool-use#json-mode
|
||||
- You usually want to provide a single tool
|
||||
- You should set tool_choice (see Forcing tool use) to instruct the model to explicitly use that tool
|
||||
- Remember that the model will pass the input to the tool, so the name of the tool and description should be from the model’s perspective.
|
||||
"""
|
||||
|
||||
_tool = self._create_json_tool_call_for_response_format(
|
||||
json_schema=json_schema,
|
||||
)
|
||||
|
||||
return _tool
|
||||
|
||||
def map_openai_params(
|
||||
self,
|
||||
non_default_params: dict,
|
||||
optional_params: dict,
|
||||
model: str,
|
||||
drop_params: bool,
|
||||
) -> dict:
|
||||
is_thinking_enabled = self.is_thinking_enabled(
|
||||
non_default_params=non_default_params
|
||||
)
|
||||
|
||||
for param, value in non_default_params.items():
|
||||
if param == "max_tokens":
|
||||
optional_params["max_tokens"] = value
|
||||
if param == "max_completion_tokens":
|
||||
optional_params["max_tokens"] = value
|
||||
if param == "tools":
|
||||
# check if optional params already has tools
|
||||
tool_value = self._map_tools(value)
|
||||
optional_params = self._add_tools_to_optional_params(
|
||||
optional_params=optional_params, tools=tool_value
|
||||
)
|
||||
if param == "tool_choice" or param == "parallel_tool_calls":
|
||||
_tool_choice: Optional[
|
||||
AnthropicMessagesToolChoice
|
||||
] = self._map_tool_choice(
|
||||
tool_choice=non_default_params.get("tool_choice"),
|
||||
parallel_tool_use=non_default_params.get("parallel_tool_calls"),
|
||||
)
|
||||
|
||||
if _tool_choice is not None:
|
||||
optional_params["tool_choice"] = _tool_choice
|
||||
if param == "stream" and value is True:
|
||||
optional_params["stream"] = value
|
||||
if param == "stop" and (isinstance(value, str) or isinstance(value, list)):
|
||||
_value = self._map_stop_sequences(value)
|
||||
if _value is not None:
|
||||
optional_params["stop_sequences"] = _value
|
||||
if param == "temperature":
|
||||
optional_params["temperature"] = value
|
||||
if param == "top_p":
|
||||
optional_params["top_p"] = value
|
||||
if param == "response_format" and isinstance(value, dict):
|
||||
_tool = self.map_response_format_to_anthropic_tool(
|
||||
value, optional_params, is_thinking_enabled
|
||||
)
|
||||
if _tool is None:
|
||||
continue
|
||||
if not is_thinking_enabled:
|
||||
_tool_choice = {"name": RESPONSE_FORMAT_TOOL_NAME, "type": "tool"}
|
||||
optional_params["tool_choice"] = _tool_choice
|
||||
optional_params["json_mode"] = True
|
||||
optional_params = self._add_tools_to_optional_params(
|
||||
optional_params=optional_params, tools=[_tool]
|
||||
)
|
||||
if param == "user":
|
||||
optional_params["metadata"] = {"user_id": value}
|
||||
if param == "thinking":
|
||||
optional_params["thinking"] = value
|
||||
elif param == "reasoning_effort" and isinstance(value, str):
|
||||
optional_params["thinking"] = AnthropicConfig._map_reasoning_effort(
|
||||
value
|
||||
)
|
||||
|
||||
## handle thinking tokens
|
||||
self.update_optional_params_with_thinking_tokens(
|
||||
non_default_params=non_default_params, optional_params=optional_params
|
||||
)
|
||||
return optional_params
|
||||
|
||||
def _create_json_tool_call_for_response_format(
|
||||
self,
|
||||
json_schema: Optional[dict] = None,
|
||||
) -> AnthropicMessagesTool:
|
||||
"""
|
||||
Handles creating a tool call for getting responses in JSON format.
|
||||
|
||||
Args:
|
||||
json_schema (Optional[dict]): The JSON schema the response should be in
|
||||
|
||||
Returns:
|
||||
AnthropicMessagesTool: The tool call to send to Anthropic API to get responses in JSON format
|
||||
"""
|
||||
_input_schema: AnthropicInputSchema = AnthropicInputSchema(
|
||||
type="object",
|
||||
)
|
||||
|
||||
if json_schema is None:
|
||||
# Anthropic raises a 400 BadRequest error if properties is passed as None
|
||||
# see usage with additionalProperties (Example 5) https://github.com/anthropics/anthropic-cookbook/blob/main/tool_use/extracting_structured_json.ipynb
|
||||
_input_schema["additionalProperties"] = True
|
||||
_input_schema["properties"] = {}
|
||||
else:
|
||||
_input_schema.update(cast(AnthropicInputSchema, json_schema))
|
||||
|
||||
_tool = AnthropicMessagesTool(
|
||||
name=RESPONSE_FORMAT_TOOL_NAME, input_schema=_input_schema
|
||||
)
|
||||
return _tool
|
||||
|
||||
def translate_system_message(
|
||||
self, messages: List[AllMessageValues]
|
||||
) -> List[AnthropicSystemMessageContent]:
|
||||
"""
|
||||
Translate system message to anthropic format.
|
||||
|
||||
Removes system message from the original list and returns a new list of anthropic system message content.
|
||||
"""
|
||||
system_prompt_indices = []
|
||||
anthropic_system_message_list: List[AnthropicSystemMessageContent] = []
|
||||
for idx, message in enumerate(messages):
|
||||
if message["role"] == "system":
|
||||
valid_content: bool = False
|
||||
system_message_block = ChatCompletionSystemMessage(**message)
|
||||
if isinstance(system_message_block["content"], str):
|
||||
anthropic_system_message_content = AnthropicSystemMessageContent(
|
||||
type="text",
|
||||
text=system_message_block["content"],
|
||||
)
|
||||
if "cache_control" in system_message_block:
|
||||
anthropic_system_message_content[
|
||||
"cache_control"
|
||||
] = system_message_block["cache_control"]
|
||||
anthropic_system_message_list.append(
|
||||
anthropic_system_message_content
|
||||
)
|
||||
valid_content = True
|
||||
elif isinstance(message["content"], list):
|
||||
for _content in message["content"]:
|
||||
anthropic_system_message_content = (
|
||||
AnthropicSystemMessageContent(
|
||||
type=_content.get("type"),
|
||||
text=_content.get("text"),
|
||||
)
|
||||
)
|
||||
if "cache_control" in _content:
|
||||
anthropic_system_message_content[
|
||||
"cache_control"
|
||||
] = _content["cache_control"]
|
||||
|
||||
anthropic_system_message_list.append(
|
||||
anthropic_system_message_content
|
||||
)
|
||||
valid_content = True
|
||||
|
||||
if valid_content:
|
||||
system_prompt_indices.append(idx)
|
||||
if len(system_prompt_indices) > 0:
|
||||
for idx in reversed(system_prompt_indices):
|
||||
messages.pop(idx)
|
||||
|
||||
return anthropic_system_message_list
|
||||
|
||||
def transform_request(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
headers: dict,
|
||||
) -> dict:
|
||||
"""
|
||||
Translate messages to anthropic format.
|
||||
"""
|
||||
## VALIDATE REQUEST
|
||||
"""
|
||||
Anthropic doesn't support tool calling without `tools=` param specified.
|
||||
"""
|
||||
if (
|
||||
"tools" not in optional_params
|
||||
and messages is not None
|
||||
and has_tool_call_blocks(messages)
|
||||
):
|
||||
if litellm.modify_params:
|
||||
optional_params["tools"] = self._map_tools(
|
||||
add_dummy_tool(custom_llm_provider="anthropic")
|
||||
)
|
||||
else:
|
||||
raise litellm.UnsupportedParamsError(
|
||||
message="Anthropic doesn't support tool calling without `tools=` param specified. Pass `tools=` param OR set `litellm.modify_params = True` // `litellm_settings::modify_params: True` to add dummy tool to the request.",
|
||||
model="",
|
||||
llm_provider="anthropic",
|
||||
)
|
||||
|
||||
# Separate system prompt from rest of message
|
||||
anthropic_system_message_list = self.translate_system_message(messages=messages)
|
||||
# Handling anthropic API Prompt Caching
|
||||
if len(anthropic_system_message_list) > 0:
|
||||
optional_params["system"] = anthropic_system_message_list
|
||||
# Format rest of message according to anthropic guidelines
|
||||
try:
|
||||
anthropic_messages = anthropic_messages_pt(
|
||||
model=model,
|
||||
messages=messages,
|
||||
llm_provider="anthropic",
|
||||
)
|
||||
except Exception as e:
|
||||
raise AnthropicError(
|
||||
status_code=400,
|
||||
message="{}\nReceived Messages={}".format(str(e), messages),
|
||||
) # don't use verbose_logger.exception, if exception is raised
|
||||
|
||||
## Load Config
|
||||
config = litellm.AnthropicConfig.get_config()
|
||||
for k, v in config.items():
|
||||
if (
|
||||
k not in optional_params
|
||||
): # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in
|
||||
optional_params[k] = v
|
||||
|
||||
## Handle user_id in metadata
|
||||
_litellm_metadata = litellm_params.get("metadata", None)
|
||||
if (
|
||||
_litellm_metadata
|
||||
and isinstance(_litellm_metadata, dict)
|
||||
and "user_id" in _litellm_metadata
|
||||
):
|
||||
optional_params["metadata"] = {"user_id": _litellm_metadata["user_id"]}
|
||||
|
||||
data = {
|
||||
"model": model,
|
||||
"messages": anthropic_messages,
|
||||
**optional_params,
|
||||
}
|
||||
|
||||
return data
|
||||
|
||||
def _transform_response_for_json_mode(
|
||||
self,
|
||||
json_mode: Optional[bool],
|
||||
tool_calls: List[ChatCompletionToolCallChunk],
|
||||
) -> Optional[LitellmMessage]:
|
||||
_message: Optional[LitellmMessage] = None
|
||||
if json_mode is True and len(tool_calls) == 1:
|
||||
# check if tool name is the default tool name
|
||||
json_mode_content_str: Optional[str] = None
|
||||
if (
|
||||
"name" in tool_calls[0]["function"]
|
||||
and tool_calls[0]["function"]["name"] == RESPONSE_FORMAT_TOOL_NAME
|
||||
):
|
||||
json_mode_content_str = tool_calls[0]["function"].get("arguments")
|
||||
if json_mode_content_str is not None:
|
||||
_message = AnthropicConfig._convert_tool_response_to_message(
|
||||
tool_calls=tool_calls,
|
||||
)
|
||||
return _message
|
||||
|
||||
def extract_response_content(
|
||||
self, completion_response: dict
|
||||
) -> Tuple[
|
||||
str,
|
||||
Optional[List[Any]],
|
||||
Optional[
|
||||
List[
|
||||
Union[ChatCompletionThinkingBlock, ChatCompletionRedactedThinkingBlock]
|
||||
]
|
||||
],
|
||||
Optional[str],
|
||||
List[ChatCompletionToolCallChunk],
|
||||
]:
|
||||
text_content = ""
|
||||
citations: Optional[List[Any]] = None
|
||||
thinking_blocks: Optional[
|
||||
List[
|
||||
Union[ChatCompletionThinkingBlock, ChatCompletionRedactedThinkingBlock]
|
||||
]
|
||||
] = None
|
||||
reasoning_content: Optional[str] = None
|
||||
tool_calls: List[ChatCompletionToolCallChunk] = []
|
||||
for idx, content in enumerate(completion_response["content"]):
|
||||
if content["type"] == "text":
|
||||
text_content += content["text"]
|
||||
## TOOL CALLING
|
||||
elif content["type"] == "tool_use":
|
||||
tool_calls.append(
|
||||
ChatCompletionToolCallChunk(
|
||||
id=content["id"],
|
||||
type="function",
|
||||
function=ChatCompletionToolCallFunctionChunk(
|
||||
name=content["name"],
|
||||
arguments=json.dumps(content["input"]),
|
||||
),
|
||||
index=idx,
|
||||
)
|
||||
)
|
||||
|
||||
elif content.get("thinking", None) is not None:
|
||||
if thinking_blocks is None:
|
||||
thinking_blocks = []
|
||||
thinking_blocks.append(cast(ChatCompletionThinkingBlock, content))
|
||||
elif content["type"] == "redacted_thinking":
|
||||
if thinking_blocks is None:
|
||||
thinking_blocks = []
|
||||
thinking_blocks.append(
|
||||
cast(ChatCompletionRedactedThinkingBlock, content)
|
||||
)
|
||||
|
||||
## CITATIONS
|
||||
if content.get("citations") is not None:
|
||||
if citations is None:
|
||||
citations = []
|
||||
citations.append(content["citations"])
|
||||
if thinking_blocks is not None:
|
||||
reasoning_content = ""
|
||||
for block in thinking_blocks:
|
||||
thinking_content = cast(Optional[str], block.get("thinking"))
|
||||
if thinking_content is not None:
|
||||
reasoning_content += thinking_content
|
||||
|
||||
return text_content, citations, thinking_blocks, reasoning_content, tool_calls
|
||||
|
||||
def calculate_usage(
|
||||
self, usage_object: dict, reasoning_content: Optional[str]
|
||||
) -> Usage:
|
||||
prompt_tokens = usage_object.get("input_tokens", 0)
|
||||
completion_tokens = usage_object.get("output_tokens", 0)
|
||||
_usage = usage_object
|
||||
cache_creation_input_tokens: int = 0
|
||||
cache_read_input_tokens: int = 0
|
||||
|
||||
if "cache_creation_input_tokens" in _usage:
|
||||
cache_creation_input_tokens = _usage["cache_creation_input_tokens"]
|
||||
if "cache_read_input_tokens" in _usage:
|
||||
cache_read_input_tokens = _usage["cache_read_input_tokens"]
|
||||
prompt_tokens += cache_read_input_tokens
|
||||
|
||||
prompt_tokens_details = PromptTokensDetailsWrapper(
|
||||
cached_tokens=cache_read_input_tokens
|
||||
)
|
||||
completion_token_details = (
|
||||
CompletionTokensDetailsWrapper(
|
||||
reasoning_tokens=token_counter(
|
||||
text=reasoning_content, count_response_tokens=True
|
||||
)
|
||||
)
|
||||
if reasoning_content
|
||||
else None
|
||||
)
|
||||
total_tokens = prompt_tokens + completion_tokens
|
||||
usage = Usage(
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=total_tokens,
|
||||
prompt_tokens_details=prompt_tokens_details,
|
||||
cache_creation_input_tokens=cache_creation_input_tokens,
|
||||
cache_read_input_tokens=cache_read_input_tokens,
|
||||
completion_tokens_details=completion_token_details,
|
||||
)
|
||||
return usage
|
||||
|
||||
def transform_response(
|
||||
self,
|
||||
model: str,
|
||||
raw_response: httpx.Response,
|
||||
model_response: ModelResponse,
|
||||
logging_obj: LoggingClass,
|
||||
request_data: Dict,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: Dict,
|
||||
litellm_params: dict,
|
||||
encoding: Any,
|
||||
api_key: Optional[str] = None,
|
||||
json_mode: Optional[bool] = None,
|
||||
) -> ModelResponse:
|
||||
_hidden_params: Dict = {}
|
||||
_hidden_params["additional_headers"] = process_anthropic_headers(
|
||||
dict(raw_response.headers)
|
||||
)
|
||||
## LOGGING
|
||||
logging_obj.post_call(
|
||||
input=messages,
|
||||
api_key=api_key,
|
||||
original_response=raw_response.text,
|
||||
additional_args={"complete_input_dict": request_data},
|
||||
)
|
||||
|
||||
## RESPONSE OBJECT
|
||||
try:
|
||||
completion_response = raw_response.json()
|
||||
except Exception as e:
|
||||
response_headers = getattr(raw_response, "headers", None)
|
||||
raise AnthropicError(
|
||||
message="Unable to get json response - {}, Original Response: {}".format(
|
||||
str(e), raw_response.text
|
||||
),
|
||||
status_code=raw_response.status_code,
|
||||
headers=response_headers,
|
||||
)
|
||||
if "error" in completion_response:
|
||||
response_headers = getattr(raw_response, "headers", None)
|
||||
raise AnthropicError(
|
||||
message=str(completion_response["error"]),
|
||||
status_code=raw_response.status_code,
|
||||
headers=response_headers,
|
||||
)
|
||||
else:
|
||||
text_content = ""
|
||||
citations: Optional[List[Any]] = None
|
||||
thinking_blocks: Optional[
|
||||
List[
|
||||
Union[
|
||||
ChatCompletionThinkingBlock, ChatCompletionRedactedThinkingBlock
|
||||
]
|
||||
]
|
||||
] = None
|
||||
reasoning_content: Optional[str] = None
|
||||
tool_calls: List[ChatCompletionToolCallChunk] = []
|
||||
|
||||
(
|
||||
text_content,
|
||||
citations,
|
||||
thinking_blocks,
|
||||
reasoning_content,
|
||||
tool_calls,
|
||||
) = self.extract_response_content(completion_response=completion_response)
|
||||
|
||||
_message = litellm.Message(
|
||||
tool_calls=tool_calls,
|
||||
content=text_content or None,
|
||||
provider_specific_fields={
|
||||
"citations": citations,
|
||||
"thinking_blocks": thinking_blocks,
|
||||
},
|
||||
thinking_blocks=thinking_blocks,
|
||||
reasoning_content=reasoning_content,
|
||||
)
|
||||
|
||||
## HANDLE JSON MODE - anthropic returns single function call
|
||||
json_mode_message = self._transform_response_for_json_mode(
|
||||
json_mode=json_mode,
|
||||
tool_calls=tool_calls,
|
||||
)
|
||||
if json_mode_message is not None:
|
||||
completion_response["stop_reason"] = "stop"
|
||||
_message = json_mode_message
|
||||
|
||||
model_response.choices[0].message = _message # type: ignore
|
||||
model_response._hidden_params["original_response"] = completion_response[
|
||||
"content"
|
||||
] # allow user to access raw anthropic tool calling response
|
||||
|
||||
model_response.choices[0].finish_reason = map_finish_reason(
|
||||
completion_response["stop_reason"]
|
||||
)
|
||||
|
||||
## CALCULATING USAGE
|
||||
usage = self.calculate_usage(
|
||||
usage_object=completion_response["usage"],
|
||||
reasoning_content=reasoning_content,
|
||||
)
|
||||
setattr(model_response, "usage", usage) # type: ignore
|
||||
|
||||
model_response.created = int(time.time())
|
||||
model_response.model = completion_response["model"]
|
||||
|
||||
model_response._hidden_params = _hidden_params
|
||||
return model_response
|
||||
|
||||
@staticmethod
|
||||
def _convert_tool_response_to_message(
|
||||
tool_calls: List[ChatCompletionToolCallChunk],
|
||||
) -> Optional[LitellmMessage]:
|
||||
"""
|
||||
In JSON mode, Anthropic API returns JSON schema as a tool call, we need to convert it to a message to follow the OpenAI format
|
||||
|
||||
"""
|
||||
## HANDLE JSON MODE - anthropic returns single function call
|
||||
json_mode_content_str: Optional[str] = tool_calls[0]["function"].get(
|
||||
"arguments"
|
||||
)
|
||||
try:
|
||||
if json_mode_content_str is not None:
|
||||
args = json.loads(json_mode_content_str)
|
||||
if (
|
||||
isinstance(args, dict)
|
||||
and (values := args.get("values")) is not None
|
||||
):
|
||||
_message = litellm.Message(content=json.dumps(values))
|
||||
return _message
|
||||
else:
|
||||
# a lot of the times the `values` key is not present in the tool response
|
||||
# relevant issue: https://github.com/BerriAI/litellm/issues/6741
|
||||
_message = litellm.Message(content=json.dumps(args))
|
||||
return _message
|
||||
except json.JSONDecodeError:
|
||||
# json decode error does occur, return the original tool response str
|
||||
return litellm.Message(content=json_mode_content_str)
|
||||
return None
|
||||
|
||||
def get_error_class(
|
||||
self, error_message: str, status_code: int, headers: Union[Dict, httpx.Headers]
|
||||
) -> BaseLLMException:
|
||||
return AnthropicError(
|
||||
status_code=status_code,
|
||||
message=error_message,
|
||||
headers=cast(httpx.Headers, headers),
|
||||
)
|
||||
@@ -0,0 +1,221 @@
|
||||
"""
|
||||
This file contains common utils for anthropic calls.
|
||||
"""
|
||||
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
import httpx
|
||||
|
||||
import litellm
|
||||
from litellm.llms.base_llm.base_utils import BaseLLMModelInfo
|
||||
from litellm.llms.base_llm.chat.transformation import BaseLLMException
|
||||
from litellm.secret_managers.main import get_secret_str
|
||||
from litellm.types.llms.anthropic import AllAnthropicToolsValues
|
||||
from litellm.types.llms.openai import AllMessageValues
|
||||
|
||||
|
||||
class AnthropicError(BaseLLMException):
|
||||
def __init__(
|
||||
self,
|
||||
status_code: int,
|
||||
message,
|
||||
headers: Optional[httpx.Headers] = None,
|
||||
):
|
||||
super().__init__(status_code=status_code, message=message, headers=headers)
|
||||
|
||||
|
||||
class AnthropicModelInfo(BaseLLMModelInfo):
|
||||
def is_cache_control_set(self, messages: List[AllMessageValues]) -> bool:
|
||||
"""
|
||||
Return if {"cache_control": ..} in message content block
|
||||
|
||||
Used to check if anthropic prompt caching headers need to be set.
|
||||
"""
|
||||
for message in messages:
|
||||
if message.get("cache_control", None) is not None:
|
||||
return True
|
||||
_message_content = message.get("content")
|
||||
if _message_content is not None and isinstance(_message_content, list):
|
||||
for content in _message_content:
|
||||
if "cache_control" in content:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def is_computer_tool_used(
|
||||
self, tools: Optional[List[AllAnthropicToolsValues]]
|
||||
) -> bool:
|
||||
if tools is None:
|
||||
return False
|
||||
for tool in tools:
|
||||
if "type" in tool and tool["type"].startswith("computer_"):
|
||||
return True
|
||||
return False
|
||||
|
||||
def is_pdf_used(self, messages: List[AllMessageValues]) -> bool:
|
||||
"""
|
||||
Set to true if media passed into messages.
|
||||
|
||||
"""
|
||||
for message in messages:
|
||||
if (
|
||||
"content" in message
|
||||
and message["content"] is not None
|
||||
and isinstance(message["content"], list)
|
||||
):
|
||||
for content in message["content"]:
|
||||
if "type" in content and content["type"] != "text":
|
||||
return True
|
||||
return False
|
||||
|
||||
def _get_user_anthropic_beta_headers(
|
||||
self, anthropic_beta_header: Optional[str]
|
||||
) -> Optional[List[str]]:
|
||||
if anthropic_beta_header is None:
|
||||
return None
|
||||
return anthropic_beta_header.split(",")
|
||||
|
||||
def get_anthropic_headers(
|
||||
self,
|
||||
api_key: str,
|
||||
anthropic_version: Optional[str] = None,
|
||||
computer_tool_used: bool = False,
|
||||
prompt_caching_set: bool = False,
|
||||
pdf_used: bool = False,
|
||||
is_vertex_request: bool = False,
|
||||
user_anthropic_beta_headers: Optional[List[str]] = None,
|
||||
) -> dict:
|
||||
betas = set()
|
||||
if prompt_caching_set:
|
||||
betas.add("prompt-caching-2024-07-31")
|
||||
if computer_tool_used:
|
||||
betas.add("computer-use-2024-10-22")
|
||||
if pdf_used:
|
||||
betas.add("pdfs-2024-09-25")
|
||||
headers = {
|
||||
"anthropic-version": anthropic_version or "2023-06-01",
|
||||
"x-api-key": api_key,
|
||||
"accept": "application/json",
|
||||
"content-type": "application/json",
|
||||
}
|
||||
|
||||
if user_anthropic_beta_headers is not None:
|
||||
betas.update(user_anthropic_beta_headers)
|
||||
|
||||
# Don't send any beta headers to Vertex, Vertex has failed requests when they are sent
|
||||
if is_vertex_request is True:
|
||||
pass
|
||||
elif len(betas) > 0:
|
||||
headers["anthropic-beta"] = ",".join(betas)
|
||||
|
||||
return headers
|
||||
|
||||
def validate_environment(
|
||||
self,
|
||||
headers: dict,
|
||||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
) -> Dict:
|
||||
if api_key is None:
|
||||
raise litellm.AuthenticationError(
|
||||
message="Missing Anthropic API Key - A call is being made to anthropic but no key is set either in the environment variables or via params. Please set `ANTHROPIC_API_KEY` in your environment vars",
|
||||
llm_provider="anthropic",
|
||||
model=model,
|
||||
)
|
||||
|
||||
tools = optional_params.get("tools")
|
||||
prompt_caching_set = self.is_cache_control_set(messages=messages)
|
||||
computer_tool_used = self.is_computer_tool_used(tools=tools)
|
||||
pdf_used = self.is_pdf_used(messages=messages)
|
||||
user_anthropic_beta_headers = self._get_user_anthropic_beta_headers(
|
||||
anthropic_beta_header=headers.get("anthropic-beta")
|
||||
)
|
||||
anthropic_headers = self.get_anthropic_headers(
|
||||
computer_tool_used=computer_tool_used,
|
||||
prompt_caching_set=prompt_caching_set,
|
||||
pdf_used=pdf_used,
|
||||
api_key=api_key,
|
||||
is_vertex_request=optional_params.get("is_vertex_request", False),
|
||||
user_anthropic_beta_headers=user_anthropic_beta_headers,
|
||||
)
|
||||
|
||||
headers = {**headers, **anthropic_headers}
|
||||
|
||||
return headers
|
||||
|
||||
@staticmethod
|
||||
def get_api_base(api_base: Optional[str] = None) -> Optional[str]:
|
||||
return (
|
||||
api_base
|
||||
or get_secret_str("ANTHROPIC_API_BASE")
|
||||
or "https://api.anthropic.com"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_api_key(api_key: Optional[str] = None) -> Optional[str]:
|
||||
return api_key or get_secret_str("ANTHROPIC_API_KEY")
|
||||
|
||||
@staticmethod
|
||||
def get_base_model(model: Optional[str] = None) -> Optional[str]:
|
||||
return model.replace("anthropic/", "") if model else None
|
||||
|
||||
def get_models(
|
||||
self, api_key: Optional[str] = None, api_base: Optional[str] = None
|
||||
) -> List[str]:
|
||||
api_base = AnthropicModelInfo.get_api_base(api_base)
|
||||
api_key = AnthropicModelInfo.get_api_key(api_key)
|
||||
if api_base is None or api_key is None:
|
||||
raise ValueError(
|
||||
"ANTHROPIC_API_BASE or ANTHROPIC_API_KEY is not set. Please set the environment variable, to query Anthropic's `/models` endpoint."
|
||||
)
|
||||
response = litellm.module_level_client.get(
|
||||
url=f"{api_base}/v1/models",
|
||||
headers={"x-api-key": api_key, "anthropic-version": "2023-06-01"},
|
||||
)
|
||||
|
||||
try:
|
||||
response.raise_for_status()
|
||||
except httpx.HTTPStatusError:
|
||||
raise Exception(
|
||||
f"Failed to fetch models from Anthropic. Status code: {response.status_code}, Response: {response.text}"
|
||||
)
|
||||
|
||||
models = response.json()["data"]
|
||||
|
||||
litellm_model_names = []
|
||||
for model in models:
|
||||
stripped_model_name = model["id"]
|
||||
litellm_model_name = "anthropic/" + stripped_model_name
|
||||
litellm_model_names.append(litellm_model_name)
|
||||
return litellm_model_names
|
||||
|
||||
|
||||
def process_anthropic_headers(headers: Union[httpx.Headers, dict]) -> dict:
|
||||
openai_headers = {}
|
||||
if "anthropic-ratelimit-requests-limit" in headers:
|
||||
openai_headers["x-ratelimit-limit-requests"] = headers[
|
||||
"anthropic-ratelimit-requests-limit"
|
||||
]
|
||||
if "anthropic-ratelimit-requests-remaining" in headers:
|
||||
openai_headers["x-ratelimit-remaining-requests"] = headers[
|
||||
"anthropic-ratelimit-requests-remaining"
|
||||
]
|
||||
if "anthropic-ratelimit-tokens-limit" in headers:
|
||||
openai_headers["x-ratelimit-limit-tokens"] = headers[
|
||||
"anthropic-ratelimit-tokens-limit"
|
||||
]
|
||||
if "anthropic-ratelimit-tokens-remaining" in headers:
|
||||
openai_headers["x-ratelimit-remaining-tokens"] = headers[
|
||||
"anthropic-ratelimit-tokens-remaining"
|
||||
]
|
||||
|
||||
llm_response_headers = {
|
||||
"{}-{}".format("llm_provider", k): v for k, v in headers.items()
|
||||
}
|
||||
|
||||
additional_headers = {**llm_response_headers, **openai_headers}
|
||||
return additional_headers
|
||||
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,5 @@
|
||||
"""
|
||||
Anthropic /complete API - uses `llm_http_handler.py` to make httpx requests
|
||||
|
||||
Request/Response transformation is handled in `transformation.py`
|
||||
"""
|
||||
@@ -0,0 +1,310 @@
|
||||
"""
|
||||
Translation logic for anthropic's `/v1/complete` endpoint
|
||||
|
||||
Litellm provider slug: `anthropic_text/<model_name>`
|
||||
"""
|
||||
|
||||
import json
|
||||
import time
|
||||
from typing import AsyncIterator, Dict, Iterator, List, Optional, Union
|
||||
|
||||
import httpx
|
||||
|
||||
import litellm
|
||||
from litellm.constants import DEFAULT_MAX_TOKENS
|
||||
from litellm.litellm_core_utils.prompt_templates.factory import (
|
||||
custom_prompt,
|
||||
prompt_factory,
|
||||
)
|
||||
from litellm.llms.base_llm.base_model_iterator import BaseModelResponseIterator
|
||||
from litellm.llms.base_llm.chat.transformation import (
|
||||
BaseConfig,
|
||||
BaseLLMException,
|
||||
LiteLLMLoggingObj,
|
||||
)
|
||||
from litellm.types.llms.openai import AllMessageValues
|
||||
from litellm.types.utils import (
|
||||
ChatCompletionToolCallChunk,
|
||||
ChatCompletionUsageBlock,
|
||||
GenericStreamingChunk,
|
||||
ModelResponse,
|
||||
Usage,
|
||||
)
|
||||
|
||||
|
||||
class AnthropicTextError(BaseLLMException):
|
||||
def __init__(self, status_code, message):
|
||||
self.status_code = status_code
|
||||
self.message = message
|
||||
self.request = httpx.Request(
|
||||
method="POST", url="https://api.anthropic.com/v1/complete"
|
||||
)
|
||||
self.response = httpx.Response(status_code=status_code, request=self.request)
|
||||
super().__init__(
|
||||
message=self.message,
|
||||
status_code=self.status_code,
|
||||
request=self.request,
|
||||
response=self.response,
|
||||
) # Call the base class constructor with the parameters it needs
|
||||
|
||||
|
||||
class AnthropicTextConfig(BaseConfig):
|
||||
"""
|
||||
Reference: https://docs.anthropic.com/claude/reference/complete_post
|
||||
|
||||
to pass metadata to anthropic, it's {"user_id": "any-relevant-information"}
|
||||
"""
|
||||
|
||||
max_tokens_to_sample: Optional[
|
||||
int
|
||||
] = litellm.max_tokens # anthropic requires a default
|
||||
stop_sequences: Optional[list] = None
|
||||
temperature: Optional[int] = None
|
||||
top_p: Optional[int] = None
|
||||
top_k: Optional[int] = None
|
||||
metadata: Optional[dict] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_tokens_to_sample: Optional[
|
||||
int
|
||||
] = DEFAULT_MAX_TOKENS, # anthropic requires a default
|
||||
stop_sequences: Optional[list] = None,
|
||||
temperature: Optional[int] = None,
|
||||
top_p: Optional[int] = None,
|
||||
top_k: Optional[int] = None,
|
||||
metadata: Optional[dict] = None,
|
||||
) -> None:
|
||||
locals_ = locals().copy()
|
||||
for key, value in locals_.items():
|
||||
if key != "self" and value is not None:
|
||||
setattr(self.__class__, key, value)
|
||||
|
||||
# makes headers for API call
|
||||
def validate_environment(
|
||||
self,
|
||||
headers: dict,
|
||||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
) -> dict:
|
||||
if api_key is None:
|
||||
raise ValueError(
|
||||
"Missing Anthropic API Key - A call is being made to anthropic but no key is set either in the environment variables or via params"
|
||||
)
|
||||
_headers = {
|
||||
"accept": "application/json",
|
||||
"anthropic-version": "2023-06-01",
|
||||
"content-type": "application/json",
|
||||
"x-api-key": api_key,
|
||||
}
|
||||
headers.update(_headers)
|
||||
return headers
|
||||
|
||||
def transform_request(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
headers: dict,
|
||||
) -> dict:
|
||||
prompt = self._get_anthropic_text_prompt_from_messages(
|
||||
messages=messages, model=model
|
||||
)
|
||||
## Load Config
|
||||
config = litellm.AnthropicTextConfig.get_config()
|
||||
for k, v in config.items():
|
||||
if (
|
||||
k not in optional_params
|
||||
): # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in
|
||||
optional_params[k] = v
|
||||
|
||||
data = {
|
||||
"model": model,
|
||||
"prompt": prompt,
|
||||
**optional_params,
|
||||
}
|
||||
|
||||
return data
|
||||
|
||||
def get_supported_openai_params(self, model: str):
|
||||
"""
|
||||
Anthropic /complete API Ref: https://docs.anthropic.com/en/api/complete
|
||||
"""
|
||||
return [
|
||||
"stream",
|
||||
"max_tokens",
|
||||
"max_completion_tokens",
|
||||
"stop",
|
||||
"temperature",
|
||||
"top_p",
|
||||
"extra_headers",
|
||||
"user",
|
||||
]
|
||||
|
||||
def map_openai_params(
|
||||
self,
|
||||
non_default_params: dict,
|
||||
optional_params: dict,
|
||||
model: str,
|
||||
drop_params: bool,
|
||||
) -> dict:
|
||||
"""
|
||||
Follows the same logic as the AnthropicConfig.map_openai_params method (which is the Anthropic /messages API)
|
||||
|
||||
Note: the only difference is in the get supported openai params method between the AnthropicConfig and AnthropicTextConfig
|
||||
API Ref: https://docs.anthropic.com/en/api/complete
|
||||
"""
|
||||
for param, value in non_default_params.items():
|
||||
if param == "max_tokens":
|
||||
optional_params["max_tokens_to_sample"] = value
|
||||
if param == "max_completion_tokens":
|
||||
optional_params["max_tokens_to_sample"] = value
|
||||
if param == "stream" and value is True:
|
||||
optional_params["stream"] = value
|
||||
if param == "stop" and (isinstance(value, str) or isinstance(value, list)):
|
||||
_value = litellm.AnthropicConfig()._map_stop_sequences(value)
|
||||
if _value is not None:
|
||||
optional_params["stop_sequences"] = _value
|
||||
if param == "temperature":
|
||||
optional_params["temperature"] = value
|
||||
if param == "top_p":
|
||||
optional_params["top_p"] = value
|
||||
if param == "user":
|
||||
optional_params["metadata"] = {"user_id": value}
|
||||
|
||||
return optional_params
|
||||
|
||||
def transform_response(
|
||||
self,
|
||||
model: str,
|
||||
raw_response: httpx.Response,
|
||||
model_response: ModelResponse,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
request_data: dict,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
encoding: str,
|
||||
api_key: Optional[str] = None,
|
||||
json_mode: Optional[bool] = None,
|
||||
) -> ModelResponse:
|
||||
try:
|
||||
completion_response = raw_response.json()
|
||||
except Exception:
|
||||
raise AnthropicTextError(
|
||||
message=raw_response.text, status_code=raw_response.status_code
|
||||
)
|
||||
prompt = self._get_anthropic_text_prompt_from_messages(
|
||||
messages=messages, model=model
|
||||
)
|
||||
if "error" in completion_response:
|
||||
raise AnthropicTextError(
|
||||
message=str(completion_response["error"]),
|
||||
status_code=raw_response.status_code,
|
||||
)
|
||||
else:
|
||||
if len(completion_response["completion"]) > 0:
|
||||
model_response.choices[0].message.content = completion_response[ # type: ignore
|
||||
"completion"
|
||||
]
|
||||
model_response.choices[0].finish_reason = completion_response["stop_reason"]
|
||||
|
||||
## CALCULATING USAGE
|
||||
prompt_tokens = len(
|
||||
encoding.encode(prompt)
|
||||
) ##[TODO] use the anthropic tokenizer here
|
||||
completion_tokens = len(
|
||||
encoding.encode(model_response["choices"][0]["message"].get("content", ""))
|
||||
) ##[TODO] use the anthropic tokenizer here
|
||||
|
||||
model_response.created = int(time.time())
|
||||
model_response.model = model
|
||||
usage = Usage(
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=prompt_tokens + completion_tokens,
|
||||
)
|
||||
|
||||
setattr(model_response, "usage", usage)
|
||||
return model_response
|
||||
|
||||
def get_error_class(
|
||||
self, error_message: str, status_code: int, headers: Union[Dict, httpx.Headers]
|
||||
) -> BaseLLMException:
|
||||
return AnthropicTextError(
|
||||
status_code=status_code,
|
||||
message=error_message,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _is_anthropic_text_model(model: str) -> bool:
|
||||
return model == "claude-2" or model == "claude-instant-1"
|
||||
|
||||
def _get_anthropic_text_prompt_from_messages(
|
||||
self, messages: List[AllMessageValues], model: str
|
||||
) -> str:
|
||||
custom_prompt_dict = litellm.custom_prompt_dict
|
||||
if model in custom_prompt_dict:
|
||||
# check if the model has a registered custom prompt
|
||||
model_prompt_details = custom_prompt_dict[model]
|
||||
prompt = custom_prompt(
|
||||
role_dict=model_prompt_details["roles"],
|
||||
initial_prompt_value=model_prompt_details["initial_prompt_value"],
|
||||
final_prompt_value=model_prompt_details["final_prompt_value"],
|
||||
messages=messages,
|
||||
)
|
||||
else:
|
||||
prompt = prompt_factory(
|
||||
model=model, messages=messages, custom_llm_provider="anthropic"
|
||||
)
|
||||
|
||||
return str(prompt)
|
||||
|
||||
def get_model_response_iterator(
|
||||
self,
|
||||
streaming_response: Union[Iterator[str], AsyncIterator[str], ModelResponse],
|
||||
sync_stream: bool,
|
||||
json_mode: Optional[bool] = False,
|
||||
):
|
||||
return AnthropicTextCompletionResponseIterator(
|
||||
streaming_response=streaming_response,
|
||||
sync_stream=sync_stream,
|
||||
json_mode=json_mode,
|
||||
)
|
||||
|
||||
|
||||
class AnthropicTextCompletionResponseIterator(BaseModelResponseIterator):
|
||||
def chunk_parser(self, chunk: dict) -> GenericStreamingChunk:
|
||||
try:
|
||||
text = ""
|
||||
tool_use: Optional[ChatCompletionToolCallChunk] = None
|
||||
is_finished = False
|
||||
finish_reason = ""
|
||||
usage: Optional[ChatCompletionUsageBlock] = None
|
||||
provider_specific_fields = None
|
||||
index = int(chunk.get("index", 0))
|
||||
_chunk_text = chunk.get("completion", None)
|
||||
if _chunk_text is not None and isinstance(_chunk_text, str):
|
||||
text = _chunk_text
|
||||
finish_reason = chunk.get("stop_reason", None)
|
||||
if finish_reason is not None:
|
||||
is_finished = True
|
||||
returned_chunk = GenericStreamingChunk(
|
||||
text=text,
|
||||
tool_use=tool_use,
|
||||
is_finished=is_finished,
|
||||
finish_reason=finish_reason,
|
||||
usage=usage,
|
||||
index=index,
|
||||
provider_specific_fields=provider_specific_fields,
|
||||
)
|
||||
|
||||
return returned_chunk
|
||||
|
||||
except json.JSONDecodeError:
|
||||
raise ValueError(f"Failed to decode JSON from chunk: {chunk}")
|
||||
@@ -0,0 +1,25 @@
|
||||
"""
|
||||
Helper util for handling anthropic-specific cost calculation
|
||||
- e.g.: prompt caching
|
||||
"""
|
||||
|
||||
from typing import Tuple
|
||||
|
||||
from litellm.litellm_core_utils.llm_cost_calc.utils import generic_cost_per_token
|
||||
from litellm.types.utils import Usage
|
||||
|
||||
|
||||
def cost_per_token(model: str, usage: Usage) -> Tuple[float, float]:
|
||||
"""
|
||||
Calculates the cost per token for a given model, prompt tokens, and completion tokens.
|
||||
|
||||
Input:
|
||||
- model: str, the model name without provider prefix
|
||||
- usage: LiteLLM Usage block, containing anthropic caching information
|
||||
|
||||
Returns:
|
||||
Tuple[float, float] - prompt_cost_in_usd, completion_cost_in_usd
|
||||
"""
|
||||
return generic_cost_per_token(
|
||||
model=model, usage=usage, custom_llm_provider="anthropic"
|
||||
)
|
||||
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,197 @@
|
||||
"""
|
||||
- call /messages on Anthropic API
|
||||
- Make streaming + non-streaming request - just pass it through direct to Anthropic. No need to do anything special here
|
||||
- Ensure requests are logged in the DB - stream + non-stream
|
||||
|
||||
"""
|
||||
|
||||
import json
|
||||
from typing import AsyncIterator, Dict, List, Optional, Union, cast
|
||||
|
||||
import httpx
|
||||
|
||||
import litellm
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
|
||||
from litellm.llms.base_llm.anthropic_messages.transformation import (
|
||||
BaseAnthropicMessagesConfig,
|
||||
)
|
||||
from litellm.llms.custom_httpx.http_handler import (
|
||||
AsyncHTTPHandler,
|
||||
get_async_httpx_client,
|
||||
)
|
||||
from litellm.types.llms.anthropic_messages.anthropic_response import (
|
||||
AnthropicMessagesResponse,
|
||||
)
|
||||
from litellm.types.router import GenericLiteLLMParams
|
||||
from litellm.types.utils import ProviderSpecificHeader
|
||||
from litellm.utils import ProviderConfigManager, client
|
||||
|
||||
|
||||
class AnthropicMessagesHandler:
|
||||
@staticmethod
|
||||
async def _handle_anthropic_streaming(
|
||||
response: httpx.Response,
|
||||
request_body: dict,
|
||||
litellm_logging_obj: LiteLLMLoggingObj,
|
||||
) -> AsyncIterator:
|
||||
"""Helper function to handle Anthropic streaming responses using the existing logging handlers"""
|
||||
from datetime import datetime
|
||||
|
||||
from litellm.proxy.pass_through_endpoints.streaming_handler import (
|
||||
PassThroughStreamingHandler,
|
||||
)
|
||||
from litellm.proxy.pass_through_endpoints.success_handler import (
|
||||
PassThroughEndpointLogging,
|
||||
)
|
||||
from litellm.types.passthrough_endpoints.pass_through_endpoints import (
|
||||
EndpointType,
|
||||
)
|
||||
|
||||
# Create success handler object
|
||||
passthrough_success_handler_obj = PassThroughEndpointLogging()
|
||||
|
||||
# Use the existing streaming handler for Anthropic
|
||||
start_time = datetime.now()
|
||||
return PassThroughStreamingHandler.chunk_processor(
|
||||
response=response,
|
||||
request_body=request_body,
|
||||
litellm_logging_obj=litellm_logging_obj,
|
||||
endpoint_type=EndpointType.ANTHROPIC,
|
||||
start_time=start_time,
|
||||
passthrough_success_handler_obj=passthrough_success_handler_obj,
|
||||
url_route="/v1/messages",
|
||||
)
|
||||
|
||||
|
||||
@client
|
||||
async def anthropic_messages(
|
||||
max_tokens: int,
|
||||
messages: List[Dict],
|
||||
model: str,
|
||||
metadata: Optional[Dict] = None,
|
||||
stop_sequences: Optional[List[str]] = None,
|
||||
stream: Optional[bool] = False,
|
||||
system: Optional[str] = None,
|
||||
temperature: Optional[float] = None,
|
||||
thinking: Optional[Dict] = None,
|
||||
tool_choice: Optional[Dict] = None,
|
||||
tools: Optional[List[Dict]] = None,
|
||||
top_k: Optional[int] = None,
|
||||
top_p: Optional[float] = None,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
client: Optional[AsyncHTTPHandler] = None,
|
||||
custom_llm_provider: Optional[str] = None,
|
||||
**kwargs,
|
||||
) -> Union[AnthropicMessagesResponse, AsyncIterator]:
|
||||
"""
|
||||
Makes Anthropic `/v1/messages` API calls In the Anthropic API Spec
|
||||
"""
|
||||
# Use provided client or create a new one
|
||||
optional_params = GenericLiteLLMParams(**kwargs)
|
||||
(
|
||||
model,
|
||||
_custom_llm_provider,
|
||||
dynamic_api_key,
|
||||
dynamic_api_base,
|
||||
) = litellm.get_llm_provider(
|
||||
model=model,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
api_base=optional_params.api_base,
|
||||
api_key=optional_params.api_key,
|
||||
)
|
||||
anthropic_messages_provider_config: Optional[BaseAnthropicMessagesConfig] = (
|
||||
ProviderConfigManager.get_provider_anthropic_messages_config(
|
||||
model=model,
|
||||
provider=litellm.LlmProviders(_custom_llm_provider),
|
||||
)
|
||||
)
|
||||
if anthropic_messages_provider_config is None:
|
||||
raise ValueError(
|
||||
f"Anthropic messages provider config not found for model: {model}"
|
||||
)
|
||||
if client is None or not isinstance(client, AsyncHTTPHandler):
|
||||
async_httpx_client = get_async_httpx_client(
|
||||
llm_provider=litellm.LlmProviders.ANTHROPIC
|
||||
)
|
||||
else:
|
||||
async_httpx_client = client
|
||||
|
||||
litellm_logging_obj: LiteLLMLoggingObj = kwargs.get("litellm_logging_obj", None)
|
||||
|
||||
# Prepare headers
|
||||
provider_specific_header = cast(
|
||||
Optional[ProviderSpecificHeader], kwargs.get("provider_specific_header", None)
|
||||
)
|
||||
extra_headers = (
|
||||
provider_specific_header.get("extra_headers", {})
|
||||
if provider_specific_header
|
||||
else {}
|
||||
)
|
||||
headers = anthropic_messages_provider_config.validate_environment(
|
||||
headers=extra_headers or {},
|
||||
model=model,
|
||||
api_key=api_key,
|
||||
)
|
||||
|
||||
litellm_logging_obj.update_environment_variables(
|
||||
model=model,
|
||||
optional_params=dict(optional_params),
|
||||
litellm_params={
|
||||
"metadata": kwargs.get("metadata", {}),
|
||||
"preset_cache_key": None,
|
||||
"stream_response": {},
|
||||
**optional_params.model_dump(exclude_unset=True),
|
||||
},
|
||||
custom_llm_provider=_custom_llm_provider,
|
||||
)
|
||||
# Prepare request body
|
||||
request_body = locals().copy()
|
||||
request_body = {
|
||||
k: v
|
||||
for k, v in request_body.items()
|
||||
if k
|
||||
in anthropic_messages_provider_config.get_supported_anthropic_messages_params(
|
||||
model=model
|
||||
)
|
||||
and v is not None
|
||||
}
|
||||
request_body["stream"] = stream
|
||||
request_body["model"] = model
|
||||
litellm_logging_obj.stream = stream
|
||||
litellm_logging_obj.model_call_details.update(request_body)
|
||||
|
||||
# Make the request
|
||||
request_url = anthropic_messages_provider_config.get_complete_url(
|
||||
api_base=api_base, model=model
|
||||
)
|
||||
|
||||
litellm_logging_obj.pre_call(
|
||||
input=[{"role": "user", "content": json.dumps(request_body)}],
|
||||
api_key="",
|
||||
additional_args={
|
||||
"complete_input_dict": request_body,
|
||||
"api_base": str(request_url),
|
||||
"headers": headers,
|
||||
},
|
||||
)
|
||||
|
||||
response = await async_httpx_client.post(
|
||||
url=request_url,
|
||||
headers=headers,
|
||||
data=json.dumps(request_body),
|
||||
stream=stream or False,
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
# used for logging + cost tracking
|
||||
litellm_logging_obj.model_call_details["httpx_response"] = response
|
||||
|
||||
if stream:
|
||||
return await AnthropicMessagesHandler._handle_anthropic_streaming(
|
||||
response=response,
|
||||
request_body=request_body,
|
||||
litellm_logging_obj=litellm_logging_obj,
|
||||
)
|
||||
else:
|
||||
return response.json()
|
||||
@@ -0,0 +1,47 @@
|
||||
from typing import Optional
|
||||
|
||||
from litellm.llms.base_llm.anthropic_messages.transformation import (
|
||||
BaseAnthropicMessagesConfig,
|
||||
)
|
||||
|
||||
DEFAULT_ANTHROPIC_API_BASE = "https://api.anthropic.com"
|
||||
DEFAULT_ANTHROPIC_API_VERSION = "2023-06-01"
|
||||
|
||||
|
||||
class AnthropicMessagesConfig(BaseAnthropicMessagesConfig):
|
||||
def get_supported_anthropic_messages_params(self, model: str) -> list:
|
||||
return [
|
||||
"messages",
|
||||
"model",
|
||||
"system",
|
||||
"max_tokens",
|
||||
"stop_sequences",
|
||||
"temperature",
|
||||
"top_p",
|
||||
"top_k",
|
||||
"tools",
|
||||
"tool_choice",
|
||||
"thinking",
|
||||
# TODO: Add Anthropic `metadata` support
|
||||
# "metadata",
|
||||
]
|
||||
|
||||
def get_complete_url(self, api_base: Optional[str], model: str) -> str:
|
||||
api_base = api_base or DEFAULT_ANTHROPIC_API_BASE
|
||||
if not api_base.endswith("/v1/messages"):
|
||||
api_base = f"{api_base}/v1/messages"
|
||||
return api_base
|
||||
|
||||
def validate_environment(
|
||||
self,
|
||||
headers: dict,
|
||||
model: str,
|
||||
api_key: Optional[str] = None,
|
||||
) -> dict:
|
||||
if "x-api-key" not in headers:
|
||||
headers["x-api-key"] = api_key
|
||||
if "anthropic-version" not in headers:
|
||||
headers["anthropic-version"] = DEFAULT_ANTHROPIC_API_VERSION
|
||||
if "content-type" not in headers:
|
||||
headers["content-type"] = "application/json"
|
||||
return headers
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
1025
.venv/lib/python3.10/site-packages/litellm/llms/azure/assistants.py
Normal file
1025
.venv/lib/python3.10/site-packages/litellm/llms/azure/assistants.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,198 @@
|
||||
import uuid
|
||||
from typing import Any, Coroutine, Optional, Union
|
||||
|
||||
from openai import AsyncAzureOpenAI, AzureOpenAI
|
||||
from pydantic import BaseModel
|
||||
|
||||
from litellm.litellm_core_utils.audio_utils.utils import get_audio_file_name
|
||||
from litellm.types.utils import FileTypes
|
||||
from litellm.utils import (
|
||||
TranscriptionResponse,
|
||||
convert_to_model_response_object,
|
||||
extract_duration_from_srt_or_vtt,
|
||||
)
|
||||
|
||||
from .azure import AzureChatCompletion
|
||||
from .common_utils import AzureOpenAIError
|
||||
|
||||
|
||||
class AzureAudioTranscription(AzureChatCompletion):
|
||||
def audio_transcriptions(
|
||||
self,
|
||||
model: str,
|
||||
audio_file: FileTypes,
|
||||
optional_params: dict,
|
||||
logging_obj: Any,
|
||||
model_response: TranscriptionResponse,
|
||||
timeout: float,
|
||||
max_retries: int,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
api_version: Optional[str] = None,
|
||||
client=None,
|
||||
azure_ad_token: Optional[str] = None,
|
||||
atranscription: bool = False,
|
||||
litellm_params: Optional[dict] = None,
|
||||
) -> Union[TranscriptionResponse, Coroutine[Any, Any, TranscriptionResponse]]:
|
||||
data = {"model": model, "file": audio_file, **optional_params}
|
||||
|
||||
if atranscription is True:
|
||||
return self.async_audio_transcriptions(
|
||||
audio_file=audio_file,
|
||||
data=data,
|
||||
model_response=model_response,
|
||||
timeout=timeout,
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
client=client,
|
||||
max_retries=max_retries,
|
||||
logging_obj=logging_obj,
|
||||
model=model,
|
||||
litellm_params=litellm_params,
|
||||
)
|
||||
|
||||
azure_client = self.get_azure_openai_client(
|
||||
api_version=api_version,
|
||||
api_base=api_base,
|
||||
api_key=api_key,
|
||||
model=model,
|
||||
_is_async=False,
|
||||
client=client,
|
||||
litellm_params=litellm_params,
|
||||
)
|
||||
if not isinstance(azure_client, AzureOpenAI):
|
||||
raise AzureOpenAIError(
|
||||
status_code=500,
|
||||
message="azure_client is not an instance of AzureOpenAI",
|
||||
)
|
||||
|
||||
## LOGGING
|
||||
logging_obj.pre_call(
|
||||
input=f"audio_file_{uuid.uuid4()}",
|
||||
api_key=azure_client.api_key,
|
||||
additional_args={
|
||||
"headers": {"Authorization": f"Bearer {azure_client.api_key}"},
|
||||
"api_base": azure_client._base_url._uri_reference,
|
||||
"atranscription": True,
|
||||
"complete_input_dict": data,
|
||||
},
|
||||
)
|
||||
|
||||
response = azure_client.audio.transcriptions.create(
|
||||
**data, timeout=timeout # type: ignore
|
||||
)
|
||||
|
||||
if isinstance(response, BaseModel):
|
||||
stringified_response = response.model_dump()
|
||||
else:
|
||||
stringified_response = TranscriptionResponse(text=response).model_dump()
|
||||
|
||||
## LOGGING
|
||||
logging_obj.post_call(
|
||||
input=get_audio_file_name(audio_file),
|
||||
api_key=api_key,
|
||||
additional_args={"complete_input_dict": data},
|
||||
original_response=stringified_response,
|
||||
)
|
||||
hidden_params = {"model": "whisper-1", "custom_llm_provider": "azure"}
|
||||
final_response: TranscriptionResponse = convert_to_model_response_object(response_object=stringified_response, model_response_object=model_response, hidden_params=hidden_params, response_type="audio_transcription") # type: ignore
|
||||
return final_response
|
||||
|
||||
async def async_audio_transcriptions(
|
||||
self,
|
||||
audio_file: FileTypes,
|
||||
model: str,
|
||||
data: dict,
|
||||
model_response: TranscriptionResponse,
|
||||
timeout: float,
|
||||
logging_obj: Any,
|
||||
api_version: Optional[str] = None,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
client=None,
|
||||
max_retries=None,
|
||||
litellm_params: Optional[dict] = None,
|
||||
) -> TranscriptionResponse:
|
||||
response = None
|
||||
try:
|
||||
async_azure_client = self.get_azure_openai_client(
|
||||
api_version=api_version,
|
||||
api_base=api_base,
|
||||
api_key=api_key,
|
||||
model=model,
|
||||
_is_async=True,
|
||||
client=client,
|
||||
litellm_params=litellm_params,
|
||||
)
|
||||
if not isinstance(async_azure_client, AsyncAzureOpenAI):
|
||||
raise AzureOpenAIError(
|
||||
status_code=500,
|
||||
message="async_azure_client is not an instance of AsyncAzureOpenAI",
|
||||
)
|
||||
|
||||
## LOGGING
|
||||
logging_obj.pre_call(
|
||||
input=f"audio_file_{uuid.uuid4()}",
|
||||
api_key=async_azure_client.api_key,
|
||||
additional_args={
|
||||
"headers": {
|
||||
"Authorization": f"Bearer {async_azure_client.api_key}"
|
||||
},
|
||||
"api_base": async_azure_client._base_url._uri_reference,
|
||||
"atranscription": True,
|
||||
"complete_input_dict": data,
|
||||
},
|
||||
)
|
||||
|
||||
raw_response = (
|
||||
await async_azure_client.audio.transcriptions.with_raw_response.create(
|
||||
**data, timeout=timeout
|
||||
)
|
||||
) # type: ignore
|
||||
|
||||
headers = dict(raw_response.headers)
|
||||
response = raw_response.parse()
|
||||
|
||||
if isinstance(response, BaseModel):
|
||||
stringified_response = response.model_dump()
|
||||
else:
|
||||
stringified_response = TranscriptionResponse(text=response).model_dump()
|
||||
duration = extract_duration_from_srt_or_vtt(response)
|
||||
stringified_response["duration"] = duration
|
||||
|
||||
## LOGGING
|
||||
logging_obj.post_call(
|
||||
input=get_audio_file_name(audio_file),
|
||||
api_key=api_key,
|
||||
additional_args={
|
||||
"headers": {
|
||||
"Authorization": f"Bearer {async_azure_client.api_key}"
|
||||
},
|
||||
"api_base": async_azure_client._base_url._uri_reference,
|
||||
"atranscription": True,
|
||||
"complete_input_dict": data,
|
||||
},
|
||||
original_response=stringified_response,
|
||||
)
|
||||
hidden_params = {"model": "whisper-1", "custom_llm_provider": "azure"}
|
||||
response = convert_to_model_response_object(
|
||||
_response_headers=headers,
|
||||
response_object=stringified_response,
|
||||
model_response_object=model_response,
|
||||
hidden_params=hidden_params,
|
||||
response_type="audio_transcription",
|
||||
)
|
||||
if not isinstance(response, TranscriptionResponse):
|
||||
raise AzureOpenAIError(
|
||||
status_code=500,
|
||||
message="response is not an instance of TranscriptionResponse",
|
||||
)
|
||||
return response
|
||||
except Exception as e:
|
||||
## LOGGING
|
||||
logging_obj.post_call(
|
||||
input=input,
|
||||
api_key=api_key,
|
||||
original_response=str(e),
|
||||
)
|
||||
raise e
|
||||
1341
.venv/lib/python3.10/site-packages/litellm/llms/azure/azure.py
Normal file
1341
.venv/lib/python3.10/site-packages/litellm/llms/azure/azure.py
Normal file
File diff suppressed because it is too large
Load Diff
Binary file not shown.
@@ -0,0 +1,210 @@
|
||||
"""
|
||||
Azure Batches API Handler
|
||||
"""
|
||||
|
||||
from typing import Any, Coroutine, Optional, Union, cast
|
||||
|
||||
import httpx
|
||||
|
||||
from litellm.llms.azure.azure import AsyncAzureOpenAI, AzureOpenAI
|
||||
from litellm.types.llms.openai import (
|
||||
Batch,
|
||||
CancelBatchRequest,
|
||||
CreateBatchRequest,
|
||||
RetrieveBatchRequest,
|
||||
)
|
||||
from litellm.types.utils import LiteLLMBatch
|
||||
|
||||
from ..common_utils import BaseAzureLLM
|
||||
|
||||
|
||||
class AzureBatchesAPI(BaseAzureLLM):
|
||||
"""
|
||||
Azure methods to support for batches
|
||||
- create_batch()
|
||||
- retrieve_batch()
|
||||
- cancel_batch()
|
||||
- list_batch()
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
async def acreate_batch(
|
||||
self,
|
||||
create_batch_data: CreateBatchRequest,
|
||||
azure_client: AsyncAzureOpenAI,
|
||||
) -> LiteLLMBatch:
|
||||
response = await azure_client.batches.create(**create_batch_data)
|
||||
return LiteLLMBatch(**response.model_dump())
|
||||
|
||||
def create_batch(
|
||||
self,
|
||||
_is_async: bool,
|
||||
create_batch_data: CreateBatchRequest,
|
||||
api_key: Optional[str],
|
||||
api_base: Optional[str],
|
||||
api_version: Optional[str],
|
||||
timeout: Union[float, httpx.Timeout],
|
||||
max_retries: Optional[int],
|
||||
client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = None,
|
||||
litellm_params: Optional[dict] = None,
|
||||
) -> Union[LiteLLMBatch, Coroutine[Any, Any, LiteLLMBatch]]:
|
||||
azure_client: Optional[
|
||||
Union[AzureOpenAI, AsyncAzureOpenAI]
|
||||
] = self.get_azure_openai_client(
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
api_version=api_version,
|
||||
client=client,
|
||||
_is_async=_is_async,
|
||||
litellm_params=litellm_params or {},
|
||||
)
|
||||
if azure_client is None:
|
||||
raise ValueError(
|
||||
"OpenAI client is not initialized. Make sure api_key is passed or OPENAI_API_KEY is set in the environment."
|
||||
)
|
||||
|
||||
if _is_async is True:
|
||||
if not isinstance(azure_client, AsyncAzureOpenAI):
|
||||
raise ValueError(
|
||||
"OpenAI client is not an instance of AsyncOpenAI. Make sure you passed an AsyncOpenAI client."
|
||||
)
|
||||
return self.acreate_batch( # type: ignore
|
||||
create_batch_data=create_batch_data, azure_client=azure_client
|
||||
)
|
||||
response = cast(AzureOpenAI, azure_client).batches.create(**create_batch_data)
|
||||
return LiteLLMBatch(**response.model_dump())
|
||||
|
||||
async def aretrieve_batch(
|
||||
self,
|
||||
retrieve_batch_data: RetrieveBatchRequest,
|
||||
client: AsyncAzureOpenAI,
|
||||
) -> LiteLLMBatch:
|
||||
response = await client.batches.retrieve(**retrieve_batch_data)
|
||||
return LiteLLMBatch(**response.model_dump())
|
||||
|
||||
def retrieve_batch(
|
||||
self,
|
||||
_is_async: bool,
|
||||
retrieve_batch_data: RetrieveBatchRequest,
|
||||
api_key: Optional[str],
|
||||
api_base: Optional[str],
|
||||
api_version: Optional[str],
|
||||
timeout: Union[float, httpx.Timeout],
|
||||
max_retries: Optional[int],
|
||||
client: Optional[AzureOpenAI] = None,
|
||||
litellm_params: Optional[dict] = None,
|
||||
):
|
||||
azure_client: Optional[
|
||||
Union[AzureOpenAI, AsyncAzureOpenAI]
|
||||
] = self.get_azure_openai_client(
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
api_version=api_version,
|
||||
client=client,
|
||||
_is_async=_is_async,
|
||||
litellm_params=litellm_params or {},
|
||||
)
|
||||
if azure_client is None:
|
||||
raise ValueError(
|
||||
"OpenAI client is not initialized. Make sure api_key is passed or OPENAI_API_KEY is set in the environment."
|
||||
)
|
||||
|
||||
if _is_async is True:
|
||||
if not isinstance(azure_client, AsyncAzureOpenAI):
|
||||
raise ValueError(
|
||||
"OpenAI client is not an instance of AsyncOpenAI. Make sure you passed an AsyncOpenAI client."
|
||||
)
|
||||
return self.aretrieve_batch( # type: ignore
|
||||
retrieve_batch_data=retrieve_batch_data, client=azure_client
|
||||
)
|
||||
response = cast(AzureOpenAI, azure_client).batches.retrieve(
|
||||
**retrieve_batch_data
|
||||
)
|
||||
return LiteLLMBatch(**response.model_dump())
|
||||
|
||||
async def acancel_batch(
|
||||
self,
|
||||
cancel_batch_data: CancelBatchRequest,
|
||||
client: AsyncAzureOpenAI,
|
||||
) -> Batch:
|
||||
response = await client.batches.cancel(**cancel_batch_data)
|
||||
return response
|
||||
|
||||
def cancel_batch(
|
||||
self,
|
||||
_is_async: bool,
|
||||
cancel_batch_data: CancelBatchRequest,
|
||||
api_key: Optional[str],
|
||||
api_base: Optional[str],
|
||||
api_version: Optional[str],
|
||||
timeout: Union[float, httpx.Timeout],
|
||||
max_retries: Optional[int],
|
||||
client: Optional[AzureOpenAI] = None,
|
||||
litellm_params: Optional[dict] = None,
|
||||
):
|
||||
azure_client: Optional[
|
||||
Union[AzureOpenAI, AsyncAzureOpenAI]
|
||||
] = self.get_azure_openai_client(
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
api_version=api_version,
|
||||
client=client,
|
||||
_is_async=_is_async,
|
||||
litellm_params=litellm_params or {},
|
||||
)
|
||||
if azure_client is None:
|
||||
raise ValueError(
|
||||
"OpenAI client is not initialized. Make sure api_key is passed or OPENAI_API_KEY is set in the environment."
|
||||
)
|
||||
response = azure_client.batches.cancel(**cancel_batch_data)
|
||||
return response
|
||||
|
||||
async def alist_batches(
|
||||
self,
|
||||
client: AsyncAzureOpenAI,
|
||||
after: Optional[str] = None,
|
||||
limit: Optional[int] = None,
|
||||
):
|
||||
response = await client.batches.list(after=after, limit=limit) # type: ignore
|
||||
return response
|
||||
|
||||
def list_batches(
|
||||
self,
|
||||
_is_async: bool,
|
||||
api_key: Optional[str],
|
||||
api_base: Optional[str],
|
||||
api_version: Optional[str],
|
||||
timeout: Union[float, httpx.Timeout],
|
||||
max_retries: Optional[int],
|
||||
after: Optional[str] = None,
|
||||
limit: Optional[int] = None,
|
||||
client: Optional[AzureOpenAI] = None,
|
||||
litellm_params: Optional[dict] = None,
|
||||
):
|
||||
azure_client: Optional[
|
||||
Union[AzureOpenAI, AsyncAzureOpenAI]
|
||||
] = self.get_azure_openai_client(
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
api_version=api_version,
|
||||
client=client,
|
||||
_is_async=_is_async,
|
||||
litellm_params=litellm_params or {},
|
||||
)
|
||||
if azure_client is None:
|
||||
raise ValueError(
|
||||
"OpenAI client is not initialized. Make sure api_key is passed or OPENAI_API_KEY is set in the environment."
|
||||
)
|
||||
|
||||
if _is_async is True:
|
||||
if not isinstance(azure_client, AsyncAzureOpenAI):
|
||||
raise ValueError(
|
||||
"OpenAI client is not an instance of AsyncOpenAI. Make sure you passed an AsyncOpenAI client."
|
||||
)
|
||||
return self.alist_batches( # type: ignore
|
||||
client=azure_client, after=after, limit=limit
|
||||
)
|
||||
response = azure_client.batches.list(after=after, limit=limit) # type: ignore
|
||||
return response
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,311 @@
|
||||
from typing import TYPE_CHECKING, Any, List, Optional, Union
|
||||
|
||||
from httpx._models import Headers, Response
|
||||
|
||||
import litellm
|
||||
from litellm.litellm_core_utils.prompt_templates.factory import (
|
||||
convert_to_azure_openai_messages,
|
||||
)
|
||||
from litellm.llms.base_llm.chat.transformation import BaseLLMException
|
||||
from litellm.types.llms.azure import (
|
||||
API_VERSION_MONTH_SUPPORTED_RESPONSE_FORMAT,
|
||||
API_VERSION_YEAR_SUPPORTED_RESPONSE_FORMAT,
|
||||
)
|
||||
from litellm.types.utils import ModelResponse
|
||||
from litellm.utils import supports_response_schema
|
||||
|
||||
from ....exceptions import UnsupportedParamsError
|
||||
from ....types.llms.openai import AllMessageValues
|
||||
from ...base_llm.chat.transformation import BaseConfig
|
||||
from ..common_utils import AzureOpenAIError
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
|
||||
|
||||
LoggingClass = LiteLLMLoggingObj
|
||||
else:
|
||||
LoggingClass = Any
|
||||
|
||||
|
||||
class AzureOpenAIConfig(BaseConfig):
|
||||
"""
|
||||
Reference: https://learn.microsoft.com/en-us/azure/ai-services/openai/reference#chat-completions
|
||||
|
||||
The class `AzureOpenAIConfig` provides configuration for the OpenAI's Chat API interface, for use with Azure. Below are the parameters::
|
||||
|
||||
- `frequency_penalty` (number or null): Defaults to 0. Allows a value between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far, thereby minimizing repetition.
|
||||
|
||||
- `function_call` (string or object): This optional parameter controls how the model calls functions.
|
||||
|
||||
- `functions` (array): An optional parameter. It is a list of functions for which the model may generate JSON inputs.
|
||||
|
||||
- `logit_bias` (map): This optional parameter modifies the likelihood of specified tokens appearing in the completion.
|
||||
|
||||
- `max_tokens` (integer or null): This optional parameter helps to set the maximum number of tokens to generate in the chat completion.
|
||||
|
||||
- `n` (integer or null): This optional parameter helps to set how many chat completion choices to generate for each input message.
|
||||
|
||||
- `presence_penalty` (number or null): Defaults to 0. It penalizes new tokens based on if they appear in the text so far, hence increasing the model's likelihood to talk about new topics.
|
||||
|
||||
- `stop` (string / array / null): Specifies up to 4 sequences where the API will stop generating further tokens.
|
||||
|
||||
- `temperature` (number or null): Defines the sampling temperature to use, varying between 0 and 2.
|
||||
|
||||
- `top_p` (number or null): An alternative to sampling with temperature, used for nucleus sampling.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
frequency_penalty: Optional[int] = None,
|
||||
function_call: Optional[Union[str, dict]] = None,
|
||||
functions: Optional[list] = None,
|
||||
logit_bias: Optional[dict] = None,
|
||||
max_tokens: Optional[int] = None,
|
||||
n: Optional[int] = None,
|
||||
presence_penalty: Optional[int] = None,
|
||||
stop: Optional[Union[str, list]] = None,
|
||||
temperature: Optional[int] = None,
|
||||
top_p: Optional[int] = None,
|
||||
) -> None:
|
||||
locals_ = locals().copy()
|
||||
for key, value in locals_.items():
|
||||
if key != "self" and value is not None:
|
||||
setattr(self.__class__, key, value)
|
||||
|
||||
@classmethod
|
||||
def get_config(cls):
|
||||
return super().get_config()
|
||||
|
||||
def get_supported_openai_params(self, model: str) -> List[str]:
|
||||
return [
|
||||
"temperature",
|
||||
"n",
|
||||
"stream",
|
||||
"stream_options",
|
||||
"stop",
|
||||
"max_tokens",
|
||||
"max_completion_tokens",
|
||||
"tools",
|
||||
"tool_choice",
|
||||
"presence_penalty",
|
||||
"frequency_penalty",
|
||||
"logit_bias",
|
||||
"user",
|
||||
"function_call",
|
||||
"functions",
|
||||
"tools",
|
||||
"tool_choice",
|
||||
"top_p",
|
||||
"logprobs",
|
||||
"top_logprobs",
|
||||
"response_format",
|
||||
"seed",
|
||||
"extra_headers",
|
||||
"parallel_tool_calls",
|
||||
"prediction",
|
||||
"modalities",
|
||||
"audio",
|
||||
]
|
||||
|
||||
def _is_response_format_supported_model(self, model: str) -> bool:
|
||||
"""
|
||||
- all 4o models are supported
|
||||
- check if 'supports_response_format' is True from get_model_info
|
||||
- [TODO] support smart retries for 3.5 models (some supported, some not)
|
||||
"""
|
||||
if "4o" in model:
|
||||
return True
|
||||
elif supports_response_schema(model):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def _is_response_format_supported_api_version(
|
||||
self, api_version_year: str, api_version_month: str
|
||||
) -> bool:
|
||||
"""
|
||||
- check if api_version is supported for response_format
|
||||
- returns True if the API version is equal to or newer than the supported version
|
||||
"""
|
||||
api_year = int(api_version_year)
|
||||
api_month = int(api_version_month)
|
||||
supported_year = int(API_VERSION_YEAR_SUPPORTED_RESPONSE_FORMAT)
|
||||
supported_month = int(API_VERSION_MONTH_SUPPORTED_RESPONSE_FORMAT)
|
||||
|
||||
# If the year is greater than supported year, it's definitely supported
|
||||
if api_year > supported_year:
|
||||
return True
|
||||
# If the year is less than supported year, it's not supported
|
||||
elif api_year < supported_year:
|
||||
return False
|
||||
# If same year, check if month is >= supported month
|
||||
else:
|
||||
return api_month >= supported_month
|
||||
|
||||
def map_openai_params(
|
||||
self,
|
||||
non_default_params: dict,
|
||||
optional_params: dict,
|
||||
model: str,
|
||||
drop_params: bool,
|
||||
api_version: str = "",
|
||||
) -> dict:
|
||||
supported_openai_params = self.get_supported_openai_params(model)
|
||||
|
||||
api_version_times = api_version.split("-")
|
||||
api_version_year = api_version_times[0]
|
||||
api_version_month = api_version_times[1]
|
||||
api_version_day = api_version_times[2]
|
||||
for param, value in non_default_params.items():
|
||||
if param == "tool_choice":
|
||||
"""
|
||||
This parameter requires API version 2023-12-01-preview or later
|
||||
|
||||
tool_choice='required' is not supported as of 2024-05-01-preview
|
||||
"""
|
||||
## check if api version supports this param ##
|
||||
if (
|
||||
api_version_year < "2023"
|
||||
or (api_version_year == "2023" and api_version_month < "12")
|
||||
or (
|
||||
api_version_year == "2023"
|
||||
and api_version_month == "12"
|
||||
and api_version_day < "01"
|
||||
)
|
||||
):
|
||||
if litellm.drop_params is True or (
|
||||
drop_params is not None and drop_params is True
|
||||
):
|
||||
pass
|
||||
else:
|
||||
raise UnsupportedParamsError(
|
||||
status_code=400,
|
||||
message=f"""Azure does not support 'tool_choice', for api_version={api_version}. Bump your API version to '2023-12-01-preview' or later. This parameter requires 'api_version="2023-12-01-preview"' or later. Azure API Reference: https://learn.microsoft.com/en-us/azure/ai-services/openai/reference#chat-completions""",
|
||||
)
|
||||
elif value == "required" and (
|
||||
api_version_year == "2024" and api_version_month <= "05"
|
||||
): ## check if tool_choice value is supported ##
|
||||
if litellm.drop_params is True or (
|
||||
drop_params is not None and drop_params is True
|
||||
):
|
||||
pass
|
||||
else:
|
||||
raise UnsupportedParamsError(
|
||||
status_code=400,
|
||||
message=f"Azure does not support '{value}' as a {param} param, for api_version={api_version}. To drop 'tool_choice=required' for calls with this Azure API version, set `litellm.drop_params=True` or for proxy:\n\n`litellm_settings:\n drop_params: true`\nAzure API Reference: https://learn.microsoft.com/en-us/azure/ai-services/openai/reference#chat-completions",
|
||||
)
|
||||
else:
|
||||
optional_params["tool_choice"] = value
|
||||
elif param == "response_format" and isinstance(value, dict):
|
||||
_is_response_format_supported_model = (
|
||||
self._is_response_format_supported_model(model)
|
||||
)
|
||||
|
||||
is_response_format_supported_api_version = (
|
||||
self._is_response_format_supported_api_version(
|
||||
api_version_year, api_version_month
|
||||
)
|
||||
)
|
||||
is_response_format_supported = (
|
||||
is_response_format_supported_api_version
|
||||
and _is_response_format_supported_model
|
||||
)
|
||||
|
||||
optional_params = self._add_response_format_to_tools(
|
||||
optional_params=optional_params,
|
||||
value=value,
|
||||
is_response_format_supported=is_response_format_supported,
|
||||
)
|
||||
elif param == "tools" and isinstance(value, list):
|
||||
optional_params.setdefault("tools", [])
|
||||
optional_params["tools"].extend(value)
|
||||
elif param in supported_openai_params:
|
||||
optional_params[param] = value
|
||||
|
||||
return optional_params
|
||||
|
||||
def transform_request(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
headers: dict,
|
||||
) -> dict:
|
||||
messages = convert_to_azure_openai_messages(messages)
|
||||
return {
|
||||
"model": model,
|
||||
"messages": messages,
|
||||
**optional_params,
|
||||
}
|
||||
|
||||
def transform_response(
|
||||
self,
|
||||
model: str,
|
||||
raw_response: Response,
|
||||
model_response: ModelResponse,
|
||||
logging_obj: LoggingClass,
|
||||
request_data: dict,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
encoding: Any,
|
||||
api_key: Optional[str] = None,
|
||||
json_mode: Optional[bool] = None,
|
||||
) -> ModelResponse:
|
||||
raise NotImplementedError(
|
||||
"Azure OpenAI handler.py has custom logic for transforming response, as it uses the OpenAI SDK."
|
||||
)
|
||||
|
||||
def get_mapped_special_auth_params(self) -> dict:
|
||||
return {"token": "azure_ad_token"}
|
||||
|
||||
def map_special_auth_params(self, non_default_params: dict, optional_params: dict):
|
||||
for param, value in non_default_params.items():
|
||||
if param == "token":
|
||||
optional_params["azure_ad_token"] = value
|
||||
return optional_params
|
||||
|
||||
def get_eu_regions(self) -> List[str]:
|
||||
"""
|
||||
Source: https://learn.microsoft.com/en-us/azure/ai-services/openai/concepts/models#gpt-4-and-gpt-4-turbo-model-availability
|
||||
"""
|
||||
return ["europe", "sweden", "switzerland", "france", "uk"]
|
||||
|
||||
def get_us_regions(self) -> List[str]:
|
||||
"""
|
||||
Source: https://learn.microsoft.com/en-us/azure/ai-services/openai/concepts/models#gpt-4-and-gpt-4-turbo-model-availability
|
||||
"""
|
||||
return [
|
||||
"us",
|
||||
"eastus",
|
||||
"eastus2",
|
||||
"eastus2euap",
|
||||
"eastus3",
|
||||
"southcentralus",
|
||||
"westus",
|
||||
"westus2",
|
||||
"westus3",
|
||||
"westus4",
|
||||
]
|
||||
|
||||
def get_error_class(
|
||||
self, error_message: str, status_code: int, headers: Union[dict, Headers]
|
||||
) -> BaseLLMException:
|
||||
return AzureOpenAIError(
|
||||
message=error_message, status_code=status_code, headers=headers
|
||||
)
|
||||
|
||||
def validate_environment(
|
||||
self,
|
||||
headers: dict,
|
||||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
) -> dict:
|
||||
raise NotImplementedError(
|
||||
"Azure OpenAI has custom logic for validating environment, as it uses the OpenAI SDK."
|
||||
)
|
||||
@@ -0,0 +1,72 @@
|
||||
"""
|
||||
Handler file for calls to Azure OpenAI's o1/o3 family of models
|
||||
|
||||
Written separately to handle faking streaming for o1 and o3 models.
|
||||
"""
|
||||
|
||||
from typing import Any, Callable, Optional, Union
|
||||
|
||||
import httpx
|
||||
|
||||
from litellm.types.utils import ModelResponse
|
||||
|
||||
from ...openai.openai import OpenAIChatCompletion
|
||||
from ..common_utils import BaseAzureLLM
|
||||
|
||||
|
||||
class AzureOpenAIO1ChatCompletion(BaseAzureLLM, OpenAIChatCompletion):
|
||||
def completion(
|
||||
self,
|
||||
model_response: ModelResponse,
|
||||
timeout: Union[float, httpx.Timeout],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
logging_obj: Any,
|
||||
model: Optional[str] = None,
|
||||
messages: Optional[list] = None,
|
||||
print_verbose: Optional[Callable] = None,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
api_version: Optional[str] = None,
|
||||
dynamic_params: Optional[bool] = None,
|
||||
azure_ad_token: Optional[str] = None,
|
||||
acompletion: bool = False,
|
||||
logger_fn=None,
|
||||
headers: Optional[dict] = None,
|
||||
custom_prompt_dict: dict = {},
|
||||
client=None,
|
||||
organization: Optional[str] = None,
|
||||
custom_llm_provider: Optional[str] = None,
|
||||
drop_params: Optional[bool] = None,
|
||||
):
|
||||
client = self.get_azure_openai_client(
|
||||
litellm_params=litellm_params,
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
api_version=api_version,
|
||||
client=client,
|
||||
_is_async=acompletion,
|
||||
)
|
||||
return super().completion(
|
||||
model_response=model_response,
|
||||
timeout=timeout,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
logging_obj=logging_obj,
|
||||
model=model,
|
||||
messages=messages,
|
||||
print_verbose=print_verbose,
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
api_version=api_version,
|
||||
dynamic_params=dynamic_params,
|
||||
azure_ad_token=azure_ad_token,
|
||||
acompletion=acompletion,
|
||||
logger_fn=logger_fn,
|
||||
headers=headers,
|
||||
custom_prompt_dict=custom_prompt_dict,
|
||||
client=client,
|
||||
organization=organization,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
drop_params=drop_params,
|
||||
)
|
||||
@@ -0,0 +1,97 @@
|
||||
"""
|
||||
Support for o1 and o3 model families
|
||||
|
||||
https://platform.openai.com/docs/guides/reasoning
|
||||
|
||||
Translations handled by LiteLLM:
|
||||
- modalities: image => drop param (if user opts in to dropping param)
|
||||
- role: system ==> translate to role 'user'
|
||||
- streaming => faked by LiteLLM
|
||||
- Tools, response_format => drop param (if user opts in to dropping param)
|
||||
- Logprobs => drop param (if user opts in to dropping param)
|
||||
- Temperature => drop param (if user opts in to dropping param)
|
||||
"""
|
||||
|
||||
from typing import List, Optional
|
||||
|
||||
import litellm
|
||||
from litellm import verbose_logger
|
||||
from litellm.types.llms.openai import AllMessageValues
|
||||
from litellm.utils import get_model_info
|
||||
|
||||
from ...openai.chat.o_series_transformation import OpenAIOSeriesConfig
|
||||
|
||||
|
||||
class AzureOpenAIO1Config(OpenAIOSeriesConfig):
|
||||
def get_supported_openai_params(self, model: str) -> list:
|
||||
"""
|
||||
Get the supported OpenAI params for the Azure O-Series models
|
||||
"""
|
||||
all_openai_params = litellm.OpenAIGPTConfig().get_supported_openai_params(
|
||||
model=model
|
||||
)
|
||||
non_supported_params = [
|
||||
"logprobs",
|
||||
"top_p",
|
||||
"presence_penalty",
|
||||
"frequency_penalty",
|
||||
"top_logprobs",
|
||||
]
|
||||
|
||||
o_series_only_param = ["reasoning_effort"]
|
||||
all_openai_params.extend(o_series_only_param)
|
||||
return [
|
||||
param for param in all_openai_params if param not in non_supported_params
|
||||
]
|
||||
|
||||
def should_fake_stream(
|
||||
self,
|
||||
model: Optional[str],
|
||||
stream: Optional[bool],
|
||||
custom_llm_provider: Optional[str] = None,
|
||||
) -> bool:
|
||||
"""
|
||||
Currently no Azure O Series models support native streaming.
|
||||
"""
|
||||
|
||||
if stream is not True:
|
||||
return False
|
||||
|
||||
if (
|
||||
model and "o3" in model
|
||||
): # o3 models support streaming - https://github.com/BerriAI/litellm/issues/8274
|
||||
return False
|
||||
|
||||
if model is not None:
|
||||
try:
|
||||
model_info = get_model_info(
|
||||
model=model, custom_llm_provider=custom_llm_provider
|
||||
) # allow user to override default with model_info={"supports_native_streaming": true}
|
||||
|
||||
if (
|
||||
model_info.get("supports_native_streaming") is True
|
||||
): # allow user to override default with model_info={"supports_native_streaming": true}
|
||||
return False
|
||||
except Exception as e:
|
||||
verbose_logger.debug(
|
||||
f"Error getting model info in AzureOpenAIO1Config: {e}"
|
||||
)
|
||||
return True
|
||||
|
||||
def is_o_series_model(self, model: str) -> bool:
|
||||
return "o1" in model or "o3" in model or "o4" in model or "o_series/" in model
|
||||
|
||||
def transform_request(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
headers: dict,
|
||||
) -> dict:
|
||||
model = model.replace(
|
||||
"o_series/", ""
|
||||
) # handle o_series/my-random-deployment-name
|
||||
return super().transform_request(
|
||||
model, messages, optional_params, litellm_params, headers
|
||||
)
|
||||
@@ -0,0 +1,438 @@
|
||||
import json
|
||||
import os
|
||||
from typing import Any, Callable, Dict, Optional, Union
|
||||
|
||||
import httpx
|
||||
from openai import AsyncAzureOpenAI, AzureOpenAI
|
||||
|
||||
import litellm
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.caching.caching import DualCache
|
||||
from litellm.llms.base_llm.chat.transformation import BaseLLMException
|
||||
from litellm.llms.openai.common_utils import BaseOpenAILLM
|
||||
from litellm.secret_managers.get_azure_ad_token_provider import (
|
||||
get_azure_ad_token_provider,
|
||||
)
|
||||
from litellm.secret_managers.main import get_secret_str
|
||||
|
||||
azure_ad_cache = DualCache()
|
||||
|
||||
|
||||
class AzureOpenAIError(BaseLLMException):
|
||||
def __init__(
|
||||
self,
|
||||
status_code,
|
||||
message,
|
||||
request: Optional[httpx.Request] = None,
|
||||
response: Optional[httpx.Response] = None,
|
||||
headers: Optional[Union[httpx.Headers, dict]] = None,
|
||||
body: Optional[dict] = None,
|
||||
):
|
||||
super().__init__(
|
||||
status_code=status_code,
|
||||
message=message,
|
||||
request=request,
|
||||
response=response,
|
||||
headers=headers,
|
||||
body=body,
|
||||
)
|
||||
|
||||
|
||||
def process_azure_headers(headers: Union[httpx.Headers, dict]) -> dict:
|
||||
openai_headers = {}
|
||||
if "x-ratelimit-limit-requests" in headers:
|
||||
openai_headers["x-ratelimit-limit-requests"] = headers[
|
||||
"x-ratelimit-limit-requests"
|
||||
]
|
||||
if "x-ratelimit-remaining-requests" in headers:
|
||||
openai_headers["x-ratelimit-remaining-requests"] = headers[
|
||||
"x-ratelimit-remaining-requests"
|
||||
]
|
||||
if "x-ratelimit-limit-tokens" in headers:
|
||||
openai_headers["x-ratelimit-limit-tokens"] = headers["x-ratelimit-limit-tokens"]
|
||||
if "x-ratelimit-remaining-tokens" in headers:
|
||||
openai_headers["x-ratelimit-remaining-tokens"] = headers[
|
||||
"x-ratelimit-remaining-tokens"
|
||||
]
|
||||
llm_response_headers = {
|
||||
"{}-{}".format("llm_provider", k): v for k, v in headers.items()
|
||||
}
|
||||
|
||||
return {**llm_response_headers, **openai_headers}
|
||||
|
||||
|
||||
def get_azure_ad_token_from_entra_id(
|
||||
tenant_id: str,
|
||||
client_id: str,
|
||||
client_secret: str,
|
||||
scope: str = "https://cognitiveservices.azure.com/.default",
|
||||
) -> Callable[[], str]:
|
||||
"""
|
||||
Get Azure AD token provider from `client_id`, `client_secret`, and `tenant_id`
|
||||
|
||||
Args:
|
||||
tenant_id: str
|
||||
client_id: str
|
||||
client_secret: str
|
||||
scope: str
|
||||
|
||||
Returns:
|
||||
callable that returns a bearer token.
|
||||
"""
|
||||
from azure.identity import ClientSecretCredential, get_bearer_token_provider
|
||||
|
||||
verbose_logger.debug("Getting Azure AD Token from Entra ID")
|
||||
|
||||
if tenant_id.startswith("os.environ/"):
|
||||
_tenant_id = get_secret_str(tenant_id)
|
||||
else:
|
||||
_tenant_id = tenant_id
|
||||
|
||||
if client_id.startswith("os.environ/"):
|
||||
_client_id = get_secret_str(client_id)
|
||||
else:
|
||||
_client_id = client_id
|
||||
|
||||
if client_secret.startswith("os.environ/"):
|
||||
_client_secret = get_secret_str(client_secret)
|
||||
else:
|
||||
_client_secret = client_secret
|
||||
|
||||
verbose_logger.debug(
|
||||
"tenant_id %s, client_id %s, client_secret %s",
|
||||
_tenant_id,
|
||||
_client_id,
|
||||
_client_secret,
|
||||
)
|
||||
if _tenant_id is None or _client_id is None or _client_secret is None:
|
||||
raise ValueError("tenant_id, client_id, and client_secret must be provided")
|
||||
credential = ClientSecretCredential(_tenant_id, _client_id, _client_secret)
|
||||
|
||||
verbose_logger.debug("credential %s", credential)
|
||||
|
||||
token_provider = get_bearer_token_provider(credential, scope)
|
||||
|
||||
verbose_logger.debug("token_provider %s", token_provider)
|
||||
|
||||
return token_provider
|
||||
|
||||
|
||||
def get_azure_ad_token_from_username_password(
|
||||
client_id: str,
|
||||
azure_username: str,
|
||||
azure_password: str,
|
||||
scope: str = "https://cognitiveservices.azure.com/.default",
|
||||
) -> Callable[[], str]:
|
||||
"""
|
||||
Get Azure AD token provider from `client_id`, `azure_username`, and `azure_password`
|
||||
|
||||
Args:
|
||||
client_id: str
|
||||
azure_username: str
|
||||
azure_password: str
|
||||
scope: str
|
||||
|
||||
Returns:
|
||||
callable that returns a bearer token.
|
||||
"""
|
||||
from azure.identity import UsernamePasswordCredential, get_bearer_token_provider
|
||||
|
||||
verbose_logger.debug(
|
||||
"client_id %s, azure_username %s, azure_password %s",
|
||||
client_id,
|
||||
azure_username,
|
||||
azure_password,
|
||||
)
|
||||
credential = UsernamePasswordCredential(
|
||||
client_id=client_id,
|
||||
username=azure_username,
|
||||
password=azure_password,
|
||||
)
|
||||
|
||||
verbose_logger.debug("credential %s", credential)
|
||||
|
||||
token_provider = get_bearer_token_provider(credential, scope)
|
||||
|
||||
verbose_logger.debug("token_provider %s", token_provider)
|
||||
|
||||
return token_provider
|
||||
|
||||
|
||||
def get_azure_ad_token_from_oidc(azure_ad_token: str):
|
||||
azure_client_id = os.getenv("AZURE_CLIENT_ID", None)
|
||||
azure_tenant_id = os.getenv("AZURE_TENANT_ID", None)
|
||||
azure_authority_host = os.getenv(
|
||||
"AZURE_AUTHORITY_HOST", "https://login.microsoftonline.com"
|
||||
)
|
||||
|
||||
if azure_client_id is None or azure_tenant_id is None:
|
||||
raise AzureOpenAIError(
|
||||
status_code=422,
|
||||
message="AZURE_CLIENT_ID and AZURE_TENANT_ID must be set",
|
||||
)
|
||||
|
||||
oidc_token = get_secret_str(azure_ad_token)
|
||||
|
||||
if oidc_token is None:
|
||||
raise AzureOpenAIError(
|
||||
status_code=401,
|
||||
message="OIDC token could not be retrieved from secret manager.",
|
||||
)
|
||||
|
||||
azure_ad_token_cache_key = json.dumps(
|
||||
{
|
||||
"azure_client_id": azure_client_id,
|
||||
"azure_tenant_id": azure_tenant_id,
|
||||
"azure_authority_host": azure_authority_host,
|
||||
"oidc_token": oidc_token,
|
||||
}
|
||||
)
|
||||
|
||||
azure_ad_token_access_token = azure_ad_cache.get_cache(azure_ad_token_cache_key)
|
||||
if azure_ad_token_access_token is not None:
|
||||
return azure_ad_token_access_token
|
||||
|
||||
client = litellm.module_level_client
|
||||
req_token = client.post(
|
||||
f"{azure_authority_host}/{azure_tenant_id}/oauth2/v2.0/token",
|
||||
data={
|
||||
"client_id": azure_client_id,
|
||||
"grant_type": "client_credentials",
|
||||
"scope": "https://cognitiveservices.azure.com/.default",
|
||||
"client_assertion_type": "urn:ietf:params:oauth:client-assertion-type:jwt-bearer",
|
||||
"client_assertion": oidc_token,
|
||||
},
|
||||
)
|
||||
|
||||
if req_token.status_code != 200:
|
||||
raise AzureOpenAIError(
|
||||
status_code=req_token.status_code,
|
||||
message=req_token.text,
|
||||
)
|
||||
|
||||
azure_ad_token_json = req_token.json()
|
||||
azure_ad_token_access_token = azure_ad_token_json.get("access_token", None)
|
||||
azure_ad_token_expires_in = azure_ad_token_json.get("expires_in", None)
|
||||
|
||||
if azure_ad_token_access_token is None:
|
||||
raise AzureOpenAIError(
|
||||
status_code=422, message="Azure AD Token access_token not returned"
|
||||
)
|
||||
|
||||
if azure_ad_token_expires_in is None:
|
||||
raise AzureOpenAIError(
|
||||
status_code=422, message="Azure AD Token expires_in not returned"
|
||||
)
|
||||
|
||||
azure_ad_cache.set_cache(
|
||||
key=azure_ad_token_cache_key,
|
||||
value=azure_ad_token_access_token,
|
||||
ttl=azure_ad_token_expires_in,
|
||||
)
|
||||
|
||||
return azure_ad_token_access_token
|
||||
|
||||
|
||||
def select_azure_base_url_or_endpoint(azure_client_params: dict):
|
||||
azure_endpoint = azure_client_params.get("azure_endpoint", None)
|
||||
if azure_endpoint is not None:
|
||||
# see : https://github.com/openai/openai-python/blob/3d61ed42aba652b547029095a7eb269ad4e1e957/src/openai/lib/azure.py#L192
|
||||
if "/openai/deployments" in azure_endpoint:
|
||||
# this is base_url, not an azure_endpoint
|
||||
azure_client_params["base_url"] = azure_endpoint
|
||||
azure_client_params.pop("azure_endpoint")
|
||||
|
||||
return azure_client_params
|
||||
|
||||
|
||||
class BaseAzureLLM(BaseOpenAILLM):
|
||||
def get_azure_openai_client(
|
||||
self,
|
||||
api_key: Optional[str],
|
||||
api_base: Optional[str],
|
||||
api_version: Optional[str] = None,
|
||||
client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = None,
|
||||
litellm_params: Optional[dict] = None,
|
||||
_is_async: bool = False,
|
||||
model: Optional[str] = None,
|
||||
) -> Optional[Union[AzureOpenAI, AsyncAzureOpenAI]]:
|
||||
openai_client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = None
|
||||
client_initialization_params: dict = locals()
|
||||
if client is None:
|
||||
cached_client = self.get_cached_openai_client(
|
||||
client_initialization_params=client_initialization_params,
|
||||
client_type="azure",
|
||||
)
|
||||
if cached_client:
|
||||
if isinstance(cached_client, AzureOpenAI) or isinstance(
|
||||
cached_client, AsyncAzureOpenAI
|
||||
):
|
||||
return cached_client
|
||||
|
||||
azure_client_params = self.initialize_azure_sdk_client(
|
||||
litellm_params=litellm_params or {},
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
model_name=model,
|
||||
api_version=api_version,
|
||||
is_async=_is_async,
|
||||
)
|
||||
if _is_async is True:
|
||||
openai_client = AsyncAzureOpenAI(**azure_client_params)
|
||||
else:
|
||||
openai_client = AzureOpenAI(**azure_client_params) # type: ignore
|
||||
else:
|
||||
openai_client = client
|
||||
if api_version is not None and isinstance(
|
||||
openai_client._custom_query, dict
|
||||
):
|
||||
# set api_version to version passed by user
|
||||
openai_client._custom_query.setdefault("api-version", api_version)
|
||||
|
||||
# save client in-memory cache
|
||||
self.set_cached_openai_client(
|
||||
openai_client=openai_client,
|
||||
client_initialization_params=client_initialization_params,
|
||||
client_type="azure",
|
||||
)
|
||||
return openai_client
|
||||
|
||||
def initialize_azure_sdk_client(
|
||||
self,
|
||||
litellm_params: dict,
|
||||
api_key: Optional[str],
|
||||
api_base: Optional[str],
|
||||
model_name: Optional[str],
|
||||
api_version: Optional[str],
|
||||
is_async: bool,
|
||||
) -> dict:
|
||||
azure_ad_token_provider: Optional[Callable[[], str]] = None
|
||||
# If we have api_key, then we have higher priority
|
||||
azure_ad_token = litellm_params.get("azure_ad_token")
|
||||
tenant_id = litellm_params.get("tenant_id", os.getenv("AZURE_TENANT_ID"))
|
||||
client_id = litellm_params.get("client_id", os.getenv("AZURE_CLIENT_ID"))
|
||||
client_secret = litellm_params.get(
|
||||
"client_secret", os.getenv("AZURE_CLIENT_SECRET")
|
||||
)
|
||||
azure_username = litellm_params.get(
|
||||
"azure_username", os.getenv("AZURE_USERNAME")
|
||||
)
|
||||
azure_password = litellm_params.get(
|
||||
"azure_password", os.getenv("AZURE_PASSWORD")
|
||||
)
|
||||
max_retries = litellm_params.get("max_retries")
|
||||
timeout = litellm_params.get("timeout")
|
||||
if not api_key and tenant_id and client_id and client_secret:
|
||||
verbose_logger.debug(
|
||||
"Using Azure AD Token Provider from Entra ID for Azure Auth"
|
||||
)
|
||||
azure_ad_token_provider = get_azure_ad_token_from_entra_id(
|
||||
tenant_id=tenant_id,
|
||||
client_id=client_id,
|
||||
client_secret=client_secret,
|
||||
)
|
||||
if azure_username and azure_password and client_id:
|
||||
verbose_logger.debug("Using Azure Username and Password for Azure Auth")
|
||||
azure_ad_token_provider = get_azure_ad_token_from_username_password(
|
||||
azure_username=azure_username,
|
||||
azure_password=azure_password,
|
||||
client_id=client_id,
|
||||
)
|
||||
|
||||
if azure_ad_token is not None and azure_ad_token.startswith("oidc/"):
|
||||
verbose_logger.debug("Using Azure OIDC Token for Azure Auth")
|
||||
azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token)
|
||||
elif (
|
||||
not api_key
|
||||
and azure_ad_token_provider is None
|
||||
and litellm.enable_azure_ad_token_refresh is True
|
||||
):
|
||||
verbose_logger.debug(
|
||||
"Using Azure AD token provider based on Service Principal with Secret workflow for Azure Auth"
|
||||
)
|
||||
try:
|
||||
azure_ad_token_provider = get_azure_ad_token_provider()
|
||||
except ValueError:
|
||||
verbose_logger.debug("Azure AD Token Provider could not be used.")
|
||||
if api_version is None:
|
||||
api_version = os.getenv(
|
||||
"AZURE_API_VERSION", litellm.AZURE_DEFAULT_API_VERSION
|
||||
)
|
||||
|
||||
_api_key = api_key
|
||||
if _api_key is not None and isinstance(_api_key, str):
|
||||
# only show first 5 chars of api_key
|
||||
_api_key = _api_key[:8] + "*" * 15
|
||||
verbose_logger.debug(
|
||||
f"Initializing Azure OpenAI Client for {model_name}, Api Base: {str(api_base)}, Api Key:{_api_key}"
|
||||
)
|
||||
azure_client_params = {
|
||||
"api_key": api_key,
|
||||
"azure_endpoint": api_base,
|
||||
"api_version": api_version,
|
||||
"azure_ad_token": azure_ad_token,
|
||||
"azure_ad_token_provider": azure_ad_token_provider,
|
||||
}
|
||||
# init http client + SSL Verification settings
|
||||
if is_async is True:
|
||||
azure_client_params["http_client"] = self._get_async_http_client()
|
||||
else:
|
||||
azure_client_params["http_client"] = self._get_sync_http_client()
|
||||
|
||||
if max_retries is not None:
|
||||
azure_client_params["max_retries"] = max_retries
|
||||
if timeout is not None:
|
||||
azure_client_params["timeout"] = timeout
|
||||
|
||||
if azure_ad_token_provider is not None:
|
||||
azure_client_params["azure_ad_token_provider"] = azure_ad_token_provider
|
||||
# this decides if we should set azure_endpoint or base_url on Azure OpenAI Client
|
||||
# required to support GPT-4 vision enhancements, since base_url needs to be set on Azure OpenAI Client
|
||||
|
||||
azure_client_params = select_azure_base_url_or_endpoint(
|
||||
azure_client_params=azure_client_params
|
||||
)
|
||||
|
||||
return azure_client_params
|
||||
|
||||
def _init_azure_client_for_cloudflare_ai_gateway(
|
||||
self,
|
||||
api_base: str,
|
||||
model: str,
|
||||
api_version: str,
|
||||
max_retries: int,
|
||||
timeout: Union[float, httpx.Timeout],
|
||||
api_key: Optional[str],
|
||||
azure_ad_token: Optional[str],
|
||||
azure_ad_token_provider: Optional[Callable[[], str]],
|
||||
acompletion: bool,
|
||||
client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = None,
|
||||
) -> Union[AzureOpenAI, AsyncAzureOpenAI]:
|
||||
## build base url - assume api base includes resource name
|
||||
if client is None:
|
||||
if not api_base.endswith("/"):
|
||||
api_base += "/"
|
||||
api_base += f"{model}"
|
||||
|
||||
azure_client_params: Dict[str, Any] = {
|
||||
"api_version": api_version,
|
||||
"base_url": f"{api_base}",
|
||||
"http_client": litellm.client_session,
|
||||
"max_retries": max_retries,
|
||||
"timeout": timeout,
|
||||
}
|
||||
if api_key is not None:
|
||||
azure_client_params["api_key"] = api_key
|
||||
elif azure_ad_token is not None:
|
||||
if azure_ad_token.startswith("oidc/"):
|
||||
azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token)
|
||||
|
||||
azure_client_params["azure_ad_token"] = azure_ad_token
|
||||
if azure_ad_token_provider is not None:
|
||||
azure_client_params["azure_ad_token_provider"] = azure_ad_token_provider
|
||||
|
||||
if acompletion is True:
|
||||
client = AsyncAzureOpenAI(**azure_client_params) # type: ignore
|
||||
else:
|
||||
client = AzureOpenAI(**azure_client_params) # type: ignore
|
||||
return client
|
||||
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,378 @@
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
from openai import AsyncAzureOpenAI, AzureOpenAI
|
||||
|
||||
from litellm.litellm_core_utils.prompt_templates.factory import prompt_factory
|
||||
from litellm.utils import CustomStreamWrapper, ModelResponse, TextCompletionResponse
|
||||
|
||||
from ...openai.completion.transformation import OpenAITextCompletionConfig
|
||||
from ..common_utils import AzureOpenAIError, BaseAzureLLM
|
||||
|
||||
openai_text_completion_config = OpenAITextCompletionConfig()
|
||||
|
||||
|
||||
class AzureTextCompletion(BaseAzureLLM):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def validate_environment(self, api_key, azure_ad_token):
|
||||
headers = {
|
||||
"content-type": "application/json",
|
||||
}
|
||||
if api_key is not None:
|
||||
headers["api-key"] = api_key
|
||||
elif azure_ad_token is not None:
|
||||
headers["Authorization"] = f"Bearer {azure_ad_token}"
|
||||
return headers
|
||||
|
||||
def completion( # noqa: PLR0915
|
||||
self,
|
||||
model: str,
|
||||
messages: list,
|
||||
model_response: ModelResponse,
|
||||
api_key: str,
|
||||
api_base: str,
|
||||
api_version: str,
|
||||
api_type: str,
|
||||
azure_ad_token: str,
|
||||
azure_ad_token_provider: Optional[Callable],
|
||||
print_verbose: Callable,
|
||||
timeout,
|
||||
logging_obj,
|
||||
optional_params,
|
||||
litellm_params,
|
||||
logger_fn,
|
||||
acompletion: bool = False,
|
||||
headers: Optional[dict] = None,
|
||||
client=None,
|
||||
):
|
||||
try:
|
||||
if model is None or messages is None:
|
||||
raise AzureOpenAIError(
|
||||
status_code=422, message="Missing model or messages"
|
||||
)
|
||||
|
||||
max_retries = optional_params.pop("max_retries", 2)
|
||||
prompt = prompt_factory(
|
||||
messages=messages, model=model, custom_llm_provider="azure_text"
|
||||
)
|
||||
|
||||
### CHECK IF CLOUDFLARE AI GATEWAY ###
|
||||
### if so - set the model as part of the base url
|
||||
if "gateway.ai.cloudflare.com" in api_base:
|
||||
## build base url - assume api base includes resource name
|
||||
client = self._init_azure_client_for_cloudflare_ai_gateway(
|
||||
api_key=api_key,
|
||||
api_version=api_version,
|
||||
api_base=api_base,
|
||||
model=model,
|
||||
client=client,
|
||||
max_retries=max_retries,
|
||||
timeout=timeout,
|
||||
azure_ad_token=azure_ad_token,
|
||||
azure_ad_token_provider=azure_ad_token_provider,
|
||||
acompletion=acompletion,
|
||||
)
|
||||
|
||||
data = {"model": None, "prompt": prompt, **optional_params}
|
||||
else:
|
||||
data = {
|
||||
"model": model, # type: ignore
|
||||
"prompt": prompt,
|
||||
**optional_params,
|
||||
}
|
||||
|
||||
if acompletion is True:
|
||||
if optional_params.get("stream", False):
|
||||
return self.async_streaming(
|
||||
logging_obj=logging_obj,
|
||||
api_base=api_base,
|
||||
data=data,
|
||||
model=model,
|
||||
api_key=api_key,
|
||||
api_version=api_version,
|
||||
azure_ad_token=azure_ad_token,
|
||||
timeout=timeout,
|
||||
client=client,
|
||||
litellm_params=litellm_params,
|
||||
)
|
||||
else:
|
||||
return self.acompletion(
|
||||
api_base=api_base,
|
||||
data=data,
|
||||
model_response=model_response,
|
||||
api_key=api_key,
|
||||
api_version=api_version,
|
||||
model=model,
|
||||
azure_ad_token=azure_ad_token,
|
||||
timeout=timeout,
|
||||
client=client,
|
||||
logging_obj=logging_obj,
|
||||
max_retries=max_retries,
|
||||
litellm_params=litellm_params,
|
||||
)
|
||||
elif "stream" in optional_params and optional_params["stream"] is True:
|
||||
return self.streaming(
|
||||
logging_obj=logging_obj,
|
||||
api_base=api_base,
|
||||
data=data,
|
||||
model=model,
|
||||
api_key=api_key,
|
||||
api_version=api_version,
|
||||
azure_ad_token=azure_ad_token,
|
||||
timeout=timeout,
|
||||
client=client,
|
||||
)
|
||||
else:
|
||||
## LOGGING
|
||||
logging_obj.pre_call(
|
||||
input=prompt,
|
||||
api_key=api_key,
|
||||
additional_args={
|
||||
"headers": {
|
||||
"api_key": api_key,
|
||||
"azure_ad_token": azure_ad_token,
|
||||
},
|
||||
"api_version": api_version,
|
||||
"api_base": api_base,
|
||||
"complete_input_dict": data,
|
||||
},
|
||||
)
|
||||
if not isinstance(max_retries, int):
|
||||
raise AzureOpenAIError(
|
||||
status_code=422, message="max retries must be an int"
|
||||
)
|
||||
# init AzureOpenAI Client
|
||||
azure_client = self.get_azure_openai_client(
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
api_version=api_version,
|
||||
client=client,
|
||||
litellm_params=litellm_params,
|
||||
_is_async=False,
|
||||
model=model,
|
||||
)
|
||||
|
||||
if not isinstance(azure_client, AzureOpenAI):
|
||||
raise AzureOpenAIError(
|
||||
status_code=500,
|
||||
message="azure_client is not an instance of AzureOpenAI",
|
||||
)
|
||||
|
||||
raw_response = azure_client.completions.with_raw_response.create(
|
||||
**data, timeout=timeout
|
||||
)
|
||||
response = raw_response.parse()
|
||||
stringified_response = response.model_dump()
|
||||
## LOGGING
|
||||
logging_obj.post_call(
|
||||
input=prompt,
|
||||
api_key=api_key,
|
||||
original_response=stringified_response,
|
||||
additional_args={
|
||||
"headers": headers,
|
||||
"api_version": api_version,
|
||||
"api_base": api_base,
|
||||
},
|
||||
)
|
||||
return (
|
||||
openai_text_completion_config.convert_to_chat_model_response_object(
|
||||
response_object=TextCompletionResponse(**stringified_response),
|
||||
model_response_object=model_response,
|
||||
)
|
||||
)
|
||||
except AzureOpenAIError as e:
|
||||
raise e
|
||||
except Exception as e:
|
||||
status_code = getattr(e, "status_code", 500)
|
||||
error_headers = getattr(e, "headers", None)
|
||||
error_response = getattr(e, "response", None)
|
||||
if error_headers is None and error_response:
|
||||
error_headers = getattr(error_response, "headers", None)
|
||||
raise AzureOpenAIError(
|
||||
status_code=status_code, message=str(e), headers=error_headers
|
||||
)
|
||||
|
||||
async def acompletion(
|
||||
self,
|
||||
api_key: str,
|
||||
api_version: str,
|
||||
model: str,
|
||||
api_base: str,
|
||||
data: dict,
|
||||
timeout: Any,
|
||||
model_response: ModelResponse,
|
||||
logging_obj: Any,
|
||||
max_retries: int,
|
||||
azure_ad_token: Optional[str] = None,
|
||||
client=None, # this is the AsyncAzureOpenAI
|
||||
litellm_params: dict = {},
|
||||
):
|
||||
response = None
|
||||
try:
|
||||
# init AzureOpenAI Client
|
||||
# setting Azure client
|
||||
azure_client = self.get_azure_openai_client(
|
||||
api_version=api_version,
|
||||
api_base=api_base,
|
||||
api_key=api_key,
|
||||
model=model,
|
||||
_is_async=True,
|
||||
client=client,
|
||||
litellm_params=litellm_params,
|
||||
)
|
||||
if not isinstance(azure_client, AsyncAzureOpenAI):
|
||||
raise AzureOpenAIError(
|
||||
status_code=500,
|
||||
message="azure_client is not an instance of AsyncAzureOpenAI",
|
||||
)
|
||||
|
||||
## LOGGING
|
||||
logging_obj.pre_call(
|
||||
input=data["prompt"],
|
||||
api_key=azure_client.api_key,
|
||||
additional_args={
|
||||
"headers": {"Authorization": f"Bearer {azure_client.api_key}"},
|
||||
"api_base": azure_client._base_url._uri_reference,
|
||||
"acompletion": True,
|
||||
"complete_input_dict": data,
|
||||
},
|
||||
)
|
||||
raw_response = await azure_client.completions.with_raw_response.create(
|
||||
**data, timeout=timeout
|
||||
)
|
||||
response = raw_response.parse()
|
||||
return openai_text_completion_config.convert_to_chat_model_response_object(
|
||||
response_object=response.model_dump(),
|
||||
model_response_object=model_response,
|
||||
)
|
||||
except AzureOpenAIError as e:
|
||||
raise e
|
||||
except Exception as e:
|
||||
status_code = getattr(e, "status_code", 500)
|
||||
error_headers = getattr(e, "headers", None)
|
||||
error_response = getattr(e, "response", None)
|
||||
if error_headers is None and error_response:
|
||||
error_headers = getattr(error_response, "headers", None)
|
||||
raise AzureOpenAIError(
|
||||
status_code=status_code, message=str(e), headers=error_headers
|
||||
)
|
||||
|
||||
def streaming(
|
||||
self,
|
||||
logging_obj,
|
||||
api_base: str,
|
||||
api_key: str,
|
||||
api_version: str,
|
||||
data: dict,
|
||||
model: str,
|
||||
timeout: Any,
|
||||
azure_ad_token: Optional[str] = None,
|
||||
client=None,
|
||||
litellm_params: dict = {},
|
||||
):
|
||||
max_retries = data.pop("max_retries", 2)
|
||||
if not isinstance(max_retries, int):
|
||||
raise AzureOpenAIError(
|
||||
status_code=422, message="max retries must be an int"
|
||||
)
|
||||
# init AzureOpenAI Client
|
||||
azure_client = self.get_azure_openai_client(
|
||||
api_version=api_version,
|
||||
api_base=api_base,
|
||||
api_key=api_key,
|
||||
model=model,
|
||||
_is_async=False,
|
||||
client=client,
|
||||
litellm_params=litellm_params,
|
||||
)
|
||||
if not isinstance(azure_client, AzureOpenAI):
|
||||
raise AzureOpenAIError(
|
||||
status_code=500,
|
||||
message="azure_client is not an instance of AzureOpenAI",
|
||||
)
|
||||
|
||||
## LOGGING
|
||||
logging_obj.pre_call(
|
||||
input=data["prompt"],
|
||||
api_key=azure_client.api_key,
|
||||
additional_args={
|
||||
"headers": {"Authorization": f"Bearer {azure_client.api_key}"},
|
||||
"api_base": azure_client._base_url._uri_reference,
|
||||
"acompletion": True,
|
||||
"complete_input_dict": data,
|
||||
},
|
||||
)
|
||||
raw_response = azure_client.completions.with_raw_response.create(
|
||||
**data, timeout=timeout
|
||||
)
|
||||
response = raw_response.parse()
|
||||
streamwrapper = CustomStreamWrapper(
|
||||
completion_stream=response,
|
||||
model=model,
|
||||
custom_llm_provider="azure_text",
|
||||
logging_obj=logging_obj,
|
||||
)
|
||||
return streamwrapper
|
||||
|
||||
async def async_streaming(
|
||||
self,
|
||||
logging_obj,
|
||||
api_base: str,
|
||||
api_key: str,
|
||||
api_version: str,
|
||||
data: dict,
|
||||
model: str,
|
||||
timeout: Any,
|
||||
azure_ad_token: Optional[str] = None,
|
||||
client=None,
|
||||
litellm_params: dict = {},
|
||||
):
|
||||
try:
|
||||
# init AzureOpenAI Client
|
||||
azure_client = self.get_azure_openai_client(
|
||||
api_version=api_version,
|
||||
api_base=api_base,
|
||||
api_key=api_key,
|
||||
model=model,
|
||||
_is_async=True,
|
||||
client=client,
|
||||
litellm_params=litellm_params,
|
||||
)
|
||||
if not isinstance(azure_client, AsyncAzureOpenAI):
|
||||
raise AzureOpenAIError(
|
||||
status_code=500,
|
||||
message="azure_client is not an instance of AsyncAzureOpenAI",
|
||||
)
|
||||
## LOGGING
|
||||
logging_obj.pre_call(
|
||||
input=data["prompt"],
|
||||
api_key=azure_client.api_key,
|
||||
additional_args={
|
||||
"headers": {"Authorization": f"Bearer {azure_client.api_key}"},
|
||||
"api_base": azure_client._base_url._uri_reference,
|
||||
"acompletion": True,
|
||||
"complete_input_dict": data,
|
||||
},
|
||||
)
|
||||
raw_response = await azure_client.completions.with_raw_response.create(
|
||||
**data, timeout=timeout
|
||||
)
|
||||
response = raw_response.parse()
|
||||
# return response
|
||||
streamwrapper = CustomStreamWrapper(
|
||||
completion_stream=response,
|
||||
model=model,
|
||||
custom_llm_provider="azure_text",
|
||||
logging_obj=logging_obj,
|
||||
)
|
||||
return streamwrapper ## DO NOT make this into an async for ... loop, it will yield an async generator, which won't raise errors if the response fails
|
||||
except Exception as e:
|
||||
status_code = getattr(e, "status_code", 500)
|
||||
error_headers = getattr(e, "headers", None)
|
||||
error_response = getattr(e, "response", None)
|
||||
if error_headers is None and error_response:
|
||||
error_headers = getattr(error_response, "headers", None)
|
||||
raise AzureOpenAIError(
|
||||
status_code=status_code, message=str(e), headers=error_headers
|
||||
)
|
||||
@@ -0,0 +1,53 @@
|
||||
from typing import Optional, Union
|
||||
|
||||
from ...openai.completion.transformation import OpenAITextCompletionConfig
|
||||
|
||||
|
||||
class AzureOpenAITextConfig(OpenAITextCompletionConfig):
|
||||
"""
|
||||
Reference: https://platform.openai.com/docs/api-reference/chat/create
|
||||
|
||||
The class `AzureOpenAIConfig` provides configuration for the OpenAI's Chat API interface, for use with Azure. It inherits from `OpenAIConfig`. Below are the parameters::
|
||||
|
||||
- `frequency_penalty` (number or null): Defaults to 0. Allows a value between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far, thereby minimizing repetition.
|
||||
|
||||
- `function_call` (string or object): This optional parameter controls how the model calls functions.
|
||||
|
||||
- `functions` (array): An optional parameter. It is a list of functions for which the model may generate JSON inputs.
|
||||
|
||||
- `logit_bias` (map): This optional parameter modifies the likelihood of specified tokens appearing in the completion.
|
||||
|
||||
- `max_tokens` (integer or null): This optional parameter helps to set the maximum number of tokens to generate in the chat completion.
|
||||
|
||||
- `n` (integer or null): This optional parameter helps to set how many chat completion choices to generate for each input message.
|
||||
|
||||
- `presence_penalty` (number or null): Defaults to 0. It penalizes new tokens based on if they appear in the text so far, hence increasing the model's likelihood to talk about new topics.
|
||||
|
||||
- `stop` (string / array / null): Specifies up to 4 sequences where the API will stop generating further tokens.
|
||||
|
||||
- `temperature` (number or null): Defines the sampling temperature to use, varying between 0 and 2.
|
||||
|
||||
- `top_p` (number or null): An alternative to sampling with temperature, used for nucleus sampling.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
frequency_penalty: Optional[int] = None,
|
||||
logit_bias: Optional[dict] = None,
|
||||
max_tokens: Optional[int] = None,
|
||||
n: Optional[int] = None,
|
||||
presence_penalty: Optional[int] = None,
|
||||
stop: Optional[Union[str, list]] = None,
|
||||
temperature: Optional[int] = None,
|
||||
top_p: Optional[int] = None,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
frequency_penalty=frequency_penalty,
|
||||
logit_bias=logit_bias,
|
||||
max_tokens=max_tokens,
|
||||
n=n,
|
||||
presence_penalty=presence_penalty,
|
||||
stop=stop,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
)
|
||||
@@ -0,0 +1,61 @@
|
||||
"""
|
||||
Helper util for handling azure openai-specific cost calculation
|
||||
- e.g.: prompt caching
|
||||
"""
|
||||
|
||||
from typing import Optional, Tuple
|
||||
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.types.utils import Usage
|
||||
from litellm.utils import get_model_info
|
||||
|
||||
|
||||
def cost_per_token(
|
||||
model: str, usage: Usage, response_time_ms: Optional[float] = 0.0
|
||||
) -> Tuple[float, float]:
|
||||
"""
|
||||
Calculates the cost per token for a given model, prompt tokens, and completion tokens.
|
||||
|
||||
Input:
|
||||
- model: str, the model name without provider prefix
|
||||
- usage: LiteLLM Usage block, containing anthropic caching information
|
||||
|
||||
Returns:
|
||||
Tuple[float, float] - prompt_cost_in_usd, completion_cost_in_usd
|
||||
"""
|
||||
## GET MODEL INFO
|
||||
model_info = get_model_info(model=model, custom_llm_provider="azure")
|
||||
cached_tokens: Optional[int] = None
|
||||
## CALCULATE INPUT COST
|
||||
non_cached_text_tokens = usage.prompt_tokens
|
||||
if usage.prompt_tokens_details and usage.prompt_tokens_details.cached_tokens:
|
||||
cached_tokens = usage.prompt_tokens_details.cached_tokens
|
||||
non_cached_text_tokens = non_cached_text_tokens - cached_tokens
|
||||
prompt_cost: float = non_cached_text_tokens * model_info["input_cost_per_token"]
|
||||
|
||||
## CALCULATE OUTPUT COST
|
||||
completion_cost: float = (
|
||||
usage["completion_tokens"] * model_info["output_cost_per_token"]
|
||||
)
|
||||
|
||||
## Prompt Caching cost calculation
|
||||
if model_info.get("cache_read_input_token_cost") is not None and cached_tokens:
|
||||
# Note: We read ._cache_read_input_tokens from the Usage - since cost_calculator.py standardizes the cache read tokens on usage._cache_read_input_tokens
|
||||
prompt_cost += cached_tokens * (
|
||||
model_info.get("cache_read_input_token_cost", 0) or 0
|
||||
)
|
||||
|
||||
## Speech / Audio cost calculation
|
||||
if (
|
||||
"output_cost_per_second" in model_info
|
||||
and model_info["output_cost_per_second"] is not None
|
||||
and response_time_ms is not None
|
||||
):
|
||||
verbose_logger.debug(
|
||||
f"For model={model} - output_cost_per_second: {model_info.get('output_cost_per_second')}; response time: {response_time_ms}"
|
||||
)
|
||||
## COST PER SECOND ##
|
||||
prompt_cost = 0
|
||||
completion_cost = model_info["output_cost_per_second"] * response_time_ms / 1000
|
||||
|
||||
return prompt_cost, completion_cost
|
||||
Binary file not shown.
@@ -0,0 +1,283 @@
|
||||
from typing import Any, Coroutine, Optional, Union, cast
|
||||
|
||||
import httpx
|
||||
from openai import AsyncAzureOpenAI, AzureOpenAI
|
||||
from openai.types.file_deleted import FileDeleted
|
||||
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.types.llms.openai import *
|
||||
|
||||
from ..common_utils import BaseAzureLLM
|
||||
|
||||
|
||||
class AzureOpenAIFilesAPI(BaseAzureLLM):
|
||||
"""
|
||||
AzureOpenAI methods to support for batches
|
||||
- create_file()
|
||||
- retrieve_file()
|
||||
- list_files()
|
||||
- delete_file()
|
||||
- file_content()
|
||||
- update_file()
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
async def acreate_file(
|
||||
self,
|
||||
create_file_data: CreateFileRequest,
|
||||
openai_client: AsyncAzureOpenAI,
|
||||
) -> OpenAIFileObject:
|
||||
verbose_logger.debug("create_file_data=%s", create_file_data)
|
||||
response = await openai_client.files.create(**create_file_data)
|
||||
verbose_logger.debug("create_file_response=%s", response)
|
||||
return OpenAIFileObject(**response.model_dump())
|
||||
|
||||
def create_file(
|
||||
self,
|
||||
_is_async: bool,
|
||||
create_file_data: CreateFileRequest,
|
||||
api_base: Optional[str],
|
||||
api_key: Optional[str],
|
||||
api_version: Optional[str],
|
||||
timeout: Union[float, httpx.Timeout],
|
||||
max_retries: Optional[int],
|
||||
client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = None,
|
||||
litellm_params: Optional[dict] = None,
|
||||
) -> Union[OpenAIFileObject, Coroutine[Any, Any, OpenAIFileObject]]:
|
||||
openai_client: Optional[
|
||||
Union[AzureOpenAI, AsyncAzureOpenAI]
|
||||
] = self.get_azure_openai_client(
|
||||
litellm_params=litellm_params or {},
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
api_version=api_version,
|
||||
client=client,
|
||||
_is_async=_is_async,
|
||||
)
|
||||
if openai_client is None:
|
||||
raise ValueError(
|
||||
"AzureOpenAI client is not initialized. Make sure api_key is passed or OPENAI_API_KEY is set in the environment."
|
||||
)
|
||||
|
||||
if _is_async is True:
|
||||
if not isinstance(openai_client, AsyncAzureOpenAI):
|
||||
raise ValueError(
|
||||
"AzureOpenAI client is not an instance of AsyncAzureOpenAI. Make sure you passed an AsyncAzureOpenAI client."
|
||||
)
|
||||
return self.acreate_file(
|
||||
create_file_data=create_file_data, openai_client=openai_client
|
||||
)
|
||||
response = cast(AzureOpenAI, openai_client).files.create(**create_file_data)
|
||||
return OpenAIFileObject(**response.model_dump())
|
||||
|
||||
async def afile_content(
|
||||
self,
|
||||
file_content_request: FileContentRequest,
|
||||
openai_client: AsyncAzureOpenAI,
|
||||
) -> HttpxBinaryResponseContent:
|
||||
response = await openai_client.files.content(**file_content_request)
|
||||
return HttpxBinaryResponseContent(response=response.response)
|
||||
|
||||
def file_content(
|
||||
self,
|
||||
_is_async: bool,
|
||||
file_content_request: FileContentRequest,
|
||||
api_base: Optional[str],
|
||||
api_key: Optional[str],
|
||||
timeout: Union[float, httpx.Timeout],
|
||||
max_retries: Optional[int],
|
||||
api_version: Optional[str] = None,
|
||||
client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = None,
|
||||
litellm_params: Optional[dict] = None,
|
||||
) -> Union[
|
||||
HttpxBinaryResponseContent, Coroutine[Any, Any, HttpxBinaryResponseContent]
|
||||
]:
|
||||
openai_client: Optional[
|
||||
Union[AzureOpenAI, AsyncAzureOpenAI]
|
||||
] = self.get_azure_openai_client(
|
||||
litellm_params=litellm_params or {},
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
api_version=api_version,
|
||||
client=client,
|
||||
_is_async=_is_async,
|
||||
)
|
||||
if openai_client is None:
|
||||
raise ValueError(
|
||||
"AzureOpenAI client is not initialized. Make sure api_key is passed or OPENAI_API_KEY is set in the environment."
|
||||
)
|
||||
|
||||
if _is_async is True:
|
||||
if not isinstance(openai_client, AsyncAzureOpenAI):
|
||||
raise ValueError(
|
||||
"AzureOpenAI client is not an instance of AsyncAzureOpenAI. Make sure you passed an AsyncAzureOpenAI client."
|
||||
)
|
||||
return self.afile_content( # type: ignore
|
||||
file_content_request=file_content_request,
|
||||
openai_client=openai_client,
|
||||
)
|
||||
response = cast(AzureOpenAI, openai_client).files.content(
|
||||
**file_content_request
|
||||
)
|
||||
|
||||
return HttpxBinaryResponseContent(response=response.response)
|
||||
|
||||
async def aretrieve_file(
|
||||
self,
|
||||
file_id: str,
|
||||
openai_client: AsyncAzureOpenAI,
|
||||
) -> FileObject:
|
||||
response = await openai_client.files.retrieve(file_id=file_id)
|
||||
return response
|
||||
|
||||
def retrieve_file(
|
||||
self,
|
||||
_is_async: bool,
|
||||
file_id: str,
|
||||
api_base: Optional[str],
|
||||
api_key: Optional[str],
|
||||
timeout: Union[float, httpx.Timeout],
|
||||
max_retries: Optional[int],
|
||||
api_version: Optional[str] = None,
|
||||
client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = None,
|
||||
litellm_params: Optional[dict] = None,
|
||||
):
|
||||
openai_client: Optional[
|
||||
Union[AzureOpenAI, AsyncAzureOpenAI]
|
||||
] = self.get_azure_openai_client(
|
||||
litellm_params=litellm_params or {},
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
api_version=api_version,
|
||||
client=client,
|
||||
_is_async=_is_async,
|
||||
)
|
||||
if openai_client is None:
|
||||
raise ValueError(
|
||||
"AzureOpenAI client is not initialized. Make sure api_key is passed or OPENAI_API_KEY is set in the environment."
|
||||
)
|
||||
|
||||
if _is_async is True:
|
||||
if not isinstance(openai_client, AsyncAzureOpenAI):
|
||||
raise ValueError(
|
||||
"AzureOpenAI client is not an instance of AsyncAzureOpenAI. Make sure you passed an AsyncAzureOpenAI client."
|
||||
)
|
||||
return self.aretrieve_file( # type: ignore
|
||||
file_id=file_id,
|
||||
openai_client=openai_client,
|
||||
)
|
||||
response = openai_client.files.retrieve(file_id=file_id)
|
||||
|
||||
return response
|
||||
|
||||
async def adelete_file(
|
||||
self,
|
||||
file_id: str,
|
||||
openai_client: AsyncAzureOpenAI,
|
||||
) -> FileDeleted:
|
||||
response = await openai_client.files.delete(file_id=file_id)
|
||||
|
||||
if not isinstance(response, FileDeleted): # azure returns an empty string
|
||||
return FileDeleted(id=file_id, deleted=True, object="file")
|
||||
return response
|
||||
|
||||
def delete_file(
|
||||
self,
|
||||
_is_async: bool,
|
||||
file_id: str,
|
||||
api_base: Optional[str],
|
||||
api_key: Optional[str],
|
||||
timeout: Union[float, httpx.Timeout],
|
||||
max_retries: Optional[int],
|
||||
organization: Optional[str] = None,
|
||||
api_version: Optional[str] = None,
|
||||
client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = None,
|
||||
litellm_params: Optional[dict] = None,
|
||||
):
|
||||
openai_client: Optional[
|
||||
Union[AzureOpenAI, AsyncAzureOpenAI]
|
||||
] = self.get_azure_openai_client(
|
||||
litellm_params=litellm_params or {},
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
api_version=api_version,
|
||||
client=client,
|
||||
_is_async=_is_async,
|
||||
)
|
||||
if openai_client is None:
|
||||
raise ValueError(
|
||||
"AzureOpenAI client is not initialized. Make sure api_key is passed or OPENAI_API_KEY is set in the environment."
|
||||
)
|
||||
|
||||
if _is_async is True:
|
||||
if not isinstance(openai_client, AsyncAzureOpenAI):
|
||||
raise ValueError(
|
||||
"AzureOpenAI client is not an instance of AsyncAzureOpenAI. Make sure you passed an AsyncAzureOpenAI client."
|
||||
)
|
||||
return self.adelete_file( # type: ignore
|
||||
file_id=file_id,
|
||||
openai_client=openai_client,
|
||||
)
|
||||
response = openai_client.files.delete(file_id=file_id)
|
||||
|
||||
if not isinstance(response, FileDeleted): # azure returns an empty string
|
||||
return FileDeleted(id=file_id, deleted=True, object="file")
|
||||
|
||||
return response
|
||||
|
||||
async def alist_files(
|
||||
self,
|
||||
openai_client: AsyncAzureOpenAI,
|
||||
purpose: Optional[str] = None,
|
||||
):
|
||||
if isinstance(purpose, str):
|
||||
response = await openai_client.files.list(purpose=purpose)
|
||||
else:
|
||||
response = await openai_client.files.list()
|
||||
return response
|
||||
|
||||
def list_files(
|
||||
self,
|
||||
_is_async: bool,
|
||||
api_base: Optional[str],
|
||||
api_key: Optional[str],
|
||||
timeout: Union[float, httpx.Timeout],
|
||||
max_retries: Optional[int],
|
||||
purpose: Optional[str] = None,
|
||||
api_version: Optional[str] = None,
|
||||
client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = None,
|
||||
litellm_params: Optional[dict] = None,
|
||||
):
|
||||
openai_client: Optional[
|
||||
Union[AzureOpenAI, AsyncAzureOpenAI]
|
||||
] = self.get_azure_openai_client(
|
||||
litellm_params=litellm_params or {},
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
api_version=api_version,
|
||||
client=client,
|
||||
_is_async=_is_async,
|
||||
)
|
||||
if openai_client is None:
|
||||
raise ValueError(
|
||||
"AzureOpenAI client is not initialized. Make sure api_key is passed or OPENAI_API_KEY is set in the environment."
|
||||
)
|
||||
|
||||
if _is_async is True:
|
||||
if not isinstance(openai_client, AsyncAzureOpenAI):
|
||||
raise ValueError(
|
||||
"AzureOpenAI client is not an instance of AsyncAzureOpenAI. Make sure you passed an AsyncAzureOpenAI client."
|
||||
)
|
||||
return self.alist_files( # type: ignore
|
||||
purpose=purpose,
|
||||
openai_client=openai_client,
|
||||
)
|
||||
|
||||
if isinstance(purpose, str):
|
||||
response = openai_client.files.list(purpose=purpose)
|
||||
else:
|
||||
response = openai_client.files.list()
|
||||
|
||||
return response
|
||||
Binary file not shown.
@@ -0,0 +1,40 @@
|
||||
from typing import Optional, Union
|
||||
|
||||
import httpx
|
||||
from openai import AsyncAzureOpenAI, AsyncOpenAI, AzureOpenAI, OpenAI
|
||||
|
||||
from litellm.llms.azure.common_utils import BaseAzureLLM
|
||||
from litellm.llms.openai.fine_tuning.handler import OpenAIFineTuningAPI
|
||||
|
||||
|
||||
class AzureOpenAIFineTuningAPI(OpenAIFineTuningAPI, BaseAzureLLM):
|
||||
"""
|
||||
AzureOpenAI methods to support fine tuning, inherits from OpenAIFineTuningAPI.
|
||||
"""
|
||||
|
||||
def get_openai_client(
|
||||
self,
|
||||
api_key: Optional[str],
|
||||
api_base: Optional[str],
|
||||
timeout: Union[float, httpx.Timeout],
|
||||
max_retries: Optional[int],
|
||||
organization: Optional[str],
|
||||
client: Optional[
|
||||
Union[OpenAI, AsyncOpenAI, AzureOpenAI, AsyncAzureOpenAI]
|
||||
] = None,
|
||||
_is_async: bool = False,
|
||||
api_version: Optional[str] = None,
|
||||
litellm_params: Optional[dict] = None,
|
||||
) -> Optional[Union[OpenAI, AsyncOpenAI, AzureOpenAI, AsyncAzureOpenAI,]]:
|
||||
# Override to use Azure-specific client initialization
|
||||
if isinstance(client, OpenAI) or isinstance(client, AsyncOpenAI):
|
||||
client = None
|
||||
|
||||
return self.get_azure_openai_client(
|
||||
litellm_params=litellm_params or {},
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
api_version=api_version,
|
||||
client=client,
|
||||
_is_async=_is_async,
|
||||
)
|
||||
Binary file not shown.
@@ -0,0 +1,75 @@
|
||||
"""
|
||||
This file contains the calling Azure OpenAI's `/openai/realtime` endpoint.
|
||||
|
||||
This requires websockets, and is currently only supported on LiteLLM Proxy.
|
||||
"""
|
||||
|
||||
from typing import Any, Optional
|
||||
|
||||
from ....litellm_core_utils.litellm_logging import Logging as LiteLLMLogging
|
||||
from ....litellm_core_utils.realtime_streaming import RealTimeStreaming
|
||||
from ..azure import AzureChatCompletion
|
||||
|
||||
# BACKEND_WS_URL = "ws://localhost:8080/v1/realtime?model=gpt-4o-realtime-preview-2024-10-01"
|
||||
|
||||
|
||||
async def forward_messages(client_ws: Any, backend_ws: Any):
|
||||
import websockets
|
||||
|
||||
try:
|
||||
while True:
|
||||
message = await backend_ws.recv()
|
||||
await client_ws.send_text(message)
|
||||
except websockets.exceptions.ConnectionClosed: # type: ignore
|
||||
pass
|
||||
|
||||
|
||||
class AzureOpenAIRealtime(AzureChatCompletion):
|
||||
def _construct_url(self, api_base: str, model: str, api_version: str) -> str:
|
||||
"""
|
||||
Example output:
|
||||
"wss://my-endpoint-sweden-berri992.openai.azure.com/openai/realtime?api-version=2024-10-01-preview&deployment=gpt-4o-realtime-preview";
|
||||
|
||||
"""
|
||||
api_base = api_base.replace("https://", "wss://")
|
||||
return (
|
||||
f"{api_base}/openai/realtime?api-version={api_version}&deployment={model}"
|
||||
)
|
||||
|
||||
async def async_realtime(
|
||||
self,
|
||||
model: str,
|
||||
websocket: Any,
|
||||
api_base: Optional[str] = None,
|
||||
api_key: Optional[str] = None,
|
||||
api_version: Optional[str] = None,
|
||||
azure_ad_token: Optional[str] = None,
|
||||
client: Optional[Any] = None,
|
||||
logging_obj: Optional[LiteLLMLogging] = None,
|
||||
timeout: Optional[float] = None,
|
||||
):
|
||||
import websockets
|
||||
|
||||
if api_base is None:
|
||||
raise ValueError("api_base is required for Azure OpenAI calls")
|
||||
if api_version is None:
|
||||
raise ValueError("api_version is required for Azure OpenAI calls")
|
||||
|
||||
url = self._construct_url(api_base, model, api_version)
|
||||
|
||||
try:
|
||||
async with websockets.connect( # type: ignore
|
||||
url,
|
||||
extra_headers={
|
||||
"api-key": api_key, # type: ignore
|
||||
},
|
||||
) as backend_ws:
|
||||
realtime_streaming = RealTimeStreaming(
|
||||
websocket, backend_ws, logging_obj
|
||||
)
|
||||
await realtime_streaming.bidirectional_forward()
|
||||
|
||||
except websockets.exceptions.InvalidStatusCode as e: # type: ignore
|
||||
await websocket.close(code=e.status_code, reason=str(e))
|
||||
except Exception:
|
||||
pass
|
||||
Binary file not shown.
@@ -0,0 +1,138 @@
|
||||
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, cast
|
||||
|
||||
import httpx
|
||||
|
||||
import litellm
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.llms.openai.responses.transformation import OpenAIResponsesAPIConfig
|
||||
from litellm.secret_managers.main import get_secret_str
|
||||
from litellm.types.llms.openai import *
|
||||
from litellm.types.responses.main import *
|
||||
from litellm.types.router import GenericLiteLLMParams
|
||||
from litellm.utils import _add_path_to_api_base
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
|
||||
|
||||
LiteLLMLoggingObj = _LiteLLMLoggingObj
|
||||
else:
|
||||
LiteLLMLoggingObj = Any
|
||||
|
||||
|
||||
class AzureOpenAIResponsesAPIConfig(OpenAIResponsesAPIConfig):
|
||||
def validate_environment(
|
||||
self,
|
||||
headers: dict,
|
||||
model: str,
|
||||
api_key: Optional[str] = None,
|
||||
) -> dict:
|
||||
api_key = (
|
||||
api_key
|
||||
or litellm.api_key
|
||||
or litellm.azure_key
|
||||
or get_secret_str("AZURE_OPENAI_API_KEY")
|
||||
or get_secret_str("AZURE_API_KEY")
|
||||
)
|
||||
|
||||
headers.update(
|
||||
{
|
||||
"Authorization": f"Bearer {api_key}",
|
||||
}
|
||||
)
|
||||
return headers
|
||||
|
||||
def get_complete_url(
|
||||
self,
|
||||
api_base: Optional[str],
|
||||
litellm_params: dict,
|
||||
) -> str:
|
||||
"""
|
||||
Constructs a complete URL for the API request.
|
||||
|
||||
Args:
|
||||
- api_base: Base URL, e.g.,
|
||||
"https://litellm8397336933.openai.azure.com"
|
||||
OR
|
||||
"https://litellm8397336933.openai.azure.com/openai/responses?api-version=2024-05-01-preview"
|
||||
- model: Model name.
|
||||
- optional_params: Additional query parameters, including "api_version".
|
||||
- stream: If streaming is required (optional).
|
||||
|
||||
Returns:
|
||||
- A complete URL string, e.g.,
|
||||
"https://litellm8397336933.openai.azure.com/openai/responses?api-version=2024-05-01-preview"
|
||||
"""
|
||||
api_base = api_base or litellm.api_base or get_secret_str("AZURE_API_BASE")
|
||||
if api_base is None:
|
||||
raise ValueError(
|
||||
f"api_base is required for Azure AI Studio. Please set the api_base parameter. Passed `api_base={api_base}`"
|
||||
)
|
||||
original_url = httpx.URL(api_base)
|
||||
|
||||
# Extract api_version or use default
|
||||
api_version = cast(Optional[str], litellm_params.get("api_version"))
|
||||
|
||||
# Create a new dictionary with existing params
|
||||
query_params = dict(original_url.params)
|
||||
|
||||
# Add api_version if needed
|
||||
if "api-version" not in query_params and api_version:
|
||||
query_params["api-version"] = api_version
|
||||
|
||||
# Add the path to the base URL
|
||||
if "/openai/responses" not in api_base:
|
||||
new_url = _add_path_to_api_base(
|
||||
api_base=api_base, ending_path="/openai/responses"
|
||||
)
|
||||
else:
|
||||
new_url = api_base
|
||||
|
||||
# Use the new query_params dictionary
|
||||
final_url = httpx.URL(new_url).copy_with(params=query_params)
|
||||
|
||||
return str(final_url)
|
||||
|
||||
#########################################################
|
||||
########## DELETE RESPONSE API TRANSFORMATION ##############
|
||||
#########################################################
|
||||
def transform_delete_response_api_request(
|
||||
self,
|
||||
response_id: str,
|
||||
api_base: str,
|
||||
litellm_params: GenericLiteLLMParams,
|
||||
headers: dict,
|
||||
) -> Tuple[str, Dict]:
|
||||
"""
|
||||
Transform the delete response API request into a URL and data
|
||||
|
||||
Azure OpenAI API expects the following request:
|
||||
- DELETE /openai/responses/{response_id}?api-version=xxx
|
||||
|
||||
This function handles URLs with query parameters by inserting the response_id
|
||||
at the correct location (before any query parameters).
|
||||
"""
|
||||
from urllib.parse import urlparse, urlunparse
|
||||
|
||||
# Parse the URL to separate its components
|
||||
parsed_url = urlparse(api_base)
|
||||
|
||||
# Insert the response_id at the end of the path component
|
||||
# Remove trailing slash if present to avoid double slashes
|
||||
path = parsed_url.path.rstrip("/")
|
||||
new_path = f"{path}/{response_id}"
|
||||
|
||||
# Reconstruct the URL with all original components but with the modified path
|
||||
delete_url = urlunparse(
|
||||
(
|
||||
parsed_url.scheme, # http, https
|
||||
parsed_url.netloc, # domain name, port
|
||||
new_path, # path with response_id added
|
||||
parsed_url.params, # parameters
|
||||
parsed_url.query, # query string
|
||||
parsed_url.fragment, # fragment
|
||||
)
|
||||
)
|
||||
|
||||
data: Dict = {}
|
||||
verbose_logger.debug(f"delete response url={delete_url}")
|
||||
return delete_url, data
|
||||
@@ -0,0 +1 @@
|
||||
`/chat/completion` calls routed via `openai.py`.
|
||||
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,3 @@
|
||||
"""
|
||||
LLM Calling done in `openai/openai.py`
|
||||
"""
|
||||
@@ -0,0 +1,321 @@
|
||||
import enum
|
||||
from typing import Any, List, Optional, Tuple, cast
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import httpx
|
||||
from httpx import Response
|
||||
|
||||
import litellm
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.litellm_core_utils.prompt_templates.common_utils import (
|
||||
_audio_or_image_in_message_content,
|
||||
convert_content_list_to_str,
|
||||
)
|
||||
from litellm.llms.base_llm.chat.transformation import LiteLLMLoggingObj
|
||||
from litellm.llms.openai.common_utils import drop_params_from_unprocessable_entity_error
|
||||
from litellm.llms.openai.openai import OpenAIConfig
|
||||
from litellm.secret_managers.main import get_secret_str
|
||||
from litellm.types.llms.openai import AllMessageValues
|
||||
from litellm.types.utils import ModelResponse, ProviderField
|
||||
from litellm.utils import _add_path_to_api_base, supports_tool_choice
|
||||
|
||||
|
||||
class AzureFoundryErrorStrings(str, enum.Enum):
|
||||
SET_EXTRA_PARAMETERS_TO_PASS_THROUGH = "Set extra-parameters to 'pass-through'"
|
||||
|
||||
|
||||
class AzureAIStudioConfig(OpenAIConfig):
|
||||
def get_supported_openai_params(self, model: str) -> List:
|
||||
model_supports_tool_choice = True # azure ai supports this by default
|
||||
if not supports_tool_choice(model=f"azure_ai/{model}"):
|
||||
model_supports_tool_choice = False
|
||||
supported_params = super().get_supported_openai_params(model)
|
||||
if not model_supports_tool_choice:
|
||||
filtered_supported_params = []
|
||||
for param in supported_params:
|
||||
if param != "tool_choice":
|
||||
filtered_supported_params.append(param)
|
||||
return filtered_supported_params
|
||||
return supported_params
|
||||
|
||||
def validate_environment(
|
||||
self,
|
||||
headers: dict,
|
||||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
) -> dict:
|
||||
if api_base and self._should_use_api_key_header(api_base):
|
||||
headers["api-key"] = api_key
|
||||
else:
|
||||
headers["Authorization"] = f"Bearer {api_key}"
|
||||
|
||||
return headers
|
||||
|
||||
def _should_use_api_key_header(self, api_base: str) -> bool:
|
||||
"""
|
||||
Returns True if the request should use `api-key` header for authentication.
|
||||
"""
|
||||
parsed_url = urlparse(api_base)
|
||||
host = parsed_url.hostname
|
||||
if host and (
|
||||
host.endswith(".services.ai.azure.com")
|
||||
or host.endswith(".openai.azure.com")
|
||||
):
|
||||
return True
|
||||
return False
|
||||
|
||||
def get_complete_url(
|
||||
self,
|
||||
api_base: Optional[str],
|
||||
api_key: Optional[str],
|
||||
model: str,
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
stream: Optional[bool] = None,
|
||||
) -> str:
|
||||
"""
|
||||
Constructs a complete URL for the API request.
|
||||
|
||||
Args:
|
||||
- api_base: Base URL, e.g.,
|
||||
"https://litellm8397336933.services.ai.azure.com"
|
||||
OR
|
||||
"https://litellm8397336933.services.ai.azure.com/models/chat/completions?api-version=2024-05-01-preview"
|
||||
- model: Model name.
|
||||
- optional_params: Additional query parameters, including "api_version".
|
||||
- stream: If streaming is required (optional).
|
||||
|
||||
Returns:
|
||||
- A complete URL string, e.g.,
|
||||
"https://litellm8397336933.services.ai.azure.com/models/chat/completions?api-version=2024-05-01-preview"
|
||||
"""
|
||||
if api_base is None:
|
||||
raise ValueError(
|
||||
f"api_base is required for Azure AI Studio. Please set the api_base parameter. Passed `api_base={api_base}`"
|
||||
)
|
||||
original_url = httpx.URL(api_base)
|
||||
|
||||
# Extract api_version or use default
|
||||
api_version = cast(Optional[str], litellm_params.get("api_version"))
|
||||
|
||||
# Create a new dictionary with existing params
|
||||
query_params = dict(original_url.params)
|
||||
|
||||
# Add api_version if needed
|
||||
if "api-version" not in query_params and api_version:
|
||||
query_params["api-version"] = api_version
|
||||
|
||||
# Add the path to the base URL
|
||||
if "services.ai.azure.com" in api_base:
|
||||
new_url = _add_path_to_api_base(
|
||||
api_base=api_base, ending_path="/models/chat/completions"
|
||||
)
|
||||
else:
|
||||
new_url = _add_path_to_api_base(
|
||||
api_base=api_base, ending_path="/chat/completions"
|
||||
)
|
||||
|
||||
# Use the new query_params dictionary
|
||||
final_url = httpx.URL(new_url).copy_with(params=query_params)
|
||||
|
||||
return str(final_url)
|
||||
|
||||
def get_required_params(self) -> List[ProviderField]:
|
||||
"""For a given provider, return it's required fields with a description"""
|
||||
return [
|
||||
ProviderField(
|
||||
field_name="api_key",
|
||||
field_type="string",
|
||||
field_description="Your Azure AI Studio API Key.",
|
||||
field_value="zEJ...",
|
||||
),
|
||||
ProviderField(
|
||||
field_name="api_base",
|
||||
field_type="string",
|
||||
field_description="Your Azure AI Studio API Base.",
|
||||
field_value="https://Mistral-serverless.",
|
||||
),
|
||||
]
|
||||
|
||||
def _transform_messages(
|
||||
self,
|
||||
messages: List[AllMessageValues],
|
||||
model: str,
|
||||
) -> List:
|
||||
"""
|
||||
- Azure AI Studio doesn't support content as a list. This handles:
|
||||
1. Transforms list content to a string.
|
||||
2. If message contains an image or audio, send as is (user-intended)
|
||||
"""
|
||||
for message in messages:
|
||||
# Do nothing if the message contains an image or audio
|
||||
if _audio_or_image_in_message_content(message):
|
||||
continue
|
||||
|
||||
texts = convert_content_list_to_str(message=message)
|
||||
if texts:
|
||||
message["content"] = texts
|
||||
return messages
|
||||
|
||||
def _is_azure_openai_model(self, model: str, api_base: Optional[str]) -> bool:
|
||||
try:
|
||||
if "/" in model:
|
||||
model = model.split("/", 1)[1]
|
||||
if (
|
||||
model in litellm.open_ai_chat_completion_models
|
||||
or model in litellm.open_ai_text_completion_models
|
||||
or model in litellm.open_ai_embedding_models
|
||||
):
|
||||
return True
|
||||
|
||||
except Exception:
|
||||
return False
|
||||
return False
|
||||
|
||||
def _get_openai_compatible_provider_info(
|
||||
self,
|
||||
model: str,
|
||||
api_base: Optional[str],
|
||||
api_key: Optional[str],
|
||||
custom_llm_provider: str,
|
||||
) -> Tuple[Optional[str], Optional[str], str]:
|
||||
api_base = api_base or get_secret_str("AZURE_AI_API_BASE")
|
||||
dynamic_api_key = api_key or get_secret_str("AZURE_AI_API_KEY")
|
||||
|
||||
if self._is_azure_openai_model(model=model, api_base=api_base):
|
||||
verbose_logger.debug(
|
||||
"Model={} is Azure OpenAI model. Setting custom_llm_provider='azure'.".format(
|
||||
model
|
||||
)
|
||||
)
|
||||
custom_llm_provider = "azure"
|
||||
return api_base, dynamic_api_key, custom_llm_provider
|
||||
|
||||
def transform_request(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
headers: dict,
|
||||
) -> dict:
|
||||
extra_body = optional_params.pop("extra_body", {})
|
||||
if extra_body and isinstance(extra_body, dict):
|
||||
optional_params.update(extra_body)
|
||||
optional_params.pop("max_retries", None)
|
||||
return super().transform_request(
|
||||
model, messages, optional_params, litellm_params, headers
|
||||
)
|
||||
|
||||
def transform_response(
|
||||
self,
|
||||
model: str,
|
||||
raw_response: Response,
|
||||
model_response: ModelResponse,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
request_data: dict,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
encoding: Any,
|
||||
api_key: Optional[str] = None,
|
||||
json_mode: Optional[bool] = None,
|
||||
) -> ModelResponse:
|
||||
model_response.model = f"azure_ai/{model}"
|
||||
return super().transform_response(
|
||||
model=model,
|
||||
raw_response=raw_response,
|
||||
model_response=model_response,
|
||||
logging_obj=logging_obj,
|
||||
request_data=request_data,
|
||||
messages=messages,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
encoding=encoding,
|
||||
api_key=api_key,
|
||||
json_mode=json_mode,
|
||||
)
|
||||
|
||||
def should_retry_llm_api_inside_llm_translation_on_http_error(
|
||||
self, e: httpx.HTTPStatusError, litellm_params: dict
|
||||
) -> bool:
|
||||
should_drop_params = litellm_params.get("drop_params") or litellm.drop_params
|
||||
error_text = e.response.text
|
||||
|
||||
if should_drop_params and "Extra inputs are not permitted" in error_text:
|
||||
return True
|
||||
elif (
|
||||
"unknown field: parameter index is not a valid field" in error_text
|
||||
): # remove index from tool calls
|
||||
return True
|
||||
elif (
|
||||
AzureFoundryErrorStrings.SET_EXTRA_PARAMETERS_TO_PASS_THROUGH.value
|
||||
in error_text
|
||||
): # remove extra-parameters from tool calls
|
||||
return True
|
||||
return super().should_retry_llm_api_inside_llm_translation_on_http_error(
|
||||
e=e, litellm_params=litellm_params
|
||||
)
|
||||
|
||||
@property
|
||||
def max_retry_on_unprocessable_entity_error(self) -> int:
|
||||
return 2
|
||||
|
||||
def transform_request_on_unprocessable_entity_error(
|
||||
self, e: httpx.HTTPStatusError, request_data: dict
|
||||
) -> dict:
|
||||
_messages = cast(Optional[List[AllMessageValues]], request_data.get("messages"))
|
||||
if (
|
||||
"unknown field: parameter index is not a valid field" in e.response.text
|
||||
and _messages is not None
|
||||
):
|
||||
litellm.remove_index_from_tool_calls(
|
||||
messages=_messages,
|
||||
)
|
||||
elif (
|
||||
AzureFoundryErrorStrings.SET_EXTRA_PARAMETERS_TO_PASS_THROUGH.value
|
||||
in e.response.text
|
||||
):
|
||||
request_data = self._drop_extra_params_from_request_data(
|
||||
request_data, e.response.text
|
||||
)
|
||||
data = drop_params_from_unprocessable_entity_error(e=e, data=request_data)
|
||||
return data
|
||||
|
||||
def _drop_extra_params_from_request_data(
|
||||
self, request_data: dict, error_text: str
|
||||
) -> dict:
|
||||
params_to_drop = self._extract_params_to_drop_from_error_text(error_text)
|
||||
if params_to_drop:
|
||||
for param in params_to_drop:
|
||||
if param in request_data:
|
||||
request_data.pop(param, None)
|
||||
return request_data
|
||||
|
||||
def _extract_params_to_drop_from_error_text(
|
||||
self, error_text: str
|
||||
) -> Optional[List[str]]:
|
||||
"""
|
||||
Error text looks like this"
|
||||
"Extra parameters ['stream_options', 'extra-parameters'] are not allowed when extra-parameters is not set or set to be 'error'.
|
||||
"""
|
||||
import re
|
||||
|
||||
# Extract parameters within square brackets
|
||||
match = re.search(r"\[(.*?)\]", error_text)
|
||||
if not match:
|
||||
return []
|
||||
|
||||
# Parse the extracted string into a list of parameter names
|
||||
params_str = match.group(1)
|
||||
params = []
|
||||
for param in params_str.split(","):
|
||||
# Clean up the parameter name (remove quotes, spaces)
|
||||
clean_param = param.strip().strip("'").strip('"')
|
||||
if clean_param:
|
||||
params.append(clean_param)
|
||||
return params
|
||||
@@ -0,0 +1 @@
|
||||
from .handler import AzureAIEmbedding
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,98 @@
|
||||
"""
|
||||
Transformation logic from OpenAI /v1/embeddings format to Azure AI Cohere's /v1/embed.
|
||||
|
||||
Why separate file? Make it easy to see how transformation works
|
||||
|
||||
Convers
|
||||
- Cohere request format
|
||||
|
||||
Docs - https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-titan-embed-text.html
|
||||
"""
|
||||
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
from litellm.types.llms.azure_ai import ImageEmbeddingInput, ImageEmbeddingRequest
|
||||
from litellm.types.llms.openai import EmbeddingCreateParams
|
||||
from litellm.types.utils import EmbeddingResponse, Usage
|
||||
from litellm.utils import is_base64_encoded
|
||||
|
||||
|
||||
class AzureAICohereConfig:
|
||||
def __init__(self) -> None:
|
||||
pass
|
||||
|
||||
def _map_azure_model_group(self, model: str) -> str:
|
||||
if model == "offer-cohere-embed-multili-paygo":
|
||||
return "Cohere-embed-v3-multilingual"
|
||||
elif model == "offer-cohere-embed-english-paygo":
|
||||
return "Cohere-embed-v3-english"
|
||||
|
||||
return model
|
||||
|
||||
def _transform_request_image_embeddings(
|
||||
self, input: List[str], optional_params: dict
|
||||
) -> ImageEmbeddingRequest:
|
||||
"""
|
||||
Assume all str in list is base64 encoded string
|
||||
"""
|
||||
image_input: List[ImageEmbeddingInput] = []
|
||||
for i in input:
|
||||
embedding_input = ImageEmbeddingInput(image=i)
|
||||
image_input.append(embedding_input)
|
||||
return ImageEmbeddingRequest(input=image_input, **optional_params)
|
||||
|
||||
def _transform_request(
|
||||
self, input: List[str], optional_params: dict, model: str
|
||||
) -> Tuple[ImageEmbeddingRequest, EmbeddingCreateParams, List[int]]:
|
||||
"""
|
||||
Return the list of input to `/image/embeddings`, `/v1/embeddings`, list of image_embedding_idx for recombination
|
||||
"""
|
||||
image_embeddings: List[str] = []
|
||||
image_embedding_idx: List[int] = []
|
||||
for idx, i in enumerate(input):
|
||||
"""
|
||||
- is base64 -> route to image embeddings
|
||||
- is ImageEmbeddingInput -> route to image embeddings
|
||||
- else -> route to `/v1/embeddings`
|
||||
"""
|
||||
if is_base64_encoded(i):
|
||||
image_embeddings.append(i)
|
||||
image_embedding_idx.append(idx)
|
||||
|
||||
## REMOVE IMAGE EMBEDDINGS FROM input list
|
||||
filtered_input = [
|
||||
item for idx, item in enumerate(input) if idx not in image_embedding_idx
|
||||
]
|
||||
|
||||
v1_embeddings_request = EmbeddingCreateParams(
|
||||
input=filtered_input, model=model, **optional_params
|
||||
)
|
||||
image_embeddings_request = self._transform_request_image_embeddings(
|
||||
input=image_embeddings, optional_params=optional_params
|
||||
)
|
||||
|
||||
return image_embeddings_request, v1_embeddings_request, image_embedding_idx
|
||||
|
||||
def _transform_response(self, response: EmbeddingResponse) -> EmbeddingResponse:
|
||||
additional_headers: Optional[dict] = response._hidden_params.get(
|
||||
"additional_headers"
|
||||
)
|
||||
if additional_headers:
|
||||
# CALCULATE USAGE
|
||||
input_tokens: Optional[str] = additional_headers.get(
|
||||
"llm_provider-num_tokens"
|
||||
)
|
||||
if input_tokens:
|
||||
if response.usage:
|
||||
response.usage.prompt_tokens = int(input_tokens)
|
||||
else:
|
||||
response.usage = Usage(prompt_tokens=int(input_tokens))
|
||||
|
||||
# SET MODEL
|
||||
base_model: Optional[str] = additional_headers.get(
|
||||
"llm_provider-azureml-model-group"
|
||||
)
|
||||
if base_model:
|
||||
response.model = self._map_azure_model_group(base_model)
|
||||
|
||||
return response
|
||||
@@ -0,0 +1,290 @@
|
||||
from typing import List, Optional, Union
|
||||
|
||||
from openai import OpenAI
|
||||
|
||||
import litellm
|
||||
from litellm.llms.custom_httpx.http_handler import (
|
||||
AsyncHTTPHandler,
|
||||
HTTPHandler,
|
||||
get_async_httpx_client,
|
||||
)
|
||||
from litellm.llms.openai.openai import OpenAIChatCompletion
|
||||
from litellm.types.llms.azure_ai import ImageEmbeddingRequest
|
||||
from litellm.types.utils import EmbeddingResponse
|
||||
from litellm.utils import convert_to_model_response_object
|
||||
|
||||
from .cohere_transformation import AzureAICohereConfig
|
||||
|
||||
|
||||
class AzureAIEmbedding(OpenAIChatCompletion):
|
||||
def _process_response(
|
||||
self,
|
||||
image_embedding_responses: Optional[List],
|
||||
text_embedding_responses: Optional[List],
|
||||
image_embeddings_idx: List[int],
|
||||
model_response: EmbeddingResponse,
|
||||
input: List,
|
||||
):
|
||||
combined_responses = []
|
||||
if (
|
||||
image_embedding_responses is not None
|
||||
and text_embedding_responses is not None
|
||||
):
|
||||
# Combine and order the results
|
||||
text_idx = 0
|
||||
image_idx = 0
|
||||
|
||||
for idx in range(len(input)):
|
||||
if idx in image_embeddings_idx:
|
||||
combined_responses.append(image_embedding_responses[image_idx])
|
||||
image_idx += 1
|
||||
else:
|
||||
combined_responses.append(text_embedding_responses[text_idx])
|
||||
text_idx += 1
|
||||
|
||||
model_response.data = combined_responses
|
||||
elif image_embedding_responses is not None:
|
||||
model_response.data = image_embedding_responses
|
||||
elif text_embedding_responses is not None:
|
||||
model_response.data = text_embedding_responses
|
||||
|
||||
response = AzureAICohereConfig()._transform_response(response=model_response) # type: ignore
|
||||
|
||||
return response
|
||||
|
||||
async def async_image_embedding(
|
||||
self,
|
||||
model: str,
|
||||
data: ImageEmbeddingRequest,
|
||||
timeout: float,
|
||||
logging_obj,
|
||||
model_response: litellm.EmbeddingResponse,
|
||||
optional_params: dict,
|
||||
api_key: Optional[str],
|
||||
api_base: Optional[str],
|
||||
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
|
||||
) -> EmbeddingResponse:
|
||||
if client is None or not isinstance(client, AsyncHTTPHandler):
|
||||
client = get_async_httpx_client(
|
||||
llm_provider=litellm.LlmProviders.AZURE_AI,
|
||||
params={"timeout": timeout},
|
||||
)
|
||||
|
||||
url = "{}/images/embeddings".format(api_base)
|
||||
|
||||
response = await client.post(
|
||||
url=url,
|
||||
json=data, # type: ignore
|
||||
headers={"Authorization": "Bearer {}".format(api_key)},
|
||||
)
|
||||
|
||||
embedding_response = response.json()
|
||||
embedding_headers = dict(response.headers)
|
||||
returned_response: EmbeddingResponse = convert_to_model_response_object( # type: ignore
|
||||
response_object=embedding_response,
|
||||
model_response_object=model_response,
|
||||
response_type="embedding",
|
||||
stream=False,
|
||||
_response_headers=embedding_headers,
|
||||
)
|
||||
return returned_response
|
||||
|
||||
def image_embedding(
|
||||
self,
|
||||
model: str,
|
||||
data: ImageEmbeddingRequest,
|
||||
timeout: float,
|
||||
logging_obj,
|
||||
model_response: EmbeddingResponse,
|
||||
optional_params: dict,
|
||||
api_key: Optional[str],
|
||||
api_base: Optional[str],
|
||||
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
|
||||
):
|
||||
if api_base is None:
|
||||
raise ValueError(
|
||||
"api_base is None. Please set AZURE_AI_API_BASE or dynamically via `api_base` param, to make the request."
|
||||
)
|
||||
if api_key is None:
|
||||
raise ValueError(
|
||||
"api_key is None. Please set AZURE_AI_API_KEY or dynamically via `api_key` param, to make the request."
|
||||
)
|
||||
|
||||
if client is None or not isinstance(client, HTTPHandler):
|
||||
client = HTTPHandler(timeout=timeout, concurrent_limit=1)
|
||||
|
||||
url = "{}/images/embeddings".format(api_base)
|
||||
|
||||
response = client.post(
|
||||
url=url,
|
||||
json=data, # type: ignore
|
||||
headers={"Authorization": "Bearer {}".format(api_key)},
|
||||
)
|
||||
|
||||
embedding_response = response.json()
|
||||
embedding_headers = dict(response.headers)
|
||||
returned_response: EmbeddingResponse = convert_to_model_response_object( # type: ignore
|
||||
response_object=embedding_response,
|
||||
model_response_object=model_response,
|
||||
response_type="embedding",
|
||||
stream=False,
|
||||
_response_headers=embedding_headers,
|
||||
)
|
||||
return returned_response
|
||||
|
||||
async def async_embedding(
|
||||
self,
|
||||
model: str,
|
||||
input: List,
|
||||
timeout: float,
|
||||
logging_obj,
|
||||
model_response: litellm.EmbeddingResponse,
|
||||
optional_params: dict,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
client=None,
|
||||
) -> EmbeddingResponse:
|
||||
(
|
||||
image_embeddings_request,
|
||||
v1_embeddings_request,
|
||||
image_embeddings_idx,
|
||||
) = AzureAICohereConfig()._transform_request(
|
||||
input=input, optional_params=optional_params, model=model
|
||||
)
|
||||
|
||||
image_embedding_responses: Optional[List] = None
|
||||
text_embedding_responses: Optional[List] = None
|
||||
|
||||
if image_embeddings_request["input"]:
|
||||
image_response = await self.async_image_embedding(
|
||||
model=model,
|
||||
data=image_embeddings_request,
|
||||
timeout=timeout,
|
||||
logging_obj=logging_obj,
|
||||
model_response=model_response,
|
||||
optional_params=optional_params,
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
client=client,
|
||||
)
|
||||
|
||||
image_embedding_responses = image_response.data
|
||||
if image_embedding_responses is None:
|
||||
raise Exception("/image/embeddings route returned None Embeddings.")
|
||||
|
||||
if v1_embeddings_request["input"]:
|
||||
response: EmbeddingResponse = await super().embedding( # type: ignore
|
||||
model=model,
|
||||
input=input,
|
||||
timeout=timeout,
|
||||
logging_obj=logging_obj,
|
||||
model_response=model_response,
|
||||
optional_params=optional_params,
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
client=client,
|
||||
aembedding=True,
|
||||
)
|
||||
text_embedding_responses = response.data
|
||||
if text_embedding_responses is None:
|
||||
raise Exception("/v1/embeddings route returned None Embeddings.")
|
||||
|
||||
return self._process_response(
|
||||
image_embedding_responses=image_embedding_responses,
|
||||
text_embedding_responses=text_embedding_responses,
|
||||
image_embeddings_idx=image_embeddings_idx,
|
||||
model_response=model_response,
|
||||
input=input,
|
||||
)
|
||||
|
||||
def embedding(
|
||||
self,
|
||||
model: str,
|
||||
input: List,
|
||||
timeout: float,
|
||||
logging_obj,
|
||||
model_response: EmbeddingResponse,
|
||||
optional_params: dict,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
client=None,
|
||||
aembedding=None,
|
||||
max_retries: Optional[int] = None,
|
||||
) -> EmbeddingResponse:
|
||||
"""
|
||||
- Separate image url from text
|
||||
-> route image url call to `/image/embeddings`
|
||||
-> route text call to `/v1/embeddings` (OpenAI route)
|
||||
|
||||
assemble result in-order, and return
|
||||
"""
|
||||
if aembedding is True:
|
||||
return self.async_embedding( # type: ignore
|
||||
model,
|
||||
input,
|
||||
timeout,
|
||||
logging_obj,
|
||||
model_response,
|
||||
optional_params,
|
||||
api_key,
|
||||
api_base,
|
||||
client,
|
||||
)
|
||||
|
||||
(
|
||||
image_embeddings_request,
|
||||
v1_embeddings_request,
|
||||
image_embeddings_idx,
|
||||
) = AzureAICohereConfig()._transform_request(
|
||||
input=input, optional_params=optional_params, model=model
|
||||
)
|
||||
|
||||
image_embedding_responses: Optional[List] = None
|
||||
text_embedding_responses: Optional[List] = None
|
||||
|
||||
if image_embeddings_request["input"]:
|
||||
image_response = self.image_embedding(
|
||||
model=model,
|
||||
data=image_embeddings_request,
|
||||
timeout=timeout,
|
||||
logging_obj=logging_obj,
|
||||
model_response=model_response,
|
||||
optional_params=optional_params,
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
client=client,
|
||||
)
|
||||
|
||||
image_embedding_responses = image_response.data
|
||||
if image_embedding_responses is None:
|
||||
raise Exception("/image/embeddings route returned None Embeddings.")
|
||||
|
||||
if v1_embeddings_request["input"]:
|
||||
response: EmbeddingResponse = super().embedding( # type: ignore
|
||||
model,
|
||||
input,
|
||||
timeout,
|
||||
logging_obj,
|
||||
model_response,
|
||||
optional_params,
|
||||
api_key,
|
||||
api_base,
|
||||
client=(
|
||||
client
|
||||
if client is not None and isinstance(client, OpenAI)
|
||||
else None
|
||||
),
|
||||
aembedding=aembedding,
|
||||
)
|
||||
|
||||
text_embedding_responses = response.data
|
||||
if text_embedding_responses is None:
|
||||
raise Exception("/v1/embeddings route returned None Embeddings.")
|
||||
|
||||
return self._process_response(
|
||||
image_embedding_responses=image_embedding_responses,
|
||||
text_embedding_responses=text_embedding_responses,
|
||||
image_embeddings_idx=image_embeddings_idx,
|
||||
model_response=model_response,
|
||||
input=input,
|
||||
)
|
||||
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,5 @@
|
||||
"""
|
||||
Azure AI Rerank - uses `llm_http_handler.py` to make httpx requests
|
||||
|
||||
Request/Response transformation is handled in `transformation.py`
|
||||
"""
|
||||
@@ -0,0 +1,91 @@
|
||||
"""
|
||||
Translate between Cohere's `/rerank` format and Azure AI's `/rerank` format.
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import httpx
|
||||
|
||||
import litellm
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
|
||||
from litellm.llms.cohere.rerank.transformation import CohereRerankConfig
|
||||
from litellm.secret_managers.main import get_secret_str
|
||||
from litellm.types.utils import RerankResponse
|
||||
|
||||
|
||||
class AzureAIRerankConfig(CohereRerankConfig):
|
||||
"""
|
||||
Azure AI Rerank - Follows the same Spec as Cohere Rerank
|
||||
"""
|
||||
|
||||
def get_complete_url(self, api_base: Optional[str], model: str) -> str:
|
||||
if api_base is None:
|
||||
raise ValueError(
|
||||
"Azure AI API Base is required. api_base=None. Set in call or via `AZURE_AI_API_BASE` env var."
|
||||
)
|
||||
if not api_base.endswith("/v1/rerank"):
|
||||
api_base = f"{api_base}/v1/rerank"
|
||||
return api_base
|
||||
|
||||
def validate_environment(
|
||||
self,
|
||||
headers: dict,
|
||||
model: str,
|
||||
api_key: Optional[str] = None,
|
||||
) -> dict:
|
||||
if api_key is None:
|
||||
api_key = get_secret_str("AZURE_AI_API_KEY") or litellm.azure_key
|
||||
|
||||
if api_key is None:
|
||||
raise ValueError(
|
||||
"Azure AI API key is required. Please set 'AZURE_AI_API_KEY' or 'litellm.azure_key'"
|
||||
)
|
||||
|
||||
default_headers = {
|
||||
"Authorization": f"Bearer {api_key}",
|
||||
"accept": "application/json",
|
||||
"content-type": "application/json",
|
||||
}
|
||||
|
||||
# If 'Authorization' is provided in headers, it overrides the default.
|
||||
if "Authorization" in headers:
|
||||
default_headers["Authorization"] = headers["Authorization"]
|
||||
|
||||
# Merge other headers, overriding any default ones except Authorization
|
||||
return {**default_headers, **headers}
|
||||
|
||||
def transform_rerank_response(
|
||||
self,
|
||||
model: str,
|
||||
raw_response: httpx.Response,
|
||||
model_response: RerankResponse,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
api_key: Optional[str] = None,
|
||||
request_data: dict = {},
|
||||
optional_params: dict = {},
|
||||
litellm_params: dict = {},
|
||||
) -> RerankResponse:
|
||||
rerank_response = super().transform_rerank_response(
|
||||
model=model,
|
||||
raw_response=raw_response,
|
||||
model_response=model_response,
|
||||
logging_obj=logging_obj,
|
||||
api_key=api_key,
|
||||
request_data=request_data,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
)
|
||||
base_model = self._get_base_model(
|
||||
rerank_response._hidden_params.get("llm_provider-azureml-model-group")
|
||||
)
|
||||
rerank_response._hidden_params["model"] = base_model
|
||||
return rerank_response
|
||||
|
||||
def _get_base_model(self, azure_model_group: Optional[str]) -> Optional[str]:
|
||||
if azure_model_group is None:
|
||||
return None
|
||||
if azure_model_group == "offer-cohere-rerank-mul-paygo":
|
||||
return "azure_ai/cohere-rerank-v3-multilingual"
|
||||
if azure_model_group == "offer-cohere-rerank-eng-paygo":
|
||||
return "azure_ai/cohere-rerank-v3-english"
|
||||
return azure_model_group
|
||||
89
.venv/lib/python3.10/site-packages/litellm/llms/base.py
Normal file
89
.venv/lib/python3.10/site-packages/litellm/llms/base.py
Normal file
@@ -0,0 +1,89 @@
|
||||
## This is a template base class to be used for adding new LLM providers via API calls
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
import httpx
|
||||
|
||||
import litellm
|
||||
from litellm.litellm_core_utils.streaming_handler import CustomStreamWrapper
|
||||
from litellm.types.utils import ModelResponse, TextCompletionResponse
|
||||
|
||||
|
||||
class BaseLLM:
|
||||
_client_session: Optional[httpx.Client] = None
|
||||
|
||||
def process_response(
|
||||
self,
|
||||
model: str,
|
||||
response: httpx.Response,
|
||||
model_response: ModelResponse,
|
||||
stream: bool,
|
||||
logging_obj: Any,
|
||||
optional_params: dict,
|
||||
api_key: str,
|
||||
data: Union[dict, str],
|
||||
messages: list,
|
||||
print_verbose,
|
||||
encoding,
|
||||
) -> Union[ModelResponse, CustomStreamWrapper]:
|
||||
"""
|
||||
Helper function to process the response across sync + async completion calls
|
||||
"""
|
||||
return model_response
|
||||
|
||||
def process_text_completion_response(
|
||||
self,
|
||||
model: str,
|
||||
response: httpx.Response,
|
||||
model_response: TextCompletionResponse,
|
||||
stream: bool,
|
||||
logging_obj: Any,
|
||||
optional_params: dict,
|
||||
api_key: str,
|
||||
data: Union[dict, str],
|
||||
messages: list,
|
||||
print_verbose,
|
||||
encoding,
|
||||
) -> Union[TextCompletionResponse, CustomStreamWrapper]:
|
||||
"""
|
||||
Helper function to process the response across sync + async completion calls
|
||||
"""
|
||||
return model_response
|
||||
|
||||
def create_client_session(self):
|
||||
if litellm.client_session:
|
||||
_client_session = litellm.client_session
|
||||
else:
|
||||
_client_session = httpx.Client()
|
||||
|
||||
return _client_session
|
||||
|
||||
def create_aclient_session(self):
|
||||
if litellm.aclient_session:
|
||||
_aclient_session = litellm.aclient_session
|
||||
else:
|
||||
_aclient_session = httpx.AsyncClient()
|
||||
|
||||
return _aclient_session
|
||||
|
||||
def __exit__(self):
|
||||
if hasattr(self, "_client_session") and self._client_session is not None:
|
||||
self._client_session.close()
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
if hasattr(self, "_aclient_session"):
|
||||
await self._aclient_session.aclose() # type: ignore
|
||||
|
||||
def validate_environment(
|
||||
self, *args, **kwargs
|
||||
) -> Optional[Any]: # set up the environment required to run the model
|
||||
return None
|
||||
|
||||
def completion(
|
||||
self, *args, **kwargs
|
||||
) -> Any: # logic for parsing in - calling - parsing out model completion calls
|
||||
return None
|
||||
|
||||
def embedding(
|
||||
self, *args, **kwargs
|
||||
) -> Any: # logic for parsing in - calling - parsing out model embedding calls
|
||||
return None
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,35 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
|
||||
|
||||
LiteLLMLoggingObj = _LiteLLMLoggingObj
|
||||
else:
|
||||
LiteLLMLoggingObj = Any
|
||||
|
||||
|
||||
class BaseAnthropicMessagesConfig(ABC):
|
||||
@abstractmethod
|
||||
def validate_environment(
|
||||
self,
|
||||
headers: dict,
|
||||
model: str,
|
||||
api_key: Optional[str] = None,
|
||||
) -> dict:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_complete_url(self, api_base: Optional[str], model: str) -> str:
|
||||
"""
|
||||
OPTIONAL
|
||||
|
||||
Get the complete url for the request
|
||||
|
||||
Some providers need `model` in `api_base`
|
||||
"""
|
||||
return api_base or ""
|
||||
|
||||
@abstractmethod
|
||||
def get_supported_anthropic_messages_params(self, model: str) -> list:
|
||||
pass
|
||||
Binary file not shown.
@@ -0,0 +1,86 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import TYPE_CHECKING, Any, List, Optional, Union
|
||||
|
||||
import httpx
|
||||
|
||||
from litellm.llms.base_llm.chat.transformation import BaseConfig
|
||||
from litellm.types.llms.openai import (
|
||||
AllMessageValues,
|
||||
OpenAIAudioTranscriptionOptionalParams,
|
||||
)
|
||||
from litellm.types.utils import FileTypes, ModelResponse
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
|
||||
|
||||
LiteLLMLoggingObj = _LiteLLMLoggingObj
|
||||
else:
|
||||
LiteLLMLoggingObj = Any
|
||||
|
||||
|
||||
class BaseAudioTranscriptionConfig(BaseConfig, ABC):
|
||||
@abstractmethod
|
||||
def get_supported_openai_params(
|
||||
self, model: str
|
||||
) -> List[OpenAIAudioTranscriptionOptionalParams]:
|
||||
pass
|
||||
|
||||
def get_complete_url(
|
||||
self,
|
||||
api_base: Optional[str],
|
||||
api_key: Optional[str],
|
||||
model: str,
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
stream: Optional[bool] = None,
|
||||
) -> str:
|
||||
"""
|
||||
OPTIONAL
|
||||
|
||||
Get the complete url for the request
|
||||
|
||||
Some providers need `model` in `api_base`
|
||||
"""
|
||||
return api_base or ""
|
||||
|
||||
@abstractmethod
|
||||
def transform_audio_transcription_request(
|
||||
self,
|
||||
model: str,
|
||||
audio_file: FileTypes,
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
) -> Union[dict, bytes]:
|
||||
raise NotImplementedError(
|
||||
"AudioTranscriptionConfig needs a request transformation for audio transcription models"
|
||||
)
|
||||
|
||||
def transform_request(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
headers: dict,
|
||||
) -> dict:
|
||||
raise NotImplementedError(
|
||||
"AudioTranscriptionConfig does not need a request transformation for audio transcription models"
|
||||
)
|
||||
|
||||
def transform_response(
|
||||
self,
|
||||
model: str,
|
||||
raw_response: httpx.Response,
|
||||
model_response: ModelResponse,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
request_data: dict,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
encoding: Any,
|
||||
api_key: Optional[str] = None,
|
||||
json_mode: Optional[bool] = None,
|
||||
) -> ModelResponse:
|
||||
raise NotImplementedError(
|
||||
"AudioTranscriptionConfig does not need a response transformation for audio transcription models"
|
||||
)
|
||||
@@ -0,0 +1,212 @@
|
||||
import json
|
||||
from abc import abstractmethod
|
||||
from typing import List, Optional, Union, cast
|
||||
|
||||
import litellm
|
||||
from litellm.types.utils import (
|
||||
Choices,
|
||||
Delta,
|
||||
GenericStreamingChunk,
|
||||
ModelResponse,
|
||||
ModelResponseStream,
|
||||
StreamingChoices,
|
||||
)
|
||||
|
||||
|
||||
class BaseModelResponseIterator:
|
||||
def __init__(
|
||||
self, streaming_response, sync_stream: bool, json_mode: Optional[bool] = False
|
||||
):
|
||||
self.streaming_response = streaming_response
|
||||
self.response_iterator = self.streaming_response
|
||||
self.json_mode = json_mode
|
||||
|
||||
def chunk_parser(
|
||||
self, chunk: dict
|
||||
) -> Union[GenericStreamingChunk, ModelResponseStream]:
|
||||
return GenericStreamingChunk(
|
||||
text="",
|
||||
is_finished=False,
|
||||
finish_reason="",
|
||||
usage=None,
|
||||
index=0,
|
||||
tool_use=None,
|
||||
)
|
||||
|
||||
# Sync iterator
|
||||
def __iter__(self):
|
||||
return self
|
||||
|
||||
def _handle_string_chunk(
|
||||
self, str_line: str
|
||||
) -> Union[GenericStreamingChunk, ModelResponseStream]:
|
||||
# chunk is a str at this point
|
||||
|
||||
stripped_chunk = litellm.CustomStreamWrapper._strip_sse_data_from_chunk(
|
||||
str_line
|
||||
)
|
||||
try:
|
||||
if stripped_chunk is not None:
|
||||
stripped_json_chunk: Optional[dict] = json.loads(stripped_chunk)
|
||||
else:
|
||||
stripped_json_chunk = None
|
||||
except json.JSONDecodeError:
|
||||
stripped_json_chunk = None
|
||||
|
||||
if "[DONE]" in str_line:
|
||||
return GenericStreamingChunk(
|
||||
text="",
|
||||
is_finished=True,
|
||||
finish_reason="stop",
|
||||
usage=None,
|
||||
index=0,
|
||||
tool_use=None,
|
||||
)
|
||||
elif stripped_json_chunk:
|
||||
return self.chunk_parser(chunk=stripped_json_chunk)
|
||||
else:
|
||||
return GenericStreamingChunk(
|
||||
text="",
|
||||
is_finished=False,
|
||||
finish_reason="",
|
||||
usage=None,
|
||||
index=0,
|
||||
tool_use=None,
|
||||
)
|
||||
|
||||
def __next__(self):
|
||||
try:
|
||||
chunk = self.response_iterator.__next__()
|
||||
except StopIteration:
|
||||
raise StopIteration
|
||||
except ValueError as e:
|
||||
raise RuntimeError(f"Error receiving chunk from stream: {e}")
|
||||
|
||||
try:
|
||||
str_line = chunk
|
||||
if isinstance(chunk, bytes): # Handle binary data
|
||||
str_line = chunk.decode("utf-8") # Convert bytes to string
|
||||
index = str_line.find("data:")
|
||||
if index != -1:
|
||||
str_line = str_line[index:]
|
||||
# chunk is a str at this point
|
||||
return self._handle_string_chunk(str_line=str_line)
|
||||
except StopIteration:
|
||||
raise StopIteration
|
||||
except ValueError as e:
|
||||
raise RuntimeError(f"Error parsing chunk: {e},\nReceived chunk: {chunk}")
|
||||
|
||||
# Async iterator
|
||||
def __aiter__(self):
|
||||
self.async_response_iterator = self.streaming_response.__aiter__()
|
||||
return self
|
||||
|
||||
async def __anext__(self):
|
||||
try:
|
||||
chunk = await self.async_response_iterator.__anext__()
|
||||
|
||||
except StopAsyncIteration:
|
||||
raise StopAsyncIteration
|
||||
except ValueError as e:
|
||||
raise RuntimeError(f"Error receiving chunk from stream: {e}")
|
||||
|
||||
try:
|
||||
str_line = chunk
|
||||
if isinstance(chunk, bytes): # Handle binary data
|
||||
str_line = chunk.decode("utf-8") # Convert bytes to string
|
||||
index = str_line.find("data:")
|
||||
if index != -1:
|
||||
str_line = str_line[index:]
|
||||
|
||||
# chunk is a str at this point
|
||||
chunk = self._handle_string_chunk(str_line=str_line)
|
||||
|
||||
return chunk
|
||||
except StopAsyncIteration:
|
||||
raise StopAsyncIteration
|
||||
except ValueError as e:
|
||||
raise RuntimeError(f"Error parsing chunk: {e},\nReceived chunk: {chunk}")
|
||||
|
||||
|
||||
class MockResponseIterator: # for returning ai21 streaming responses
|
||||
def __init__(
|
||||
self, model_response: ModelResponse, json_mode: Optional[bool] = False
|
||||
):
|
||||
self.model_response = model_response
|
||||
self.json_mode = json_mode
|
||||
self.is_done = False
|
||||
|
||||
# Sync iterator
|
||||
def __iter__(self):
|
||||
return self
|
||||
|
||||
def _chunk_parser(self, chunk_data: ModelResponse) -> ModelResponseStream:
|
||||
try:
|
||||
streaming_choices: List[StreamingChoices] = []
|
||||
for choice in chunk_data.choices:
|
||||
streaming_choices.append(
|
||||
StreamingChoices(
|
||||
index=choice.index,
|
||||
delta=Delta(
|
||||
**cast(Choices, choice).message.model_dump(),
|
||||
),
|
||||
finish_reason=choice.finish_reason,
|
||||
)
|
||||
)
|
||||
processed_chunk = ModelResponseStream(
|
||||
id=chunk_data.id,
|
||||
object="chat.completion",
|
||||
created=chunk_data.created,
|
||||
model=chunk_data.model,
|
||||
choices=streaming_choices,
|
||||
)
|
||||
return processed_chunk
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to decode chunk: {chunk_data}. Error: {e}")
|
||||
|
||||
def __next__(self):
|
||||
if self.is_done:
|
||||
raise StopIteration
|
||||
self.is_done = True
|
||||
return self._chunk_parser(self.model_response)
|
||||
|
||||
# Async iterator
|
||||
def __aiter__(self):
|
||||
return self
|
||||
|
||||
async def __anext__(self):
|
||||
if self.is_done:
|
||||
raise StopAsyncIteration
|
||||
self.is_done = True
|
||||
return self._chunk_parser(self.model_response)
|
||||
|
||||
|
||||
class FakeStreamResponseIterator:
|
||||
def __init__(self, model_response, json_mode: Optional[bool] = False):
|
||||
self.model_response = model_response
|
||||
self.json_mode = json_mode
|
||||
self.is_done = False
|
||||
|
||||
# Sync iterator
|
||||
def __iter__(self):
|
||||
return self
|
||||
|
||||
@abstractmethod
|
||||
def chunk_parser(self, chunk: dict) -> GenericStreamingChunk:
|
||||
pass
|
||||
|
||||
def __next__(self):
|
||||
if self.is_done:
|
||||
raise StopIteration
|
||||
self.is_done = True
|
||||
return self.chunk_parser(self.model_response)
|
||||
|
||||
# Async iterator
|
||||
def __aiter__(self):
|
||||
return self
|
||||
|
||||
async def __anext__(self):
|
||||
if self.is_done:
|
||||
raise StopAsyncIteration
|
||||
self.is_done = True
|
||||
return self.chunk_parser(self.model_response)
|
||||
@@ -0,0 +1,190 @@
|
||||
"""
|
||||
Utility functions for base LLM classes.
|
||||
"""
|
||||
|
||||
import copy
|
||||
import json
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Optional, Type, Union
|
||||
|
||||
from openai.lib import _parsing, _pydantic
|
||||
from pydantic import BaseModel
|
||||
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.types.llms.openai import AllMessageValues, ChatCompletionToolCallChunk
|
||||
from litellm.types.utils import Message, ProviderSpecificModelInfo
|
||||
|
||||
|
||||
class BaseLLMModelInfo(ABC):
|
||||
def get_provider_info(
|
||||
self,
|
||||
model: str,
|
||||
) -> Optional[ProviderSpecificModelInfo]:
|
||||
"""
|
||||
Default values all models of this provider support.
|
||||
"""
|
||||
return None
|
||||
|
||||
@abstractmethod
|
||||
def get_models(
|
||||
self, api_key: Optional[str] = None, api_base: Optional[str] = None
|
||||
) -> List[str]:
|
||||
"""
|
||||
Returns a list of models supported by this provider.
|
||||
"""
|
||||
return []
|
||||
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
def get_api_key(api_key: Optional[str] = None) -> Optional[str]:
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
def get_api_base(api_base: Optional[str] = None) -> Optional[str]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def validate_environment(
|
||||
self,
|
||||
headers: dict,
|
||||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
) -> dict:
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
def get_base_model(model: str) -> Optional[str]:
|
||||
"""
|
||||
Returns the base model name from the given model name.
|
||||
|
||||
Some providers like bedrock - can receive model=`invoke/anthropic.claude-3-opus-20240229-v1:0` or `converse/anthropic.claude-3-opus-20240229-v1:0`
|
||||
This function will return `anthropic.claude-3-opus-20240229-v1:0`
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
def _convert_tool_response_to_message(
|
||||
tool_calls: List[ChatCompletionToolCallChunk],
|
||||
) -> Optional[Message]:
|
||||
"""
|
||||
In JSON mode, Anthropic API returns JSON schema as a tool call, we need to convert it to a message to follow the OpenAI format
|
||||
|
||||
"""
|
||||
## HANDLE JSON MODE - anthropic returns single function call
|
||||
json_mode_content_str: Optional[str] = tool_calls[0]["function"].get("arguments")
|
||||
try:
|
||||
if json_mode_content_str is not None:
|
||||
args = json.loads(json_mode_content_str)
|
||||
if isinstance(args, dict) and (values := args.get("values")) is not None:
|
||||
_message = Message(content=json.dumps(values))
|
||||
return _message
|
||||
else:
|
||||
# a lot of the times the `values` key is not present in the tool response
|
||||
# relevant issue: https://github.com/BerriAI/litellm/issues/6741
|
||||
_message = Message(content=json.dumps(args))
|
||||
return _message
|
||||
except json.JSONDecodeError:
|
||||
# json decode error does occur, return the original tool response str
|
||||
return Message(content=json_mode_content_str)
|
||||
return None
|
||||
|
||||
|
||||
def _dict_to_response_format_helper(
|
||||
response_format: dict, ref_template: Optional[str] = None
|
||||
) -> dict:
|
||||
if ref_template is not None and response_format.get("type") == "json_schema":
|
||||
# Deep copy to avoid modifying original
|
||||
modified_format = copy.deepcopy(response_format)
|
||||
schema = modified_format["json_schema"]["schema"]
|
||||
|
||||
# Update all $ref values in the schema
|
||||
def update_refs(schema):
|
||||
stack = [(schema, [])]
|
||||
visited = set()
|
||||
|
||||
while stack:
|
||||
obj, path = stack.pop()
|
||||
obj_id = id(obj)
|
||||
|
||||
if obj_id in visited:
|
||||
continue
|
||||
visited.add(obj_id)
|
||||
|
||||
if isinstance(obj, dict):
|
||||
if "$ref" in obj:
|
||||
ref_path = obj["$ref"]
|
||||
model_name = ref_path.split("/")[-1]
|
||||
obj["$ref"] = ref_template.format(model=model_name)
|
||||
|
||||
for k, v in obj.items():
|
||||
if isinstance(v, (dict, list)):
|
||||
stack.append((v, path + [k]))
|
||||
|
||||
elif isinstance(obj, list):
|
||||
for i, item in enumerate(obj):
|
||||
if isinstance(item, (dict, list)):
|
||||
stack.append((item, path + [i]))
|
||||
|
||||
update_refs(schema)
|
||||
return modified_format
|
||||
return response_format
|
||||
|
||||
|
||||
def type_to_response_format_param(
|
||||
response_format: Optional[Union[Type[BaseModel], dict]],
|
||||
ref_template: Optional[str] = None,
|
||||
) -> Optional[dict]:
|
||||
"""
|
||||
Re-implementation of openai's 'type_to_response_format_param' function
|
||||
|
||||
Used for converting pydantic object to api schema.
|
||||
"""
|
||||
if response_format is None:
|
||||
return None
|
||||
|
||||
if isinstance(response_format, dict):
|
||||
return _dict_to_response_format_helper(response_format, ref_template)
|
||||
|
||||
# type checkers don't narrow the negation of a `TypeGuard` as it isn't
|
||||
# a safe default behaviour but we know that at this point the `response_format`
|
||||
# can only be a `type`
|
||||
if not _parsing._completions.is_basemodel_type(response_format):
|
||||
raise TypeError(f"Unsupported response_format type - {response_format}")
|
||||
|
||||
if ref_template is not None:
|
||||
schema = response_format.model_json_schema(ref_template=ref_template)
|
||||
else:
|
||||
schema = _pydantic.to_strict_json_schema(response_format)
|
||||
|
||||
return {
|
||||
"type": "json_schema",
|
||||
"json_schema": {
|
||||
"schema": schema,
|
||||
"name": response_format.__name__,
|
||||
"strict": True,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def map_developer_role_to_system_role(
|
||||
messages: List[AllMessageValues],
|
||||
) -> List[AllMessageValues]:
|
||||
"""
|
||||
Translate `developer` role to `system` role for non-OpenAI providers.
|
||||
"""
|
||||
new_messages: List[AllMessageValues] = []
|
||||
for m in messages:
|
||||
if m["role"] == "developer":
|
||||
verbose_logger.debug(
|
||||
"Translating developer role to system role for non-OpenAI providers."
|
||||
) # ensure user knows what's happening with their input.
|
||||
new_messages.append({"role": "system", "content": m["content"]})
|
||||
else:
|
||||
new_messages.append(m)
|
||||
return new_messages
|
||||
Binary file not shown.
@@ -0,0 +1,400 @@
|
||||
"""
|
||||
Common base config for all LLM providers
|
||||
"""
|
||||
|
||||
import types
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
AsyncIterator,
|
||||
Iterator,
|
||||
List,
|
||||
Optional,
|
||||
Type,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
|
||||
import httpx
|
||||
from pydantic import BaseModel
|
||||
|
||||
from litellm.constants import DEFAULT_MAX_TOKENS, RESPONSE_FORMAT_TOOL_NAME
|
||||
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
|
||||
from litellm.types.llms.openai import (
|
||||
AllMessageValues,
|
||||
ChatCompletionToolChoiceFunctionParam,
|
||||
ChatCompletionToolChoiceObjectParam,
|
||||
ChatCompletionToolParam,
|
||||
ChatCompletionToolParamFunctionChunk,
|
||||
)
|
||||
from litellm.types.utils import ModelResponse
|
||||
from litellm.utils import CustomStreamWrapper
|
||||
|
||||
from ..base_utils import (
|
||||
map_developer_role_to_system_role,
|
||||
type_to_response_format_param,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
|
||||
|
||||
LiteLLMLoggingObj = _LiteLLMLoggingObj
|
||||
else:
|
||||
LiteLLMLoggingObj = Any
|
||||
|
||||
|
||||
class BaseLLMException(Exception):
|
||||
def __init__(
|
||||
self,
|
||||
status_code: int,
|
||||
message: str,
|
||||
headers: Optional[Union[dict, httpx.Headers]] = None,
|
||||
request: Optional[httpx.Request] = None,
|
||||
response: Optional[httpx.Response] = None,
|
||||
body: Optional[dict] = None,
|
||||
):
|
||||
self.status_code = status_code
|
||||
self.message: str = message
|
||||
self.headers = headers
|
||||
if request:
|
||||
self.request = request
|
||||
else:
|
||||
self.request = httpx.Request(
|
||||
method="POST", url="https://docs.litellm.ai/docs"
|
||||
)
|
||||
if response:
|
||||
self.response = response
|
||||
else:
|
||||
self.response = httpx.Response(
|
||||
status_code=status_code, request=self.request
|
||||
)
|
||||
self.body = body
|
||||
super().__init__(
|
||||
self.message
|
||||
) # Call the base class constructor with the parameters it needs
|
||||
|
||||
|
||||
class BaseConfig(ABC):
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def get_config(cls):
|
||||
return {
|
||||
k: v
|
||||
for k, v in cls.__dict__.items()
|
||||
if not k.startswith("__")
|
||||
and not k.startswith("_abc")
|
||||
and not isinstance(
|
||||
v,
|
||||
(
|
||||
types.FunctionType,
|
||||
types.BuiltinFunctionType,
|
||||
classmethod,
|
||||
staticmethod,
|
||||
),
|
||||
)
|
||||
and v is not None
|
||||
}
|
||||
|
||||
def get_json_schema_from_pydantic_object(
|
||||
self, response_format: Optional[Union[Type[BaseModel], dict]]
|
||||
) -> Optional[dict]:
|
||||
return type_to_response_format_param(response_format=response_format)
|
||||
|
||||
def is_thinking_enabled(self, non_default_params: dict) -> bool:
|
||||
return (
|
||||
non_default_params.get("thinking", {}).get("type") == "enabled"
|
||||
or non_default_params.get("reasoning_effort") is not None
|
||||
)
|
||||
|
||||
def update_optional_params_with_thinking_tokens(
|
||||
self, non_default_params: dict, optional_params: dict
|
||||
):
|
||||
"""
|
||||
Handles scenario where max tokens is not specified. For anthropic models (anthropic api/bedrock/vertex ai), this requires having the max tokens being set and being greater than the thinking token budget.
|
||||
|
||||
Checks 'non_default_params' for 'thinking' and 'max_tokens'
|
||||
|
||||
if 'thinking' is enabled and 'max_tokens' is not specified, set 'max_tokens' to the thinking token budget + DEFAULT_MAX_TOKENS
|
||||
"""
|
||||
is_thinking_enabled = self.is_thinking_enabled(optional_params)
|
||||
if is_thinking_enabled and "max_tokens" not in non_default_params:
|
||||
thinking_token_budget = cast(dict, optional_params["thinking"]).get(
|
||||
"budget_tokens", None
|
||||
)
|
||||
if thinking_token_budget is not None:
|
||||
optional_params["max_tokens"] = (
|
||||
thinking_token_budget + DEFAULT_MAX_TOKENS
|
||||
)
|
||||
|
||||
def should_fake_stream(
|
||||
self,
|
||||
model: Optional[str],
|
||||
stream: Optional[bool],
|
||||
custom_llm_provider: Optional[str] = None,
|
||||
) -> bool:
|
||||
"""
|
||||
Returns True if the model/provider should fake stream
|
||||
"""
|
||||
return False
|
||||
|
||||
def _add_tools_to_optional_params(self, optional_params: dict, tools: List) -> dict:
|
||||
"""
|
||||
Helper util to add tools to optional_params.
|
||||
"""
|
||||
if "tools" not in optional_params:
|
||||
optional_params["tools"] = tools
|
||||
else:
|
||||
optional_params["tools"] = [
|
||||
*optional_params["tools"],
|
||||
*tools,
|
||||
]
|
||||
return optional_params
|
||||
|
||||
def translate_developer_role_to_system_role(
|
||||
self,
|
||||
messages: List[AllMessageValues],
|
||||
) -> List[AllMessageValues]:
|
||||
"""
|
||||
Translate `developer` role to `system` role for non-OpenAI providers.
|
||||
|
||||
Overriden by OpenAI/Azure
|
||||
"""
|
||||
return map_developer_role_to_system_role(messages=messages)
|
||||
|
||||
def should_retry_llm_api_inside_llm_translation_on_http_error(
|
||||
self, e: httpx.HTTPStatusError, litellm_params: dict
|
||||
) -> bool:
|
||||
"""
|
||||
Returns True if the model/provider should retry the LLM API on UnprocessableEntityError
|
||||
|
||||
Overriden by azure ai - where different models support different parameters
|
||||
"""
|
||||
return False
|
||||
|
||||
def transform_request_on_unprocessable_entity_error(
|
||||
self, e: httpx.HTTPStatusError, request_data: dict
|
||||
) -> dict:
|
||||
"""
|
||||
Transform the request data on UnprocessableEntityError
|
||||
"""
|
||||
return request_data
|
||||
|
||||
@property
|
||||
def max_retry_on_unprocessable_entity_error(self) -> int:
|
||||
"""
|
||||
Returns the max retry count for UnprocessableEntityError
|
||||
|
||||
Used if `should_retry_llm_api_inside_llm_translation_on_http_error` is True
|
||||
"""
|
||||
return 0
|
||||
|
||||
@abstractmethod
|
||||
def get_supported_openai_params(self, model: str) -> list:
|
||||
pass
|
||||
|
||||
def _add_response_format_to_tools(
|
||||
self,
|
||||
optional_params: dict,
|
||||
value: dict,
|
||||
is_response_format_supported: bool,
|
||||
enforce_tool_choice: bool = True,
|
||||
) -> dict:
|
||||
"""
|
||||
Follow similar approach to anthropic - translate to a single tool call.
|
||||
|
||||
When using tools in this way: - https://docs.anthropic.com/en/docs/build-with-claude/tool-use#json-mode
|
||||
- You usually want to provide a single tool
|
||||
- You should set tool_choice (see Forcing tool use) to instruct the model to explicitly use that tool
|
||||
- Remember that the model will pass the input to the tool, so the name of the tool and description should be from the model’s perspective.
|
||||
|
||||
Add response format to tools
|
||||
|
||||
This is used to translate response_format to a tool call, for models/APIs that don't support response_format directly.
|
||||
"""
|
||||
json_schema: Optional[dict] = None
|
||||
if "response_schema" in value:
|
||||
json_schema = value["response_schema"]
|
||||
elif "json_schema" in value:
|
||||
json_schema = value["json_schema"]["schema"]
|
||||
|
||||
if json_schema and not is_response_format_supported:
|
||||
_tool_choice = ChatCompletionToolChoiceObjectParam(
|
||||
type="function",
|
||||
function=ChatCompletionToolChoiceFunctionParam(
|
||||
name=RESPONSE_FORMAT_TOOL_NAME
|
||||
),
|
||||
)
|
||||
|
||||
_tool = ChatCompletionToolParam(
|
||||
type="function",
|
||||
function=ChatCompletionToolParamFunctionChunk(
|
||||
name=RESPONSE_FORMAT_TOOL_NAME, parameters=json_schema
|
||||
),
|
||||
)
|
||||
|
||||
optional_params.setdefault("tools", [])
|
||||
optional_params["tools"].append(_tool)
|
||||
if enforce_tool_choice:
|
||||
optional_params["tool_choice"] = _tool_choice
|
||||
|
||||
optional_params["json_mode"] = True
|
||||
elif is_response_format_supported:
|
||||
optional_params["response_format"] = value
|
||||
return optional_params
|
||||
|
||||
@abstractmethod
|
||||
def map_openai_params(
|
||||
self,
|
||||
non_default_params: dict,
|
||||
optional_params: dict,
|
||||
model: str,
|
||||
drop_params: bool,
|
||||
) -> dict:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def validate_environment(
|
||||
self,
|
||||
headers: dict,
|
||||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
) -> dict:
|
||||
pass
|
||||
|
||||
def sign_request(
|
||||
self,
|
||||
headers: dict,
|
||||
optional_params: dict,
|
||||
request_data: dict,
|
||||
api_base: str,
|
||||
model: Optional[str] = None,
|
||||
stream: Optional[bool] = None,
|
||||
fake_stream: Optional[bool] = None,
|
||||
) -> dict:
|
||||
"""
|
||||
Some providers like Bedrock require signing the request. The sign request funtion needs access to `request_data` and `complete_url`
|
||||
Args:
|
||||
headers: dict
|
||||
optional_params: dict
|
||||
request_data: dict - the request body being sent in http request
|
||||
api_base: str - the complete url being sent in http request
|
||||
Returns:
|
||||
dict - the signed headers
|
||||
|
||||
Update the headers with the signed headers in this function. The return values will be sent as headers in the http request.
|
||||
"""
|
||||
return headers
|
||||
|
||||
def get_complete_url(
|
||||
self,
|
||||
api_base: Optional[str],
|
||||
api_key: Optional[str],
|
||||
model: str,
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
stream: Optional[bool] = None,
|
||||
) -> str:
|
||||
"""
|
||||
OPTIONAL
|
||||
|
||||
Get the complete url for the request
|
||||
|
||||
Some providers need `model` in `api_base`
|
||||
"""
|
||||
if api_base is None:
|
||||
raise ValueError("api_base is required")
|
||||
return api_base
|
||||
|
||||
@abstractmethod
|
||||
def transform_request(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
headers: dict,
|
||||
) -> dict:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def transform_response(
|
||||
self,
|
||||
model: str,
|
||||
raw_response: httpx.Response,
|
||||
model_response: ModelResponse,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
request_data: dict,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
encoding: Any,
|
||||
api_key: Optional[str] = None,
|
||||
json_mode: Optional[bool] = None,
|
||||
) -> ModelResponse:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_error_class(
|
||||
self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers]
|
||||
) -> BaseLLMException:
|
||||
pass
|
||||
|
||||
def get_model_response_iterator(
|
||||
self,
|
||||
streaming_response: Union[Iterator[str], AsyncIterator[str], ModelResponse],
|
||||
sync_stream: bool,
|
||||
json_mode: Optional[bool] = False,
|
||||
) -> Any:
|
||||
pass
|
||||
|
||||
def get_async_custom_stream_wrapper(
|
||||
self,
|
||||
model: str,
|
||||
custom_llm_provider: str,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
api_base: str,
|
||||
headers: dict,
|
||||
data: dict,
|
||||
messages: list,
|
||||
client: Optional[AsyncHTTPHandler] = None,
|
||||
json_mode: Optional[bool] = None,
|
||||
) -> CustomStreamWrapper:
|
||||
raise NotImplementedError
|
||||
|
||||
def get_sync_custom_stream_wrapper(
|
||||
self,
|
||||
model: str,
|
||||
custom_llm_provider: str,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
api_base: str,
|
||||
headers: dict,
|
||||
data: dict,
|
||||
messages: list,
|
||||
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
|
||||
json_mode: Optional[bool] = None,
|
||||
) -> CustomStreamWrapper:
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def custom_llm_provider(self) -> Optional[str]:
|
||||
return None
|
||||
|
||||
@property
|
||||
def has_custom_stream_wrapper(self) -> bool:
|
||||
return False
|
||||
|
||||
@property
|
||||
def supports_stream_param_in_request_body(self) -> bool:
|
||||
"""
|
||||
Some providers like Bedrock invoke do not support the stream parameter in the request body.
|
||||
|
||||
By default, this is true for almost all providers.
|
||||
"""
|
||||
return True
|
||||
Binary file not shown.
@@ -0,0 +1,75 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import TYPE_CHECKING, Any, List, Optional, Union
|
||||
|
||||
import httpx
|
||||
|
||||
from litellm.llms.base_llm.chat.transformation import BaseConfig
|
||||
from litellm.types.llms.openai import AllMessageValues, OpenAITextCompletionUserMessage
|
||||
from litellm.types.utils import ModelResponse
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
|
||||
|
||||
LiteLLMLoggingObj = _LiteLLMLoggingObj
|
||||
else:
|
||||
LiteLLMLoggingObj = Any
|
||||
|
||||
|
||||
class BaseTextCompletionConfig(BaseConfig, ABC):
|
||||
@abstractmethod
|
||||
def transform_text_completion_request(
|
||||
self,
|
||||
model: str,
|
||||
messages: Union[List[AllMessageValues], List[OpenAITextCompletionUserMessage]],
|
||||
optional_params: dict,
|
||||
headers: dict,
|
||||
) -> dict:
|
||||
return {}
|
||||
|
||||
def get_complete_url(
|
||||
self,
|
||||
api_base: Optional[str],
|
||||
api_key: Optional[str],
|
||||
model: str,
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
stream: Optional[bool] = None,
|
||||
) -> str:
|
||||
"""
|
||||
OPTIONAL
|
||||
|
||||
Get the complete url for the request
|
||||
|
||||
Some providers need `model` in `api_base`
|
||||
"""
|
||||
return api_base or ""
|
||||
|
||||
def transform_request(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
headers: dict,
|
||||
) -> dict:
|
||||
raise NotImplementedError(
|
||||
"AudioTranscriptionConfig does not need a request transformation for audio transcription models"
|
||||
)
|
||||
|
||||
def transform_response(
|
||||
self,
|
||||
model: str,
|
||||
raw_response: httpx.Response,
|
||||
model_response: ModelResponse,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
request_data: dict,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
encoding: Any,
|
||||
api_key: Optional[str] = None,
|
||||
json_mode: Optional[bool] = None,
|
||||
) -> ModelResponse:
|
||||
raise NotImplementedError(
|
||||
"AudioTranscriptionConfig does not need a response transformation for audio transcription models"
|
||||
)
|
||||
Binary file not shown.
@@ -0,0 +1,89 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import TYPE_CHECKING, Any, List, Optional
|
||||
|
||||
import httpx
|
||||
|
||||
from litellm.llms.base_llm.chat.transformation import BaseConfig
|
||||
from litellm.types.llms.openai import AllEmbeddingInputValues, AllMessageValues
|
||||
from litellm.types.utils import EmbeddingResponse, ModelResponse
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
|
||||
|
||||
LiteLLMLoggingObj = _LiteLLMLoggingObj
|
||||
else:
|
||||
LiteLLMLoggingObj = Any
|
||||
|
||||
|
||||
class BaseEmbeddingConfig(BaseConfig, ABC):
|
||||
@abstractmethod
|
||||
def transform_embedding_request(
|
||||
self,
|
||||
model: str,
|
||||
input: AllEmbeddingInputValues,
|
||||
optional_params: dict,
|
||||
headers: dict,
|
||||
) -> dict:
|
||||
return {}
|
||||
|
||||
@abstractmethod
|
||||
def transform_embedding_response(
|
||||
self,
|
||||
model: str,
|
||||
raw_response: httpx.Response,
|
||||
model_response: EmbeddingResponse,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
api_key: Optional[str],
|
||||
request_data: dict,
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
) -> EmbeddingResponse:
|
||||
return model_response
|
||||
|
||||
def get_complete_url(
|
||||
self,
|
||||
api_base: Optional[str],
|
||||
api_key: Optional[str],
|
||||
model: str,
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
stream: Optional[bool] = None,
|
||||
) -> str:
|
||||
"""
|
||||
OPTIONAL
|
||||
|
||||
Get the complete url for the request
|
||||
|
||||
Some providers need `model` in `api_base`
|
||||
"""
|
||||
return api_base or ""
|
||||
|
||||
def transform_request(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
headers: dict,
|
||||
) -> dict:
|
||||
raise NotImplementedError(
|
||||
"EmbeddingConfig does not need a request transformation for chat models"
|
||||
)
|
||||
|
||||
def transform_response(
|
||||
self,
|
||||
model: str,
|
||||
raw_response: httpx.Response,
|
||||
model_response: ModelResponse,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
request_data: dict,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
encoding: Any,
|
||||
api_key: Optional[str] = None,
|
||||
json_mode: Optional[bool] = None,
|
||||
) -> ModelResponse:
|
||||
raise NotImplementedError(
|
||||
"EmbeddingConfig does not need a response transformation for chat models"
|
||||
)
|
||||
Binary file not shown.
@@ -0,0 +1,101 @@
|
||||
from abc import abstractmethod
|
||||
from typing import TYPE_CHECKING, Any, List, Optional, Union
|
||||
|
||||
import httpx
|
||||
|
||||
from litellm.types.llms.openai import (
|
||||
AllMessageValues,
|
||||
CreateFileRequest,
|
||||
OpenAICreateFileRequestOptionalParams,
|
||||
OpenAIFileObject,
|
||||
)
|
||||
from litellm.types.utils import LlmProviders, ModelResponse
|
||||
|
||||
from ..chat.transformation import BaseConfig
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
|
||||
|
||||
LiteLLMLoggingObj = _LiteLLMLoggingObj
|
||||
else:
|
||||
LiteLLMLoggingObj = Any
|
||||
|
||||
|
||||
class BaseFilesConfig(BaseConfig):
|
||||
@property
|
||||
@abstractmethod
|
||||
def custom_llm_provider(self) -> LlmProviders:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_supported_openai_params(
|
||||
self, model: str
|
||||
) -> List[OpenAICreateFileRequestOptionalParams]:
|
||||
pass
|
||||
|
||||
def get_complete_file_url(
|
||||
self,
|
||||
api_base: Optional[str],
|
||||
api_key: Optional[str],
|
||||
model: str,
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
data: CreateFileRequest,
|
||||
):
|
||||
return self.get_complete_url(
|
||||
api_base=api_base,
|
||||
api_key=api_key,
|
||||
model=model,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
def transform_create_file_request(
|
||||
self,
|
||||
model: str,
|
||||
create_file_data: CreateFileRequest,
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
) -> Union[dict, str, bytes]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def transform_create_file_response(
|
||||
self,
|
||||
model: Optional[str],
|
||||
raw_response: httpx.Response,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
litellm_params: dict,
|
||||
) -> OpenAIFileObject:
|
||||
pass
|
||||
|
||||
def transform_request(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
headers: dict,
|
||||
) -> dict:
|
||||
raise NotImplementedError(
|
||||
"AudioTranscriptionConfig does not need a request transformation for audio transcription models"
|
||||
)
|
||||
|
||||
def transform_response(
|
||||
self,
|
||||
model: str,
|
||||
raw_response: httpx.Response,
|
||||
model_response: ModelResponse,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
request_data: dict,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
encoding: Any,
|
||||
api_key: Optional[str] = None,
|
||||
json_mode: Optional[bool] = None,
|
||||
) -> ModelResponse:
|
||||
raise NotImplementedError(
|
||||
"AudioTranscriptionConfig does not need a response transformation for audio transcription models"
|
||||
)
|
||||
Binary file not shown.
@@ -0,0 +1,134 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import TYPE_CHECKING, Any, List, Optional
|
||||
|
||||
import httpx
|
||||
from aiohttp import ClientResponse
|
||||
|
||||
from litellm.llms.base_llm.chat.transformation import BaseConfig
|
||||
from litellm.types.llms.openai import (
|
||||
AllMessageValues,
|
||||
OpenAIImageVariationOptionalParams,
|
||||
)
|
||||
from litellm.types.utils import (
|
||||
FileTypes,
|
||||
HttpHandlerRequestFields,
|
||||
ImageResponse,
|
||||
ModelResponse,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
|
||||
|
||||
LiteLLMLoggingObj = _LiteLLMLoggingObj
|
||||
else:
|
||||
LiteLLMLoggingObj = Any
|
||||
|
||||
|
||||
class BaseImageVariationConfig(BaseConfig, ABC):
|
||||
@abstractmethod
|
||||
def get_supported_openai_params(
|
||||
self, model: str
|
||||
) -> List[OpenAIImageVariationOptionalParams]:
|
||||
pass
|
||||
|
||||
def get_complete_url(
|
||||
self,
|
||||
api_base: Optional[str],
|
||||
api_key: Optional[str],
|
||||
model: str,
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
stream: Optional[bool] = None,
|
||||
) -> str:
|
||||
"""
|
||||
OPTIONAL
|
||||
|
||||
Get the complete url for the request
|
||||
|
||||
Some providers need `model` in `api_base`
|
||||
"""
|
||||
return api_base or ""
|
||||
|
||||
@abstractmethod
|
||||
def transform_request_image_variation(
|
||||
self,
|
||||
model: Optional[str],
|
||||
image: FileTypes,
|
||||
optional_params: dict,
|
||||
headers: dict,
|
||||
) -> HttpHandlerRequestFields:
|
||||
pass
|
||||
|
||||
def validate_environment(
|
||||
self,
|
||||
headers: dict,
|
||||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
) -> dict:
|
||||
return {}
|
||||
|
||||
@abstractmethod
|
||||
async def async_transform_response_image_variation(
|
||||
self,
|
||||
model: Optional[str],
|
||||
raw_response: ClientResponse,
|
||||
model_response: ImageResponse,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
request_data: dict,
|
||||
image: FileTypes,
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
encoding: Any,
|
||||
api_key: Optional[str] = None,
|
||||
) -> ImageResponse:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def transform_response_image_variation(
|
||||
self,
|
||||
model: Optional[str],
|
||||
raw_response: httpx.Response,
|
||||
model_response: ImageResponse,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
request_data: dict,
|
||||
image: FileTypes,
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
encoding: Any,
|
||||
api_key: Optional[str] = None,
|
||||
) -> ImageResponse:
|
||||
pass
|
||||
|
||||
def transform_request(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
headers: dict,
|
||||
) -> dict:
|
||||
raise NotImplementedError(
|
||||
"ImageVariationConfig implementa 'transform_request_image_variation' for image variation models"
|
||||
)
|
||||
|
||||
def transform_response(
|
||||
self,
|
||||
model: str,
|
||||
raw_response: httpx.Response,
|
||||
model_response: ModelResponse,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
request_data: dict,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
encoding: Any,
|
||||
api_key: Optional[str] = None,
|
||||
json_mode: Optional[bool] = None,
|
||||
) -> ModelResponse:
|
||||
raise NotImplementedError(
|
||||
"ImageVariationConfig implements 'transform_response_image_variation' for image variation models"
|
||||
)
|
||||
Binary file not shown.
@@ -0,0 +1,128 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import httpx
|
||||
|
||||
from litellm.types.rerank import OptionalRerankParams, RerankBilledUnits, RerankResponse
|
||||
from litellm.types.utils import ModelInfo
|
||||
|
||||
from ..chat.transformation import BaseLLMException
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
|
||||
|
||||
LiteLLMLoggingObj = _LiteLLMLoggingObj
|
||||
else:
|
||||
LiteLLMLoggingObj = Any
|
||||
|
||||
|
||||
class BaseRerankConfig(ABC):
|
||||
@abstractmethod
|
||||
def validate_environment(
|
||||
self,
|
||||
headers: dict,
|
||||
model: str,
|
||||
api_key: Optional[str] = None,
|
||||
) -> dict:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def transform_rerank_request(
|
||||
self,
|
||||
model: str,
|
||||
optional_rerank_params: OptionalRerankParams,
|
||||
headers: dict,
|
||||
) -> dict:
|
||||
return {}
|
||||
|
||||
@abstractmethod
|
||||
def transform_rerank_response(
|
||||
self,
|
||||
model: str,
|
||||
raw_response: httpx.Response,
|
||||
model_response: RerankResponse,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
api_key: Optional[str] = None,
|
||||
request_data: dict = {},
|
||||
optional_params: dict = {},
|
||||
litellm_params: dict = {},
|
||||
) -> RerankResponse:
|
||||
return model_response
|
||||
|
||||
@abstractmethod
|
||||
def get_complete_url(self, api_base: Optional[str], model: str) -> str:
|
||||
"""
|
||||
OPTIONAL
|
||||
|
||||
Get the complete url for the request
|
||||
|
||||
Some providers need `model` in `api_base`
|
||||
"""
|
||||
return api_base or ""
|
||||
|
||||
@abstractmethod
|
||||
def get_supported_cohere_rerank_params(self, model: str) -> list:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def map_cohere_rerank_params(
|
||||
self,
|
||||
non_default_params: dict,
|
||||
model: str,
|
||||
drop_params: bool,
|
||||
query: str,
|
||||
documents: List[Union[str, Dict[str, Any]]],
|
||||
custom_llm_provider: Optional[str] = None,
|
||||
top_n: Optional[int] = None,
|
||||
rank_fields: Optional[List[str]] = None,
|
||||
return_documents: Optional[bool] = True,
|
||||
max_chunks_per_doc: Optional[int] = None,
|
||||
max_tokens_per_doc: Optional[int] = None,
|
||||
) -> OptionalRerankParams:
|
||||
pass
|
||||
|
||||
def get_error_class(
|
||||
self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers]
|
||||
) -> BaseLLMException:
|
||||
raise BaseLLMException(
|
||||
status_code=status_code,
|
||||
message=error_message,
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
def calculate_rerank_cost(
|
||||
self,
|
||||
model: str,
|
||||
custom_llm_provider: Optional[str] = None,
|
||||
billed_units: Optional[RerankBilledUnits] = None,
|
||||
model_info: Optional[ModelInfo] = None,
|
||||
) -> Tuple[float, float]:
|
||||
"""
|
||||
Calculates the cost per query for a given rerank model.
|
||||
|
||||
Input:
|
||||
- model: str, the model name without provider prefix
|
||||
- custom_llm_provider: str, the provider used for the model. If provided, used to check if the litellm model info is for that provider.
|
||||
- num_queries: int, the number of queries to calculate the cost for
|
||||
- model_info: ModelInfo, the model info for the given model
|
||||
|
||||
Returns:
|
||||
Tuple[float, float] - prompt_cost_in_usd, completion_cost_in_usd
|
||||
"""
|
||||
|
||||
if (
|
||||
model_info is None
|
||||
or "input_cost_per_query" not in model_info
|
||||
or model_info["input_cost_per_query"] is None
|
||||
or billed_units is None
|
||||
):
|
||||
return 0.0, 0.0
|
||||
|
||||
search_units = billed_units.get("search_units")
|
||||
|
||||
if search_units is None:
|
||||
return 0.0, 0.0
|
||||
|
||||
prompt_cost = model_info["input_cost_per_query"] * search_units
|
||||
|
||||
return prompt_cost, 0.0
|
||||
Binary file not shown.
@@ -0,0 +1,165 @@
|
||||
import types
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union
|
||||
|
||||
import httpx
|
||||
|
||||
from litellm.types.llms.openai import (
|
||||
ResponseInputParam,
|
||||
ResponsesAPIOptionalRequestParams,
|
||||
ResponsesAPIResponse,
|
||||
ResponsesAPIStreamingResponse,
|
||||
)
|
||||
from litellm.types.responses.main import *
|
||||
from litellm.types.router import GenericLiteLLMParams
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
|
||||
|
||||
from ..chat.transformation import BaseLLMException as _BaseLLMException
|
||||
|
||||
LiteLLMLoggingObj = _LiteLLMLoggingObj
|
||||
BaseLLMException = _BaseLLMException
|
||||
else:
|
||||
LiteLLMLoggingObj = Any
|
||||
BaseLLMException = Any
|
||||
|
||||
|
||||
class BaseResponsesAPIConfig(ABC):
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def get_config(cls):
|
||||
return {
|
||||
k: v
|
||||
for k, v in cls.__dict__.items()
|
||||
if not k.startswith("__")
|
||||
and not k.startswith("_abc")
|
||||
and not isinstance(
|
||||
v,
|
||||
(
|
||||
types.FunctionType,
|
||||
types.BuiltinFunctionType,
|
||||
classmethod,
|
||||
staticmethod,
|
||||
),
|
||||
)
|
||||
and v is not None
|
||||
}
|
||||
|
||||
@abstractmethod
|
||||
def get_supported_openai_params(self, model: str) -> list:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def map_openai_params(
|
||||
self,
|
||||
response_api_optional_params: ResponsesAPIOptionalRequestParams,
|
||||
model: str,
|
||||
drop_params: bool,
|
||||
) -> Dict:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def validate_environment(
|
||||
self,
|
||||
headers: dict,
|
||||
model: str,
|
||||
api_key: Optional[str] = None,
|
||||
) -> dict:
|
||||
return {}
|
||||
|
||||
@abstractmethod
|
||||
def get_complete_url(
|
||||
self,
|
||||
api_base: Optional[str],
|
||||
litellm_params: dict,
|
||||
) -> str:
|
||||
"""
|
||||
OPTIONAL
|
||||
|
||||
Get the complete url for the request
|
||||
|
||||
Some providers need `model` in `api_base`
|
||||
"""
|
||||
if api_base is None:
|
||||
raise ValueError("api_base is required")
|
||||
return api_base
|
||||
|
||||
@abstractmethod
|
||||
def transform_responses_api_request(
|
||||
self,
|
||||
model: str,
|
||||
input: Union[str, ResponseInputParam],
|
||||
response_api_optional_request_params: Dict,
|
||||
litellm_params: GenericLiteLLMParams,
|
||||
headers: dict,
|
||||
) -> Dict:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def transform_response_api_response(
|
||||
self,
|
||||
model: str,
|
||||
raw_response: httpx.Response,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
) -> ResponsesAPIResponse:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def transform_streaming_response(
|
||||
self,
|
||||
model: str,
|
||||
parsed_chunk: dict,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
) -> ResponsesAPIStreamingResponse:
|
||||
"""
|
||||
Transform a parsed streaming response chunk into a ResponsesAPIStreamingResponse
|
||||
"""
|
||||
pass
|
||||
|
||||
#########################################################
|
||||
########## DELETE RESPONSE API TRANSFORMATION ##############
|
||||
#########################################################
|
||||
@abstractmethod
|
||||
def transform_delete_response_api_request(
|
||||
self,
|
||||
response_id: str,
|
||||
api_base: str,
|
||||
litellm_params: GenericLiteLLMParams,
|
||||
headers: dict,
|
||||
) -> Tuple[str, Dict]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def transform_delete_response_api_response(
|
||||
self,
|
||||
raw_response: httpx.Response,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
) -> DeleteResponseResult:
|
||||
pass
|
||||
|
||||
#########################################################
|
||||
########## END DELETE RESPONSE API TRANSFORMATION ##########
|
||||
#########################################################
|
||||
|
||||
def get_error_class(
|
||||
self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers]
|
||||
) -> BaseLLMException:
|
||||
from ..chat.transformation import BaseLLMException
|
||||
|
||||
raise BaseLLMException(
|
||||
status_code=status_code,
|
||||
message=error_message,
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
def should_fake_stream(
|
||||
self,
|
||||
model: Optional[str],
|
||||
stream: Optional[bool],
|
||||
custom_llm_provider: Optional[str] = None,
|
||||
) -> bool:
|
||||
"""Returns True if litellm should fake a stream for the given model and stream value"""
|
||||
return False
|
||||
172
.venv/lib/python3.10/site-packages/litellm/llms/baseten.py
Normal file
172
.venv/lib/python3.10/site-packages/litellm/llms/baseten.py
Normal file
@@ -0,0 +1,172 @@
|
||||
import json
|
||||
import time
|
||||
from typing import Callable
|
||||
|
||||
import litellm
|
||||
from litellm.types.utils import ModelResponse, Usage
|
||||
|
||||
|
||||
class BasetenError(Exception):
|
||||
def __init__(self, status_code, message):
|
||||
self.status_code = status_code
|
||||
self.message = message
|
||||
super().__init__(
|
||||
self.message
|
||||
) # Call the base class constructor with the parameters it needs
|
||||
|
||||
|
||||
def validate_environment(api_key):
|
||||
headers = {
|
||||
"accept": "application/json",
|
||||
"content-type": "application/json",
|
||||
}
|
||||
if api_key:
|
||||
headers["Authorization"] = f"Api-Key {api_key}"
|
||||
return headers
|
||||
|
||||
|
||||
def completion(
|
||||
model: str,
|
||||
messages: list,
|
||||
model_response: ModelResponse,
|
||||
print_verbose: Callable,
|
||||
encoding,
|
||||
api_key,
|
||||
logging_obj,
|
||||
optional_params: dict,
|
||||
litellm_params=None,
|
||||
logger_fn=None,
|
||||
):
|
||||
headers = validate_environment(api_key)
|
||||
completion_url_fragment_1 = "https://app.baseten.co/models/"
|
||||
completion_url_fragment_2 = "/predict"
|
||||
model = model
|
||||
prompt = ""
|
||||
for message in messages:
|
||||
if "role" in message:
|
||||
if message["role"] == "user":
|
||||
prompt += f"{message['content']}"
|
||||
else:
|
||||
prompt += f"{message['content']}"
|
||||
else:
|
||||
prompt += f"{message['content']}"
|
||||
data = {
|
||||
"inputs": prompt,
|
||||
"prompt": prompt,
|
||||
"parameters": optional_params,
|
||||
"stream": (
|
||||
True
|
||||
if "stream" in optional_params and optional_params["stream"] is True
|
||||
else False
|
||||
),
|
||||
}
|
||||
|
||||
## LOGGING
|
||||
logging_obj.pre_call(
|
||||
input=prompt,
|
||||
api_key=api_key,
|
||||
additional_args={"complete_input_dict": data},
|
||||
)
|
||||
## COMPLETION CALL
|
||||
response = litellm.module_level_client.post(
|
||||
completion_url_fragment_1 + model + completion_url_fragment_2,
|
||||
headers=headers,
|
||||
data=json.dumps(data),
|
||||
stream=(
|
||||
True
|
||||
if "stream" in optional_params and optional_params["stream"] is True
|
||||
else False
|
||||
),
|
||||
)
|
||||
if "text/event-stream" in response.headers["Content-Type"] or (
|
||||
"stream" in optional_params and optional_params["stream"] is True
|
||||
):
|
||||
return response.iter_lines()
|
||||
else:
|
||||
## LOGGING
|
||||
logging_obj.post_call(
|
||||
input=prompt,
|
||||
api_key=api_key,
|
||||
original_response=response.text,
|
||||
additional_args={"complete_input_dict": data},
|
||||
)
|
||||
print_verbose(f"raw model_response: {response.text}")
|
||||
## RESPONSE OBJECT
|
||||
completion_response = response.json()
|
||||
if "error" in completion_response:
|
||||
raise BasetenError(
|
||||
message=completion_response["error"],
|
||||
status_code=response.status_code,
|
||||
)
|
||||
else:
|
||||
if "model_output" in completion_response:
|
||||
if (
|
||||
isinstance(completion_response["model_output"], dict)
|
||||
and "data" in completion_response["model_output"]
|
||||
and isinstance(completion_response["model_output"]["data"], list)
|
||||
):
|
||||
model_response.choices[0].message.content = completion_response[ # type: ignore
|
||||
"model_output"
|
||||
][
|
||||
"data"
|
||||
][
|
||||
0
|
||||
]
|
||||
elif isinstance(completion_response["model_output"], str):
|
||||
model_response.choices[0].message.content = completion_response[ # type: ignore
|
||||
"model_output"
|
||||
]
|
||||
elif "completion" in completion_response and isinstance(
|
||||
completion_response["completion"], str
|
||||
):
|
||||
model_response.choices[0].message.content = completion_response[ # type: ignore
|
||||
"completion"
|
||||
]
|
||||
elif isinstance(completion_response, list) and len(completion_response) > 0:
|
||||
if "generated_text" not in completion_response:
|
||||
raise BasetenError(
|
||||
message=f"Unable to parse response. Original response: {response.text}",
|
||||
status_code=response.status_code,
|
||||
)
|
||||
model_response.choices[0].message.content = completion_response[0][ # type: ignore
|
||||
"generated_text"
|
||||
]
|
||||
## GETTING LOGPROBS
|
||||
if (
|
||||
"details" in completion_response[0]
|
||||
and "tokens" in completion_response[0]["details"]
|
||||
):
|
||||
model_response.choices[0].finish_reason = completion_response[0][
|
||||
"details"
|
||||
]["finish_reason"]
|
||||
sum_logprob = 0
|
||||
for token in completion_response[0]["details"]["tokens"]:
|
||||
sum_logprob += token["logprob"]
|
||||
model_response.choices[0].logprobs = sum_logprob # type: ignore
|
||||
else:
|
||||
raise BasetenError(
|
||||
message=f"Unable to parse response. Original response: {response.text}",
|
||||
status_code=response.status_code,
|
||||
)
|
||||
|
||||
## CALCULATING USAGE - baseten charges on time, not tokens - have some mapping of cost here.
|
||||
prompt_tokens = len(encoding.encode(prompt))
|
||||
completion_tokens = len(
|
||||
encoding.encode(model_response["choices"][0]["message"]["content"])
|
||||
)
|
||||
|
||||
model_response.created = int(time.time())
|
||||
model_response.model = model
|
||||
usage = Usage(
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=prompt_tokens + completion_tokens,
|
||||
)
|
||||
|
||||
setattr(model_response, "usage", usage)
|
||||
return model_response
|
||||
|
||||
|
||||
def embedding():
|
||||
# logic for parsing in - calling - parsing out model embedding calls
|
||||
pass
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user