structure saas with tools
This commit is contained in:
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,566 @@
|
||||
"""
|
||||
Translates from OpenAI's `/v1/chat/completions` to Databricks' `/chat/completions`
|
||||
"""
|
||||
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
AsyncIterator,
|
||||
Iterator,
|
||||
List,
|
||||
Optional,
|
||||
Tuple,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
|
||||
import httpx
|
||||
from pydantic import BaseModel
|
||||
|
||||
from litellm.constants import RESPONSE_FORMAT_TOOL_NAME
|
||||
from litellm.litellm_core_utils.llm_response_utils.convert_dict_to_response import (
|
||||
_handle_invalid_parallel_tool_calls,
|
||||
_should_convert_tool_call_to_json_mode,
|
||||
)
|
||||
from litellm.litellm_core_utils.prompt_templates.common_utils import (
|
||||
handle_messages_with_content_list_to_str_conversion,
|
||||
strip_name_from_messages,
|
||||
)
|
||||
from litellm.llms.base_llm.base_model_iterator import BaseModelResponseIterator
|
||||
from litellm.types.llms.anthropic import AllAnthropicToolsValues
|
||||
from litellm.types.llms.databricks import (
|
||||
AllDatabricksContentValues,
|
||||
DatabricksChoice,
|
||||
DatabricksFunction,
|
||||
DatabricksResponse,
|
||||
DatabricksTool,
|
||||
)
|
||||
from litellm.types.llms.openai import (
|
||||
AllMessageValues,
|
||||
ChatCompletionRedactedThinkingBlock,
|
||||
ChatCompletionThinkingBlock,
|
||||
ChatCompletionToolChoiceFunctionParam,
|
||||
ChatCompletionToolChoiceObjectParam,
|
||||
)
|
||||
from litellm.types.utils import (
|
||||
ChatCompletionMessageToolCall,
|
||||
Choices,
|
||||
Message,
|
||||
ModelResponse,
|
||||
ModelResponseStream,
|
||||
ProviderField,
|
||||
Usage,
|
||||
)
|
||||
|
||||
from ...anthropic.chat.transformation import AnthropicConfig
|
||||
from ...openai_like.chat.transformation import OpenAILikeChatConfig
|
||||
from ..common_utils import DatabricksBase, DatabricksException
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
|
||||
|
||||
LiteLLMLoggingObj = _LiteLLMLoggingObj
|
||||
else:
|
||||
LiteLLMLoggingObj = Any
|
||||
|
||||
|
||||
class DatabricksConfig(DatabricksBase, OpenAILikeChatConfig, AnthropicConfig):
|
||||
"""
|
||||
Reference: https://docs.databricks.com/en/machine-learning/foundation-models/api-reference.html#chat-request
|
||||
"""
|
||||
|
||||
max_tokens: Optional[int] = None
|
||||
temperature: Optional[int] = None
|
||||
top_p: Optional[int] = None
|
||||
top_k: Optional[int] = None
|
||||
stop: Optional[Union[List[str], str]] = None
|
||||
n: Optional[int] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_tokens: Optional[int] = None,
|
||||
temperature: Optional[int] = None,
|
||||
top_p: Optional[int] = None,
|
||||
top_k: Optional[int] = None,
|
||||
stop: Optional[Union[List[str], str]] = None,
|
||||
n: Optional[int] = None,
|
||||
) -> None:
|
||||
locals_ = locals().copy()
|
||||
for key, value in locals_.items():
|
||||
if key != "self" and value is not None:
|
||||
setattr(self.__class__, key, value)
|
||||
|
||||
@classmethod
|
||||
def get_config(cls):
|
||||
return super().get_config()
|
||||
|
||||
def get_required_params(self) -> List[ProviderField]:
|
||||
"""For a given provider, return it's required fields with a description"""
|
||||
return [
|
||||
ProviderField(
|
||||
field_name="api_key",
|
||||
field_type="string",
|
||||
field_description="Your Databricks API Key.",
|
||||
field_value="dapi...",
|
||||
),
|
||||
ProviderField(
|
||||
field_name="api_base",
|
||||
field_type="string",
|
||||
field_description="Your Databricks API Base.",
|
||||
field_value="https://adb-..",
|
||||
),
|
||||
]
|
||||
|
||||
def validate_environment(
|
||||
self,
|
||||
headers: dict,
|
||||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
) -> dict:
|
||||
api_base, headers = self.databricks_validate_environment(
|
||||
api_base=api_base,
|
||||
api_key=api_key,
|
||||
endpoint_type="chat_completions",
|
||||
custom_endpoint=False,
|
||||
headers=headers,
|
||||
)
|
||||
# Ensure Content-Type header is set
|
||||
headers["Content-Type"] = "application/json"
|
||||
return headers
|
||||
|
||||
def get_complete_url(
|
||||
self,
|
||||
api_base: Optional[str],
|
||||
api_key: Optional[str],
|
||||
model: str,
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
stream: Optional[bool] = None,
|
||||
) -> str:
|
||||
api_base = self._get_api_base(api_base)
|
||||
complete_url = f"{api_base}/chat/completions"
|
||||
return complete_url
|
||||
|
||||
def get_supported_openai_params(self, model: Optional[str] = None) -> list:
|
||||
return [
|
||||
"stream",
|
||||
"stop",
|
||||
"temperature",
|
||||
"top_p",
|
||||
"max_tokens",
|
||||
"max_completion_tokens",
|
||||
"n",
|
||||
"response_format",
|
||||
"tools",
|
||||
"tool_choice",
|
||||
"reasoning_effort",
|
||||
"thinking",
|
||||
]
|
||||
|
||||
def convert_anthropic_tool_to_databricks_tool(
|
||||
self, tool: Optional[AllAnthropicToolsValues]
|
||||
) -> Optional[DatabricksTool]:
|
||||
if tool is None:
|
||||
return None
|
||||
|
||||
return DatabricksTool(
|
||||
type="function",
|
||||
function=DatabricksFunction(
|
||||
name=tool["name"],
|
||||
parameters=cast(dict, tool.get("input_schema") or {}),
|
||||
),
|
||||
)
|
||||
|
||||
def _map_openai_to_dbrx_tool(self, model: str, tools: List) -> List[DatabricksTool]:
|
||||
# if not claude, send as is
|
||||
if "claude" not in model:
|
||||
return tools
|
||||
|
||||
# if claude, convert to anthropic tool and then to databricks tool
|
||||
anthropic_tools = self._map_tools(tools=tools)
|
||||
databricks_tools = [
|
||||
cast(DatabricksTool, self.convert_anthropic_tool_to_databricks_tool(tool))
|
||||
for tool in anthropic_tools
|
||||
]
|
||||
return databricks_tools
|
||||
|
||||
def map_response_format_to_databricks_tool(
|
||||
self,
|
||||
model: str,
|
||||
value: Optional[dict],
|
||||
optional_params: dict,
|
||||
is_thinking_enabled: bool,
|
||||
) -> Optional[DatabricksTool]:
|
||||
if value is None:
|
||||
return None
|
||||
|
||||
tool = self.map_response_format_to_anthropic_tool(
|
||||
value, optional_params, is_thinking_enabled
|
||||
)
|
||||
|
||||
databricks_tool = self.convert_anthropic_tool_to_databricks_tool(tool)
|
||||
return databricks_tool
|
||||
|
||||
def map_openai_params(
|
||||
self,
|
||||
non_default_params: dict,
|
||||
optional_params: dict,
|
||||
model: str,
|
||||
drop_params: bool,
|
||||
replace_max_completion_tokens_with_max_tokens: bool = True,
|
||||
) -> dict:
|
||||
is_thinking_enabled = self.is_thinking_enabled(non_default_params)
|
||||
mapped_params = super().map_openai_params(
|
||||
non_default_params, optional_params, model, drop_params
|
||||
)
|
||||
if "tools" in mapped_params:
|
||||
mapped_params["tools"] = self._map_openai_to_dbrx_tool(
|
||||
model=model, tools=mapped_params["tools"]
|
||||
)
|
||||
if (
|
||||
"max_completion_tokens" in non_default_params
|
||||
and replace_max_completion_tokens_with_max_tokens
|
||||
):
|
||||
mapped_params["max_tokens"] = non_default_params[
|
||||
"max_completion_tokens"
|
||||
] # most openai-compatible providers support 'max_tokens' not 'max_completion_tokens'
|
||||
mapped_params.pop("max_completion_tokens", None)
|
||||
|
||||
if "response_format" in non_default_params and "claude" in model:
|
||||
_tool = self.map_response_format_to_databricks_tool(
|
||||
model,
|
||||
non_default_params["response_format"],
|
||||
mapped_params,
|
||||
is_thinking_enabled,
|
||||
)
|
||||
|
||||
if _tool is not None:
|
||||
self._add_tools_to_optional_params(
|
||||
optional_params=optional_params, tools=[_tool]
|
||||
)
|
||||
optional_params["json_mode"] = True
|
||||
if not is_thinking_enabled:
|
||||
_tool_choice = ChatCompletionToolChoiceObjectParam(
|
||||
type="function",
|
||||
function=ChatCompletionToolChoiceFunctionParam(
|
||||
name=RESPONSE_FORMAT_TOOL_NAME
|
||||
),
|
||||
)
|
||||
optional_params["tool_choice"] = _tool_choice
|
||||
optional_params.pop(
|
||||
"response_format", None
|
||||
) # unsupported for claude models - if json_schema -> convert to tool call
|
||||
|
||||
if "reasoning_effort" in non_default_params and "claude" in model:
|
||||
optional_params["thinking"] = AnthropicConfig._map_reasoning_effort(
|
||||
non_default_params.get("reasoning_effort")
|
||||
)
|
||||
optional_params.pop("reasoning_effort", None)
|
||||
## handle thinking tokens
|
||||
self.update_optional_params_with_thinking_tokens(
|
||||
non_default_params=non_default_params, optional_params=mapped_params
|
||||
)
|
||||
|
||||
return mapped_params
|
||||
|
||||
def _should_fake_stream(self, optional_params: dict) -> bool:
|
||||
"""
|
||||
Databricks doesn't support 'response_format' while streaming
|
||||
"""
|
||||
if optional_params.get("response_format") is not None:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def _transform_messages(
|
||||
self, messages: List[AllMessageValues], model: str
|
||||
) -> List[AllMessageValues]:
|
||||
"""
|
||||
Databricks does not support:
|
||||
- content in list format.
|
||||
- 'name' in user message.
|
||||
"""
|
||||
new_messages = []
|
||||
for idx, message in enumerate(messages):
|
||||
if isinstance(message, BaseModel):
|
||||
_message = message.model_dump(exclude_none=True)
|
||||
else:
|
||||
_message = message
|
||||
new_messages.append(_message)
|
||||
new_messages = handle_messages_with_content_list_to_str_conversion(new_messages)
|
||||
new_messages = strip_name_from_messages(new_messages)
|
||||
return super()._transform_messages(messages=new_messages, model=model)
|
||||
|
||||
@staticmethod
|
||||
def extract_content_str(
|
||||
content: Optional[AllDatabricksContentValues],
|
||||
) -> Optional[str]:
|
||||
if content is None:
|
||||
return None
|
||||
if isinstance(content, str):
|
||||
return content
|
||||
elif isinstance(content, list):
|
||||
content_str = ""
|
||||
for item in content:
|
||||
if item["type"] == "text":
|
||||
content_str += item["text"]
|
||||
return content_str
|
||||
else:
|
||||
raise Exception(f"Unsupported content type: {type(content)}")
|
||||
|
||||
@staticmethod
|
||||
def extract_reasoning_content(
|
||||
content: Optional[AllDatabricksContentValues],
|
||||
) -> Tuple[
|
||||
Optional[str],
|
||||
Optional[
|
||||
List[
|
||||
Union[ChatCompletionThinkingBlock, ChatCompletionRedactedThinkingBlock]
|
||||
]
|
||||
],
|
||||
]:
|
||||
"""
|
||||
Extract and return the reasoning content and thinking blocks
|
||||
"""
|
||||
if content is None:
|
||||
return None, None
|
||||
thinking_blocks: Optional[
|
||||
List[
|
||||
Union[ChatCompletionThinkingBlock, ChatCompletionRedactedThinkingBlock]
|
||||
]
|
||||
] = None
|
||||
reasoning_content: Optional[str] = None
|
||||
if isinstance(content, list):
|
||||
for item in content:
|
||||
if item["type"] == "reasoning":
|
||||
for sum in item["summary"]:
|
||||
if reasoning_content is None:
|
||||
reasoning_content = ""
|
||||
reasoning_content += sum["text"]
|
||||
thinking_block = ChatCompletionThinkingBlock(
|
||||
type="thinking",
|
||||
thinking=sum["text"],
|
||||
signature=sum["signature"],
|
||||
)
|
||||
if thinking_blocks is None:
|
||||
thinking_blocks = []
|
||||
thinking_blocks.append(thinking_block)
|
||||
return reasoning_content, thinking_blocks
|
||||
|
||||
def _transform_choices(
|
||||
self, choices: List[DatabricksChoice], json_mode: Optional[bool] = None
|
||||
) -> List[Choices]:
|
||||
transformed_choices = []
|
||||
|
||||
for choice in choices:
|
||||
## HANDLE JSON MODE - anthropic returns single function call]
|
||||
tool_calls = choice["message"].get("tool_calls", None)
|
||||
if tool_calls is not None:
|
||||
_openai_tool_calls = []
|
||||
for _tc in tool_calls:
|
||||
_openai_tc = ChatCompletionMessageToolCall(**_tc) # type: ignore
|
||||
_openai_tool_calls.append(_openai_tc)
|
||||
fixed_tool_calls = _handle_invalid_parallel_tool_calls(
|
||||
_openai_tool_calls
|
||||
)
|
||||
|
||||
if fixed_tool_calls is not None:
|
||||
tool_calls = fixed_tool_calls
|
||||
|
||||
translated_message: Optional[Message] = None
|
||||
finish_reason: Optional[str] = None
|
||||
if tool_calls and _should_convert_tool_call_to_json_mode(
|
||||
tool_calls=tool_calls,
|
||||
convert_tool_call_to_json_mode=json_mode,
|
||||
):
|
||||
# to support response_format on claude models
|
||||
json_mode_content_str: Optional[str] = (
|
||||
str(tool_calls[0]["function"].get("arguments", "")) or None
|
||||
)
|
||||
if json_mode_content_str is not None:
|
||||
translated_message = Message(content=json_mode_content_str)
|
||||
finish_reason = "stop"
|
||||
|
||||
if translated_message is None:
|
||||
## get the content str
|
||||
content_str = DatabricksConfig.extract_content_str(
|
||||
choice["message"]["content"]
|
||||
)
|
||||
|
||||
## get the reasoning content
|
||||
(
|
||||
reasoning_content,
|
||||
thinking_blocks,
|
||||
) = DatabricksConfig.extract_reasoning_content(
|
||||
choice["message"].get("content")
|
||||
)
|
||||
|
||||
translated_message = Message(
|
||||
role="assistant",
|
||||
content=content_str,
|
||||
reasoning_content=reasoning_content,
|
||||
thinking_blocks=thinking_blocks,
|
||||
tool_calls=choice["message"].get("tool_calls"),
|
||||
)
|
||||
|
||||
if finish_reason is None:
|
||||
finish_reason = choice["finish_reason"]
|
||||
|
||||
translated_choice = Choices(
|
||||
finish_reason=finish_reason,
|
||||
index=choice["index"],
|
||||
message=translated_message,
|
||||
logprobs=None,
|
||||
enhancements=None,
|
||||
)
|
||||
|
||||
transformed_choices.append(translated_choice)
|
||||
|
||||
return transformed_choices
|
||||
|
||||
def transform_response(
|
||||
self,
|
||||
model: str,
|
||||
raw_response: httpx.Response,
|
||||
model_response: ModelResponse,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
request_data: dict,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
encoding: Any,
|
||||
api_key: Optional[str] = None,
|
||||
json_mode: Optional[bool] = None,
|
||||
) -> ModelResponse:
|
||||
## LOGGING
|
||||
logging_obj.post_call(
|
||||
input=messages,
|
||||
api_key=api_key,
|
||||
original_response=raw_response.text,
|
||||
additional_args={"complete_input_dict": request_data},
|
||||
)
|
||||
|
||||
## RESPONSE OBJECT
|
||||
try:
|
||||
completion_response = DatabricksResponse(**raw_response.json()) # type: ignore
|
||||
except Exception as e:
|
||||
response_headers = getattr(raw_response, "headers", None)
|
||||
raise DatabricksException(
|
||||
message="Unable to get json response - {}, Original Response: {}".format(
|
||||
str(e), raw_response.text
|
||||
),
|
||||
status_code=raw_response.status_code,
|
||||
headers=response_headers,
|
||||
)
|
||||
|
||||
model_response.model = completion_response["model"]
|
||||
model_response.id = completion_response["id"]
|
||||
model_response.created = completion_response["created"]
|
||||
setattr(model_response, "usage", Usage(**completion_response["usage"]))
|
||||
|
||||
model_response.choices = self._transform_choices( # type: ignore
|
||||
choices=completion_response["choices"],
|
||||
json_mode=json_mode,
|
||||
)
|
||||
|
||||
return model_response
|
||||
|
||||
def get_model_response_iterator(
|
||||
self,
|
||||
streaming_response: Union[Iterator[str], AsyncIterator[str], ModelResponse],
|
||||
sync_stream: bool,
|
||||
json_mode: Optional[bool] = False,
|
||||
):
|
||||
return DatabricksChatResponseIterator(
|
||||
streaming_response=streaming_response,
|
||||
sync_stream=sync_stream,
|
||||
json_mode=json_mode,
|
||||
)
|
||||
|
||||
|
||||
class DatabricksChatResponseIterator(BaseModelResponseIterator):
|
||||
def __init__(
|
||||
self,
|
||||
streaming_response: Union[Iterator[str], AsyncIterator[str], ModelResponse],
|
||||
sync_stream: bool,
|
||||
json_mode: Optional[bool] = False,
|
||||
):
|
||||
super().__init__(streaming_response, sync_stream)
|
||||
|
||||
self.json_mode = json_mode
|
||||
self._last_function_name = None # Track the last seen function name
|
||||
|
||||
def chunk_parser(self, chunk: dict) -> ModelResponseStream:
|
||||
try:
|
||||
translated_choices = []
|
||||
for choice in chunk["choices"]:
|
||||
tool_calls = choice["delta"].get("tool_calls")
|
||||
if tool_calls and self.json_mode:
|
||||
# 1. Check if the function name is set and == RESPONSE_FORMAT_TOOL_NAME
|
||||
# 2. If no function name, just args -> check last function name (saved via state variable)
|
||||
# 3. Convert args to json
|
||||
# 4. Convert json to message
|
||||
# 5. Set content to message.content
|
||||
# 6. Set tool_calls to None
|
||||
from litellm.constants import RESPONSE_FORMAT_TOOL_NAME
|
||||
from litellm.llms.base_llm.base_utils import (
|
||||
_convert_tool_response_to_message,
|
||||
)
|
||||
|
||||
# Check if this chunk has a function name
|
||||
function_name = tool_calls[0].get("function", {}).get("name")
|
||||
if function_name is not None:
|
||||
self._last_function_name = function_name
|
||||
|
||||
# If we have a saved function name that matches RESPONSE_FORMAT_TOOL_NAME
|
||||
# or this chunk has the matching function name
|
||||
if (
|
||||
self._last_function_name == RESPONSE_FORMAT_TOOL_NAME
|
||||
or function_name == RESPONSE_FORMAT_TOOL_NAME
|
||||
):
|
||||
# Convert tool calls to message format
|
||||
message = _convert_tool_response_to_message(tool_calls)
|
||||
if message is not None:
|
||||
if message.content == "{}": # empty json
|
||||
message.content = ""
|
||||
choice["delta"]["content"] = message.content
|
||||
choice["delta"]["tool_calls"] = None
|
||||
elif tool_calls:
|
||||
for _tc in tool_calls:
|
||||
if _tc.get("function", {}).get("arguments") == "{}":
|
||||
_tc["function"]["arguments"] = "" # avoid invalid json
|
||||
# extract the content str
|
||||
content_str = DatabricksConfig.extract_content_str(
|
||||
choice["delta"].get("content")
|
||||
)
|
||||
|
||||
# extract the reasoning content
|
||||
(
|
||||
reasoning_content,
|
||||
thinking_blocks,
|
||||
) = DatabricksConfig.extract_reasoning_content(
|
||||
choice["delta"]["content"]
|
||||
)
|
||||
|
||||
choice["delta"]["content"] = content_str
|
||||
choice["delta"]["reasoning_content"] = reasoning_content
|
||||
choice["delta"]["thinking_blocks"] = thinking_blocks
|
||||
translated_choices.append(choice)
|
||||
return ModelResponseStream(
|
||||
id=chunk["id"],
|
||||
object="chat.completion.chunk",
|
||||
created=chunk["created"],
|
||||
model=chunk["model"],
|
||||
choices=translated_choices,
|
||||
)
|
||||
except KeyError as e:
|
||||
raise DatabricksException(
|
||||
message=f"KeyError: {e}, Got unexpected response from Databricks: {chunk}",
|
||||
status_code=400,
|
||||
)
|
||||
except Exception as e:
|
||||
raise e
|
||||
@@ -0,0 +1,108 @@
|
||||
from typing import Literal, Optional, Tuple
|
||||
|
||||
from litellm.llms.base_llm.chat.transformation import BaseLLMException
|
||||
|
||||
|
||||
class DatabricksException(BaseLLMException):
|
||||
pass
|
||||
|
||||
|
||||
class DatabricksBase:
|
||||
def _get_api_base(self, api_base: Optional[str]) -> str:
|
||||
if api_base is None:
|
||||
try:
|
||||
from databricks.sdk import WorkspaceClient
|
||||
|
||||
databricks_client = WorkspaceClient()
|
||||
|
||||
api_base = (
|
||||
api_base or f"{databricks_client.config.host}/serving-endpoints"
|
||||
)
|
||||
|
||||
return api_base
|
||||
except ImportError:
|
||||
raise DatabricksException(
|
||||
status_code=400,
|
||||
message=(
|
||||
"Either set the DATABRICKS_API_BASE and DATABRICKS_API_KEY environment variables, "
|
||||
"or install the databricks-sdk Python library."
|
||||
),
|
||||
)
|
||||
return api_base
|
||||
|
||||
def _get_databricks_credentials(
|
||||
self, api_key: Optional[str], api_base: Optional[str], headers: Optional[dict]
|
||||
) -> Tuple[str, dict]:
|
||||
headers = headers or {"Content-Type": "application/json"}
|
||||
try:
|
||||
from databricks.sdk import WorkspaceClient
|
||||
|
||||
databricks_client = WorkspaceClient()
|
||||
|
||||
api_base = api_base or f"{databricks_client.config.host}/serving-endpoints"
|
||||
|
||||
if api_key is None:
|
||||
databricks_auth_headers: dict[
|
||||
str, str
|
||||
] = databricks_client.config.authenticate()
|
||||
headers = {**databricks_auth_headers, **headers}
|
||||
|
||||
return api_base, headers
|
||||
except ImportError:
|
||||
raise DatabricksException(
|
||||
status_code=400,
|
||||
message=(
|
||||
"If the Databricks base URL and API key are not set, the databricks-sdk "
|
||||
"Python library must be installed. Please install the databricks-sdk, set "
|
||||
"{LLM_PROVIDER}_API_BASE and {LLM_PROVIDER}_API_KEY environment variables, "
|
||||
"or provide the base URL and API key as arguments."
|
||||
),
|
||||
)
|
||||
|
||||
def databricks_validate_environment(
|
||||
self,
|
||||
api_key: Optional[str],
|
||||
api_base: Optional[str],
|
||||
endpoint_type: Literal["chat_completions", "embeddings"],
|
||||
custom_endpoint: Optional[bool],
|
||||
headers: Optional[dict],
|
||||
) -> Tuple[str, dict]:
|
||||
if api_key is None and not headers: # handle empty headers
|
||||
if custom_endpoint is True:
|
||||
raise DatabricksException(
|
||||
status_code=400,
|
||||
message="Missing API Key - A call is being made to LLM Provider but no key is set either in the environment variables ({LLM_PROVIDER}_API_KEY) or via params",
|
||||
)
|
||||
else:
|
||||
api_base, headers = self._get_databricks_credentials(
|
||||
api_base=api_base, api_key=api_key, headers=headers
|
||||
)
|
||||
|
||||
if api_base is None:
|
||||
if custom_endpoint:
|
||||
raise DatabricksException(
|
||||
status_code=400,
|
||||
message="Missing API Base - A call is being made to LLM Provider but no api base is set either in the environment variables ({LLM_PROVIDER}_API_KEY) or via params",
|
||||
)
|
||||
else:
|
||||
api_base, headers = self._get_databricks_credentials(
|
||||
api_base=api_base, api_key=api_key, headers=headers
|
||||
)
|
||||
|
||||
if headers is None:
|
||||
headers = {
|
||||
"Authorization": "Bearer {}".format(api_key),
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
else:
|
||||
if api_key is not None:
|
||||
headers.update({"Authorization": "Bearer {}".format(api_key)})
|
||||
|
||||
if api_key is not None:
|
||||
headers["Authorization"] = f"Bearer {api_key}"
|
||||
|
||||
if endpoint_type == "chat_completions" and custom_endpoint is not True:
|
||||
api_base = "{}/chat/completions".format(api_base)
|
||||
elif endpoint_type == "embeddings" and custom_endpoint is not True:
|
||||
api_base = "{}/embeddings".format(api_base)
|
||||
return api_base, headers
|
||||
@@ -0,0 +1,66 @@
|
||||
"""
|
||||
Helper util for handling databricks-specific cost calculation
|
||||
- e.g.: handling 'dbrx-instruct-*'
|
||||
"""
|
||||
|
||||
from typing import Tuple
|
||||
|
||||
from litellm.types.utils import Usage
|
||||
from litellm.utils import get_model_info
|
||||
|
||||
|
||||
def cost_per_token(model: str, usage: Usage) -> Tuple[float, float]:
|
||||
"""
|
||||
Calculates the cost per token for a given model, prompt tokens, and completion tokens.
|
||||
|
||||
Input:
|
||||
- model: str, the model name without provider prefix
|
||||
- usage: LiteLLM Usage block, containing anthropic caching information
|
||||
|
||||
Returns:
|
||||
Tuple[float, float] - prompt_cost_in_usd, completion_cost_in_usd
|
||||
"""
|
||||
base_model = model
|
||||
if model.startswith("databricks/dbrx-instruct") or model.startswith(
|
||||
"dbrx-instruct"
|
||||
):
|
||||
base_model = "databricks-dbrx-instruct"
|
||||
elif model.startswith("databricks/meta-llama-3.1-70b-instruct") or model.startswith(
|
||||
"meta-llama-3.1-70b-instruct"
|
||||
):
|
||||
base_model = "databricks-meta-llama-3-1-70b-instruct"
|
||||
elif model.startswith(
|
||||
"databricks/meta-llama-3.1-405b-instruct"
|
||||
) or model.startswith("meta-llama-3.1-405b-instruct"):
|
||||
base_model = "databricks-meta-llama-3-1-405b-instruct"
|
||||
elif model.startswith("databricks/mixtral-8x7b-instruct-v0.1") or model.startswith(
|
||||
"mixtral-8x7b-instruct-v0.1"
|
||||
):
|
||||
base_model = "databricks-mixtral-8x7b-instruct"
|
||||
elif model.startswith("databricks/mixtral-8x7b-instruct-v0.1") or model.startswith(
|
||||
"mixtral-8x7b-instruct-v0.1"
|
||||
):
|
||||
base_model = "databricks-mixtral-8x7b-instruct"
|
||||
elif model.startswith("databricks/bge-large-en") or model.startswith(
|
||||
"bge-large-en"
|
||||
):
|
||||
base_model = "databricks-bge-large-en"
|
||||
elif model.startswith("databricks/gte-large-en") or model.startswith(
|
||||
"gte-large-en"
|
||||
):
|
||||
base_model = "databricks-gte-large-en"
|
||||
elif model.startswith("databricks/llama-2-70b-chat") or model.startswith(
|
||||
"llama-2-70b-chat"
|
||||
):
|
||||
base_model = "databricks-llama-2-70b-chat"
|
||||
## GET MODEL INFO
|
||||
model_info = get_model_info(model=base_model, custom_llm_provider="databricks")
|
||||
|
||||
## CALCULATE INPUT COST
|
||||
|
||||
prompt_cost: float = usage["prompt_tokens"] * model_info["input_cost_per_token"]
|
||||
|
||||
## CALCULATE OUTPUT COST
|
||||
completion_cost = usage["completion_tokens"] * model_info["output_cost_per_token"]
|
||||
|
||||
return prompt_cost, completion_cost
|
||||
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,49 @@
|
||||
"""
|
||||
Calling logic for Databricks embeddings
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from litellm.utils import EmbeddingResponse
|
||||
|
||||
from ...openai_like.embedding.handler import OpenAILikeEmbeddingHandler
|
||||
from ..common_utils import DatabricksBase
|
||||
|
||||
|
||||
class DatabricksEmbeddingHandler(OpenAILikeEmbeddingHandler, DatabricksBase):
|
||||
def embedding(
|
||||
self,
|
||||
model: str,
|
||||
input: list,
|
||||
timeout: float,
|
||||
logging_obj,
|
||||
api_key: Optional[str],
|
||||
api_base: Optional[str],
|
||||
optional_params: dict,
|
||||
model_response: Optional[EmbeddingResponse] = None,
|
||||
client=None,
|
||||
aembedding=None,
|
||||
custom_endpoint: Optional[bool] = None,
|
||||
headers: Optional[dict] = None,
|
||||
) -> EmbeddingResponse:
|
||||
api_base, headers = self.databricks_validate_environment(
|
||||
api_base=api_base,
|
||||
api_key=api_key,
|
||||
endpoint_type="embeddings",
|
||||
custom_endpoint=custom_endpoint,
|
||||
headers=headers,
|
||||
)
|
||||
return super().embedding(
|
||||
model=model,
|
||||
input=input,
|
||||
timeout=timeout,
|
||||
logging_obj=logging_obj,
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
optional_params=optional_params,
|
||||
model_response=model_response,
|
||||
client=client,
|
||||
aembedding=aembedding,
|
||||
custom_endpoint=True,
|
||||
headers=headers,
|
||||
)
|
||||
@@ -0,0 +1,48 @@
|
||||
"""
|
||||
Translates from OpenAI's `/v1/embeddings` to Databricks' `/embeddings`
|
||||
"""
|
||||
|
||||
import types
|
||||
from typing import Optional
|
||||
|
||||
|
||||
class DatabricksEmbeddingConfig:
|
||||
"""
|
||||
Reference: https://learn.microsoft.com/en-us/azure/databricks/machine-learning/foundation-models/api-reference#--embedding-task
|
||||
"""
|
||||
|
||||
instruction: Optional[
|
||||
str
|
||||
] = None # An optional instruction to pass to the embedding model. BGE Authors recommend 'Represent this sentence for searching relevant passages:' for retrieval queries
|
||||
|
||||
def __init__(self, instruction: Optional[str] = None) -> None:
|
||||
locals_ = locals().copy()
|
||||
for key, value in locals_.items():
|
||||
if key != "self" and value is not None:
|
||||
setattr(self.__class__, key, value)
|
||||
|
||||
@classmethod
|
||||
def get_config(cls):
|
||||
return {
|
||||
k: v
|
||||
for k, v in cls.__dict__.items()
|
||||
if not k.startswith("__")
|
||||
and not isinstance(
|
||||
v,
|
||||
(
|
||||
types.FunctionType,
|
||||
types.BuiltinFunctionType,
|
||||
classmethod,
|
||||
staticmethod,
|
||||
),
|
||||
)
|
||||
and v is not None
|
||||
}
|
||||
|
||||
def get_supported_openai_params(
|
||||
self,
|
||||
): # no optional openai embedding params supported
|
||||
return []
|
||||
|
||||
def map_openai_params(self, non_default_params: dict, optional_params: dict):
|
||||
return optional_params
|
||||
@@ -0,0 +1,165 @@
|
||||
import json
|
||||
from typing import Optional
|
||||
|
||||
import litellm
|
||||
from litellm import verbose_logger
|
||||
from litellm.types.llms.openai import (
|
||||
ChatCompletionToolCallChunk,
|
||||
ChatCompletionToolCallFunctionChunk,
|
||||
ChatCompletionUsageBlock,
|
||||
)
|
||||
from litellm.types.utils import GenericStreamingChunk, Usage
|
||||
|
||||
|
||||
class ModelResponseIterator:
|
||||
def __init__(self, streaming_response, sync_stream: bool):
|
||||
self.streaming_response = streaming_response
|
||||
|
||||
def chunk_parser(self, chunk: dict) -> GenericStreamingChunk:
|
||||
try:
|
||||
processed_chunk = litellm.ModelResponseStream(**chunk)
|
||||
|
||||
text = ""
|
||||
tool_use: Optional[ChatCompletionToolCallChunk] = None
|
||||
is_finished = False
|
||||
finish_reason = ""
|
||||
usage: Optional[ChatCompletionUsageBlock] = None
|
||||
|
||||
if processed_chunk.choices[0].delta.content is not None: # type: ignore
|
||||
text = processed_chunk.choices[0].delta.content # type: ignore
|
||||
|
||||
if (
|
||||
processed_chunk.choices[0].delta.tool_calls is not None # type: ignore
|
||||
and len(processed_chunk.choices[0].delta.tool_calls) > 0 # type: ignore
|
||||
and processed_chunk.choices[0].delta.tool_calls[0].function is not None # type: ignore
|
||||
and processed_chunk.choices[0].delta.tool_calls[0].function.arguments # type: ignore
|
||||
is not None
|
||||
):
|
||||
tool_use = ChatCompletionToolCallChunk(
|
||||
id=processed_chunk.choices[0].delta.tool_calls[0].id, # type: ignore
|
||||
type="function",
|
||||
function=ChatCompletionToolCallFunctionChunk(
|
||||
name=processed_chunk.choices[0]
|
||||
.delta.tool_calls[0] # type: ignore
|
||||
.function.name,
|
||||
arguments=processed_chunk.choices[0]
|
||||
.delta.tool_calls[0] # type: ignore
|
||||
.function.arguments,
|
||||
),
|
||||
index=processed_chunk.choices[0].delta.tool_calls[0].index,
|
||||
)
|
||||
|
||||
if processed_chunk.choices[0].finish_reason is not None:
|
||||
is_finished = True
|
||||
finish_reason = processed_chunk.choices[0].finish_reason
|
||||
|
||||
usage_chunk: Optional[Usage] = getattr(processed_chunk, "usage", None)
|
||||
if usage_chunk is not None:
|
||||
usage = ChatCompletionUsageBlock(
|
||||
prompt_tokens=usage_chunk.prompt_tokens,
|
||||
completion_tokens=usage_chunk.completion_tokens,
|
||||
total_tokens=usage_chunk.total_tokens,
|
||||
)
|
||||
|
||||
return GenericStreamingChunk(
|
||||
text=text,
|
||||
tool_use=tool_use,
|
||||
is_finished=is_finished,
|
||||
finish_reason=finish_reason,
|
||||
usage=usage,
|
||||
index=0,
|
||||
)
|
||||
except json.JSONDecodeError:
|
||||
raise ValueError(f"Failed to decode JSON from chunk: {chunk}")
|
||||
|
||||
# Sync iterator
|
||||
def __iter__(self):
|
||||
self.response_iterator = self.streaming_response
|
||||
return self
|
||||
|
||||
def __next__(self):
|
||||
if not hasattr(self, "response_iterator"):
|
||||
self.response_iterator = self.streaming_response
|
||||
try:
|
||||
chunk = self.response_iterator.__next__()
|
||||
except StopIteration:
|
||||
raise StopIteration
|
||||
except ValueError as e:
|
||||
raise RuntimeError(f"Error receiving chunk from stream: {e}")
|
||||
|
||||
try:
|
||||
chunk = litellm.CustomStreamWrapper._strip_sse_data_from_chunk(chunk) or ""
|
||||
chunk = chunk.strip()
|
||||
if len(chunk) > 0:
|
||||
json_chunk = json.loads(chunk)
|
||||
return self.chunk_parser(chunk=json_chunk)
|
||||
else:
|
||||
return GenericStreamingChunk(
|
||||
text="",
|
||||
is_finished=False,
|
||||
finish_reason="",
|
||||
usage=None,
|
||||
index=0,
|
||||
tool_use=None,
|
||||
)
|
||||
except StopIteration:
|
||||
raise StopIteration
|
||||
except ValueError as e:
|
||||
verbose_logger.debug(
|
||||
f"Error parsing chunk: {e},\nReceived chunk: {chunk}. Defaulting to empty chunk here."
|
||||
)
|
||||
return GenericStreamingChunk(
|
||||
text="",
|
||||
is_finished=False,
|
||||
finish_reason="",
|
||||
usage=None,
|
||||
index=0,
|
||||
tool_use=None,
|
||||
)
|
||||
|
||||
# Async iterator
|
||||
def __aiter__(self):
|
||||
self.async_response_iterator = self.streaming_response.__aiter__()
|
||||
return self
|
||||
|
||||
async def __anext__(self):
|
||||
try:
|
||||
chunk = await self.async_response_iterator.__anext__()
|
||||
except StopAsyncIteration:
|
||||
raise StopAsyncIteration
|
||||
except ValueError as e:
|
||||
raise RuntimeError(f"Error receiving chunk from stream: {e}")
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Error receiving chunk from stream: {e}")
|
||||
|
||||
try:
|
||||
chunk = litellm.CustomStreamWrapper._strip_sse_data_from_chunk(chunk) or ""
|
||||
chunk = chunk.strip()
|
||||
if chunk == "[DONE]":
|
||||
raise StopAsyncIteration
|
||||
if len(chunk) > 0:
|
||||
json_chunk = json.loads(chunk)
|
||||
return self.chunk_parser(chunk=json_chunk)
|
||||
else:
|
||||
return GenericStreamingChunk(
|
||||
text="",
|
||||
is_finished=False,
|
||||
finish_reason="",
|
||||
usage=None,
|
||||
index=0,
|
||||
tool_use=None,
|
||||
)
|
||||
except StopAsyncIteration:
|
||||
raise StopAsyncIteration
|
||||
except ValueError as e:
|
||||
verbose_logger.debug(
|
||||
f"Error parsing chunk: {e},\nReceived chunk: {chunk}. Defaulting to empty chunk here."
|
||||
)
|
||||
return GenericStreamingChunk(
|
||||
text="",
|
||||
is_finished=False,
|
||||
finish_reason="",
|
||||
usage=None,
|
||||
index=0,
|
||||
tool_use=None,
|
||||
)
|
||||
Reference in New Issue
Block a user