structure saas with tools
This commit is contained in:
Binary file not shown.
@@ -0,0 +1,15 @@
|
||||
from typing import Optional, Union
|
||||
|
||||
import httpx
|
||||
|
||||
from litellm.llms.base_llm.chat.transformation import BaseLLMException
|
||||
|
||||
|
||||
class TritonError(BaseLLMException):
|
||||
def __init__(
|
||||
self,
|
||||
status_code: int,
|
||||
message: str,
|
||||
headers: Optional[Union[dict, httpx.Headers]] = None,
|
||||
) -> None:
|
||||
super().__init__(status_code=status_code, message=message, headers=headers)
|
||||
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,5 @@
|
||||
"""
|
||||
Triton Completion - uses `llm_http_handler.py` to make httpx requests
|
||||
|
||||
Request/Response transformation is handled in `transformation.py`
|
||||
"""
|
||||
@@ -0,0 +1,345 @@
|
||||
"""
|
||||
Translates from OpenAI's `/v1/chat/completions` endpoint to Triton's `/generate` endpoint.
|
||||
"""
|
||||
|
||||
import json
|
||||
from typing import Any, AsyncIterator, Dict, Iterator, List, Literal, Optional, Union
|
||||
|
||||
from httpx import Headers, Response
|
||||
|
||||
from litellm.constants import DEFAULT_MAX_TOKENS_FOR_TRITON
|
||||
from litellm.litellm_core_utils.prompt_templates.factory import prompt_factory
|
||||
from litellm.llms.base_llm.base_model_iterator import BaseModelResponseIterator
|
||||
from litellm.llms.base_llm.chat.transformation import (
|
||||
BaseConfig,
|
||||
BaseLLMException,
|
||||
LiteLLMLoggingObj,
|
||||
)
|
||||
from litellm.types.llms.openai import AllMessageValues
|
||||
from litellm.types.utils import (
|
||||
ChatCompletionToolCallChunk,
|
||||
ChatCompletionUsageBlock,
|
||||
Choices,
|
||||
GenericStreamingChunk,
|
||||
Message,
|
||||
ModelResponse,
|
||||
)
|
||||
|
||||
from ..common_utils import TritonError
|
||||
|
||||
|
||||
class TritonConfig(BaseConfig):
|
||||
"""
|
||||
Base class for Triton configurations.
|
||||
|
||||
Handles routing between /infer and /generate triton completion llms
|
||||
"""
|
||||
|
||||
def get_error_class(
|
||||
self, error_message: str, status_code: int, headers: Union[Dict, Headers]
|
||||
) -> BaseLLMException:
|
||||
return TritonError(
|
||||
status_code=status_code, message=error_message, headers=headers
|
||||
)
|
||||
|
||||
def validate_environment(
|
||||
self,
|
||||
headers: Dict,
|
||||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: Dict,
|
||||
litellm_params: dict,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
) -> Dict:
|
||||
return {"Content-Type": "application/json"}
|
||||
|
||||
def get_supported_openai_params(self, model: str) -> List:
|
||||
return ["max_tokens", "max_completion_tokens"]
|
||||
|
||||
def map_openai_params(
|
||||
self,
|
||||
non_default_params: Dict,
|
||||
optional_params: Dict,
|
||||
model: str,
|
||||
drop_params: bool,
|
||||
) -> Dict:
|
||||
for param, value in non_default_params.items():
|
||||
if param == "max_tokens" or param == "max_completion_tokens":
|
||||
optional_params[param] = value
|
||||
return optional_params
|
||||
|
||||
def get_complete_url(
|
||||
self,
|
||||
api_base: Optional[str],
|
||||
api_key: Optional[str],
|
||||
model: str,
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
stream: Optional[bool] = None,
|
||||
) -> str:
|
||||
if api_base is None:
|
||||
raise ValueError("api_base is required")
|
||||
llm_type = self._get_triton_llm_type(api_base)
|
||||
if llm_type == "generate" and stream:
|
||||
return api_base + "_stream"
|
||||
return api_base
|
||||
|
||||
def transform_response(
|
||||
self,
|
||||
model: str,
|
||||
raw_response: Response,
|
||||
model_response: ModelResponse,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
request_data: Dict,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: Dict,
|
||||
litellm_params: Dict,
|
||||
encoding: Any,
|
||||
api_key: Optional[str] = None,
|
||||
json_mode: Optional[bool] = None,
|
||||
) -> ModelResponse:
|
||||
api_base = litellm_params.get("api_base", "")
|
||||
llm_type = self._get_triton_llm_type(api_base)
|
||||
if llm_type == "generate":
|
||||
return TritonGenerateConfig().transform_response(
|
||||
model=model,
|
||||
raw_response=raw_response,
|
||||
model_response=model_response,
|
||||
logging_obj=logging_obj,
|
||||
request_data=request_data,
|
||||
messages=messages,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
encoding=encoding,
|
||||
api_key=api_key,
|
||||
json_mode=json_mode,
|
||||
)
|
||||
elif llm_type == "infer":
|
||||
return TritonInferConfig().transform_response(
|
||||
model=model,
|
||||
raw_response=raw_response,
|
||||
model_response=model_response,
|
||||
logging_obj=logging_obj,
|
||||
request_data=request_data,
|
||||
messages=messages,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
encoding=encoding,
|
||||
api_key=api_key,
|
||||
json_mode=json_mode,
|
||||
)
|
||||
return model_response
|
||||
|
||||
def transform_request(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
headers: dict,
|
||||
) -> dict:
|
||||
api_base = litellm_params.get("api_base", "")
|
||||
llm_type = self._get_triton_llm_type(api_base)
|
||||
if llm_type == "generate":
|
||||
return TritonGenerateConfig().transform_request(
|
||||
model=model,
|
||||
messages=messages,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
headers=headers,
|
||||
)
|
||||
elif llm_type == "infer":
|
||||
return TritonInferConfig().transform_request(
|
||||
model=model,
|
||||
messages=messages,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
headers=headers,
|
||||
)
|
||||
return {}
|
||||
|
||||
def _get_triton_llm_type(self, api_base: str) -> Literal["generate", "infer"]:
|
||||
if api_base.endswith("/generate"):
|
||||
return "generate"
|
||||
elif api_base.endswith("/infer"):
|
||||
return "infer"
|
||||
else:
|
||||
raise ValueError(f"Invalid Triton API base: {api_base}")
|
||||
|
||||
def get_model_response_iterator(
|
||||
self,
|
||||
streaming_response: Union[Iterator[str], AsyncIterator[str], ModelResponse],
|
||||
sync_stream: bool,
|
||||
json_mode: Optional[bool] = False,
|
||||
) -> Any:
|
||||
return TritonResponseIterator(
|
||||
streaming_response=streaming_response,
|
||||
sync_stream=sync_stream,
|
||||
json_mode=json_mode,
|
||||
)
|
||||
|
||||
|
||||
class TritonGenerateConfig(TritonConfig):
|
||||
"""
|
||||
Transformations for triton /generate endpoint (This is a trtllm model)
|
||||
"""
|
||||
|
||||
def transform_request(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
headers: dict,
|
||||
) -> dict:
|
||||
inference_params = optional_params.copy()
|
||||
stream = inference_params.pop("stream", False)
|
||||
data_for_triton: Dict[str, Any] = {
|
||||
"text_input": prompt_factory(model=model, messages=messages),
|
||||
"parameters": {
|
||||
"max_tokens": int(
|
||||
optional_params.get("max_tokens", DEFAULT_MAX_TOKENS_FOR_TRITON)
|
||||
),
|
||||
},
|
||||
"stream": bool(stream),
|
||||
}
|
||||
data_for_triton["parameters"].update(inference_params)
|
||||
return data_for_triton
|
||||
|
||||
def transform_response(
|
||||
self,
|
||||
model: str,
|
||||
raw_response: Response,
|
||||
model_response: ModelResponse,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
request_data: Dict,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: Dict,
|
||||
litellm_params: Dict,
|
||||
encoding: Any,
|
||||
api_key: Optional[str] = None,
|
||||
json_mode: Optional[bool] = None,
|
||||
) -> ModelResponse:
|
||||
try:
|
||||
raw_response_json = raw_response.json()
|
||||
except Exception:
|
||||
raise TritonError(
|
||||
message=raw_response.text, status_code=raw_response.status_code
|
||||
)
|
||||
model_response.choices = [
|
||||
Choices(index=0, message=Message(content=raw_response_json["text_output"]))
|
||||
]
|
||||
|
||||
return model_response
|
||||
|
||||
|
||||
class TritonInferConfig(TritonConfig):
|
||||
"""
|
||||
Transformations for triton /infer endpoint (his is an infer model with a custom model on triton)
|
||||
"""
|
||||
|
||||
def transform_request(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
headers: dict,
|
||||
) -> dict:
|
||||
text_input = messages[0].get("content", "")
|
||||
data_for_triton = {
|
||||
"inputs": [
|
||||
{
|
||||
"name": "text_input",
|
||||
"shape": [1],
|
||||
"datatype": "BYTES",
|
||||
"data": [text_input],
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
for k, v in optional_params.items():
|
||||
if not (k == "stream" or k == "max_retries"):
|
||||
datatype = "INT32" if isinstance(v, int) else "BYTES"
|
||||
datatype = "FP32" if isinstance(v, float) else datatype
|
||||
data_for_triton["inputs"].append(
|
||||
{"name": k, "shape": [1], "datatype": datatype, "data": [v]}
|
||||
)
|
||||
|
||||
if "max_tokens" not in optional_params:
|
||||
data_for_triton["inputs"].append(
|
||||
{
|
||||
"name": "max_tokens",
|
||||
"shape": [1],
|
||||
"datatype": "INT32",
|
||||
"data": [20],
|
||||
}
|
||||
)
|
||||
return data_for_triton
|
||||
|
||||
def transform_response(
|
||||
self,
|
||||
model: str,
|
||||
raw_response: Response,
|
||||
model_response: ModelResponse,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
request_data: Dict,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: Dict,
|
||||
litellm_params: Dict,
|
||||
encoding: Any,
|
||||
api_key: Optional[str] = None,
|
||||
json_mode: Optional[bool] = None,
|
||||
) -> ModelResponse:
|
||||
try:
|
||||
raw_response_json = raw_response.json()
|
||||
except Exception:
|
||||
raise TritonError(
|
||||
message=raw_response.text, status_code=raw_response.status_code
|
||||
)
|
||||
|
||||
_triton_response_data = raw_response_json["outputs"][0]["data"]
|
||||
triton_response_data: Optional[str] = None
|
||||
if isinstance(_triton_response_data, list):
|
||||
triton_response_data = "".join(_triton_response_data)
|
||||
else:
|
||||
triton_response_data = _triton_response_data
|
||||
|
||||
model_response.choices = [
|
||||
Choices(
|
||||
index=0,
|
||||
message=Message(content=triton_response_data),
|
||||
)
|
||||
]
|
||||
|
||||
return model_response
|
||||
|
||||
|
||||
class TritonResponseIterator(BaseModelResponseIterator):
|
||||
def chunk_parser(self, chunk: dict) -> GenericStreamingChunk:
|
||||
try:
|
||||
text = ""
|
||||
tool_use: Optional[ChatCompletionToolCallChunk] = None
|
||||
is_finished = False
|
||||
finish_reason = ""
|
||||
usage: Optional[ChatCompletionUsageBlock] = None
|
||||
provider_specific_fields = None
|
||||
index = int(chunk.get("index", 0))
|
||||
|
||||
# set values
|
||||
text = chunk.get("text_output", "")
|
||||
finish_reason = chunk.get("stop_reason", "")
|
||||
is_finished = chunk.get("is_finished", False)
|
||||
|
||||
return GenericStreamingChunk(
|
||||
text=text,
|
||||
tool_use=tool_use,
|
||||
is_finished=is_finished,
|
||||
finish_reason=finish_reason,
|
||||
usage=usage,
|
||||
index=index,
|
||||
provider_specific_fields=provider_specific_fields,
|
||||
)
|
||||
except json.JSONDecodeError:
|
||||
raise ValueError(f"Failed to decode JSON from chunk: {chunk}")
|
||||
Binary file not shown.
@@ -0,0 +1,124 @@
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import httpx
|
||||
|
||||
from litellm.llms.base_llm.chat.transformation import AllMessageValues, BaseLLMException
|
||||
from litellm.llms.base_llm.embedding.transformation import (
|
||||
BaseEmbeddingConfig,
|
||||
LiteLLMLoggingObj,
|
||||
)
|
||||
from litellm.types.llms.openai import AllEmbeddingInputValues
|
||||
from litellm.types.utils import EmbeddingResponse
|
||||
|
||||
from ..common_utils import TritonError
|
||||
|
||||
|
||||
class TritonEmbeddingConfig(BaseEmbeddingConfig):
|
||||
"""
|
||||
Transformations for triton /embeddings endpoint (This is a trtllm model)
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
pass
|
||||
|
||||
def get_supported_openai_params(self, model: str) -> list:
|
||||
return []
|
||||
|
||||
def map_openai_params(
|
||||
self,
|
||||
non_default_params: dict,
|
||||
optional_params: dict,
|
||||
model: str,
|
||||
drop_params: bool,
|
||||
) -> dict:
|
||||
"""
|
||||
Map OpenAI params to Triton Embedding params
|
||||
"""
|
||||
return optional_params
|
||||
|
||||
def validate_environment(
|
||||
self,
|
||||
headers: dict,
|
||||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
) -> dict:
|
||||
return {}
|
||||
|
||||
def transform_embedding_request(
|
||||
self,
|
||||
model: str,
|
||||
input: AllEmbeddingInputValues,
|
||||
optional_params: dict,
|
||||
headers: dict,
|
||||
) -> dict:
|
||||
return {
|
||||
"inputs": [
|
||||
{
|
||||
"name": "input_text",
|
||||
"shape": [len(input)],
|
||||
"datatype": "BYTES",
|
||||
"data": input,
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
def transform_embedding_response(
|
||||
self,
|
||||
model: str,
|
||||
raw_response: httpx.Response,
|
||||
model_response: EmbeddingResponse,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
api_key: Optional[str] = None,
|
||||
request_data: dict = {},
|
||||
optional_params: dict = {},
|
||||
litellm_params: dict = {},
|
||||
) -> EmbeddingResponse:
|
||||
try:
|
||||
raw_response_json = raw_response.json()
|
||||
except Exception:
|
||||
raise TritonError(
|
||||
message=raw_response.text, status_code=raw_response.status_code
|
||||
)
|
||||
|
||||
_embedding_output = []
|
||||
|
||||
_outputs = raw_response_json["outputs"]
|
||||
for output in _outputs:
|
||||
_shape = output["shape"]
|
||||
_data = output["data"]
|
||||
_split_output_data = self.split_embedding_by_shape(_data, _shape)
|
||||
|
||||
for idx, embedding in enumerate(_split_output_data):
|
||||
_embedding_output.append(
|
||||
{
|
||||
"object": "embedding",
|
||||
"index": idx,
|
||||
"embedding": embedding,
|
||||
}
|
||||
)
|
||||
|
||||
model_response.model = raw_response_json.get("model_name", "None")
|
||||
model_response.data = _embedding_output
|
||||
return model_response
|
||||
|
||||
def get_error_class(
|
||||
self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers]
|
||||
) -> BaseLLMException:
|
||||
return TritonError(
|
||||
message=error_message, status_code=status_code, headers=headers
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def split_embedding_by_shape(
|
||||
data: List[float], shape: List[int]
|
||||
) -> List[List[float]]:
|
||||
if len(shape) != 2:
|
||||
raise ValueError("Shape must be of length 2.")
|
||||
embedding_size = shape[1]
|
||||
return [
|
||||
data[i * embedding_size : (i + 1) * embedding_size] for i in range(shape[0])
|
||||
]
|
||||
Reference in New Issue
Block a user