structure saas with tools
This commit is contained in:
Binary file not shown.
@@ -0,0 +1,149 @@
|
||||
import logging
|
||||
import os
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
|
||||
|
||||
import httpx
|
||||
|
||||
from litellm.types.llms.openai import AllMessageValues, ChatCompletionRequest
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
|
||||
|
||||
LoggingClass = LiteLLMLoggingObj
|
||||
else:
|
||||
LoggingClass = Any
|
||||
|
||||
from litellm.llms.base_llm.chat.transformation import BaseLLMException
|
||||
|
||||
from ...openai.chat.gpt_transformation import OpenAIGPTConfig
|
||||
from ..common_utils import HuggingFaceError, _fetch_inference_provider_mapping
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
BASE_URL = "https://router.huggingface.co"
|
||||
|
||||
|
||||
class HuggingFaceChatConfig(OpenAIGPTConfig):
|
||||
"""
|
||||
Reference: https://huggingface.co/docs/huggingface_hub/guides/inference
|
||||
"""
|
||||
|
||||
def validate_environment(
|
||||
self,
|
||||
headers: dict,
|
||||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: Dict,
|
||||
litellm_params: dict,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
) -> dict:
|
||||
default_headers = {
|
||||
"content-type": "application/json",
|
||||
}
|
||||
if api_key is not None:
|
||||
default_headers["Authorization"] = f"Bearer {api_key}"
|
||||
|
||||
headers = {**headers, **default_headers}
|
||||
|
||||
return headers
|
||||
|
||||
def get_error_class(
|
||||
self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers]
|
||||
) -> BaseLLMException:
|
||||
return HuggingFaceError(
|
||||
status_code=status_code, message=error_message, headers=headers
|
||||
)
|
||||
|
||||
def get_base_url(self, model: str, base_url: Optional[str]) -> Optional[str]:
|
||||
"""
|
||||
Get the API base for the Huggingface API.
|
||||
|
||||
Do not add the chat/embedding/rerank extension here. Let the handler do this.
|
||||
"""
|
||||
if model.startswith(("http://", "https://")):
|
||||
base_url = model
|
||||
elif base_url is None:
|
||||
base_url = os.getenv("HF_API_BASE") or os.getenv("HUGGINGFACE_API_BASE", "")
|
||||
return base_url
|
||||
|
||||
def get_complete_url(
|
||||
self,
|
||||
api_base: Optional[str],
|
||||
api_key: Optional[str],
|
||||
model: str,
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
stream: Optional[bool] = None,
|
||||
) -> str:
|
||||
"""
|
||||
Get the complete URL for the API call.
|
||||
For provider-specific routing through huggingface
|
||||
"""
|
||||
# 1. Check if api_base is provided
|
||||
if api_base is not None:
|
||||
complete_url = api_base
|
||||
elif os.getenv("HF_API_BASE") or os.getenv("HUGGINGFACE_API_BASE"):
|
||||
complete_url = str(os.getenv("HF_API_BASE")) or str(
|
||||
os.getenv("HUGGINGFACE_API_BASE")
|
||||
)
|
||||
elif model.startswith(("http://", "https://")):
|
||||
complete_url = model
|
||||
# 4. Default construction with provider
|
||||
else:
|
||||
# Parse provider and model
|
||||
first_part, remaining = model.split("/", 1)
|
||||
if "/" in remaining:
|
||||
provider = first_part
|
||||
else:
|
||||
provider = "hf-inference"
|
||||
|
||||
if provider == "hf-inference":
|
||||
route = f"{provider}/models/{model}/v1/chat/completions"
|
||||
elif provider == "novita":
|
||||
route = f"{provider}/chat/completions"
|
||||
else:
|
||||
route = f"{provider}/v1/chat/completions"
|
||||
complete_url = f"{BASE_URL}/{route}"
|
||||
|
||||
# Ensure URL doesn't end with a slash
|
||||
complete_url = complete_url.rstrip("/")
|
||||
return complete_url
|
||||
|
||||
def transform_request(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
headers: dict,
|
||||
) -> dict:
|
||||
if "max_retries" in optional_params:
|
||||
logger.warning("`max_retries` is not supported. It will be ignored.")
|
||||
optional_params.pop("max_retries", None)
|
||||
first_part, remaining = model.split("/", 1)
|
||||
if "/" in remaining:
|
||||
provider = first_part
|
||||
model_id = remaining
|
||||
else:
|
||||
provider = "hf-inference"
|
||||
model_id = model
|
||||
provider_mapping = _fetch_inference_provider_mapping(model_id)
|
||||
if provider not in provider_mapping:
|
||||
raise HuggingFaceError(
|
||||
message=f"Model {model_id} is not supported for provider {provider}",
|
||||
status_code=404,
|
||||
headers={},
|
||||
)
|
||||
provider_mapping = provider_mapping[provider]
|
||||
if provider_mapping["status"] == "staging":
|
||||
logger.warning(
|
||||
f"Model {model_id} is in staging mode for provider {provider}. Meant for test purposes only."
|
||||
)
|
||||
mapped_model = provider_mapping["providerId"]
|
||||
messages = self._transform_messages(messages=messages, model=mapped_model)
|
||||
return dict(
|
||||
ChatCompletionRequest(
|
||||
model=mapped_model, messages=messages, **optional_params
|
||||
)
|
||||
)
|
||||
Reference in New Issue
Block a user