structure saas with tools
This commit is contained in:
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,17 @@
|
||||
from typing import List
|
||||
|
||||
from litellm.types.llms.openai import OpenAIAudioTranscriptionOptionalParams
|
||||
|
||||
from ...openai.transcriptions.whisper_transformation import (
|
||||
OpenAIWhisperAudioTranscriptionConfig,
|
||||
)
|
||||
from ..common_utils import FireworksAIMixin
|
||||
|
||||
|
||||
class FireworksAIAudioTranscriptionConfig(
|
||||
FireworksAIMixin, OpenAIWhisperAudioTranscriptionConfig
|
||||
):
|
||||
def get_supported_openai_params(
|
||||
self, model: str
|
||||
) -> List[OpenAIAudioTranscriptionOptionalParams]:
|
||||
return ["language", "prompt", "response_format", "timestamp_granularities"]
|
||||
Binary file not shown.
@@ -0,0 +1,380 @@
|
||||
import json
|
||||
import uuid
|
||||
from typing import Any, List, Literal, Optional, Tuple, Union, cast
|
||||
|
||||
import httpx
|
||||
|
||||
import litellm
|
||||
from litellm.constants import RESPONSE_FORMAT_TOOL_NAME
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
|
||||
from litellm.litellm_core_utils.llm_response_utils.get_headers import (
|
||||
get_response_headers,
|
||||
)
|
||||
from litellm.secret_managers.main import get_secret_str
|
||||
from litellm.types.llms.openai import (
|
||||
AllMessageValues,
|
||||
ChatCompletionImageObject,
|
||||
ChatCompletionToolParam,
|
||||
OpenAIChatCompletionToolParam,
|
||||
)
|
||||
from litellm.types.utils import (
|
||||
ChatCompletionMessageToolCall,
|
||||
Choices,
|
||||
Function,
|
||||
Message,
|
||||
ModelResponse,
|
||||
ProviderSpecificModelInfo,
|
||||
)
|
||||
|
||||
from ...openai.chat.gpt_transformation import OpenAIGPTConfig
|
||||
from ..common_utils import FireworksAIException
|
||||
|
||||
|
||||
class FireworksAIConfig(OpenAIGPTConfig):
|
||||
"""
|
||||
Reference: https://docs.fireworks.ai/api-reference/post-chatcompletions
|
||||
|
||||
The class `FireworksAIConfig` provides configuration for the Fireworks's Chat Completions API interface. Below are the parameters:
|
||||
"""
|
||||
|
||||
tools: Optional[list] = None
|
||||
tool_choice: Optional[Union[str, dict]] = None
|
||||
max_tokens: Optional[int] = None
|
||||
temperature: Optional[int] = None
|
||||
top_p: Optional[int] = None
|
||||
top_k: Optional[int] = None
|
||||
frequency_penalty: Optional[int] = None
|
||||
presence_penalty: Optional[int] = None
|
||||
n: Optional[int] = None
|
||||
stop: Optional[Union[str, list]] = None
|
||||
response_format: Optional[dict] = None
|
||||
user: Optional[str] = None
|
||||
logprobs: Optional[int] = None
|
||||
|
||||
# Non OpenAI parameters - Fireworks AI only params
|
||||
prompt_truncate_length: Optional[int] = None
|
||||
context_length_exceeded_behavior: Optional[Literal["error", "truncate"]] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tools: Optional[list] = None,
|
||||
tool_choice: Optional[Union[str, dict]] = None,
|
||||
max_tokens: Optional[int] = None,
|
||||
temperature: Optional[int] = None,
|
||||
top_p: Optional[int] = None,
|
||||
top_k: Optional[int] = None,
|
||||
frequency_penalty: Optional[int] = None,
|
||||
presence_penalty: Optional[int] = None,
|
||||
n: Optional[int] = None,
|
||||
stop: Optional[Union[str, list]] = None,
|
||||
response_format: Optional[dict] = None,
|
||||
user: Optional[str] = None,
|
||||
logprobs: Optional[int] = None,
|
||||
prompt_truncate_length: Optional[int] = None,
|
||||
context_length_exceeded_behavior: Optional[Literal["error", "truncate"]] = None,
|
||||
) -> None:
|
||||
locals_ = locals().copy()
|
||||
for key, value in locals_.items():
|
||||
if key != "self" and value is not None:
|
||||
setattr(self.__class__, key, value)
|
||||
|
||||
@classmethod
|
||||
def get_config(cls):
|
||||
return super().get_config()
|
||||
|
||||
def get_supported_openai_params(self, model: str):
|
||||
return [
|
||||
"stream",
|
||||
"tools",
|
||||
"tool_choice",
|
||||
"max_completion_tokens",
|
||||
"max_tokens",
|
||||
"temperature",
|
||||
"top_p",
|
||||
"top_k",
|
||||
"frequency_penalty",
|
||||
"presence_penalty",
|
||||
"n",
|
||||
"stop",
|
||||
"response_format",
|
||||
"user",
|
||||
"logprobs",
|
||||
"prompt_truncate_length",
|
||||
"context_length_exceeded_behavior",
|
||||
]
|
||||
|
||||
def map_openai_params(
|
||||
self,
|
||||
non_default_params: dict,
|
||||
optional_params: dict,
|
||||
model: str,
|
||||
drop_params: bool,
|
||||
) -> dict:
|
||||
supported_openai_params = self.get_supported_openai_params(model=model)
|
||||
is_tools_set = any(
|
||||
param == "tools" and value is not None
|
||||
for param, value in non_default_params.items()
|
||||
)
|
||||
|
||||
for param, value in non_default_params.items():
|
||||
if param == "tool_choice":
|
||||
if value == "required":
|
||||
# relevant issue: https://github.com/BerriAI/litellm/issues/4416
|
||||
optional_params["tool_choice"] = "any"
|
||||
else:
|
||||
# pass through the value of tool choice
|
||||
optional_params["tool_choice"] = value
|
||||
elif param == "response_format":
|
||||
if (
|
||||
is_tools_set
|
||||
): # fireworks ai doesn't support tools and response_format together
|
||||
optional_params = self._add_response_format_to_tools(
|
||||
optional_params=optional_params,
|
||||
value=value,
|
||||
is_response_format_supported=False,
|
||||
enforce_tool_choice=False, # tools and response_format are both set, don't enforce tool_choice
|
||||
)
|
||||
elif "json_schema" in value:
|
||||
optional_params["response_format"] = {
|
||||
"type": "json_object",
|
||||
"schema": value["json_schema"]["schema"],
|
||||
}
|
||||
else:
|
||||
optional_params["response_format"] = value
|
||||
elif param == "max_completion_tokens":
|
||||
optional_params["max_tokens"] = value
|
||||
elif param in supported_openai_params:
|
||||
if value is not None:
|
||||
optional_params[param] = value
|
||||
|
||||
return optional_params
|
||||
|
||||
def _add_transform_inline_image_block(
|
||||
self,
|
||||
content: ChatCompletionImageObject,
|
||||
model: str,
|
||||
disable_add_transform_inline_image_block: Optional[bool],
|
||||
) -> ChatCompletionImageObject:
|
||||
"""
|
||||
Add transform_inline to the image_url (allows non-vision models to parse documents/images/etc.)
|
||||
- ignore if model is a vision model
|
||||
- ignore if user has disabled this feature
|
||||
"""
|
||||
if (
|
||||
"vision" in model or disable_add_transform_inline_image_block
|
||||
): # allow user to toggle this feature.
|
||||
return content
|
||||
if isinstance(content["image_url"], str):
|
||||
content["image_url"] = f"{content['image_url']}#transform=inline"
|
||||
elif isinstance(content["image_url"], dict):
|
||||
content["image_url"][
|
||||
"url"
|
||||
] = f"{content['image_url']['url']}#transform=inline"
|
||||
return content
|
||||
|
||||
def _transform_tools(
|
||||
self, tools: List[OpenAIChatCompletionToolParam]
|
||||
) -> List[OpenAIChatCompletionToolParam]:
|
||||
for tool in tools:
|
||||
if tool.get("type") == "function":
|
||||
tool["function"].pop("strict", None)
|
||||
return tools
|
||||
|
||||
def _transform_messages_helper(
|
||||
self, messages: List[AllMessageValues], model: str, litellm_params: dict
|
||||
) -> List[AllMessageValues]:
|
||||
"""
|
||||
Add 'transform=inline' to the url of the image_url
|
||||
"""
|
||||
disable_add_transform_inline_image_block = cast(
|
||||
Optional[bool],
|
||||
litellm_params.get("disable_add_transform_inline_image_block")
|
||||
or litellm.disable_add_transform_inline_image_block,
|
||||
)
|
||||
for message in messages:
|
||||
if message["role"] == "user":
|
||||
_message_content = message.get("content")
|
||||
if _message_content is not None and isinstance(_message_content, list):
|
||||
for content in _message_content:
|
||||
if content["type"] == "image_url":
|
||||
content = self._add_transform_inline_image_block(
|
||||
content=content,
|
||||
model=model,
|
||||
disable_add_transform_inline_image_block=disable_add_transform_inline_image_block,
|
||||
)
|
||||
return messages
|
||||
|
||||
def get_provider_info(self, model: str) -> ProviderSpecificModelInfo:
|
||||
provider_specific_model_info = ProviderSpecificModelInfo(
|
||||
supports_function_calling=True,
|
||||
supports_prompt_caching=True, # https://docs.fireworks.ai/guides/prompt-caching
|
||||
supports_pdf_input=True, # via document inlining
|
||||
supports_vision=True, # via document inlining
|
||||
)
|
||||
return provider_specific_model_info
|
||||
|
||||
def transform_request(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
headers: dict,
|
||||
) -> dict:
|
||||
if not model.startswith("accounts/"):
|
||||
model = f"accounts/fireworks/models/{model}"
|
||||
messages = self._transform_messages_helper(
|
||||
messages=messages, model=model, litellm_params=litellm_params
|
||||
)
|
||||
if "tools" in optional_params and optional_params["tools"] is not None:
|
||||
tools = self._transform_tools(tools=optional_params["tools"])
|
||||
optional_params["tools"] = tools
|
||||
return super().transform_request(
|
||||
model=model,
|
||||
messages=messages,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
def _handle_message_content_with_tool_calls(
|
||||
self,
|
||||
message: Message,
|
||||
tool_calls: Optional[List[ChatCompletionToolParam]],
|
||||
) -> Message:
|
||||
"""
|
||||
Fireworks AI sends tool calls in the content field instead of tool_calls
|
||||
|
||||
Relevant Issue: https://github.com/BerriAI/litellm/issues/7209#issuecomment-2813208780
|
||||
"""
|
||||
if (
|
||||
tool_calls is not None
|
||||
and message.content is not None
|
||||
and message.tool_calls is None
|
||||
):
|
||||
try:
|
||||
function = Function(**json.loads(message.content))
|
||||
if function.name != RESPONSE_FORMAT_TOOL_NAME and function.name in [
|
||||
tool["function"]["name"] for tool in tool_calls
|
||||
]:
|
||||
tool_call = ChatCompletionMessageToolCall(
|
||||
function=function, id=str(uuid.uuid4()), type="function"
|
||||
)
|
||||
message.tool_calls = [tool_call]
|
||||
|
||||
message.content = None
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return message
|
||||
|
||||
def transform_response(
|
||||
self,
|
||||
model: str,
|
||||
raw_response: httpx.Response,
|
||||
model_response: ModelResponse,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
request_data: dict,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
encoding: Any,
|
||||
api_key: Optional[str] = None,
|
||||
json_mode: Optional[bool] = None,
|
||||
) -> ModelResponse:
|
||||
## LOGGING
|
||||
logging_obj.post_call(
|
||||
input=messages,
|
||||
api_key=api_key,
|
||||
original_response=raw_response.text,
|
||||
additional_args={"complete_input_dict": request_data},
|
||||
)
|
||||
|
||||
## RESPONSE OBJECT
|
||||
try:
|
||||
completion_response = raw_response.json()
|
||||
except Exception as e:
|
||||
response_headers = getattr(raw_response, "headers", None)
|
||||
raise FireworksAIException(
|
||||
message="Unable to get json response - {}, Original Response: {}".format(
|
||||
str(e), raw_response.text
|
||||
),
|
||||
status_code=raw_response.status_code,
|
||||
headers=response_headers,
|
||||
)
|
||||
|
||||
raw_response_headers = dict(raw_response.headers)
|
||||
|
||||
additional_headers = get_response_headers(raw_response_headers)
|
||||
|
||||
response = ModelResponse(**completion_response)
|
||||
|
||||
if response.model is not None:
|
||||
response.model = "fireworks_ai/" + response.model
|
||||
|
||||
## FIREWORKS AI sends tool calls in the content field instead of tool_calls
|
||||
for choice in response.choices:
|
||||
cast(
|
||||
Choices, choice
|
||||
).message = self._handle_message_content_with_tool_calls(
|
||||
message=cast(Choices, choice).message,
|
||||
tool_calls=optional_params.get("tools", None),
|
||||
)
|
||||
|
||||
response._hidden_params = {"additional_headers": additional_headers}
|
||||
|
||||
return response
|
||||
|
||||
def _get_openai_compatible_provider_info(
|
||||
self, api_base: Optional[str], api_key: Optional[str]
|
||||
) -> Tuple[Optional[str], Optional[str]]:
|
||||
api_base = (
|
||||
api_base
|
||||
or get_secret_str("FIREWORKS_API_BASE")
|
||||
or "https://api.fireworks.ai/inference/v1"
|
||||
) # type: ignore
|
||||
dynamic_api_key = api_key or (
|
||||
get_secret_str("FIREWORKS_API_KEY")
|
||||
or get_secret_str("FIREWORKS_AI_API_KEY")
|
||||
or get_secret_str("FIREWORKSAI_API_KEY")
|
||||
or get_secret_str("FIREWORKS_AI_TOKEN")
|
||||
)
|
||||
return api_base, dynamic_api_key
|
||||
|
||||
def get_models(self, api_key: Optional[str] = None, api_base: Optional[str] = None):
|
||||
api_base, api_key = self._get_openai_compatible_provider_info(
|
||||
api_base=api_base, api_key=api_key
|
||||
)
|
||||
if api_base is None or api_key is None:
|
||||
raise ValueError(
|
||||
"FIREWORKS_API_BASE or FIREWORKS_API_KEY is not set. Please set the environment variable, to query Fireworks AI's `/models` endpoint."
|
||||
)
|
||||
|
||||
account_id = get_secret_str("FIREWORKS_ACCOUNT_ID")
|
||||
if account_id is None:
|
||||
raise ValueError(
|
||||
"FIREWORKS_ACCOUNT_ID is not set. Please set the environment variable, to query Fireworks AI's `/models` endpoint."
|
||||
)
|
||||
|
||||
response = litellm.module_level_client.get(
|
||||
url=f"{api_base}/v1/accounts/{account_id}/models",
|
||||
headers={"Authorization": f"Bearer {api_key}"},
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
raise ValueError(
|
||||
f"Failed to fetch models from Fireworks AI. Status code: {response.status_code}, Response: {response.json()}"
|
||||
)
|
||||
|
||||
models = response.json()["models"]
|
||||
|
||||
return ["fireworks_ai/" + model["name"] for model in models]
|
||||
|
||||
@staticmethod
|
||||
def get_api_key(api_key: Optional[str] = None) -> Optional[str]:
|
||||
return api_key or (
|
||||
get_secret_str("FIREWORKS_API_KEY")
|
||||
or get_secret_str("FIREWORKS_AI_API_KEY")
|
||||
or get_secret_str("FIREWORKSAI_API_KEY")
|
||||
or get_secret_str("FIREWORKS_AI_TOKEN")
|
||||
)
|
||||
@@ -0,0 +1,52 @@
|
||||
from typing import List, Optional, Union
|
||||
|
||||
from httpx import Headers
|
||||
|
||||
from litellm.secret_managers.main import get_secret_str
|
||||
from litellm.types.llms.openai import AllMessageValues
|
||||
|
||||
from ..base_llm.chat.transformation import BaseLLMException
|
||||
|
||||
|
||||
class FireworksAIException(BaseLLMException):
|
||||
pass
|
||||
|
||||
|
||||
class FireworksAIMixin:
|
||||
"""
|
||||
Common Base Config functions across Fireworks AI Endpoints
|
||||
"""
|
||||
|
||||
def get_error_class(
|
||||
self, error_message: str, status_code: int, headers: Union[dict, Headers]
|
||||
) -> BaseLLMException:
|
||||
return FireworksAIException(
|
||||
status_code=status_code,
|
||||
message=error_message,
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
def _get_api_key(self, api_key: Optional[str]) -> Optional[str]:
|
||||
dynamic_api_key = api_key or (
|
||||
get_secret_str("FIREWORKS_API_KEY")
|
||||
or get_secret_str("FIREWORKS_AI_API_KEY")
|
||||
or get_secret_str("FIREWORKSAI_API_KEY")
|
||||
or get_secret_str("FIREWORKS_AI_TOKEN")
|
||||
)
|
||||
return dynamic_api_key
|
||||
|
||||
def validate_environment(
|
||||
self,
|
||||
headers: dict,
|
||||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
) -> dict:
|
||||
api_key = self._get_api_key(api_key)
|
||||
if api_key is None:
|
||||
raise ValueError("FIREWORKS_API_KEY is not set")
|
||||
|
||||
return {"Authorization": "Bearer {}".format(api_key), **headers}
|
||||
Binary file not shown.
@@ -0,0 +1,61 @@
|
||||
from typing import List, Union
|
||||
|
||||
from litellm.types.llms.openai import AllMessageValues, OpenAITextCompletionUserMessage
|
||||
|
||||
from ...base_llm.completion.transformation import BaseTextCompletionConfig
|
||||
from ...openai.completion.utils import _transform_prompt
|
||||
from ..common_utils import FireworksAIMixin
|
||||
|
||||
|
||||
class FireworksAITextCompletionConfig(FireworksAIMixin, BaseTextCompletionConfig):
|
||||
def get_supported_openai_params(self, model: str) -> list:
|
||||
"""
|
||||
See how LiteLLM supports Provider-specific parameters - https://docs.litellm.ai/docs/completion/provider_specific_params#proxy-usage
|
||||
"""
|
||||
return [
|
||||
"max_tokens",
|
||||
"logprobs",
|
||||
"echo",
|
||||
"temperature",
|
||||
"top_p",
|
||||
"top_k",
|
||||
"frequency_penalty",
|
||||
"presence_penalty",
|
||||
"n",
|
||||
"stop",
|
||||
"response_format",
|
||||
"stream",
|
||||
"user",
|
||||
]
|
||||
|
||||
def map_openai_params(
|
||||
self,
|
||||
non_default_params: dict,
|
||||
optional_params: dict,
|
||||
model: str,
|
||||
drop_params: bool,
|
||||
) -> dict:
|
||||
supported_params = self.get_supported_openai_params(model)
|
||||
for k, v in non_default_params.items():
|
||||
if k in supported_params:
|
||||
optional_params[k] = v
|
||||
return optional_params
|
||||
|
||||
def transform_text_completion_request(
|
||||
self,
|
||||
model: str,
|
||||
messages: Union[List[AllMessageValues], List[OpenAITextCompletionUserMessage]],
|
||||
optional_params: dict,
|
||||
headers: dict,
|
||||
) -> dict:
|
||||
prompt = _transform_prompt(messages=messages)
|
||||
|
||||
if not model.startswith("accounts/"):
|
||||
model = f"accounts/fireworks/models/{model}"
|
||||
|
||||
data = {
|
||||
"model": model,
|
||||
"prompt": prompt,
|
||||
**optional_params,
|
||||
}
|
||||
return data
|
||||
@@ -0,0 +1,84 @@
|
||||
"""
|
||||
For calculating cost of fireworks ai serverless inference models.
|
||||
"""
|
||||
|
||||
from typing import Tuple
|
||||
|
||||
from litellm.constants import (
|
||||
FIREWORKS_AI_16_B,
|
||||
FIREWORKS_AI_56_B_MOE,
|
||||
FIREWORKS_AI_80_B,
|
||||
FIREWORKS_AI_176_B_MOE,
|
||||
)
|
||||
from litellm.types.utils import Usage
|
||||
from litellm.utils import get_model_info
|
||||
|
||||
|
||||
# Extract the number of billion parameters from the model name
|
||||
# only used for together_computer LLMs
|
||||
def get_base_model_for_pricing(model_name: str) -> str:
|
||||
"""
|
||||
Helper function for calculating together ai pricing.
|
||||
|
||||
Returns:
|
||||
- str: model pricing category if mapped else received model name
|
||||
"""
|
||||
import re
|
||||
|
||||
model_name = model_name.lower()
|
||||
|
||||
# Check for MoE models in the form <number>x<number>b
|
||||
moe_match = re.search(r"(\d+)x(\d+)b", model_name)
|
||||
if moe_match:
|
||||
total_billion = int(moe_match.group(1)) * int(moe_match.group(2))
|
||||
if total_billion <= FIREWORKS_AI_56_B_MOE:
|
||||
return "fireworks-ai-moe-up-to-56b"
|
||||
elif total_billion <= FIREWORKS_AI_176_B_MOE:
|
||||
return "fireworks-ai-56b-to-176b"
|
||||
|
||||
# Check for standard models in the form <number>b
|
||||
re_params_match = re.search(r"(\d+)b", model_name)
|
||||
if re_params_match is not None:
|
||||
params_match = str(re_params_match.group(1))
|
||||
params_billion = float(params_match)
|
||||
|
||||
# Determine the category based on the number of parameters
|
||||
if params_billion <= FIREWORKS_AI_16_B:
|
||||
return "fireworks-ai-up-to-16b"
|
||||
elif params_billion <= FIREWORKS_AI_80_B:
|
||||
return "fireworks-ai-16b-80b"
|
||||
|
||||
# If no matches, return the original model_name
|
||||
return "fireworks-ai-default"
|
||||
|
||||
|
||||
def cost_per_token(model: str, usage: Usage) -> Tuple[float, float]:
|
||||
"""
|
||||
Calculates the cost per token for a given model, prompt tokens, and completion tokens.
|
||||
|
||||
Input:
|
||||
- model: str, the model name without provider prefix
|
||||
- usage: LiteLLM Usage block, containing anthropic caching information
|
||||
|
||||
Returns:
|
||||
Tuple[float, float] - prompt_cost_in_usd, completion_cost_in_usd
|
||||
"""
|
||||
## check if model mapped, else use default pricing
|
||||
try:
|
||||
model_info = get_model_info(model=model, custom_llm_provider="fireworks_ai")
|
||||
except Exception:
|
||||
base_model = get_base_model_for_pricing(model_name=model)
|
||||
|
||||
## GET MODEL INFO
|
||||
model_info = get_model_info(
|
||||
model=base_model, custom_llm_provider="fireworks_ai"
|
||||
)
|
||||
|
||||
## CALCULATE INPUT COST
|
||||
|
||||
prompt_cost: float = usage["prompt_tokens"] * model_info["input_cost_per_token"]
|
||||
|
||||
## CALCULATE OUTPUT COST
|
||||
completion_cost = usage["completion_tokens"] * model_info["output_cost_per_token"]
|
||||
|
||||
return prompt_cost, completion_cost
|
||||
Binary file not shown.
@@ -0,0 +1,44 @@
|
||||
"""
|
||||
This is OpenAI compatible - no transformation is applied
|
||||
|
||||
"""
|
||||
|
||||
import litellm
|
||||
|
||||
|
||||
class FireworksAIEmbeddingConfig:
|
||||
def get_supported_openai_params(self, model: str):
|
||||
"""
|
||||
dimensions Only supported in nomic-ai/nomic-embed-text-v1.5 and later models.
|
||||
|
||||
https://docs.fireworks.ai/api-reference/creates-an-embedding-vector-representing-the-input-text
|
||||
"""
|
||||
if "nomic-ai" in model:
|
||||
return ["dimensions"]
|
||||
return []
|
||||
|
||||
def map_openai_params(
|
||||
self, non_default_params: dict, optional_params: dict, model: str
|
||||
):
|
||||
"""
|
||||
No transformation is applied - fireworks ai is openai compatible
|
||||
"""
|
||||
supported_openai_params = self.get_supported_openai_params(model)
|
||||
for param, value in non_default_params.items():
|
||||
if param in supported_openai_params:
|
||||
optional_params[param] = value
|
||||
return optional_params
|
||||
|
||||
def is_fireworks_embedding_model(self, model: str):
|
||||
"""
|
||||
helper to check if a model is a fireworks embedding model
|
||||
|
||||
Fireworks embeddings does not support passing /accounts/fireworks in the model name so we need to know if it's a known embedding model
|
||||
"""
|
||||
if (
|
||||
model in litellm.fireworks_ai_embedding_models
|
||||
or f"fireworks_ai/{model}" in litellm.fireworks_ai_embedding_models
|
||||
):
|
||||
return True
|
||||
|
||||
return False
|
||||
Reference in New Issue
Block a user