structure saas with tools
This commit is contained in:
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,307 @@
|
||||
import asyncio
|
||||
import json
|
||||
import time
|
||||
from typing import Callable, List, Union
|
||||
|
||||
import litellm
|
||||
from litellm.constants import REPLICATE_POLLING_DELAY_SECONDS
|
||||
from litellm.llms.custom_httpx.http_handler import (
|
||||
AsyncHTTPHandler,
|
||||
HTTPHandler,
|
||||
_get_httpx_client,
|
||||
get_async_httpx_client,
|
||||
)
|
||||
from litellm.types.llms.openai import AllMessageValues
|
||||
from litellm.utils import CustomStreamWrapper, ModelResponse
|
||||
|
||||
from ..common_utils import ReplicateError
|
||||
from .transformation import ReplicateConfig
|
||||
|
||||
replicate_config = ReplicateConfig()
|
||||
|
||||
|
||||
# Function to handle prediction response (streaming)
|
||||
def handle_prediction_response_streaming(
|
||||
prediction_url, api_token, print_verbose, headers: dict, http_client: HTTPHandler
|
||||
):
|
||||
previous_output = ""
|
||||
output_string = ""
|
||||
|
||||
status = ""
|
||||
while True and (status not in ["succeeded", "failed", "canceled"]):
|
||||
time.sleep(
|
||||
REPLICATE_POLLING_DELAY_SECONDS
|
||||
) # prevent being rate limited by replicate
|
||||
print_verbose(f"replicate: polling endpoint: {prediction_url}")
|
||||
response = http_client.get(prediction_url, headers=headers)
|
||||
if response.status_code == 200:
|
||||
response_data = response.json()
|
||||
status = response_data["status"]
|
||||
if "output" in response_data:
|
||||
try:
|
||||
output_string = "".join(response_data["output"])
|
||||
except Exception:
|
||||
raise ReplicateError(
|
||||
status_code=422,
|
||||
message="Unable to parse response. Got={}".format(
|
||||
response_data["output"]
|
||||
),
|
||||
headers=response.headers,
|
||||
)
|
||||
new_output = output_string[len(previous_output) :]
|
||||
print_verbose(f"New chunk: {new_output}")
|
||||
yield {"output": new_output, "status": status}
|
||||
previous_output = output_string
|
||||
status = response_data["status"]
|
||||
if status == "failed":
|
||||
replicate_error = response_data.get("error", "")
|
||||
raise ReplicateError(
|
||||
status_code=400,
|
||||
message=f"Error: {replicate_error}",
|
||||
headers=response.headers,
|
||||
)
|
||||
else:
|
||||
# this can fail temporarily but it does not mean the replicate request failed, replicate request fails when status=="failed"
|
||||
print_verbose(
|
||||
f"Replicate: Failed to fetch prediction status and output.{response.status_code}{response.text}"
|
||||
)
|
||||
|
||||
|
||||
# Function to handle prediction response (streaming)
|
||||
async def async_handle_prediction_response_streaming(
|
||||
prediction_url,
|
||||
api_token,
|
||||
print_verbose,
|
||||
headers: dict,
|
||||
http_client: AsyncHTTPHandler,
|
||||
):
|
||||
previous_output = ""
|
||||
output_string = ""
|
||||
|
||||
status = ""
|
||||
while True and (status not in ["succeeded", "failed", "canceled"]):
|
||||
await asyncio.sleep(
|
||||
REPLICATE_POLLING_DELAY_SECONDS
|
||||
) # prevent being rate limited by replicate
|
||||
print_verbose(f"replicate: polling endpoint: {prediction_url}")
|
||||
response = await http_client.get(prediction_url, headers=headers)
|
||||
if response.status_code == 200:
|
||||
response_data = response.json()
|
||||
status = response_data["status"]
|
||||
if "output" in response_data:
|
||||
try:
|
||||
output_string = "".join(response_data["output"])
|
||||
except Exception:
|
||||
raise ReplicateError(
|
||||
status_code=422,
|
||||
message="Unable to parse response. Got={}".format(
|
||||
response_data["output"]
|
||||
),
|
||||
headers=response.headers,
|
||||
)
|
||||
new_output = output_string[len(previous_output) :]
|
||||
print_verbose(f"New chunk: {new_output}")
|
||||
yield {"output": new_output, "status": status}
|
||||
previous_output = output_string
|
||||
status = response_data["status"]
|
||||
if status == "failed":
|
||||
replicate_error = response_data.get("error", "")
|
||||
raise ReplicateError(
|
||||
status_code=400,
|
||||
message=f"Error: {replicate_error}",
|
||||
headers=response.headers,
|
||||
)
|
||||
else:
|
||||
# this can fail temporarily but it does not mean the replicate request failed, replicate request fails when status=="failed"
|
||||
print_verbose(
|
||||
f"Replicate: Failed to fetch prediction status and output.{response.status_code}{response.text}"
|
||||
)
|
||||
|
||||
|
||||
# Main function for prediction completion
|
||||
def completion(
|
||||
model: str,
|
||||
messages: list,
|
||||
api_base: str,
|
||||
model_response: ModelResponse,
|
||||
print_verbose: Callable,
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
logging_obj,
|
||||
api_key,
|
||||
encoding,
|
||||
custom_prompt_dict={},
|
||||
logger_fn=None,
|
||||
acompletion=None,
|
||||
headers={},
|
||||
) -> Union[ModelResponse, CustomStreamWrapper]:
|
||||
headers = replicate_config.validate_environment(
|
||||
api_key=api_key,
|
||||
headers=headers,
|
||||
model=model,
|
||||
messages=messages,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
)
|
||||
# Start a prediction and get the prediction URL
|
||||
version_id = replicate_config.model_to_version_id(model)
|
||||
input_data = replicate_config.transform_request(
|
||||
model=model,
|
||||
messages=messages,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
if acompletion is not None and acompletion is True:
|
||||
return async_completion(
|
||||
model_response=model_response,
|
||||
model=model,
|
||||
encoding=encoding,
|
||||
messages=messages,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
version_id=version_id,
|
||||
input_data=input_data,
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
logging_obj=logging_obj,
|
||||
print_verbose=print_verbose,
|
||||
headers=headers,
|
||||
) # type: ignore
|
||||
## COMPLETION CALL
|
||||
model_response.created = int(
|
||||
time.time()
|
||||
) # for pricing this must remain right before calling api
|
||||
|
||||
prediction_url = replicate_config.get_complete_url(
|
||||
api_base=api_base,
|
||||
api_key=api_key,
|
||||
model=model,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
)
|
||||
|
||||
## COMPLETION CALL
|
||||
httpx_client = _get_httpx_client(
|
||||
params={"timeout": 600.0},
|
||||
)
|
||||
response = httpx_client.post(
|
||||
url=prediction_url,
|
||||
headers=headers,
|
||||
data=json.dumps(input_data),
|
||||
)
|
||||
|
||||
prediction_url = replicate_config.get_prediction_url(response)
|
||||
|
||||
# Handle the prediction response (streaming or non-streaming)
|
||||
if "stream" in optional_params and optional_params["stream"] is True:
|
||||
print_verbose("streaming request")
|
||||
_response = handle_prediction_response_streaming(
|
||||
prediction_url,
|
||||
api_key,
|
||||
print_verbose,
|
||||
headers=headers,
|
||||
http_client=httpx_client,
|
||||
)
|
||||
return CustomStreamWrapper(_response, model, logging_obj=logging_obj, custom_llm_provider="replicate") # type: ignore
|
||||
else:
|
||||
for retry in range(litellm.DEFAULT_REPLICATE_POLLING_RETRIES):
|
||||
time.sleep(
|
||||
litellm.DEFAULT_REPLICATE_POLLING_DELAY_SECONDS + 2 * retry
|
||||
) # wait to allow response to be generated by replicate - else partial output is generated with status=="processing"
|
||||
response = httpx_client.get(url=prediction_url, headers=headers)
|
||||
if (
|
||||
response.status_code == 200
|
||||
and response.json().get("status") == "processing"
|
||||
):
|
||||
continue
|
||||
return litellm.ReplicateConfig().transform_response(
|
||||
model=model,
|
||||
raw_response=response,
|
||||
model_response=model_response,
|
||||
logging_obj=logging_obj,
|
||||
api_key=api_key,
|
||||
request_data=input_data,
|
||||
messages=messages,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
encoding=encoding,
|
||||
)
|
||||
|
||||
raise ReplicateError(
|
||||
status_code=500,
|
||||
message="No response received from Replicate API after max retries",
|
||||
headers=None,
|
||||
)
|
||||
|
||||
|
||||
async def async_completion(
|
||||
model_response: ModelResponse,
|
||||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
encoding,
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
version_id,
|
||||
input_data,
|
||||
api_key,
|
||||
api_base,
|
||||
logging_obj,
|
||||
print_verbose,
|
||||
headers: dict,
|
||||
) -> Union[ModelResponse, CustomStreamWrapper]:
|
||||
prediction_url = replicate_config.get_complete_url(
|
||||
api_base=api_base,
|
||||
api_key=api_key,
|
||||
model=model,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
)
|
||||
async_handler = get_async_httpx_client(
|
||||
llm_provider=litellm.LlmProviders.REPLICATE,
|
||||
params={"timeout": 600.0},
|
||||
)
|
||||
response = await async_handler.post(
|
||||
url=prediction_url, headers=headers, data=json.dumps(input_data)
|
||||
)
|
||||
prediction_url = replicate_config.get_prediction_url(response)
|
||||
|
||||
if "stream" in optional_params and optional_params["stream"] is True:
|
||||
_response = async_handle_prediction_response_streaming(
|
||||
prediction_url,
|
||||
api_key,
|
||||
print_verbose,
|
||||
headers=headers,
|
||||
http_client=async_handler,
|
||||
)
|
||||
return CustomStreamWrapper(_response, model, logging_obj=logging_obj, custom_llm_provider="replicate") # type: ignore
|
||||
|
||||
for retry in range(litellm.DEFAULT_REPLICATE_POLLING_RETRIES):
|
||||
await asyncio.sleep(
|
||||
litellm.DEFAULT_REPLICATE_POLLING_DELAY_SECONDS + 2 * retry
|
||||
) # wait to allow response to be generated by replicate - else partial output is generated with status=="processing"
|
||||
response = await async_handler.get(url=prediction_url, headers=headers)
|
||||
if (
|
||||
response.status_code == 200
|
||||
and response.json().get("status") == "processing"
|
||||
):
|
||||
continue
|
||||
return litellm.ReplicateConfig().transform_response(
|
||||
model=model,
|
||||
raw_response=response,
|
||||
model_response=model_response,
|
||||
logging_obj=logging_obj,
|
||||
api_key=api_key,
|
||||
request_data=input_data,
|
||||
messages=messages,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
encoding=encoding,
|
||||
)
|
||||
# Add a fallback return if no response is received after max retries
|
||||
raise ReplicateError(
|
||||
status_code=500,
|
||||
message="No response received from Replicate API after max retries",
|
||||
headers=None,
|
||||
)
|
||||
@@ -0,0 +1,323 @@
|
||||
from typing import TYPE_CHECKING, Any, List, Optional, Union
|
||||
|
||||
import httpx
|
||||
|
||||
import litellm
|
||||
from litellm.constants import REPLICATE_MODEL_NAME_WITH_ID_LENGTH
|
||||
from litellm.litellm_core_utils.prompt_templates.common_utils import (
|
||||
convert_content_list_to_str,
|
||||
)
|
||||
from litellm.litellm_core_utils.prompt_templates.factory import (
|
||||
custom_prompt,
|
||||
prompt_factory,
|
||||
)
|
||||
from litellm.llms.base_llm.chat.transformation import BaseConfig, BaseLLMException
|
||||
from litellm.types.llms.openai import AllMessageValues
|
||||
from litellm.types.utils import ModelResponse, Usage
|
||||
from litellm.utils import token_counter
|
||||
|
||||
from ..common_utils import ReplicateError
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
|
||||
|
||||
LoggingClass = LiteLLMLoggingObj
|
||||
else:
|
||||
LoggingClass = Any
|
||||
|
||||
|
||||
class ReplicateConfig(BaseConfig):
|
||||
"""
|
||||
Reference: https://replicate.com/meta/llama-2-70b-chat/api
|
||||
- `prompt` (string): The prompt to send to the model.
|
||||
|
||||
- `system_prompt` (string): The system prompt to send to the model. This is prepended to the prompt and helps guide system behavior. Default value: `You are a helpful assistant`.
|
||||
|
||||
- `max_new_tokens` (integer): Maximum number of tokens to generate. Typically, a word is made up of 2-3 tokens. Default value: `128`.
|
||||
|
||||
- `min_new_tokens` (integer): Minimum number of tokens to generate. To disable, set to `-1`. A word is usually 2-3 tokens. Default value: `-1`.
|
||||
|
||||
- `temperature` (number): Adjusts the randomness of outputs. Values greater than 1 increase randomness, 0 is deterministic, and 0.75 is a reasonable starting value. Default value: `0.75`.
|
||||
|
||||
- `top_p` (number): During text decoding, it samples from the top `p` percentage of most likely tokens. Reduce this to ignore less probable tokens. Default value: `0.9`.
|
||||
|
||||
- `top_k` (integer): During text decoding, samples from the top `k` most likely tokens. Reduce this to ignore less probable tokens. Default value: `50`.
|
||||
|
||||
- `stop_sequences` (string): A comma-separated list of sequences to stop generation at. For example, inputting '<end>,<stop>' will cease generation at the first occurrence of either 'end' or '<stop>'.
|
||||
|
||||
- `seed` (integer): This is the seed for the random generator. Leave it blank to randomize the seed.
|
||||
|
||||
- `debug` (boolean): If set to `True`, it provides debugging output in logs.
|
||||
|
||||
Please note that Replicate's mapping of these parameters can be inconsistent across different models, indicating that not all of these parameters may be available for use with all models.
|
||||
"""
|
||||
|
||||
system_prompt: Optional[str] = None
|
||||
max_new_tokens: Optional[int] = None
|
||||
min_new_tokens: Optional[int] = None
|
||||
temperature: Optional[int] = None
|
||||
top_p: Optional[int] = None
|
||||
top_k: Optional[int] = None
|
||||
stop_sequences: Optional[str] = None
|
||||
seed: Optional[int] = None
|
||||
debug: Optional[bool] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
system_prompt: Optional[str] = None,
|
||||
max_new_tokens: Optional[int] = None,
|
||||
min_new_tokens: Optional[int] = None,
|
||||
temperature: Optional[int] = None,
|
||||
top_p: Optional[int] = None,
|
||||
top_k: Optional[int] = None,
|
||||
stop_sequences: Optional[str] = None,
|
||||
seed: Optional[int] = None,
|
||||
debug: Optional[bool] = None,
|
||||
) -> None:
|
||||
locals_ = locals().copy()
|
||||
for key, value in locals_.items():
|
||||
if key != "self" and value is not None:
|
||||
setattr(self.__class__, key, value)
|
||||
|
||||
@classmethod
|
||||
def get_config(cls):
|
||||
return super().get_config()
|
||||
|
||||
def get_supported_openai_params(self, model: str) -> list:
|
||||
return [
|
||||
"stream",
|
||||
"temperature",
|
||||
"max_tokens",
|
||||
"top_p",
|
||||
"stop",
|
||||
"seed",
|
||||
"tools",
|
||||
"tool_choice",
|
||||
"functions",
|
||||
"function_call",
|
||||
]
|
||||
|
||||
def map_openai_params(
|
||||
self,
|
||||
non_default_params: dict,
|
||||
optional_params: dict,
|
||||
model: str,
|
||||
drop_params: bool,
|
||||
) -> dict:
|
||||
for param, value in non_default_params.items():
|
||||
if param == "stream":
|
||||
optional_params["stream"] = value
|
||||
if param == "max_tokens":
|
||||
if "vicuna" in model or "flan" in model:
|
||||
optional_params["max_length"] = value
|
||||
elif "meta/codellama-13b" in model:
|
||||
optional_params["max_tokens"] = value
|
||||
else:
|
||||
optional_params["max_new_tokens"] = value
|
||||
if param == "temperature":
|
||||
optional_params["temperature"] = value
|
||||
if param == "top_p":
|
||||
optional_params["top_p"] = value
|
||||
if param == "stop":
|
||||
optional_params["stop_sequences"] = value
|
||||
|
||||
return optional_params
|
||||
|
||||
# Function to extract version ID from model string
|
||||
def model_to_version_id(self, model: str) -> str:
|
||||
if ":" in model:
|
||||
split_model = model.split(":")
|
||||
return split_model[1]
|
||||
return model
|
||||
|
||||
def get_error_class(
|
||||
self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers]
|
||||
) -> BaseLLMException:
|
||||
return ReplicateError(
|
||||
status_code=status_code, message=error_message, headers=headers
|
||||
)
|
||||
|
||||
def get_complete_url(
|
||||
self,
|
||||
api_base: Optional[str],
|
||||
api_key: Optional[str],
|
||||
model: str,
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
stream: Optional[bool] = None,
|
||||
) -> str:
|
||||
version_id = self.model_to_version_id(model)
|
||||
base_url = api_base
|
||||
if "deployments" in version_id:
|
||||
version_id = version_id.replace("deployments/", "")
|
||||
base_url = f"https://api.replicate.com/v1/deployments/{version_id}"
|
||||
else: # assume it's a model
|
||||
base_url = f"https://api.replicate.com/v1/models/{version_id}"
|
||||
|
||||
base_url = f"{base_url}/predictions"
|
||||
return base_url
|
||||
|
||||
def transform_request(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
headers: dict,
|
||||
) -> dict:
|
||||
## Load Config
|
||||
config = litellm.ReplicateConfig.get_config()
|
||||
for k, v in config.items():
|
||||
if (
|
||||
k not in optional_params
|
||||
): # completion(top_k=3) > replicate_config(top_k=3) <- allows for dynamic variables to be passed in
|
||||
optional_params[k] = v
|
||||
|
||||
system_prompt = None
|
||||
if optional_params is not None and "supports_system_prompt" in optional_params:
|
||||
supports_sys_prompt = optional_params.pop("supports_system_prompt")
|
||||
else:
|
||||
supports_sys_prompt = False
|
||||
|
||||
if supports_sys_prompt:
|
||||
for i in range(len(messages)):
|
||||
if messages[i]["role"] == "system":
|
||||
first_sys_message = messages.pop(i)
|
||||
system_prompt = convert_content_list_to_str(first_sys_message)
|
||||
break
|
||||
|
||||
if model in litellm.custom_prompt_dict:
|
||||
# check if the model has a registered custom prompt
|
||||
model_prompt_details = litellm.custom_prompt_dict[model]
|
||||
prompt = custom_prompt(
|
||||
role_dict=model_prompt_details.get("roles", {}),
|
||||
initial_prompt_value=model_prompt_details.get(
|
||||
"initial_prompt_value", ""
|
||||
),
|
||||
final_prompt_value=model_prompt_details.get("final_prompt_value", ""),
|
||||
bos_token=model_prompt_details.get("bos_token", ""),
|
||||
eos_token=model_prompt_details.get("eos_token", ""),
|
||||
messages=messages,
|
||||
)
|
||||
else:
|
||||
prompt = prompt_factory(model=model, messages=messages)
|
||||
|
||||
if prompt is None or not isinstance(prompt, str):
|
||||
raise ReplicateError(
|
||||
status_code=400,
|
||||
message="LiteLLM Error - prompt is not a string - {}".format(prompt),
|
||||
headers={},
|
||||
)
|
||||
|
||||
# If system prompt is supported, and a system prompt is provided, use it
|
||||
if system_prompt is not None:
|
||||
input_data = {
|
||||
"prompt": prompt,
|
||||
"system_prompt": system_prompt,
|
||||
**optional_params,
|
||||
}
|
||||
# Otherwise, use the prompt as is
|
||||
else:
|
||||
input_data = {"prompt": prompt, **optional_params}
|
||||
|
||||
version_id = self.model_to_version_id(model)
|
||||
request_data: dict = {"input": input_data}
|
||||
if ":" in version_id and len(version_id) > REPLICATE_MODEL_NAME_WITH_ID_LENGTH:
|
||||
model_parts = version_id.split(":")
|
||||
if (
|
||||
len(model_parts) > 1
|
||||
and len(model_parts[1]) == REPLICATE_MODEL_NAME_WITH_ID_LENGTH
|
||||
): ## checks if model name has a 64 digit code - e.g. "meta/llama-2-70b-chat:02e509c789964a7ea8736978a43525956ef40397be9033abf9fd2badfe68c9e3"
|
||||
request_data["version"] = model_parts[1]
|
||||
|
||||
return request_data
|
||||
|
||||
def transform_response(
|
||||
self,
|
||||
model: str,
|
||||
raw_response: httpx.Response,
|
||||
model_response: ModelResponse,
|
||||
logging_obj: LoggingClass,
|
||||
request_data: dict,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
encoding: Any,
|
||||
api_key: Optional[str] = None,
|
||||
json_mode: Optional[bool] = None,
|
||||
) -> ModelResponse:
|
||||
logging_obj.post_call(
|
||||
input=messages,
|
||||
api_key=api_key,
|
||||
original_response=raw_response.text,
|
||||
additional_args={"complete_input_dict": request_data},
|
||||
)
|
||||
raw_response_json = raw_response.json()
|
||||
if raw_response_json.get("status") != "succeeded":
|
||||
raise ReplicateError(
|
||||
status_code=422,
|
||||
message="LiteLLM Error - prediction not succeeded - {}".format(
|
||||
raw_response_json
|
||||
),
|
||||
headers=raw_response.headers,
|
||||
)
|
||||
outputs = raw_response_json.get("output", [])
|
||||
response_str = "".join(outputs)
|
||||
if len(response_str) == 0: # edge case, where result from replicate is empty
|
||||
response_str = " "
|
||||
|
||||
## Building RESPONSE OBJECT
|
||||
if len(response_str) >= 1:
|
||||
model_response.choices[0].message.content = response_str # type: ignore
|
||||
|
||||
# Calculate usage
|
||||
prompt_tokens = token_counter(model=model, messages=messages)
|
||||
completion_tokens = token_counter(
|
||||
model=model,
|
||||
text=response_str,
|
||||
count_response_tokens=True,
|
||||
)
|
||||
model_response.model = "replicate/" + model
|
||||
usage = Usage(
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=prompt_tokens + completion_tokens,
|
||||
)
|
||||
setattr(model_response, "usage", usage)
|
||||
|
||||
return model_response
|
||||
|
||||
def get_prediction_url(self, response: httpx.Response) -> str:
|
||||
"""
|
||||
response json: {
|
||||
...,
|
||||
"urls":{"cancel":"https://api.replicate.com/v1/predictions/gqsmqmp1pdrj00cknr08dgmvb4/cancel","get":"https://api.replicate.com/v1/predictions/gqsmqmp1pdrj00cknr08dgmvb4","stream":"https://stream-b.svc.rno2.c.replicate.net/v1/streams/eot4gbydowuin4snhncydwxt57dfwgsc3w3snycx5nid7oef7jga"}
|
||||
}
|
||||
"""
|
||||
response_json = response.json()
|
||||
prediction_url = response_json.get("urls", {}).get("get")
|
||||
if prediction_url is None:
|
||||
raise ReplicateError(
|
||||
status_code=400,
|
||||
message="LiteLLM Error - prediction url is None - {}".format(
|
||||
response_json
|
||||
),
|
||||
headers=response.headers,
|
||||
)
|
||||
return prediction_url
|
||||
|
||||
def validate_environment(
|
||||
self,
|
||||
headers: dict,
|
||||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
) -> dict:
|
||||
headers = {
|
||||
"Authorization": f"Token {api_key}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
return headers
|
||||
@@ -0,0 +1,15 @@
|
||||
from typing import Optional, Union
|
||||
|
||||
import httpx
|
||||
|
||||
from litellm.llms.base_llm.chat.transformation import BaseLLMException
|
||||
|
||||
|
||||
class ReplicateError(BaseLLMException):
|
||||
def __init__(
|
||||
self,
|
||||
status_code: int,
|
||||
message: str,
|
||||
headers: Optional[Union[dict, httpx.Headers]],
|
||||
):
|
||||
super().__init__(status_code=status_code, message=message, headers=headers)
|
||||
Reference in New Issue
Block a user