structure saas with tools
This commit is contained in:
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,68 @@
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from litellm.types.utils import HiddenParams
|
||||
|
||||
|
||||
def _add_headers_to_response(response: Any, headers: dict) -> Any:
|
||||
"""
|
||||
Helper function to add headers to a response's hidden params
|
||||
"""
|
||||
if response is None or not isinstance(response, BaseModel):
|
||||
return response
|
||||
|
||||
hidden_params: Optional[Union[dict, HiddenParams]] = getattr(
|
||||
response, "_hidden_params", {}
|
||||
)
|
||||
|
||||
if hidden_params is None:
|
||||
hidden_params = {}
|
||||
elif isinstance(hidden_params, HiddenParams):
|
||||
hidden_params = hidden_params.model_dump()
|
||||
|
||||
hidden_params.setdefault("additional_headers", {})
|
||||
hidden_params["additional_headers"].update(headers)
|
||||
|
||||
setattr(response, "_hidden_params", hidden_params)
|
||||
return response
|
||||
|
||||
|
||||
def add_retry_headers_to_response(
|
||||
response: Any,
|
||||
attempted_retries: int,
|
||||
max_retries: Optional[int] = None,
|
||||
) -> Any:
|
||||
"""
|
||||
Add retry headers to the request
|
||||
"""
|
||||
retry_headers = {
|
||||
"x-litellm-attempted-retries": attempted_retries,
|
||||
}
|
||||
if max_retries is not None:
|
||||
retry_headers["x-litellm-max-retries"] = max_retries
|
||||
|
||||
return _add_headers_to_response(response, retry_headers)
|
||||
|
||||
|
||||
def add_fallback_headers_to_response(
|
||||
response: Any,
|
||||
attempted_fallbacks: int,
|
||||
) -> Any:
|
||||
"""
|
||||
Add fallback headers to the response
|
||||
|
||||
Args:
|
||||
response: The response to add the headers to
|
||||
attempted_fallbacks: The number of fallbacks attempted
|
||||
|
||||
Returns:
|
||||
The response with the headers added
|
||||
|
||||
Note: It's intentional that we don't add max_fallbacks in response headers
|
||||
Want to avoid bloat in the response headers for performance.
|
||||
"""
|
||||
fallback_headers = {
|
||||
"x-litellm-attempted-fallbacks": attempted_fallbacks,
|
||||
}
|
||||
return _add_headers_to_response(response, fallback_headers)
|
||||
@@ -0,0 +1,63 @@
|
||||
import io
|
||||
import json
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
|
||||
class InMemoryFile(io.BytesIO):
|
||||
def __init__(self, content: bytes, name: str):
|
||||
super().__init__(content)
|
||||
self.name = name
|
||||
|
||||
|
||||
def replace_model_in_jsonl(
|
||||
file_content: Union[bytes, Tuple[str, bytes, str]], new_model_name: str
|
||||
) -> Optional[InMemoryFile]:
|
||||
try:
|
||||
# Decode the bytes to a string and split into lines
|
||||
# If file_content is a file-like object, read the bytes
|
||||
if hasattr(file_content, "read"):
|
||||
file_content_bytes = file_content.read() # type: ignore
|
||||
elif isinstance(file_content, tuple):
|
||||
file_content_bytes = file_content[1]
|
||||
else:
|
||||
file_content_bytes = file_content
|
||||
|
||||
# Decode the bytes to a string and split into lines
|
||||
if isinstance(file_content_bytes, bytes):
|
||||
file_content_str = file_content_bytes.decode("utf-8")
|
||||
else:
|
||||
file_content_str = file_content_bytes
|
||||
lines = file_content_str.splitlines()
|
||||
modified_lines = []
|
||||
for line in lines:
|
||||
# Parse each line as a JSON object
|
||||
json_object = json.loads(line.strip())
|
||||
|
||||
# Replace the model name if it exists
|
||||
if "body" in json_object:
|
||||
json_object["body"]["model"] = new_model_name
|
||||
|
||||
# Convert the modified JSON object back to a string
|
||||
modified_lines.append(json.dumps(json_object))
|
||||
|
||||
# Reassemble the modified lines and return as bytes
|
||||
modified_file_content = "\n".join(modified_lines).encode("utf-8")
|
||||
return InMemoryFile(modified_file_content, name="modified_file.jsonl") # type: ignore
|
||||
|
||||
except (json.JSONDecodeError, UnicodeDecodeError, TypeError):
|
||||
return None
|
||||
|
||||
|
||||
def _get_router_metadata_variable_name(function_name) -> str:
|
||||
"""
|
||||
Helper to return what the "metadata" field should be called in the request data
|
||||
|
||||
For all /thread or /assistant endpoints we need to call this "litellm_metadata"
|
||||
|
||||
For ALL other endpoints we call this "metadata
|
||||
"""
|
||||
ROUTER_METHODS_USING_LITELLM_METADATA = set(["batch", "generic_api_call"])
|
||||
if function_name in ROUTER_METHODS_USING_LITELLM_METADATA:
|
||||
return "litellm_metadata"
|
||||
else:
|
||||
return "metadata"
|
||||
@@ -0,0 +1,37 @@
|
||||
import asyncio
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from litellm.utils import calculate_max_parallel_requests
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.router import Router as _Router
|
||||
|
||||
LitellmRouter = _Router
|
||||
else:
|
||||
LitellmRouter = Any
|
||||
|
||||
|
||||
class InitalizeCachedClient:
|
||||
@staticmethod
|
||||
def set_max_parallel_requests_client(
|
||||
litellm_router_instance: LitellmRouter, model: dict
|
||||
):
|
||||
litellm_params = model.get("litellm_params", {})
|
||||
model_id = model["model_info"]["id"]
|
||||
rpm = litellm_params.get("rpm", None)
|
||||
tpm = litellm_params.get("tpm", None)
|
||||
max_parallel_requests = litellm_params.get("max_parallel_requests", None)
|
||||
calculated_max_parallel_requests = calculate_max_parallel_requests(
|
||||
rpm=rpm,
|
||||
max_parallel_requests=max_parallel_requests,
|
||||
tpm=tpm,
|
||||
default_max_parallel_requests=litellm_router_instance.default_max_parallel_requests,
|
||||
)
|
||||
if calculated_max_parallel_requests:
|
||||
semaphore = asyncio.Semaphore(calculated_max_parallel_requests)
|
||||
cache_key = f"{model_id}_max_parallel_requests_client"
|
||||
litellm_router_instance.cache.set_cache(
|
||||
key=cache_key,
|
||||
value=semaphore,
|
||||
local_only=True,
|
||||
)
|
||||
@@ -0,0 +1,37 @@
|
||||
"""
|
||||
Utils for handling clientside credentials
|
||||
|
||||
Supported clientside credentials:
|
||||
- api_key
|
||||
- api_base
|
||||
- base_url
|
||||
|
||||
If given, generate a unique model_id for the deployment.
|
||||
|
||||
Ensures cooldowns are applied correctly.
|
||||
"""
|
||||
|
||||
clientside_credential_keys = ["api_key", "api_base", "base_url"]
|
||||
|
||||
|
||||
def is_clientside_credential(request_kwargs: dict) -> bool:
|
||||
"""
|
||||
Check if the credential is a clientside credential.
|
||||
"""
|
||||
return any(key in request_kwargs for key in clientside_credential_keys)
|
||||
|
||||
|
||||
def get_dynamic_litellm_params(litellm_params: dict, request_kwargs: dict) -> dict:
|
||||
"""
|
||||
Generate a unique model_id for the deployment.
|
||||
|
||||
Returns
|
||||
- litellm_params: dict
|
||||
|
||||
for generating a unique model_id.
|
||||
"""
|
||||
# update litellm_params with clientside credentials
|
||||
for key in clientside_credential_keys:
|
||||
if key in request_kwargs:
|
||||
litellm_params[key] = request_kwargs[key]
|
||||
return litellm_params
|
||||
@@ -0,0 +1,14 @@
|
||||
import hashlib
|
||||
import json
|
||||
|
||||
from litellm.types.router import CredentialLiteLLMParams
|
||||
|
||||
|
||||
def get_litellm_params_sensitive_credential_hash(litellm_params: dict) -> str:
|
||||
"""
|
||||
Hash of the credential params, used for mapping the file id to the right model
|
||||
"""
|
||||
sensitive_params = CredentialLiteLLMParams(**litellm_params)
|
||||
return hashlib.sha256(
|
||||
json.dumps(sensitive_params.model_dump()).encode()
|
||||
).hexdigest()
|
||||
@@ -0,0 +1,170 @@
|
||||
"""
|
||||
Wrapper around router cache. Meant to handle model cooldown logic
|
||||
"""
|
||||
|
||||
import time
|
||||
from typing import TYPE_CHECKING, Any, List, Optional, Tuple, TypedDict, Union
|
||||
|
||||
from litellm import verbose_logger
|
||||
from litellm.caching.caching import DualCache
|
||||
from litellm.caching.in_memory_cache import InMemoryCache
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from opentelemetry.trace import Span as _Span
|
||||
|
||||
Span = Union[_Span, Any]
|
||||
else:
|
||||
Span = Any
|
||||
|
||||
|
||||
class CooldownCacheValue(TypedDict):
|
||||
exception_received: str
|
||||
status_code: str
|
||||
timestamp: float
|
||||
cooldown_time: float
|
||||
|
||||
|
||||
class CooldownCache:
|
||||
def __init__(self, cache: DualCache, default_cooldown_time: float):
|
||||
self.cache = cache
|
||||
self.default_cooldown_time = default_cooldown_time
|
||||
self.in_memory_cache = InMemoryCache()
|
||||
|
||||
def _common_add_cooldown_logic(
|
||||
self, model_id: str, original_exception, exception_status, cooldown_time: float
|
||||
) -> Tuple[str, CooldownCacheValue]:
|
||||
try:
|
||||
current_time = time.time()
|
||||
cooldown_key = f"deployment:{model_id}:cooldown"
|
||||
|
||||
# Store the cooldown information for the deployment separately
|
||||
cooldown_data = CooldownCacheValue(
|
||||
exception_received=str(original_exception),
|
||||
status_code=str(exception_status),
|
||||
timestamp=current_time,
|
||||
cooldown_time=cooldown_time,
|
||||
)
|
||||
|
||||
return cooldown_key, cooldown_data
|
||||
except Exception as e:
|
||||
verbose_logger.error(
|
||||
"CooldownCache::_common_add_cooldown_logic - Exception occurred - {}".format(
|
||||
str(e)
|
||||
)
|
||||
)
|
||||
raise e
|
||||
|
||||
def add_deployment_to_cooldown(
|
||||
self,
|
||||
model_id: str,
|
||||
original_exception: Exception,
|
||||
exception_status: int,
|
||||
cooldown_time: Optional[float],
|
||||
):
|
||||
try:
|
||||
_cooldown_time = cooldown_time or self.default_cooldown_time
|
||||
cooldown_key, cooldown_data = self._common_add_cooldown_logic(
|
||||
model_id=model_id,
|
||||
original_exception=original_exception,
|
||||
exception_status=exception_status,
|
||||
cooldown_time=_cooldown_time,
|
||||
)
|
||||
|
||||
# Set the cache with a TTL equal to the cooldown time
|
||||
self.cache.set_cache(
|
||||
value=cooldown_data,
|
||||
key=cooldown_key,
|
||||
ttl=_cooldown_time,
|
||||
)
|
||||
except Exception as e:
|
||||
verbose_logger.error(
|
||||
"CooldownCache::add_deployment_to_cooldown - Exception occurred - {}".format(
|
||||
str(e)
|
||||
)
|
||||
)
|
||||
raise e
|
||||
|
||||
@staticmethod
|
||||
def get_cooldown_cache_key(model_id: str) -> str:
|
||||
return f"deployment:{model_id}:cooldown"
|
||||
|
||||
async def async_get_active_cooldowns(
|
||||
self, model_ids: List[str], parent_otel_span: Optional[Span]
|
||||
) -> List[Tuple[str, CooldownCacheValue]]:
|
||||
# Generate the keys for the deployments
|
||||
keys = [
|
||||
CooldownCache.get_cooldown_cache_key(model_id) for model_id in model_ids
|
||||
]
|
||||
|
||||
# Retrieve the values for the keys using mget
|
||||
## more likely to be none if no models ratelimited. So just check redis every 1s
|
||||
## each redis call adds ~100ms latency.
|
||||
|
||||
## check in memory cache first
|
||||
results = await self.cache.async_batch_get_cache(
|
||||
keys=keys, parent_otel_span=parent_otel_span
|
||||
)
|
||||
active_cooldowns: List[Tuple[str, CooldownCacheValue]] = []
|
||||
|
||||
if results is None:
|
||||
return active_cooldowns
|
||||
|
||||
# Process the results
|
||||
for model_id, result in zip(model_ids, results):
|
||||
if result and isinstance(result, dict):
|
||||
cooldown_cache_value = CooldownCacheValue(**result) # type: ignore
|
||||
active_cooldowns.append((model_id, cooldown_cache_value))
|
||||
|
||||
return active_cooldowns
|
||||
|
||||
def get_active_cooldowns(
|
||||
self, model_ids: List[str], parent_otel_span: Optional[Span]
|
||||
) -> List[Tuple[str, CooldownCacheValue]]:
|
||||
# Generate the keys for the deployments
|
||||
keys = [f"deployment:{model_id}:cooldown" for model_id in model_ids]
|
||||
# Retrieve the values for the keys using mget
|
||||
results = (
|
||||
self.cache.batch_get_cache(keys=keys, parent_otel_span=parent_otel_span)
|
||||
or []
|
||||
)
|
||||
|
||||
active_cooldowns = []
|
||||
# Process the results
|
||||
for model_id, result in zip(model_ids, results):
|
||||
if result and isinstance(result, dict):
|
||||
cooldown_cache_value = CooldownCacheValue(**result) # type: ignore
|
||||
active_cooldowns.append((model_id, cooldown_cache_value))
|
||||
|
||||
return active_cooldowns
|
||||
|
||||
def get_min_cooldown(
|
||||
self, model_ids: List[str], parent_otel_span: Optional[Span]
|
||||
) -> float:
|
||||
"""Return min cooldown time required for a group of model id's."""
|
||||
|
||||
# Generate the keys for the deployments
|
||||
keys = [f"deployment:{model_id}:cooldown" for model_id in model_ids]
|
||||
|
||||
# Retrieve the values for the keys using mget
|
||||
results = (
|
||||
self.cache.batch_get_cache(keys=keys, parent_otel_span=parent_otel_span)
|
||||
or []
|
||||
)
|
||||
|
||||
min_cooldown_time: Optional[float] = None
|
||||
# Process the results
|
||||
for model_id, result in zip(model_ids, results):
|
||||
if result and isinstance(result, dict):
|
||||
cooldown_cache_value = CooldownCacheValue(**result) # type: ignore
|
||||
if min_cooldown_time is None:
|
||||
min_cooldown_time = cooldown_cache_value["cooldown_time"]
|
||||
elif cooldown_cache_value["cooldown_time"] < min_cooldown_time:
|
||||
min_cooldown_time = cooldown_cache_value["cooldown_time"]
|
||||
|
||||
return min_cooldown_time or self.default_cooldown_time
|
||||
|
||||
|
||||
# Usage example:
|
||||
# cooldown_cache = CooldownCache(cache=your_cache_instance, cooldown_time=your_cooldown_time)
|
||||
# cooldown_cache.add_deployment_to_cooldown(deployment, original_exception, exception_status)
|
||||
# active_cooldowns = cooldown_cache.get_active_cooldowns()
|
||||
@@ -0,0 +1,98 @@
|
||||
"""
|
||||
Callbacks triggered on cooling down deployments
|
||||
"""
|
||||
|
||||
import copy
|
||||
from typing import TYPE_CHECKING, Any, Optional, Union
|
||||
|
||||
import litellm
|
||||
from litellm._logging import verbose_logger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.router import Router as _Router
|
||||
|
||||
LitellmRouter = _Router
|
||||
from litellm.integrations.prometheus import PrometheusLogger
|
||||
else:
|
||||
LitellmRouter = Any
|
||||
PrometheusLogger = Any
|
||||
|
||||
|
||||
async def router_cooldown_event_callback(
|
||||
litellm_router_instance: LitellmRouter,
|
||||
deployment_id: str,
|
||||
exception_status: Union[str, int],
|
||||
cooldown_time: float,
|
||||
):
|
||||
"""
|
||||
Callback triggered when a deployment is put into cooldown by litellm
|
||||
|
||||
- Updates deployment state on Prometheus
|
||||
- Increments cooldown metric for deployment on Prometheus
|
||||
"""
|
||||
verbose_logger.debug("In router_cooldown_event_callback - updating prometheus")
|
||||
_deployment = litellm_router_instance.get_deployment(model_id=deployment_id)
|
||||
if _deployment is None:
|
||||
verbose_logger.warning(
|
||||
f"in router_cooldown_event_callback but _deployment is None for deployment_id={deployment_id}. Doing nothing"
|
||||
)
|
||||
return
|
||||
_litellm_params = _deployment["litellm_params"]
|
||||
temp_litellm_params = copy.deepcopy(_litellm_params)
|
||||
temp_litellm_params = dict(temp_litellm_params)
|
||||
_model_name = _deployment.get("model_name", None) or ""
|
||||
_api_base = (
|
||||
litellm.get_api_base(model=_model_name, optional_params=temp_litellm_params)
|
||||
or ""
|
||||
)
|
||||
model_info = _deployment["model_info"]
|
||||
model_id = model_info.id
|
||||
|
||||
litellm_model_name = temp_litellm_params.get("model") or ""
|
||||
llm_provider = ""
|
||||
try:
|
||||
_, llm_provider, _, _ = litellm.get_llm_provider(
|
||||
model=litellm_model_name,
|
||||
custom_llm_provider=temp_litellm_params.get("custom_llm_provider"),
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# get the prometheus logger from in memory loggers
|
||||
prometheusLogger: Optional[
|
||||
PrometheusLogger
|
||||
] = _get_prometheus_logger_from_callbacks()
|
||||
|
||||
if prometheusLogger is not None:
|
||||
prometheusLogger.set_deployment_complete_outage(
|
||||
litellm_model_name=_model_name,
|
||||
model_id=model_id,
|
||||
api_base=_api_base,
|
||||
api_provider=llm_provider,
|
||||
)
|
||||
|
||||
prometheusLogger.increment_deployment_cooled_down(
|
||||
litellm_model_name=_model_name,
|
||||
model_id=model_id,
|
||||
api_base=_api_base,
|
||||
api_provider=llm_provider,
|
||||
exception_status=str(exception_status),
|
||||
)
|
||||
|
||||
return
|
||||
|
||||
|
||||
def _get_prometheus_logger_from_callbacks() -> Optional[PrometheusLogger]:
|
||||
"""
|
||||
Checks if prometheus is a initalized callback, if yes returns it
|
||||
"""
|
||||
from litellm.integrations.prometheus import PrometheusLogger
|
||||
|
||||
for _callback in litellm._async_success_callback:
|
||||
if isinstance(_callback, PrometheusLogger):
|
||||
return _callback
|
||||
for global_callback in litellm.callbacks:
|
||||
if isinstance(global_callback, PrometheusLogger):
|
||||
return global_callback
|
||||
|
||||
return None
|
||||
@@ -0,0 +1,438 @@
|
||||
"""
|
||||
Router cooldown handlers
|
||||
- _set_cooldown_deployments: puts a deployment in the cooldown list
|
||||
- get_cooldown_deployments: returns the list of deployments in the cooldown list
|
||||
- async_get_cooldown_deployments: ASYNC: returns the list of deployments in the cooldown list
|
||||
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from typing import TYPE_CHECKING, Any, List, Optional, Union
|
||||
|
||||
import litellm
|
||||
from litellm._logging import verbose_router_logger
|
||||
from litellm.constants import (
|
||||
DEFAULT_COOLDOWN_TIME_SECONDS,
|
||||
DEFAULT_FAILURE_THRESHOLD_PERCENT,
|
||||
SINGLE_DEPLOYMENT_TRAFFIC_FAILURE_THRESHOLD,
|
||||
)
|
||||
from litellm.router_utils.cooldown_callbacks import router_cooldown_event_callback
|
||||
|
||||
from .router_callbacks.track_deployment_metrics import (
|
||||
get_deployment_failures_for_current_minute,
|
||||
get_deployment_successes_for_current_minute,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from opentelemetry.trace import Span as _Span
|
||||
|
||||
from litellm.router import Router as _Router
|
||||
|
||||
LitellmRouter = _Router
|
||||
Span = Union[_Span, Any]
|
||||
else:
|
||||
LitellmRouter = Any
|
||||
Span = Any
|
||||
|
||||
|
||||
def _is_cooldown_required(
|
||||
litellm_router_instance: LitellmRouter,
|
||||
model_id: str,
|
||||
exception_status: Union[str, int],
|
||||
exception_str: Optional[str] = None,
|
||||
) -> bool:
|
||||
"""
|
||||
A function to determine if a cooldown is required based on the exception status.
|
||||
|
||||
Parameters:
|
||||
model_id (str) The id of the model in the model list
|
||||
exception_status (Union[str, int]): The status of the exception.
|
||||
|
||||
Returns:
|
||||
bool: True if a cooldown is required, False otherwise.
|
||||
"""
|
||||
try:
|
||||
ignored_strings = ["APIConnectionError"]
|
||||
if (
|
||||
exception_str is not None
|
||||
): # don't cooldown on litellm api connection errors errors
|
||||
for ignored_string in ignored_strings:
|
||||
if ignored_string in exception_str:
|
||||
return False
|
||||
|
||||
if isinstance(exception_status, str):
|
||||
exception_status = int(exception_status)
|
||||
|
||||
if exception_status >= 400 and exception_status < 500:
|
||||
if exception_status == 429:
|
||||
# Cool down 429 Rate Limit Errors
|
||||
return True
|
||||
|
||||
elif exception_status == 401:
|
||||
# Cool down 401 Auth Errors
|
||||
return True
|
||||
|
||||
elif exception_status == 408:
|
||||
return True
|
||||
|
||||
elif exception_status == 404:
|
||||
return True
|
||||
|
||||
else:
|
||||
# Do NOT cool down all other 4XX Errors
|
||||
return False
|
||||
|
||||
else:
|
||||
# should cool down for all other errors
|
||||
return True
|
||||
|
||||
except Exception:
|
||||
# Catch all - if any exceptions default to cooling down
|
||||
return True
|
||||
|
||||
|
||||
def _should_run_cooldown_logic(
|
||||
litellm_router_instance: LitellmRouter,
|
||||
deployment: Optional[str],
|
||||
exception_status: Union[str, int],
|
||||
original_exception: Any,
|
||||
) -> bool:
|
||||
"""
|
||||
Helper that decides if cooldown logic should be run
|
||||
Returns False if cooldown logic should not be run
|
||||
|
||||
Does not run cooldown logic when:
|
||||
- router.disable_cooldowns is True
|
||||
- deployment is None
|
||||
- _is_cooldown_required() returns False
|
||||
- deployment is in litellm_router_instance.provider_default_deployment_ids
|
||||
- exception_status is not one that should be immediately retried (e.g. 401)
|
||||
"""
|
||||
if (
|
||||
deployment is None
|
||||
or litellm_router_instance.get_model_group(id=deployment) is None
|
||||
):
|
||||
verbose_router_logger.debug(
|
||||
"Should Not Run Cooldown Logic: deployment id is none or model group can't be found."
|
||||
)
|
||||
return False
|
||||
|
||||
if litellm_router_instance.disable_cooldowns:
|
||||
verbose_router_logger.debug(
|
||||
"Should Not Run Cooldown Logic: disable_cooldowns is True"
|
||||
)
|
||||
return False
|
||||
|
||||
if deployment is None:
|
||||
verbose_router_logger.debug("Should Not Run Cooldown Logic: deployment is None")
|
||||
return False
|
||||
|
||||
if not _is_cooldown_required(
|
||||
litellm_router_instance=litellm_router_instance,
|
||||
model_id=deployment,
|
||||
exception_status=exception_status,
|
||||
exception_str=str(original_exception),
|
||||
):
|
||||
verbose_router_logger.debug(
|
||||
"Should Not Run Cooldown Logic: _is_cooldown_required returned False"
|
||||
)
|
||||
return False
|
||||
|
||||
if deployment in litellm_router_instance.provider_default_deployment_ids:
|
||||
verbose_router_logger.debug(
|
||||
"Should Not Run Cooldown Logic: deployment is in provider_default_deployment_ids"
|
||||
)
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def _should_cooldown_deployment(
|
||||
litellm_router_instance: LitellmRouter,
|
||||
deployment: str,
|
||||
exception_status: Union[str, int],
|
||||
original_exception: Any,
|
||||
) -> bool:
|
||||
"""
|
||||
Helper that decides if a deployment should be put in cooldown
|
||||
|
||||
Returns True if the deployment should be put in cooldown
|
||||
Returns False if the deployment should not be put in cooldown
|
||||
|
||||
|
||||
Deployment is put in cooldown when:
|
||||
- v2 logic (Current):
|
||||
cooldown if:
|
||||
- got a 429 error from LLM API
|
||||
- if %fails/%(successes + fails) > ALLOWED_FAILURE_RATE_PER_MINUTE
|
||||
- got 401 Auth error, 404 NotFounder - checked by litellm._should_retry()
|
||||
|
||||
|
||||
|
||||
- v1 logic (Legacy): if allowed fails or allowed fail policy set, coolsdown if num fails in this minute > allowed fails
|
||||
"""
|
||||
## BASE CASE - single deployment
|
||||
model_group = litellm_router_instance.get_model_group(id=deployment)
|
||||
is_single_deployment_model_group = False
|
||||
if model_group is not None and len(model_group) == 1:
|
||||
is_single_deployment_model_group = True
|
||||
if (
|
||||
litellm_router_instance.allowed_fails_policy is None
|
||||
and _is_allowed_fails_set_on_router(
|
||||
litellm_router_instance=litellm_router_instance
|
||||
)
|
||||
is False
|
||||
):
|
||||
num_successes_this_minute = get_deployment_successes_for_current_minute(
|
||||
litellm_router_instance=litellm_router_instance, deployment_id=deployment
|
||||
)
|
||||
num_fails_this_minute = get_deployment_failures_for_current_minute(
|
||||
litellm_router_instance=litellm_router_instance, deployment_id=deployment
|
||||
)
|
||||
|
||||
total_requests_this_minute = num_successes_this_minute + num_fails_this_minute
|
||||
percent_fails = 0.0
|
||||
if total_requests_this_minute > 0:
|
||||
percent_fails = num_fails_this_minute / (
|
||||
num_successes_this_minute + num_fails_this_minute
|
||||
)
|
||||
verbose_router_logger.debug(
|
||||
"percent fails for deployment = %s, percent fails = %s, num successes = %s, num fails = %s",
|
||||
deployment,
|
||||
percent_fails,
|
||||
num_successes_this_minute,
|
||||
num_fails_this_minute,
|
||||
)
|
||||
|
||||
exception_status_int = cast_exception_status_to_int(exception_status)
|
||||
if exception_status_int == 429 and not is_single_deployment_model_group:
|
||||
return True
|
||||
elif (
|
||||
percent_fails == 1.0
|
||||
and total_requests_this_minute
|
||||
>= SINGLE_DEPLOYMENT_TRAFFIC_FAILURE_THRESHOLD
|
||||
):
|
||||
# Cooldown if all requests failed and we have reasonable traffic
|
||||
return True
|
||||
elif (
|
||||
percent_fails > DEFAULT_FAILURE_THRESHOLD_PERCENT
|
||||
and not is_single_deployment_model_group # by default we should avoid cooldowns on single deployment model groups
|
||||
):
|
||||
return True
|
||||
|
||||
elif (
|
||||
litellm._should_retry(
|
||||
status_code=cast_exception_status_to_int(exception_status)
|
||||
)
|
||||
is False
|
||||
):
|
||||
return True
|
||||
|
||||
return False
|
||||
else:
|
||||
return should_cooldown_based_on_allowed_fails_policy(
|
||||
litellm_router_instance=litellm_router_instance,
|
||||
deployment=deployment,
|
||||
original_exception=original_exception,
|
||||
)
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def _set_cooldown_deployments(
|
||||
litellm_router_instance: LitellmRouter,
|
||||
original_exception: Any,
|
||||
exception_status: Union[str, int],
|
||||
deployment: Optional[str] = None,
|
||||
time_to_cooldown: Optional[float] = None,
|
||||
) -> bool:
|
||||
"""
|
||||
Add a model to the list of models being cooled down for that minute, if it exceeds the allowed fails / minute
|
||||
|
||||
or
|
||||
|
||||
the exception is not one that should be immediately retried (e.g. 401)
|
||||
|
||||
Returns:
|
||||
- True if the deployment should be put in cooldown
|
||||
- False if the deployment should not be put in cooldown
|
||||
"""
|
||||
verbose_router_logger.debug("checks 'should_run_cooldown_logic'")
|
||||
|
||||
if (
|
||||
_should_run_cooldown_logic(
|
||||
litellm_router_instance, deployment, exception_status, original_exception
|
||||
)
|
||||
is False
|
||||
or deployment is None
|
||||
):
|
||||
verbose_router_logger.debug("should_run_cooldown_logic returned False")
|
||||
return False
|
||||
|
||||
exception_status_int = cast_exception_status_to_int(exception_status)
|
||||
|
||||
verbose_router_logger.debug(f"Attempting to add {deployment} to cooldown list")
|
||||
cooldown_time = litellm_router_instance.cooldown_time or 1
|
||||
if time_to_cooldown is not None:
|
||||
cooldown_time = time_to_cooldown
|
||||
|
||||
if _should_cooldown_deployment(
|
||||
litellm_router_instance, deployment, exception_status, original_exception
|
||||
):
|
||||
litellm_router_instance.cooldown_cache.add_deployment_to_cooldown(
|
||||
model_id=deployment,
|
||||
original_exception=original_exception,
|
||||
exception_status=exception_status_int,
|
||||
cooldown_time=cooldown_time,
|
||||
)
|
||||
|
||||
# Trigger cooldown callback handler
|
||||
asyncio.create_task(
|
||||
router_cooldown_event_callback(
|
||||
litellm_router_instance=litellm_router_instance,
|
||||
deployment_id=deployment,
|
||||
exception_status=exception_status,
|
||||
cooldown_time=cooldown_time,
|
||||
)
|
||||
)
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
async def _async_get_cooldown_deployments(
|
||||
litellm_router_instance: LitellmRouter,
|
||||
parent_otel_span: Optional[Span],
|
||||
) -> List[str]:
|
||||
"""
|
||||
Async implementation of '_get_cooldown_deployments'
|
||||
"""
|
||||
model_ids = litellm_router_instance.get_model_ids()
|
||||
cooldown_models = (
|
||||
await litellm_router_instance.cooldown_cache.async_get_active_cooldowns(
|
||||
model_ids=model_ids,
|
||||
parent_otel_span=parent_otel_span,
|
||||
)
|
||||
)
|
||||
|
||||
cached_value_deployment_ids = []
|
||||
if (
|
||||
cooldown_models is not None
|
||||
and isinstance(cooldown_models, list)
|
||||
and len(cooldown_models) > 0
|
||||
and isinstance(cooldown_models[0], tuple)
|
||||
):
|
||||
cached_value_deployment_ids = [cv[0] for cv in cooldown_models]
|
||||
|
||||
verbose_router_logger.debug(f"retrieve cooldown models: {cooldown_models}")
|
||||
return cached_value_deployment_ids
|
||||
|
||||
|
||||
async def _async_get_cooldown_deployments_with_debug_info(
|
||||
litellm_router_instance: LitellmRouter,
|
||||
parent_otel_span: Optional[Span],
|
||||
) -> List[tuple]:
|
||||
"""
|
||||
Async implementation of '_get_cooldown_deployments'
|
||||
"""
|
||||
model_ids = litellm_router_instance.get_model_ids()
|
||||
cooldown_models = (
|
||||
await litellm_router_instance.cooldown_cache.async_get_active_cooldowns(
|
||||
model_ids=model_ids, parent_otel_span=parent_otel_span
|
||||
)
|
||||
)
|
||||
|
||||
verbose_router_logger.debug(f"retrieve cooldown models: {cooldown_models}")
|
||||
return cooldown_models
|
||||
|
||||
|
||||
def _get_cooldown_deployments(
|
||||
litellm_router_instance: LitellmRouter, parent_otel_span: Optional[Span]
|
||||
) -> List[str]:
|
||||
"""
|
||||
Get the list of models being cooled down for this minute
|
||||
"""
|
||||
# get the current cooldown list for that minute
|
||||
|
||||
# ----------------------
|
||||
# Return cooldown models
|
||||
# ----------------------
|
||||
model_ids = litellm_router_instance.get_model_ids()
|
||||
|
||||
cooldown_models = litellm_router_instance.cooldown_cache.get_active_cooldowns(
|
||||
model_ids=model_ids, parent_otel_span=parent_otel_span
|
||||
)
|
||||
|
||||
cached_value_deployment_ids = []
|
||||
if (
|
||||
cooldown_models is not None
|
||||
and isinstance(cooldown_models, list)
|
||||
and len(cooldown_models) > 0
|
||||
and isinstance(cooldown_models[0], tuple)
|
||||
):
|
||||
cached_value_deployment_ids = [cv[0] for cv in cooldown_models]
|
||||
|
||||
return cached_value_deployment_ids
|
||||
|
||||
|
||||
def should_cooldown_based_on_allowed_fails_policy(
|
||||
litellm_router_instance: LitellmRouter,
|
||||
deployment: str,
|
||||
original_exception: Any,
|
||||
) -> bool:
|
||||
"""
|
||||
Check if fails are within the allowed limit and update the number of fails.
|
||||
|
||||
Returns:
|
||||
- True if fails exceed the allowed limit (should cooldown)
|
||||
- False if fails are within the allowed limit (should not cooldown)
|
||||
"""
|
||||
allowed_fails = (
|
||||
litellm_router_instance.get_allowed_fails_from_policy(
|
||||
exception=original_exception,
|
||||
)
|
||||
or litellm_router_instance.allowed_fails
|
||||
)
|
||||
cooldown_time = (
|
||||
litellm_router_instance.cooldown_time or DEFAULT_COOLDOWN_TIME_SECONDS
|
||||
)
|
||||
|
||||
current_fails = litellm_router_instance.failed_calls.get_cache(key=deployment) or 0
|
||||
updated_fails = current_fails + 1
|
||||
|
||||
if updated_fails > allowed_fails:
|
||||
return True
|
||||
else:
|
||||
litellm_router_instance.failed_calls.set_cache(
|
||||
key=deployment, value=updated_fails, ttl=cooldown_time
|
||||
)
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def _is_allowed_fails_set_on_router(
|
||||
litellm_router_instance: LitellmRouter,
|
||||
) -> bool:
|
||||
"""
|
||||
Check if Router.allowed_fails is set or is Non-default Value
|
||||
|
||||
Returns:
|
||||
- True if Router.allowed_fails is set or is Non-default Value
|
||||
- False if Router.allowed_fails is None or is Default Value
|
||||
"""
|
||||
if litellm_router_instance.allowed_fails is None:
|
||||
return False
|
||||
if litellm_router_instance.allowed_fails != litellm.allowed_fails:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def cast_exception_status_to_int(exception_status: Union[str, int]) -> int:
|
||||
if isinstance(exception_status, str):
|
||||
try:
|
||||
exception_status = int(exception_status)
|
||||
except Exception:
|
||||
verbose_router_logger.debug(
|
||||
f"Unable to cast exception status to int {exception_status}. Defaulting to status=500."
|
||||
)
|
||||
exception_status = 500
|
||||
return exception_status
|
||||
@@ -0,0 +1,303 @@
|
||||
from enum import Enum
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import litellm
|
||||
from litellm._logging import verbose_router_logger
|
||||
from litellm.integrations.custom_logger import CustomLogger
|
||||
from litellm.router_utils.add_retry_fallback_headers import (
|
||||
add_fallback_headers_to_response,
|
||||
)
|
||||
from litellm.types.router import LiteLLMParamsTypedDict
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.router import Router as _Router
|
||||
|
||||
LitellmRouter = _Router
|
||||
else:
|
||||
LitellmRouter = Any
|
||||
|
||||
|
||||
def _check_stripped_model_group(model_group: str, fallback_key: str) -> bool:
|
||||
"""
|
||||
Handles wildcard routing scenario
|
||||
|
||||
where fallbacks set like:
|
||||
[{"gpt-3.5-turbo": ["claude-3-haiku"]}]
|
||||
|
||||
but model_group is like:
|
||||
"openai/gpt-3.5-turbo"
|
||||
|
||||
Returns:
|
||||
- True if the stripped model group == fallback_key
|
||||
"""
|
||||
for provider in litellm.provider_list:
|
||||
if isinstance(provider, Enum):
|
||||
_provider = provider.value
|
||||
else:
|
||||
_provider = provider
|
||||
if model_group.startswith(f"{_provider}/"):
|
||||
stripped_model_group = model_group.replace(f"{_provider}/", "")
|
||||
if stripped_model_group == fallback_key:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def get_fallback_model_group(
|
||||
fallbacks: List[Any], model_group: str
|
||||
) -> Tuple[Optional[List[str]], Optional[int]]:
|
||||
"""
|
||||
Returns:
|
||||
- fallback_model_group: List[str] of fallback model groups. example: ["gpt-4", "gpt-3.5-turbo"]
|
||||
- generic_fallback_idx: int of the index of the generic fallback in the fallbacks list.
|
||||
|
||||
Checks:
|
||||
- exact match
|
||||
- stripped model group match
|
||||
- generic fallback
|
||||
"""
|
||||
generic_fallback_idx: Optional[int] = None
|
||||
stripped_model_fallback: Optional[List[str]] = None
|
||||
fallback_model_group: Optional[List[str]] = None
|
||||
## check for specific model group-specific fallbacks
|
||||
for idx, item in enumerate(fallbacks):
|
||||
if isinstance(item, dict):
|
||||
if list(item.keys())[0] == model_group: # check exact match
|
||||
fallback_model_group = item[model_group]
|
||||
break
|
||||
elif _check_stripped_model_group(
|
||||
model_group=model_group, fallback_key=list(item.keys())[0]
|
||||
): # check generic fallback
|
||||
stripped_model_fallback = item[list(item.keys())[0]]
|
||||
elif list(item.keys())[0] == "*": # check generic fallback
|
||||
generic_fallback_idx = idx
|
||||
elif isinstance(item, str):
|
||||
fallback_model_group = [fallbacks.pop(idx)] # returns single-item list
|
||||
## if none, check for generic fallback
|
||||
if fallback_model_group is None:
|
||||
if stripped_model_fallback is not None:
|
||||
fallback_model_group = stripped_model_fallback
|
||||
elif generic_fallback_idx is not None:
|
||||
fallback_model_group = fallbacks[generic_fallback_idx]["*"]
|
||||
|
||||
return fallback_model_group, generic_fallback_idx
|
||||
|
||||
|
||||
async def run_async_fallback(
|
||||
*args: Tuple[Any],
|
||||
litellm_router: LitellmRouter,
|
||||
fallback_model_group: List[str],
|
||||
original_model_group: str,
|
||||
original_exception: Exception,
|
||||
max_fallbacks: int,
|
||||
fallback_depth: int,
|
||||
**kwargs,
|
||||
) -> Any:
|
||||
"""
|
||||
Loops through all the fallback model groups and calls kwargs["original_function"] with the arguments and keyword arguments provided.
|
||||
|
||||
If the call is successful, it logs the success and returns the response.
|
||||
If the call fails, it logs the failure and continues to the next fallback model group.
|
||||
If all fallback model groups fail, it raises the most recent exception.
|
||||
|
||||
Args:
|
||||
litellm_router: The litellm router instance.
|
||||
*args: Positional arguments.
|
||||
fallback_model_group: List[str] of fallback model groups. example: ["gpt-4", "gpt-3.5-turbo"]
|
||||
original_model_group: The original model group. example: "gpt-3.5-turbo"
|
||||
original_exception: The original exception.
|
||||
**kwargs: Keyword arguments.
|
||||
|
||||
Returns:
|
||||
The response from the successful fallback model group.
|
||||
Raises:
|
||||
The most recent exception if all fallback model groups fail.
|
||||
"""
|
||||
|
||||
### BASE CASE ### MAX FALLBACK DEPTH REACHED
|
||||
if fallback_depth >= max_fallbacks:
|
||||
raise original_exception
|
||||
|
||||
error_from_fallbacks = original_exception
|
||||
|
||||
for mg in fallback_model_group:
|
||||
if mg == original_model_group:
|
||||
continue
|
||||
try:
|
||||
# LOGGING
|
||||
kwargs = litellm_router.log_retry(kwargs=kwargs, e=original_exception)
|
||||
verbose_router_logger.info(f"Falling back to model_group = {mg}")
|
||||
if isinstance(mg, str):
|
||||
kwargs["model"] = mg
|
||||
elif isinstance(mg, dict):
|
||||
kwargs.update(mg)
|
||||
kwargs.setdefault("metadata", {}).update(
|
||||
{"model_group": kwargs.get("model", None)}
|
||||
) # update model_group used, if fallbacks are done
|
||||
fallback_depth = fallback_depth + 1
|
||||
kwargs["fallback_depth"] = fallback_depth
|
||||
kwargs["max_fallbacks"] = max_fallbacks
|
||||
response = await litellm_router.async_function_with_fallbacks(
|
||||
*args, **kwargs
|
||||
)
|
||||
verbose_router_logger.info("Successful fallback b/w models.")
|
||||
response = add_fallback_headers_to_response(
|
||||
response=response,
|
||||
attempted_fallbacks=fallback_depth,
|
||||
)
|
||||
# callback for successfull_fallback_event():
|
||||
await log_success_fallback_event(
|
||||
original_model_group=original_model_group,
|
||||
kwargs=kwargs,
|
||||
original_exception=original_exception,
|
||||
)
|
||||
return response
|
||||
except Exception as e:
|
||||
error_from_fallbacks = e
|
||||
await log_failure_fallback_event(
|
||||
original_model_group=original_model_group,
|
||||
kwargs=kwargs,
|
||||
original_exception=original_exception,
|
||||
)
|
||||
raise error_from_fallbacks
|
||||
|
||||
|
||||
async def log_success_fallback_event(
|
||||
original_model_group: str, kwargs: dict, original_exception: Exception
|
||||
):
|
||||
"""
|
||||
Log a successful fallback event to all registered callbacks.
|
||||
|
||||
This function iterates through all callbacks, initializing _known_custom_logger_compatible_callbacks if needed,
|
||||
and calls the log_success_fallback_event method on CustomLogger instances.
|
||||
|
||||
Args:
|
||||
original_model_group (str): The original model group before fallback.
|
||||
kwargs (dict): kwargs for the request
|
||||
|
||||
Note:
|
||||
Errors during logging are caught and reported but do not interrupt the process.
|
||||
"""
|
||||
from litellm.litellm_core_utils.litellm_logging import (
|
||||
_init_custom_logger_compatible_class,
|
||||
)
|
||||
|
||||
for _callback in litellm.callbacks:
|
||||
if isinstance(_callback, CustomLogger) or (
|
||||
_callback in litellm._known_custom_logger_compatible_callbacks
|
||||
):
|
||||
try:
|
||||
_callback_custom_logger: Optional[CustomLogger] = None
|
||||
if _callback in litellm._known_custom_logger_compatible_callbacks:
|
||||
_callback_custom_logger = _init_custom_logger_compatible_class(
|
||||
logging_integration=_callback, # type: ignore
|
||||
llm_router=None,
|
||||
internal_usage_cache=None,
|
||||
)
|
||||
elif isinstance(_callback, CustomLogger):
|
||||
_callback_custom_logger = _callback
|
||||
else:
|
||||
verbose_router_logger.exception(
|
||||
f"{_callback} logger not found / initialized properly"
|
||||
)
|
||||
continue
|
||||
|
||||
if _callback_custom_logger is None:
|
||||
verbose_router_logger.exception(
|
||||
f"{_callback} logger not found / initialized properly, callback is None"
|
||||
)
|
||||
continue
|
||||
|
||||
await _callback_custom_logger.log_success_fallback_event(
|
||||
original_model_group=original_model_group,
|
||||
kwargs=kwargs,
|
||||
original_exception=original_exception,
|
||||
)
|
||||
except Exception as e:
|
||||
verbose_router_logger.error(
|
||||
f"Error in log_success_fallback_event: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
async def log_failure_fallback_event(
|
||||
original_model_group: str, kwargs: dict, original_exception: Exception
|
||||
):
|
||||
"""
|
||||
Log a failed fallback event to all registered callbacks.
|
||||
|
||||
This function iterates through all callbacks, initializing _known_custom_logger_compatible_callbacks if needed,
|
||||
and calls the log_failure_fallback_event method on CustomLogger instances.
|
||||
|
||||
Args:
|
||||
original_model_group (str): The original model group before fallback.
|
||||
kwargs (dict): kwargs for the request
|
||||
|
||||
Note:
|
||||
Errors during logging are caught and reported but do not interrupt the process.
|
||||
"""
|
||||
from litellm.litellm_core_utils.litellm_logging import (
|
||||
_init_custom_logger_compatible_class,
|
||||
)
|
||||
|
||||
for _callback in litellm.callbacks:
|
||||
if isinstance(_callback, CustomLogger) or (
|
||||
_callback in litellm._known_custom_logger_compatible_callbacks
|
||||
):
|
||||
try:
|
||||
_callback_custom_logger: Optional[CustomLogger] = None
|
||||
if _callback in litellm._known_custom_logger_compatible_callbacks:
|
||||
_callback_custom_logger = _init_custom_logger_compatible_class(
|
||||
logging_integration=_callback, # type: ignore
|
||||
llm_router=None,
|
||||
internal_usage_cache=None,
|
||||
)
|
||||
elif isinstance(_callback, CustomLogger):
|
||||
_callback_custom_logger = _callback
|
||||
else:
|
||||
verbose_router_logger.exception(
|
||||
f"{_callback} logger not found / initialized properly"
|
||||
)
|
||||
continue
|
||||
|
||||
if _callback_custom_logger is None:
|
||||
verbose_router_logger.exception(
|
||||
f"{_callback} logger not found / initialized properly"
|
||||
)
|
||||
continue
|
||||
|
||||
await _callback_custom_logger.log_failure_fallback_event(
|
||||
original_model_group=original_model_group,
|
||||
kwargs=kwargs,
|
||||
original_exception=original_exception,
|
||||
)
|
||||
except Exception as e:
|
||||
verbose_router_logger.error(
|
||||
f"Error in log_failure_fallback_event: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
def _check_non_standard_fallback_format(fallbacks: Optional[List[Any]]) -> bool:
|
||||
"""
|
||||
Checks if the fallbacks list is a list of strings or a list of dictionaries.
|
||||
|
||||
If
|
||||
- List[str]: e.g. ["claude-3-haiku", "openai/o-1"]
|
||||
- List[Dict[<LiteLLMParamsTypedDict>, Any]]: e.g. [{"model": "claude-3-haiku", "messages": [{"role": "user", "content": "Hey, how's it going?"}]}]
|
||||
|
||||
If [{"gpt-3.5-turbo": ["claude-3-haiku"]}] then standard format.
|
||||
"""
|
||||
if fallbacks is None or not isinstance(fallbacks, list) or len(fallbacks) == 0:
|
||||
return False
|
||||
if all(isinstance(item, str) for item in fallbacks):
|
||||
return True
|
||||
elif all(isinstance(item, dict) for item in fallbacks):
|
||||
for key in LiteLLMParamsTypedDict.__annotations__.keys():
|
||||
if key in fallbacks[0].keys():
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def run_non_standard_fallback_format(
|
||||
fallbacks: Union[List[str], List[Dict[str, Any]]], model_group: str
|
||||
):
|
||||
pass
|
||||
@@ -0,0 +1,71 @@
|
||||
"""
|
||||
Get num retries for an exception.
|
||||
|
||||
- Account for retry policy by exception type.
|
||||
"""
|
||||
|
||||
from typing import Dict, Optional, Union
|
||||
|
||||
from litellm.exceptions import (
|
||||
AuthenticationError,
|
||||
BadRequestError,
|
||||
ContentPolicyViolationError,
|
||||
RateLimitError,
|
||||
Timeout,
|
||||
)
|
||||
from litellm.types.router import RetryPolicy
|
||||
|
||||
|
||||
def get_num_retries_from_retry_policy(
|
||||
exception: Exception,
|
||||
retry_policy: Optional[Union[RetryPolicy, dict]] = None,
|
||||
model_group: Optional[str] = None,
|
||||
model_group_retry_policy: Optional[Dict[str, RetryPolicy]] = None,
|
||||
):
|
||||
"""
|
||||
BadRequestErrorRetries: Optional[int] = None
|
||||
AuthenticationErrorRetries: Optional[int] = None
|
||||
TimeoutErrorRetries: Optional[int] = None
|
||||
RateLimitErrorRetries: Optional[int] = None
|
||||
ContentPolicyViolationErrorRetries: Optional[int] = None
|
||||
"""
|
||||
# if we can find the exception then in the retry policy -> return the number of retries
|
||||
|
||||
if (
|
||||
model_group_retry_policy is not None
|
||||
and model_group is not None
|
||||
and model_group in model_group_retry_policy
|
||||
):
|
||||
retry_policy = model_group_retry_policy.get(model_group, None) # type: ignore
|
||||
|
||||
if retry_policy is None:
|
||||
return None
|
||||
if isinstance(retry_policy, dict):
|
||||
retry_policy = RetryPolicy(**retry_policy)
|
||||
|
||||
if (
|
||||
isinstance(exception, BadRequestError)
|
||||
and retry_policy.BadRequestErrorRetries is not None
|
||||
):
|
||||
return retry_policy.BadRequestErrorRetries
|
||||
if (
|
||||
isinstance(exception, AuthenticationError)
|
||||
and retry_policy.AuthenticationErrorRetries is not None
|
||||
):
|
||||
return retry_policy.AuthenticationErrorRetries
|
||||
if isinstance(exception, Timeout) and retry_policy.TimeoutErrorRetries is not None:
|
||||
return retry_policy.TimeoutErrorRetries
|
||||
if (
|
||||
isinstance(exception, RateLimitError)
|
||||
and retry_policy.RateLimitErrorRetries is not None
|
||||
):
|
||||
return retry_policy.RateLimitErrorRetries
|
||||
if (
|
||||
isinstance(exception, ContentPolicyViolationError)
|
||||
and retry_policy.ContentPolicyViolationErrorRetries is not None
|
||||
):
|
||||
return retry_policy.ContentPolicyViolationErrorRetries
|
||||
|
||||
|
||||
def reset_retry_policy() -> RetryPolicy:
|
||||
return RetryPolicy()
|
||||
@@ -0,0 +1,90 @@
|
||||
from typing import TYPE_CHECKING, Any, Optional, Union
|
||||
|
||||
from litellm._logging import verbose_router_logger
|
||||
from litellm.constants import MAX_EXCEPTION_MESSAGE_LENGTH
|
||||
from litellm.router_utils.cooldown_handlers import (
|
||||
_async_get_cooldown_deployments_with_debug_info,
|
||||
)
|
||||
from litellm.types.integrations.slack_alerting import AlertType
|
||||
from litellm.types.router import RouterRateLimitError
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from opentelemetry.trace import Span as _Span
|
||||
|
||||
from litellm.router import Router as _Router
|
||||
|
||||
LitellmRouter = _Router
|
||||
Span = Union[_Span, Any]
|
||||
else:
|
||||
LitellmRouter = Any
|
||||
Span = Any
|
||||
|
||||
|
||||
async def send_llm_exception_alert(
|
||||
litellm_router_instance: LitellmRouter,
|
||||
request_kwargs: dict,
|
||||
error_traceback_str: str,
|
||||
original_exception,
|
||||
):
|
||||
"""
|
||||
Only runs if router.slack_alerting_logger is set
|
||||
Sends a Slack / MS Teams alert for the LLM API call failure. Only if router.slack_alerting_logger is set.
|
||||
|
||||
Parameters:
|
||||
litellm_router_instance (_Router): The LitellmRouter instance.
|
||||
original_exception (Any): The original exception that occurred.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
if litellm_router_instance is None:
|
||||
return
|
||||
|
||||
if not hasattr(litellm_router_instance, "slack_alerting_logger"):
|
||||
return
|
||||
|
||||
if litellm_router_instance.slack_alerting_logger is None:
|
||||
return
|
||||
|
||||
if "proxy_server_request" in request_kwargs:
|
||||
# Do not send any alert if it's a request from litellm proxy server request
|
||||
# the proxy is already instrumented to send LLM API call failures
|
||||
return
|
||||
|
||||
litellm_debug_info = getattr(original_exception, "litellm_debug_info", None)
|
||||
exception_str = str(original_exception)
|
||||
if litellm_debug_info is not None:
|
||||
exception_str += litellm_debug_info
|
||||
exception_str += f"\n\n{error_traceback_str[:MAX_EXCEPTION_MESSAGE_LENGTH]}"
|
||||
|
||||
await litellm_router_instance.slack_alerting_logger.send_alert(
|
||||
message=f"LLM API call failed: `{exception_str}`",
|
||||
level="High",
|
||||
alert_type=AlertType.llm_exceptions,
|
||||
alerting_metadata={},
|
||||
)
|
||||
|
||||
|
||||
async def async_raise_no_deployment_exception(
|
||||
litellm_router_instance: LitellmRouter, model: str, parent_otel_span: Optional[Span]
|
||||
):
|
||||
"""
|
||||
Raises a RouterRateLimitError if no deployment is found for the given model.
|
||||
"""
|
||||
verbose_router_logger.info(
|
||||
f"get_available_deployment for model: {model}, No deployment available"
|
||||
)
|
||||
model_ids = litellm_router_instance.get_model_ids(model_name=model)
|
||||
_cooldown_time = litellm_router_instance.cooldown_cache.get_min_cooldown(
|
||||
model_ids=model_ids, parent_otel_span=parent_otel_span
|
||||
)
|
||||
_cooldown_list = await _async_get_cooldown_deployments_with_debug_info(
|
||||
litellm_router_instance=litellm_router_instance,
|
||||
parent_otel_span=parent_otel_span,
|
||||
)
|
||||
return RouterRateLimitError(
|
||||
model=model,
|
||||
cooldown_time=_cooldown_time,
|
||||
enable_pre_call_checks=litellm_router_instance.enable_pre_call_checks,
|
||||
cooldown_list=_cooldown_list,
|
||||
)
|
||||
@@ -0,0 +1,264 @@
|
||||
"""
|
||||
Class to handle llm wildcard routing and regex pattern matching
|
||||
"""
|
||||
|
||||
import copy
|
||||
import re
|
||||
from re import Match
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
from litellm import get_llm_provider
|
||||
from litellm._logging import verbose_router_logger
|
||||
|
||||
|
||||
class PatternUtils:
|
||||
@staticmethod
|
||||
def calculate_pattern_specificity(pattern: str) -> Tuple[int, int]:
|
||||
"""
|
||||
Calculate pattern specificity based on length and complexity.
|
||||
|
||||
Args:
|
||||
pattern: Regex pattern to analyze
|
||||
|
||||
Returns:
|
||||
Tuple of (length, complexity) for sorting
|
||||
"""
|
||||
complexity_chars = ["*", "+", "?", "\\", "^", "$", "|", "(", ")"]
|
||||
ret_val = (
|
||||
len(pattern), # Longer patterns more specific
|
||||
sum(
|
||||
pattern.count(char) for char in complexity_chars
|
||||
), # More regex complexity
|
||||
)
|
||||
return ret_val
|
||||
|
||||
@staticmethod
|
||||
def sorted_patterns(
|
||||
patterns: Dict[str, List[Dict]]
|
||||
) -> List[Tuple[str, List[Dict]]]:
|
||||
"""
|
||||
Cached property for patterns sorted by specificity.
|
||||
|
||||
Returns:
|
||||
Sorted list of pattern-deployment tuples
|
||||
"""
|
||||
return sorted(
|
||||
patterns.items(),
|
||||
key=lambda x: PatternUtils.calculate_pattern_specificity(x[0]),
|
||||
reverse=True,
|
||||
)
|
||||
|
||||
|
||||
class PatternMatchRouter:
|
||||
"""
|
||||
Class to handle llm wildcard routing and regex pattern matching
|
||||
|
||||
doc: https://docs.litellm.ai/docs/proxy/configs#provider-specific-wildcard-routing
|
||||
|
||||
This class will store a mapping for regex pattern: List[Deployments]
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.patterns: Dict[str, List] = {}
|
||||
|
||||
def add_pattern(self, pattern: str, llm_deployment: Dict):
|
||||
"""
|
||||
Add a regex pattern and the corresponding llm deployments to the patterns
|
||||
|
||||
Args:
|
||||
pattern: str
|
||||
llm_deployment: str or List[str]
|
||||
"""
|
||||
# Convert the pattern to a regex
|
||||
regex = self._pattern_to_regex(pattern)
|
||||
if regex not in self.patterns:
|
||||
self.patterns[regex] = []
|
||||
self.patterns[regex].append(llm_deployment)
|
||||
|
||||
def _pattern_to_regex(self, pattern: str) -> str:
|
||||
"""
|
||||
Convert a wildcard pattern to a regex pattern
|
||||
|
||||
example:
|
||||
pattern: openai/*
|
||||
regex: openai/.*
|
||||
|
||||
pattern: openai/fo::*::static::*
|
||||
regex: openai/fo::.*::static::.*
|
||||
|
||||
Args:
|
||||
pattern: str
|
||||
|
||||
Returns:
|
||||
str: regex pattern
|
||||
"""
|
||||
# # Replace '*' with '.*' for regex matching
|
||||
# regex = pattern.replace("*", ".*")
|
||||
# # Escape other special characters
|
||||
# regex = re.escape(regex).replace(r"\.\*", ".*")
|
||||
# return f"^{regex}$"
|
||||
return re.escape(pattern).replace(r"\*", "(.*)")
|
||||
|
||||
def _return_pattern_matched_deployments(
|
||||
self, matched_pattern: Match, deployments: List[Dict]
|
||||
) -> List[Dict]:
|
||||
new_deployments = []
|
||||
for deployment in deployments:
|
||||
new_deployment = copy.deepcopy(deployment)
|
||||
new_deployment["litellm_params"][
|
||||
"model"
|
||||
] = PatternMatchRouter.set_deployment_model_name(
|
||||
matched_pattern=matched_pattern,
|
||||
litellm_deployment_litellm_model=deployment["litellm_params"]["model"],
|
||||
)
|
||||
new_deployments.append(new_deployment)
|
||||
|
||||
return new_deployments
|
||||
|
||||
def route(
|
||||
self, request: Optional[str], filtered_model_names: Optional[List[str]] = None
|
||||
) -> Optional[List[Dict]]:
|
||||
"""
|
||||
Route a requested model to the corresponding llm deployments based on the regex pattern
|
||||
|
||||
loop through all the patterns and find the matching pattern
|
||||
if a pattern is found, return the corresponding llm deployments
|
||||
if no pattern is found, return None
|
||||
|
||||
Args:
|
||||
request: str - the received model name from the user (can be a wildcard route). If none, No deployments will be returned.
|
||||
filtered_model_names: Optional[List[str]] - if provided, only return deployments that match the filtered_model_names
|
||||
Returns:
|
||||
Optional[List[Deployment]]: llm deployments
|
||||
"""
|
||||
try:
|
||||
if request is None:
|
||||
return None
|
||||
|
||||
sorted_patterns = PatternUtils.sorted_patterns(self.patterns)
|
||||
regex_filtered_model_names = (
|
||||
[self._pattern_to_regex(m) for m in filtered_model_names]
|
||||
if filtered_model_names is not None
|
||||
else []
|
||||
)
|
||||
for pattern, llm_deployments in sorted_patterns:
|
||||
if (
|
||||
filtered_model_names is not None
|
||||
and pattern not in regex_filtered_model_names
|
||||
):
|
||||
continue
|
||||
pattern_match = re.match(pattern, request)
|
||||
if pattern_match:
|
||||
return self._return_pattern_matched_deployments(
|
||||
matched_pattern=pattern_match, deployments=llm_deployments
|
||||
)
|
||||
except Exception as e:
|
||||
verbose_router_logger.debug(f"Error in PatternMatchRouter.route: {str(e)}")
|
||||
|
||||
return None # No matching pattern found
|
||||
|
||||
@staticmethod
|
||||
def set_deployment_model_name(
|
||||
matched_pattern: Match,
|
||||
litellm_deployment_litellm_model: str,
|
||||
) -> str:
|
||||
"""
|
||||
Set the model name for the matched pattern llm deployment
|
||||
|
||||
E.g.:
|
||||
|
||||
Case 1:
|
||||
model_name: llmengine/* (can be any regex pattern or wildcard pattern)
|
||||
litellm_params:
|
||||
model: openai/*
|
||||
|
||||
if model_name = "llmengine/foo" -> model = "openai/foo"
|
||||
|
||||
Case 2:
|
||||
model_name: llmengine/fo::*::static::*
|
||||
litellm_params:
|
||||
model: openai/fo::*::static::*
|
||||
|
||||
if model_name = "llmengine/foo::bar::static::baz" -> model = "openai/foo::bar::static::baz"
|
||||
|
||||
Case 3:
|
||||
model_name: *meta.llama3*
|
||||
litellm_params:
|
||||
model: bedrock/meta.llama3*
|
||||
|
||||
if model_name = "hello-world-meta.llama3-70b" -> model = "bedrock/meta.llama3-70b"
|
||||
"""
|
||||
|
||||
## BASE CASE: if the deployment model name does not contain a wildcard, return the deployment model name
|
||||
if "*" not in litellm_deployment_litellm_model:
|
||||
return litellm_deployment_litellm_model
|
||||
|
||||
wildcard_count = litellm_deployment_litellm_model.count("*")
|
||||
|
||||
# Extract all dynamic segments from the request
|
||||
dynamic_segments = matched_pattern.groups()
|
||||
|
||||
if len(dynamic_segments) > wildcard_count:
|
||||
return (
|
||||
matched_pattern.string
|
||||
) # default to the user input, if unable to map based on wildcards.
|
||||
# Replace the corresponding wildcards in the litellm model pattern with extracted segments
|
||||
for segment in dynamic_segments:
|
||||
litellm_deployment_litellm_model = litellm_deployment_litellm_model.replace(
|
||||
"*", segment, 1
|
||||
)
|
||||
|
||||
return litellm_deployment_litellm_model
|
||||
|
||||
def get_pattern(
|
||||
self, model: str, custom_llm_provider: Optional[str] = None
|
||||
) -> Optional[List[Dict]]:
|
||||
"""
|
||||
Check if a pattern exists for the given model and custom llm provider
|
||||
|
||||
Args:
|
||||
model: str
|
||||
custom_llm_provider: Optional[str]
|
||||
|
||||
Returns:
|
||||
bool: True if pattern exists, False otherwise
|
||||
"""
|
||||
if custom_llm_provider is None:
|
||||
try:
|
||||
(
|
||||
_,
|
||||
custom_llm_provider,
|
||||
_,
|
||||
_,
|
||||
) = get_llm_provider(model=model)
|
||||
except Exception:
|
||||
# get_llm_provider raises exception when provider is unknown
|
||||
pass
|
||||
return self.route(model) or self.route(f"{custom_llm_provider}/{model}")
|
||||
|
||||
def get_deployments_by_pattern(
|
||||
self, model: str, custom_llm_provider: Optional[str] = None
|
||||
) -> List[Dict]:
|
||||
"""
|
||||
Get the deployments by pattern
|
||||
|
||||
Args:
|
||||
model: str
|
||||
custom_llm_provider: Optional[str]
|
||||
|
||||
Returns:
|
||||
List[Dict]: llm deployments matching the pattern
|
||||
"""
|
||||
pattern_match = self.get_pattern(model, custom_llm_provider)
|
||||
if pattern_match:
|
||||
return pattern_match
|
||||
return []
|
||||
|
||||
|
||||
# Example usage:
|
||||
# router = PatternRouter()
|
||||
# router.add_pattern('openai/*', [Deployment(), Deployment()])
|
||||
# router.add_pattern('openai/fo::*::static::*', Deployment())
|
||||
# print(router.route('openai/gpt-4')) # Output: [Deployment(), Deployment()]
|
||||
# print(router.route('openai/fo::hi::static::hi')) # Output: [Deployment()]
|
||||
# print(router.route('something/else')) # Output: None
|
||||
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,99 @@
|
||||
"""
|
||||
Check if prompt caching is valid for a given deployment
|
||||
|
||||
Route to previously cached model id, if valid
|
||||
"""
|
||||
|
||||
from typing import List, Optional, cast
|
||||
|
||||
from litellm import verbose_logger
|
||||
from litellm.caching.dual_cache import DualCache
|
||||
from litellm.integrations.custom_logger import CustomLogger, Span
|
||||
from litellm.types.llms.openai import AllMessageValues
|
||||
from litellm.types.utils import CallTypes, StandardLoggingPayload
|
||||
from litellm.utils import is_prompt_caching_valid_prompt
|
||||
|
||||
from ..prompt_caching_cache import PromptCachingCache
|
||||
|
||||
|
||||
class PromptCachingDeploymentCheck(CustomLogger):
|
||||
def __init__(self, cache: DualCache):
|
||||
self.cache = cache
|
||||
|
||||
async def async_filter_deployments(
|
||||
self,
|
||||
model: str,
|
||||
healthy_deployments: List,
|
||||
messages: Optional[List[AllMessageValues]],
|
||||
request_kwargs: Optional[dict] = None,
|
||||
parent_otel_span: Optional[Span] = None,
|
||||
) -> List[dict]:
|
||||
if messages is not None and is_prompt_caching_valid_prompt(
|
||||
messages=messages,
|
||||
model=model,
|
||||
): # prompt > 1024 tokens
|
||||
prompt_cache = PromptCachingCache(
|
||||
cache=self.cache,
|
||||
)
|
||||
|
||||
model_id_dict = await prompt_cache.async_get_model_id(
|
||||
messages=cast(List[AllMessageValues], messages),
|
||||
tools=None,
|
||||
)
|
||||
if model_id_dict is not None:
|
||||
model_id = model_id_dict["model_id"]
|
||||
for deployment in healthy_deployments:
|
||||
if deployment["model_info"]["id"] == model_id:
|
||||
return [deployment]
|
||||
|
||||
return healthy_deployments
|
||||
|
||||
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
|
||||
standard_logging_object: Optional[StandardLoggingPayload] = kwargs.get(
|
||||
"standard_logging_object", None
|
||||
)
|
||||
|
||||
if standard_logging_object is None:
|
||||
return
|
||||
|
||||
call_type = standard_logging_object["call_type"]
|
||||
|
||||
if (
|
||||
call_type != CallTypes.completion.value
|
||||
and call_type != CallTypes.acompletion.value
|
||||
): # only use prompt caching for completion calls
|
||||
verbose_logger.debug(
|
||||
"litellm.router_utils.pre_call_checks.prompt_caching_deployment_check: skipping adding model id to prompt caching cache, CALL TYPE IS NOT COMPLETION"
|
||||
)
|
||||
return
|
||||
|
||||
model = standard_logging_object["model"]
|
||||
messages = standard_logging_object["messages"]
|
||||
model_id = standard_logging_object["model_id"]
|
||||
|
||||
if messages is None or not isinstance(messages, list):
|
||||
verbose_logger.debug(
|
||||
"litellm.router_utils.pre_call_checks.prompt_caching_deployment_check: skipping adding model id to prompt caching cache, MESSAGES IS NOT A LIST"
|
||||
)
|
||||
return
|
||||
if model_id is None:
|
||||
verbose_logger.debug(
|
||||
"litellm.router_utils.pre_call_checks.prompt_caching_deployment_check: skipping adding model id to prompt caching cache, MODEL ID IS NONE"
|
||||
)
|
||||
return
|
||||
|
||||
## PROMPT CACHING - cache model id, if prompt caching valid prompt + provider
|
||||
if is_prompt_caching_valid_prompt(
|
||||
model=model,
|
||||
messages=cast(List[AllMessageValues], messages),
|
||||
):
|
||||
cache = PromptCachingCache(
|
||||
cache=self.cache,
|
||||
)
|
||||
await cache.async_add_model_id(
|
||||
model_id=model_id,
|
||||
messages=messages,
|
||||
tools=None, # [TODO]: add tools once standard_logging_object supports it
|
||||
)
|
||||
|
||||
return
|
||||
@@ -0,0 +1,45 @@
|
||||
"""
|
||||
For Responses API, we need routing affinity when a user sends a previous_response_id.
|
||||
|
||||
eg. If proxy admins are load balancing between N gpt-4.1-turbo deployments, and a user sends a previous_response_id,
|
||||
we want to route to the same gpt-4.1-turbo deployment.
|
||||
|
||||
This is different from the normal behavior of the router, which does not have routing affinity for previous_response_id.
|
||||
|
||||
|
||||
If previous_response_id is provided, route to the deployment that returned the previous response
|
||||
"""
|
||||
|
||||
from typing import List, Optional
|
||||
|
||||
from litellm.integrations.custom_logger import CustomLogger, Span
|
||||
from litellm.responses.utils import ResponsesAPIRequestUtils
|
||||
from litellm.types.llms.openai import AllMessageValues
|
||||
|
||||
|
||||
class ResponsesApiDeploymentCheck(CustomLogger):
|
||||
async def async_filter_deployments(
|
||||
self,
|
||||
model: str,
|
||||
healthy_deployments: List,
|
||||
messages: Optional[List[AllMessageValues]],
|
||||
request_kwargs: Optional[dict] = None,
|
||||
parent_otel_span: Optional[Span] = None,
|
||||
) -> List[dict]:
|
||||
request_kwargs = request_kwargs or {}
|
||||
previous_response_id = request_kwargs.get("previous_response_id", None)
|
||||
if previous_response_id is None:
|
||||
return healthy_deployments
|
||||
|
||||
decoded_response = ResponsesAPIRequestUtils._decode_responses_api_response_id(
|
||||
response_id=previous_response_id,
|
||||
)
|
||||
model_id = decoded_response.get("model_id")
|
||||
if model_id is None:
|
||||
return healthy_deployments
|
||||
|
||||
for deployment in healthy_deployments:
|
||||
if deployment["model_info"]["id"] == model_id:
|
||||
return [deployment]
|
||||
|
||||
return healthy_deployments
|
||||
@@ -0,0 +1,171 @@
|
||||
"""
|
||||
Wrapper around router cache. Meant to store model id when prompt caching supported prompt is called.
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import json
|
||||
from typing import TYPE_CHECKING, Any, List, Optional, TypedDict, Union
|
||||
|
||||
from litellm.caching.caching import DualCache
|
||||
from litellm.caching.in_memory_cache import InMemoryCache
|
||||
from litellm.types.llms.openai import AllMessageValues, ChatCompletionToolParam
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from opentelemetry.trace import Span as _Span
|
||||
|
||||
from litellm.router import Router
|
||||
|
||||
litellm_router = Router
|
||||
Span = Union[_Span, Any]
|
||||
else:
|
||||
Span = Any
|
||||
litellm_router = Any
|
||||
|
||||
|
||||
class PromptCachingCacheValue(TypedDict):
|
||||
model_id: str
|
||||
|
||||
|
||||
class PromptCachingCache:
|
||||
def __init__(self, cache: DualCache):
|
||||
self.cache = cache
|
||||
self.in_memory_cache = InMemoryCache()
|
||||
|
||||
@staticmethod
|
||||
def serialize_object(obj: Any) -> Any:
|
||||
"""Helper function to serialize Pydantic objects, dictionaries, or fallback to string."""
|
||||
if hasattr(obj, "dict"):
|
||||
# If the object is a Pydantic model, use its `dict()` method
|
||||
return obj.dict()
|
||||
elif isinstance(obj, dict):
|
||||
# If the object is a dictionary, serialize it with sorted keys
|
||||
return json.dumps(
|
||||
obj, sort_keys=True, separators=(",", ":")
|
||||
) # Standardize serialization
|
||||
|
||||
elif isinstance(obj, list):
|
||||
# Serialize lists by ensuring each element is handled properly
|
||||
return [PromptCachingCache.serialize_object(item) for item in obj]
|
||||
elif isinstance(obj, (int, float, bool)):
|
||||
return obj # Keep primitive types as-is
|
||||
return str(obj)
|
||||
|
||||
@staticmethod
|
||||
def get_prompt_caching_cache_key(
|
||||
messages: Optional[List[AllMessageValues]],
|
||||
tools: Optional[List[ChatCompletionToolParam]],
|
||||
) -> Optional[str]:
|
||||
if messages is None and tools is None:
|
||||
return None
|
||||
# Use serialize_object for consistent and stable serialization
|
||||
data_to_hash = {}
|
||||
if messages is not None:
|
||||
serialized_messages = PromptCachingCache.serialize_object(messages)
|
||||
data_to_hash["messages"] = serialized_messages
|
||||
if tools is not None:
|
||||
serialized_tools = PromptCachingCache.serialize_object(tools)
|
||||
data_to_hash["tools"] = serialized_tools
|
||||
|
||||
# Combine serialized data into a single string
|
||||
data_to_hash_str = json.dumps(
|
||||
data_to_hash,
|
||||
sort_keys=True,
|
||||
separators=(",", ":"),
|
||||
)
|
||||
|
||||
# Create a hash of the serialized data for a stable cache key
|
||||
hashed_data = hashlib.sha256(data_to_hash_str.encode()).hexdigest()
|
||||
return f"deployment:{hashed_data}:prompt_caching"
|
||||
|
||||
def add_model_id(
|
||||
self,
|
||||
model_id: str,
|
||||
messages: Optional[List[AllMessageValues]],
|
||||
tools: Optional[List[ChatCompletionToolParam]],
|
||||
) -> None:
|
||||
if messages is None and tools is None:
|
||||
return None
|
||||
|
||||
cache_key = PromptCachingCache.get_prompt_caching_cache_key(messages, tools)
|
||||
self.cache.set_cache(
|
||||
cache_key, PromptCachingCacheValue(model_id=model_id), ttl=300
|
||||
)
|
||||
return None
|
||||
|
||||
async def async_add_model_id(
|
||||
self,
|
||||
model_id: str,
|
||||
messages: Optional[List[AllMessageValues]],
|
||||
tools: Optional[List[ChatCompletionToolParam]],
|
||||
) -> None:
|
||||
if messages is None and tools is None:
|
||||
return None
|
||||
|
||||
cache_key = PromptCachingCache.get_prompt_caching_cache_key(messages, tools)
|
||||
await self.cache.async_set_cache(
|
||||
cache_key,
|
||||
PromptCachingCacheValue(model_id=model_id),
|
||||
ttl=300, # store for 5 minutes
|
||||
)
|
||||
return None
|
||||
|
||||
async def async_get_model_id(
|
||||
self,
|
||||
messages: Optional[List[AllMessageValues]],
|
||||
tools: Optional[List[ChatCompletionToolParam]],
|
||||
) -> Optional[PromptCachingCacheValue]:
|
||||
"""
|
||||
if messages is not none
|
||||
- check full messages
|
||||
- check messages[:-1]
|
||||
- check messages[:-2]
|
||||
- check messages[:-3]
|
||||
|
||||
use self.cache.async_batch_get_cache(keys=potential_cache_keys])
|
||||
"""
|
||||
if messages is None and tools is None:
|
||||
return None
|
||||
|
||||
# Generate potential cache keys by slicing messages
|
||||
|
||||
potential_cache_keys = []
|
||||
|
||||
if messages is not None:
|
||||
full_cache_key = PromptCachingCache.get_prompt_caching_cache_key(
|
||||
messages, tools
|
||||
)
|
||||
potential_cache_keys.append(full_cache_key)
|
||||
|
||||
# Check progressively shorter message slices
|
||||
for i in range(1, min(4, len(messages))):
|
||||
partial_messages = messages[:-i]
|
||||
partial_cache_key = PromptCachingCache.get_prompt_caching_cache_key(
|
||||
partial_messages, tools
|
||||
)
|
||||
potential_cache_keys.append(partial_cache_key)
|
||||
|
||||
# Perform batch cache lookup
|
||||
cache_results = await self.cache.async_batch_get_cache(
|
||||
keys=potential_cache_keys
|
||||
)
|
||||
|
||||
if cache_results is None:
|
||||
return None
|
||||
|
||||
# Return the first non-None cache result
|
||||
for result in cache_results:
|
||||
if result is not None:
|
||||
return result
|
||||
|
||||
return None
|
||||
|
||||
def get_model_id(
|
||||
self,
|
||||
messages: Optional[List[AllMessageValues]],
|
||||
tools: Optional[List[ChatCompletionToolParam]],
|
||||
) -> Optional[PromptCachingCacheValue]:
|
||||
if messages is None and tools is None:
|
||||
return None
|
||||
|
||||
cache_key = PromptCachingCache.get_prompt_caching_cache_key(messages, tools)
|
||||
return self.cache.get_cache(cache_key)
|
||||
Binary file not shown.
@@ -0,0 +1,90 @@
|
||||
"""
|
||||
Helper functions to get/set num success and num failures per deployment
|
||||
|
||||
|
||||
set_deployment_failures_for_current_minute
|
||||
set_deployment_successes_for_current_minute
|
||||
|
||||
get_deployment_failures_for_current_minute
|
||||
get_deployment_successes_for_current_minute
|
||||
"""
|
||||
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.router import Router as _Router
|
||||
|
||||
LitellmRouter = _Router
|
||||
else:
|
||||
LitellmRouter = Any
|
||||
|
||||
|
||||
def increment_deployment_successes_for_current_minute(
|
||||
litellm_router_instance: LitellmRouter,
|
||||
deployment_id: str,
|
||||
) -> str:
|
||||
"""
|
||||
In-Memory: Increments the number of successes for the current minute for a deployment_id
|
||||
"""
|
||||
key = f"{deployment_id}:successes"
|
||||
litellm_router_instance.cache.increment_cache(
|
||||
local_only=True,
|
||||
key=key,
|
||||
value=1,
|
||||
ttl=60,
|
||||
)
|
||||
return key
|
||||
|
||||
|
||||
def increment_deployment_failures_for_current_minute(
|
||||
litellm_router_instance: LitellmRouter,
|
||||
deployment_id: str,
|
||||
):
|
||||
"""
|
||||
In-Memory: Increments the number of failures for the current minute for a deployment_id
|
||||
"""
|
||||
key = f"{deployment_id}:fails"
|
||||
litellm_router_instance.cache.increment_cache(
|
||||
local_only=True,
|
||||
key=key,
|
||||
value=1,
|
||||
ttl=60,
|
||||
)
|
||||
|
||||
|
||||
def get_deployment_successes_for_current_minute(
|
||||
litellm_router_instance: LitellmRouter,
|
||||
deployment_id: str,
|
||||
) -> int:
|
||||
"""
|
||||
Returns the number of successes for the current minute for a deployment_id
|
||||
|
||||
Returns 0 if no value found
|
||||
"""
|
||||
key = f"{deployment_id}:successes"
|
||||
return (
|
||||
litellm_router_instance.cache.get_cache(
|
||||
local_only=True,
|
||||
key=key,
|
||||
)
|
||||
or 0
|
||||
)
|
||||
|
||||
|
||||
def get_deployment_failures_for_current_minute(
|
||||
litellm_router_instance: LitellmRouter,
|
||||
deployment_id: str,
|
||||
) -> int:
|
||||
"""
|
||||
Returns the number of fails for the current minute for a deployment_id
|
||||
|
||||
Returns 0 if no value found
|
||||
"""
|
||||
key = f"{deployment_id}:fails"
|
||||
return (
|
||||
litellm_router_instance.cache.get_cache(
|
||||
local_only=True,
|
||||
key=key,
|
||||
)
|
||||
or 0
|
||||
)
|
||||
Reference in New Issue
Block a user