structure saas with tools
This commit is contained in:
Binary file not shown.
@@ -0,0 +1,124 @@
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import httpx
|
||||
|
||||
from litellm.llms.base_llm.chat.transformation import AllMessageValues, BaseLLMException
|
||||
from litellm.llms.base_llm.embedding.transformation import (
|
||||
BaseEmbeddingConfig,
|
||||
LiteLLMLoggingObj,
|
||||
)
|
||||
from litellm.types.llms.openai import AllEmbeddingInputValues
|
||||
from litellm.types.utils import EmbeddingResponse
|
||||
|
||||
from ..common_utils import TritonError
|
||||
|
||||
|
||||
class TritonEmbeddingConfig(BaseEmbeddingConfig):
|
||||
"""
|
||||
Transformations for triton /embeddings endpoint (This is a trtllm model)
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
pass
|
||||
|
||||
def get_supported_openai_params(self, model: str) -> list:
|
||||
return []
|
||||
|
||||
def map_openai_params(
|
||||
self,
|
||||
non_default_params: dict,
|
||||
optional_params: dict,
|
||||
model: str,
|
||||
drop_params: bool,
|
||||
) -> dict:
|
||||
"""
|
||||
Map OpenAI params to Triton Embedding params
|
||||
"""
|
||||
return optional_params
|
||||
|
||||
def validate_environment(
|
||||
self,
|
||||
headers: dict,
|
||||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
) -> dict:
|
||||
return {}
|
||||
|
||||
def transform_embedding_request(
|
||||
self,
|
||||
model: str,
|
||||
input: AllEmbeddingInputValues,
|
||||
optional_params: dict,
|
||||
headers: dict,
|
||||
) -> dict:
|
||||
return {
|
||||
"inputs": [
|
||||
{
|
||||
"name": "input_text",
|
||||
"shape": [len(input)],
|
||||
"datatype": "BYTES",
|
||||
"data": input,
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
def transform_embedding_response(
|
||||
self,
|
||||
model: str,
|
||||
raw_response: httpx.Response,
|
||||
model_response: EmbeddingResponse,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
api_key: Optional[str] = None,
|
||||
request_data: dict = {},
|
||||
optional_params: dict = {},
|
||||
litellm_params: dict = {},
|
||||
) -> EmbeddingResponse:
|
||||
try:
|
||||
raw_response_json = raw_response.json()
|
||||
except Exception:
|
||||
raise TritonError(
|
||||
message=raw_response.text, status_code=raw_response.status_code
|
||||
)
|
||||
|
||||
_embedding_output = []
|
||||
|
||||
_outputs = raw_response_json["outputs"]
|
||||
for output in _outputs:
|
||||
_shape = output["shape"]
|
||||
_data = output["data"]
|
||||
_split_output_data = self.split_embedding_by_shape(_data, _shape)
|
||||
|
||||
for idx, embedding in enumerate(_split_output_data):
|
||||
_embedding_output.append(
|
||||
{
|
||||
"object": "embedding",
|
||||
"index": idx,
|
||||
"embedding": embedding,
|
||||
}
|
||||
)
|
||||
|
||||
model_response.model = raw_response_json.get("model_name", "None")
|
||||
model_response.data = _embedding_output
|
||||
return model_response
|
||||
|
||||
def get_error_class(
|
||||
self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers]
|
||||
) -> BaseLLMException:
|
||||
return TritonError(
|
||||
message=error_message, status_code=status_code, headers=headers
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def split_embedding_by_shape(
|
||||
data: List[float], shape: List[int]
|
||||
) -> List[List[float]]:
|
||||
if len(shape) != 2:
|
||||
raise ValueError("Shape must be of length 2.")
|
||||
embedding_size = shape[1]
|
||||
return [
|
||||
data[i * embedding_size : (i + 1) * embedding_size] for i in range(shape[0])
|
||||
]
|
||||
Reference in New Issue
Block a user