structure saas with tools
This commit is contained in:
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,90 @@
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import litellm
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.litellm_core_utils.prompt_templates.common_utils import (
|
||||
strip_name_from_messages,
|
||||
)
|
||||
from litellm.secret_managers.main import get_secret_str
|
||||
from litellm.types.llms.openai import AllMessageValues
|
||||
|
||||
from ...openai.chat.gpt_transformation import OpenAIGPTConfig
|
||||
|
||||
XAI_API_BASE = "https://api.x.ai/v1"
|
||||
|
||||
|
||||
class XAIChatConfig(OpenAIGPTConfig):
|
||||
@property
|
||||
def custom_llm_provider(self) -> Optional[str]:
|
||||
return "xai"
|
||||
|
||||
def _get_openai_compatible_provider_info(
|
||||
self, api_base: Optional[str], api_key: Optional[str]
|
||||
) -> Tuple[Optional[str], Optional[str]]:
|
||||
api_base = api_base or get_secret_str("XAI_API_BASE") or XAI_API_BASE # type: ignore
|
||||
dynamic_api_key = api_key or get_secret_str("XAI_API_KEY")
|
||||
return api_base, dynamic_api_key
|
||||
|
||||
def get_supported_openai_params(self, model: str) -> list:
|
||||
base_openai_params = [
|
||||
"frequency_penalty",
|
||||
"logit_bias",
|
||||
"logprobs",
|
||||
"max_tokens",
|
||||
"n",
|
||||
"presence_penalty",
|
||||
"response_format",
|
||||
"seed",
|
||||
"stop",
|
||||
"stream",
|
||||
"stream_options",
|
||||
"temperature",
|
||||
"tool_choice",
|
||||
"tools",
|
||||
"top_logprobs",
|
||||
"top_p",
|
||||
"user",
|
||||
]
|
||||
try:
|
||||
if litellm.supports_reasoning(
|
||||
model=model, custom_llm_provider=self.custom_llm_provider
|
||||
):
|
||||
base_openai_params.append("reasoning_effort")
|
||||
except Exception as e:
|
||||
verbose_logger.debug(f"Error checking if model supports reasoning: {e}")
|
||||
|
||||
return base_openai_params
|
||||
|
||||
def map_openai_params(
|
||||
self,
|
||||
non_default_params: dict,
|
||||
optional_params: dict,
|
||||
model: str,
|
||||
drop_params: bool = False,
|
||||
) -> dict:
|
||||
supported_openai_params = self.get_supported_openai_params(model=model)
|
||||
for param, value in non_default_params.items():
|
||||
if param == "max_completion_tokens":
|
||||
optional_params["max_tokens"] = value
|
||||
elif param in supported_openai_params:
|
||||
if value is not None:
|
||||
optional_params[param] = value
|
||||
return optional_params
|
||||
|
||||
def transform_request(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
headers: dict,
|
||||
) -> dict:
|
||||
"""
|
||||
Handle https://github.com/BerriAI/litellm/issues/9720
|
||||
|
||||
Filter out 'name' from messages
|
||||
"""
|
||||
messages = strip_name_from_messages(messages)
|
||||
return super().transform_request(
|
||||
model, messages, optional_params, litellm_params, headers
|
||||
)
|
||||
@@ -0,0 +1,71 @@
|
||||
from typing import List, Optional
|
||||
|
||||
import httpx
|
||||
|
||||
import litellm
|
||||
from litellm.llms.base_llm.base_utils import BaseLLMModelInfo
|
||||
from litellm.secret_managers.main import get_secret_str
|
||||
from litellm.types.llms.openai import AllMessageValues
|
||||
|
||||
|
||||
class XAIModelInfo(BaseLLMModelInfo):
|
||||
def validate_environment(
|
||||
self,
|
||||
headers: dict,
|
||||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
) -> dict:
|
||||
if api_key is not None:
|
||||
headers["Authorization"] = f"Bearer {api_key}"
|
||||
|
||||
# Ensure Content-Type is set to application/json
|
||||
if "content-type" not in headers and "Content-Type" not in headers:
|
||||
headers["Content-Type"] = "application/json"
|
||||
|
||||
return headers
|
||||
|
||||
@staticmethod
|
||||
def get_api_base(api_base: Optional[str] = None) -> Optional[str]:
|
||||
return api_base or get_secret_str("XAI_API_BASE") or "https://api.x.ai"
|
||||
|
||||
@staticmethod
|
||||
def get_api_key(api_key: Optional[str] = None) -> Optional[str]:
|
||||
return api_key or get_secret_str("XAI_API_KEY")
|
||||
|
||||
@staticmethod
|
||||
def get_base_model(model: str) -> Optional[str]:
|
||||
return model.replace("xai/", "")
|
||||
|
||||
def get_models(
|
||||
self, api_key: Optional[str] = None, api_base: Optional[str] = None
|
||||
) -> List[str]:
|
||||
api_base = self.get_api_base(api_base)
|
||||
api_key = self.get_api_key(api_key)
|
||||
if api_base is None or api_key is None:
|
||||
raise ValueError(
|
||||
"XAI_API_BASE or XAI_API_KEY is not set. Please set the environment variable, to query XAI's `/models` endpoint."
|
||||
)
|
||||
response = litellm.module_level_client.get(
|
||||
url=f"{api_base}/v1/models",
|
||||
headers={"Authorization": f"Bearer {api_key}"},
|
||||
)
|
||||
|
||||
try:
|
||||
response.raise_for_status()
|
||||
except httpx.HTTPStatusError:
|
||||
raise Exception(
|
||||
f"Failed to fetch models from XAI. Status code: {response.status_code}, Response: {response.text}"
|
||||
)
|
||||
|
||||
models = response.json()["data"]
|
||||
|
||||
litellm_model_names = []
|
||||
for model in models:
|
||||
stripped_model_name = model["id"]
|
||||
litellm_model_name = "xai/" + stripped_model_name
|
||||
litellm_model_names.append(litellm_model_name)
|
||||
return litellm_model_names
|
||||
Reference in New Issue
Block a user