structure saas with tools
This commit is contained in:
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,3 @@
|
||||
"""
|
||||
LLM Calling done in `openai/openai.py`
|
||||
"""
|
||||
@@ -0,0 +1,321 @@
|
||||
import enum
|
||||
from typing import Any, List, Optional, Tuple, cast
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import httpx
|
||||
from httpx import Response
|
||||
|
||||
import litellm
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.litellm_core_utils.prompt_templates.common_utils import (
|
||||
_audio_or_image_in_message_content,
|
||||
convert_content_list_to_str,
|
||||
)
|
||||
from litellm.llms.base_llm.chat.transformation import LiteLLMLoggingObj
|
||||
from litellm.llms.openai.common_utils import drop_params_from_unprocessable_entity_error
|
||||
from litellm.llms.openai.openai import OpenAIConfig
|
||||
from litellm.secret_managers.main import get_secret_str
|
||||
from litellm.types.llms.openai import AllMessageValues
|
||||
from litellm.types.utils import ModelResponse, ProviderField
|
||||
from litellm.utils import _add_path_to_api_base, supports_tool_choice
|
||||
|
||||
|
||||
class AzureFoundryErrorStrings(str, enum.Enum):
|
||||
SET_EXTRA_PARAMETERS_TO_PASS_THROUGH = "Set extra-parameters to 'pass-through'"
|
||||
|
||||
|
||||
class AzureAIStudioConfig(OpenAIConfig):
|
||||
def get_supported_openai_params(self, model: str) -> List:
|
||||
model_supports_tool_choice = True # azure ai supports this by default
|
||||
if not supports_tool_choice(model=f"azure_ai/{model}"):
|
||||
model_supports_tool_choice = False
|
||||
supported_params = super().get_supported_openai_params(model)
|
||||
if not model_supports_tool_choice:
|
||||
filtered_supported_params = []
|
||||
for param in supported_params:
|
||||
if param != "tool_choice":
|
||||
filtered_supported_params.append(param)
|
||||
return filtered_supported_params
|
||||
return supported_params
|
||||
|
||||
def validate_environment(
|
||||
self,
|
||||
headers: dict,
|
||||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
) -> dict:
|
||||
if api_base and self._should_use_api_key_header(api_base):
|
||||
headers["api-key"] = api_key
|
||||
else:
|
||||
headers["Authorization"] = f"Bearer {api_key}"
|
||||
|
||||
return headers
|
||||
|
||||
def _should_use_api_key_header(self, api_base: str) -> bool:
|
||||
"""
|
||||
Returns True if the request should use `api-key` header for authentication.
|
||||
"""
|
||||
parsed_url = urlparse(api_base)
|
||||
host = parsed_url.hostname
|
||||
if host and (
|
||||
host.endswith(".services.ai.azure.com")
|
||||
or host.endswith(".openai.azure.com")
|
||||
):
|
||||
return True
|
||||
return False
|
||||
|
||||
def get_complete_url(
|
||||
self,
|
||||
api_base: Optional[str],
|
||||
api_key: Optional[str],
|
||||
model: str,
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
stream: Optional[bool] = None,
|
||||
) -> str:
|
||||
"""
|
||||
Constructs a complete URL for the API request.
|
||||
|
||||
Args:
|
||||
- api_base: Base URL, e.g.,
|
||||
"https://litellm8397336933.services.ai.azure.com"
|
||||
OR
|
||||
"https://litellm8397336933.services.ai.azure.com/models/chat/completions?api-version=2024-05-01-preview"
|
||||
- model: Model name.
|
||||
- optional_params: Additional query parameters, including "api_version".
|
||||
- stream: If streaming is required (optional).
|
||||
|
||||
Returns:
|
||||
- A complete URL string, e.g.,
|
||||
"https://litellm8397336933.services.ai.azure.com/models/chat/completions?api-version=2024-05-01-preview"
|
||||
"""
|
||||
if api_base is None:
|
||||
raise ValueError(
|
||||
f"api_base is required for Azure AI Studio. Please set the api_base parameter. Passed `api_base={api_base}`"
|
||||
)
|
||||
original_url = httpx.URL(api_base)
|
||||
|
||||
# Extract api_version or use default
|
||||
api_version = cast(Optional[str], litellm_params.get("api_version"))
|
||||
|
||||
# Create a new dictionary with existing params
|
||||
query_params = dict(original_url.params)
|
||||
|
||||
# Add api_version if needed
|
||||
if "api-version" not in query_params and api_version:
|
||||
query_params["api-version"] = api_version
|
||||
|
||||
# Add the path to the base URL
|
||||
if "services.ai.azure.com" in api_base:
|
||||
new_url = _add_path_to_api_base(
|
||||
api_base=api_base, ending_path="/models/chat/completions"
|
||||
)
|
||||
else:
|
||||
new_url = _add_path_to_api_base(
|
||||
api_base=api_base, ending_path="/chat/completions"
|
||||
)
|
||||
|
||||
# Use the new query_params dictionary
|
||||
final_url = httpx.URL(new_url).copy_with(params=query_params)
|
||||
|
||||
return str(final_url)
|
||||
|
||||
def get_required_params(self) -> List[ProviderField]:
|
||||
"""For a given provider, return it's required fields with a description"""
|
||||
return [
|
||||
ProviderField(
|
||||
field_name="api_key",
|
||||
field_type="string",
|
||||
field_description="Your Azure AI Studio API Key.",
|
||||
field_value="zEJ...",
|
||||
),
|
||||
ProviderField(
|
||||
field_name="api_base",
|
||||
field_type="string",
|
||||
field_description="Your Azure AI Studio API Base.",
|
||||
field_value="https://Mistral-serverless.",
|
||||
),
|
||||
]
|
||||
|
||||
def _transform_messages(
|
||||
self,
|
||||
messages: List[AllMessageValues],
|
||||
model: str,
|
||||
) -> List:
|
||||
"""
|
||||
- Azure AI Studio doesn't support content as a list. This handles:
|
||||
1. Transforms list content to a string.
|
||||
2. If message contains an image or audio, send as is (user-intended)
|
||||
"""
|
||||
for message in messages:
|
||||
# Do nothing if the message contains an image or audio
|
||||
if _audio_or_image_in_message_content(message):
|
||||
continue
|
||||
|
||||
texts = convert_content_list_to_str(message=message)
|
||||
if texts:
|
||||
message["content"] = texts
|
||||
return messages
|
||||
|
||||
def _is_azure_openai_model(self, model: str, api_base: Optional[str]) -> bool:
|
||||
try:
|
||||
if "/" in model:
|
||||
model = model.split("/", 1)[1]
|
||||
if (
|
||||
model in litellm.open_ai_chat_completion_models
|
||||
or model in litellm.open_ai_text_completion_models
|
||||
or model in litellm.open_ai_embedding_models
|
||||
):
|
||||
return True
|
||||
|
||||
except Exception:
|
||||
return False
|
||||
return False
|
||||
|
||||
def _get_openai_compatible_provider_info(
|
||||
self,
|
||||
model: str,
|
||||
api_base: Optional[str],
|
||||
api_key: Optional[str],
|
||||
custom_llm_provider: str,
|
||||
) -> Tuple[Optional[str], Optional[str], str]:
|
||||
api_base = api_base or get_secret_str("AZURE_AI_API_BASE")
|
||||
dynamic_api_key = api_key or get_secret_str("AZURE_AI_API_KEY")
|
||||
|
||||
if self._is_azure_openai_model(model=model, api_base=api_base):
|
||||
verbose_logger.debug(
|
||||
"Model={} is Azure OpenAI model. Setting custom_llm_provider='azure'.".format(
|
||||
model
|
||||
)
|
||||
)
|
||||
custom_llm_provider = "azure"
|
||||
return api_base, dynamic_api_key, custom_llm_provider
|
||||
|
||||
def transform_request(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
headers: dict,
|
||||
) -> dict:
|
||||
extra_body = optional_params.pop("extra_body", {})
|
||||
if extra_body and isinstance(extra_body, dict):
|
||||
optional_params.update(extra_body)
|
||||
optional_params.pop("max_retries", None)
|
||||
return super().transform_request(
|
||||
model, messages, optional_params, litellm_params, headers
|
||||
)
|
||||
|
||||
def transform_response(
|
||||
self,
|
||||
model: str,
|
||||
raw_response: Response,
|
||||
model_response: ModelResponse,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
request_data: dict,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
encoding: Any,
|
||||
api_key: Optional[str] = None,
|
||||
json_mode: Optional[bool] = None,
|
||||
) -> ModelResponse:
|
||||
model_response.model = f"azure_ai/{model}"
|
||||
return super().transform_response(
|
||||
model=model,
|
||||
raw_response=raw_response,
|
||||
model_response=model_response,
|
||||
logging_obj=logging_obj,
|
||||
request_data=request_data,
|
||||
messages=messages,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
encoding=encoding,
|
||||
api_key=api_key,
|
||||
json_mode=json_mode,
|
||||
)
|
||||
|
||||
def should_retry_llm_api_inside_llm_translation_on_http_error(
|
||||
self, e: httpx.HTTPStatusError, litellm_params: dict
|
||||
) -> bool:
|
||||
should_drop_params = litellm_params.get("drop_params") or litellm.drop_params
|
||||
error_text = e.response.text
|
||||
|
||||
if should_drop_params and "Extra inputs are not permitted" in error_text:
|
||||
return True
|
||||
elif (
|
||||
"unknown field: parameter index is not a valid field" in error_text
|
||||
): # remove index from tool calls
|
||||
return True
|
||||
elif (
|
||||
AzureFoundryErrorStrings.SET_EXTRA_PARAMETERS_TO_PASS_THROUGH.value
|
||||
in error_text
|
||||
): # remove extra-parameters from tool calls
|
||||
return True
|
||||
return super().should_retry_llm_api_inside_llm_translation_on_http_error(
|
||||
e=e, litellm_params=litellm_params
|
||||
)
|
||||
|
||||
@property
|
||||
def max_retry_on_unprocessable_entity_error(self) -> int:
|
||||
return 2
|
||||
|
||||
def transform_request_on_unprocessable_entity_error(
|
||||
self, e: httpx.HTTPStatusError, request_data: dict
|
||||
) -> dict:
|
||||
_messages = cast(Optional[List[AllMessageValues]], request_data.get("messages"))
|
||||
if (
|
||||
"unknown field: parameter index is not a valid field" in e.response.text
|
||||
and _messages is not None
|
||||
):
|
||||
litellm.remove_index_from_tool_calls(
|
||||
messages=_messages,
|
||||
)
|
||||
elif (
|
||||
AzureFoundryErrorStrings.SET_EXTRA_PARAMETERS_TO_PASS_THROUGH.value
|
||||
in e.response.text
|
||||
):
|
||||
request_data = self._drop_extra_params_from_request_data(
|
||||
request_data, e.response.text
|
||||
)
|
||||
data = drop_params_from_unprocessable_entity_error(e=e, data=request_data)
|
||||
return data
|
||||
|
||||
def _drop_extra_params_from_request_data(
|
||||
self, request_data: dict, error_text: str
|
||||
) -> dict:
|
||||
params_to_drop = self._extract_params_to_drop_from_error_text(error_text)
|
||||
if params_to_drop:
|
||||
for param in params_to_drop:
|
||||
if param in request_data:
|
||||
request_data.pop(param, None)
|
||||
return request_data
|
||||
|
||||
def _extract_params_to_drop_from_error_text(
|
||||
self, error_text: str
|
||||
) -> Optional[List[str]]:
|
||||
"""
|
||||
Error text looks like this"
|
||||
"Extra parameters ['stream_options', 'extra-parameters'] are not allowed when extra-parameters is not set or set to be 'error'.
|
||||
"""
|
||||
import re
|
||||
|
||||
# Extract parameters within square brackets
|
||||
match = re.search(r"\[(.*?)\]", error_text)
|
||||
if not match:
|
||||
return []
|
||||
|
||||
# Parse the extracted string into a list of parameter names
|
||||
params_str = match.group(1)
|
||||
params = []
|
||||
for param in params_str.split(","):
|
||||
# Clean up the parameter name (remove quotes, spaces)
|
||||
clean_param = param.strip().strip("'").strip('"')
|
||||
if clean_param:
|
||||
params.append(clean_param)
|
||||
return params
|
||||
Reference in New Issue
Block a user