structure saas with tools

This commit is contained in:
Davidson Gomes
2025-04-25 15:30:54 -03:00
commit 1aef473937
16434 changed files with 6584257 additions and 0 deletions

View File

@@ -0,0 +1,68 @@
from typing import Any, Optional, Union
from pydantic import BaseModel
from litellm.types.utils import HiddenParams
def _add_headers_to_response(response: Any, headers: dict) -> Any:
"""
Helper function to add headers to a response's hidden params
"""
if response is None or not isinstance(response, BaseModel):
return response
hidden_params: Optional[Union[dict, HiddenParams]] = getattr(
response, "_hidden_params", {}
)
if hidden_params is None:
hidden_params = {}
elif isinstance(hidden_params, HiddenParams):
hidden_params = hidden_params.model_dump()
hidden_params.setdefault("additional_headers", {})
hidden_params["additional_headers"].update(headers)
setattr(response, "_hidden_params", hidden_params)
return response
def add_retry_headers_to_response(
response: Any,
attempted_retries: int,
max_retries: Optional[int] = None,
) -> Any:
"""
Add retry headers to the request
"""
retry_headers = {
"x-litellm-attempted-retries": attempted_retries,
}
if max_retries is not None:
retry_headers["x-litellm-max-retries"] = max_retries
return _add_headers_to_response(response, retry_headers)
def add_fallback_headers_to_response(
response: Any,
attempted_fallbacks: int,
) -> Any:
"""
Add fallback headers to the response
Args:
response: The response to add the headers to
attempted_fallbacks: The number of fallbacks attempted
Returns:
The response with the headers added
Note: It's intentional that we don't add max_fallbacks in response headers
Want to avoid bloat in the response headers for performance.
"""
fallback_headers = {
"x-litellm-attempted-fallbacks": attempted_fallbacks,
}
return _add_headers_to_response(response, fallback_headers)

View File

@@ -0,0 +1,63 @@
import io
import json
from typing import Optional, Tuple, Union
class InMemoryFile(io.BytesIO):
def __init__(self, content: bytes, name: str):
super().__init__(content)
self.name = name
def replace_model_in_jsonl(
file_content: Union[bytes, Tuple[str, bytes, str]], new_model_name: str
) -> Optional[InMemoryFile]:
try:
# Decode the bytes to a string and split into lines
# If file_content is a file-like object, read the bytes
if hasattr(file_content, "read"):
file_content_bytes = file_content.read() # type: ignore
elif isinstance(file_content, tuple):
file_content_bytes = file_content[1]
else:
file_content_bytes = file_content
# Decode the bytes to a string and split into lines
if isinstance(file_content_bytes, bytes):
file_content_str = file_content_bytes.decode("utf-8")
else:
file_content_str = file_content_bytes
lines = file_content_str.splitlines()
modified_lines = []
for line in lines:
# Parse each line as a JSON object
json_object = json.loads(line.strip())
# Replace the model name if it exists
if "body" in json_object:
json_object["body"]["model"] = new_model_name
# Convert the modified JSON object back to a string
modified_lines.append(json.dumps(json_object))
# Reassemble the modified lines and return as bytes
modified_file_content = "\n".join(modified_lines).encode("utf-8")
return InMemoryFile(modified_file_content, name="modified_file.jsonl") # type: ignore
except (json.JSONDecodeError, UnicodeDecodeError, TypeError):
return None
def _get_router_metadata_variable_name(function_name) -> str:
"""
Helper to return what the "metadata" field should be called in the request data
For all /thread or /assistant endpoints we need to call this "litellm_metadata"
For ALL other endpoints we call this "metadata
"""
ROUTER_METHODS_USING_LITELLM_METADATA = set(["batch", "generic_api_call"])
if function_name in ROUTER_METHODS_USING_LITELLM_METADATA:
return "litellm_metadata"
else:
return "metadata"

View File

@@ -0,0 +1,37 @@
import asyncio
from typing import TYPE_CHECKING, Any
from litellm.utils import calculate_max_parallel_requests
if TYPE_CHECKING:
from litellm.router import Router as _Router
LitellmRouter = _Router
else:
LitellmRouter = Any
class InitalizeCachedClient:
@staticmethod
def set_max_parallel_requests_client(
litellm_router_instance: LitellmRouter, model: dict
):
litellm_params = model.get("litellm_params", {})
model_id = model["model_info"]["id"]
rpm = litellm_params.get("rpm", None)
tpm = litellm_params.get("tpm", None)
max_parallel_requests = litellm_params.get("max_parallel_requests", None)
calculated_max_parallel_requests = calculate_max_parallel_requests(
rpm=rpm,
max_parallel_requests=max_parallel_requests,
tpm=tpm,
default_max_parallel_requests=litellm_router_instance.default_max_parallel_requests,
)
if calculated_max_parallel_requests:
semaphore = asyncio.Semaphore(calculated_max_parallel_requests)
cache_key = f"{model_id}_max_parallel_requests_client"
litellm_router_instance.cache.set_cache(
key=cache_key,
value=semaphore,
local_only=True,
)

View File

@@ -0,0 +1,37 @@
"""
Utils for handling clientside credentials
Supported clientside credentials:
- api_key
- api_base
- base_url
If given, generate a unique model_id for the deployment.
Ensures cooldowns are applied correctly.
"""
clientside_credential_keys = ["api_key", "api_base", "base_url"]
def is_clientside_credential(request_kwargs: dict) -> bool:
"""
Check if the credential is a clientside credential.
"""
return any(key in request_kwargs for key in clientside_credential_keys)
def get_dynamic_litellm_params(litellm_params: dict, request_kwargs: dict) -> dict:
"""
Generate a unique model_id for the deployment.
Returns
- litellm_params: dict
for generating a unique model_id.
"""
# update litellm_params with clientside credentials
for key in clientside_credential_keys:
if key in request_kwargs:
litellm_params[key] = request_kwargs[key]
return litellm_params

View File

@@ -0,0 +1,14 @@
import hashlib
import json
from litellm.types.router import CredentialLiteLLMParams
def get_litellm_params_sensitive_credential_hash(litellm_params: dict) -> str:
"""
Hash of the credential params, used for mapping the file id to the right model
"""
sensitive_params = CredentialLiteLLMParams(**litellm_params)
return hashlib.sha256(
json.dumps(sensitive_params.model_dump()).encode()
).hexdigest()

View File

@@ -0,0 +1,170 @@
"""
Wrapper around router cache. Meant to handle model cooldown logic
"""
import time
from typing import TYPE_CHECKING, Any, List, Optional, Tuple, TypedDict, Union
from litellm import verbose_logger
from litellm.caching.caching import DualCache
from litellm.caching.in_memory_cache import InMemoryCache
if TYPE_CHECKING:
from opentelemetry.trace import Span as _Span
Span = Union[_Span, Any]
else:
Span = Any
class CooldownCacheValue(TypedDict):
exception_received: str
status_code: str
timestamp: float
cooldown_time: float
class CooldownCache:
def __init__(self, cache: DualCache, default_cooldown_time: float):
self.cache = cache
self.default_cooldown_time = default_cooldown_time
self.in_memory_cache = InMemoryCache()
def _common_add_cooldown_logic(
self, model_id: str, original_exception, exception_status, cooldown_time: float
) -> Tuple[str, CooldownCacheValue]:
try:
current_time = time.time()
cooldown_key = f"deployment:{model_id}:cooldown"
# Store the cooldown information for the deployment separately
cooldown_data = CooldownCacheValue(
exception_received=str(original_exception),
status_code=str(exception_status),
timestamp=current_time,
cooldown_time=cooldown_time,
)
return cooldown_key, cooldown_data
except Exception as e:
verbose_logger.error(
"CooldownCache::_common_add_cooldown_logic - Exception occurred - {}".format(
str(e)
)
)
raise e
def add_deployment_to_cooldown(
self,
model_id: str,
original_exception: Exception,
exception_status: int,
cooldown_time: Optional[float],
):
try:
_cooldown_time = cooldown_time or self.default_cooldown_time
cooldown_key, cooldown_data = self._common_add_cooldown_logic(
model_id=model_id,
original_exception=original_exception,
exception_status=exception_status,
cooldown_time=_cooldown_time,
)
# Set the cache with a TTL equal to the cooldown time
self.cache.set_cache(
value=cooldown_data,
key=cooldown_key,
ttl=_cooldown_time,
)
except Exception as e:
verbose_logger.error(
"CooldownCache::add_deployment_to_cooldown - Exception occurred - {}".format(
str(e)
)
)
raise e
@staticmethod
def get_cooldown_cache_key(model_id: str) -> str:
return f"deployment:{model_id}:cooldown"
async def async_get_active_cooldowns(
self, model_ids: List[str], parent_otel_span: Optional[Span]
) -> List[Tuple[str, CooldownCacheValue]]:
# Generate the keys for the deployments
keys = [
CooldownCache.get_cooldown_cache_key(model_id) for model_id in model_ids
]
# Retrieve the values for the keys using mget
## more likely to be none if no models ratelimited. So just check redis every 1s
## each redis call adds ~100ms latency.
## check in memory cache first
results = await self.cache.async_batch_get_cache(
keys=keys, parent_otel_span=parent_otel_span
)
active_cooldowns: List[Tuple[str, CooldownCacheValue]] = []
if results is None:
return active_cooldowns
# Process the results
for model_id, result in zip(model_ids, results):
if result and isinstance(result, dict):
cooldown_cache_value = CooldownCacheValue(**result) # type: ignore
active_cooldowns.append((model_id, cooldown_cache_value))
return active_cooldowns
def get_active_cooldowns(
self, model_ids: List[str], parent_otel_span: Optional[Span]
) -> List[Tuple[str, CooldownCacheValue]]:
# Generate the keys for the deployments
keys = [f"deployment:{model_id}:cooldown" for model_id in model_ids]
# Retrieve the values for the keys using mget
results = (
self.cache.batch_get_cache(keys=keys, parent_otel_span=parent_otel_span)
or []
)
active_cooldowns = []
# Process the results
for model_id, result in zip(model_ids, results):
if result and isinstance(result, dict):
cooldown_cache_value = CooldownCacheValue(**result) # type: ignore
active_cooldowns.append((model_id, cooldown_cache_value))
return active_cooldowns
def get_min_cooldown(
self, model_ids: List[str], parent_otel_span: Optional[Span]
) -> float:
"""Return min cooldown time required for a group of model id's."""
# Generate the keys for the deployments
keys = [f"deployment:{model_id}:cooldown" for model_id in model_ids]
# Retrieve the values for the keys using mget
results = (
self.cache.batch_get_cache(keys=keys, parent_otel_span=parent_otel_span)
or []
)
min_cooldown_time: Optional[float] = None
# Process the results
for model_id, result in zip(model_ids, results):
if result and isinstance(result, dict):
cooldown_cache_value = CooldownCacheValue(**result) # type: ignore
if min_cooldown_time is None:
min_cooldown_time = cooldown_cache_value["cooldown_time"]
elif cooldown_cache_value["cooldown_time"] < min_cooldown_time:
min_cooldown_time = cooldown_cache_value["cooldown_time"]
return min_cooldown_time or self.default_cooldown_time
# Usage example:
# cooldown_cache = CooldownCache(cache=your_cache_instance, cooldown_time=your_cooldown_time)
# cooldown_cache.add_deployment_to_cooldown(deployment, original_exception, exception_status)
# active_cooldowns = cooldown_cache.get_active_cooldowns()

View File

@@ -0,0 +1,98 @@
"""
Callbacks triggered on cooling down deployments
"""
import copy
from typing import TYPE_CHECKING, Any, Optional, Union
import litellm
from litellm._logging import verbose_logger
if TYPE_CHECKING:
from litellm.router import Router as _Router
LitellmRouter = _Router
from litellm.integrations.prometheus import PrometheusLogger
else:
LitellmRouter = Any
PrometheusLogger = Any
async def router_cooldown_event_callback(
litellm_router_instance: LitellmRouter,
deployment_id: str,
exception_status: Union[str, int],
cooldown_time: float,
):
"""
Callback triggered when a deployment is put into cooldown by litellm
- Updates deployment state on Prometheus
- Increments cooldown metric for deployment on Prometheus
"""
verbose_logger.debug("In router_cooldown_event_callback - updating prometheus")
_deployment = litellm_router_instance.get_deployment(model_id=deployment_id)
if _deployment is None:
verbose_logger.warning(
f"in router_cooldown_event_callback but _deployment is None for deployment_id={deployment_id}. Doing nothing"
)
return
_litellm_params = _deployment["litellm_params"]
temp_litellm_params = copy.deepcopy(_litellm_params)
temp_litellm_params = dict(temp_litellm_params)
_model_name = _deployment.get("model_name", None) or ""
_api_base = (
litellm.get_api_base(model=_model_name, optional_params=temp_litellm_params)
or ""
)
model_info = _deployment["model_info"]
model_id = model_info.id
litellm_model_name = temp_litellm_params.get("model") or ""
llm_provider = ""
try:
_, llm_provider, _, _ = litellm.get_llm_provider(
model=litellm_model_name,
custom_llm_provider=temp_litellm_params.get("custom_llm_provider"),
)
except Exception:
pass
# get the prometheus logger from in memory loggers
prometheusLogger: Optional[
PrometheusLogger
] = _get_prometheus_logger_from_callbacks()
if prometheusLogger is not None:
prometheusLogger.set_deployment_complete_outage(
litellm_model_name=_model_name,
model_id=model_id,
api_base=_api_base,
api_provider=llm_provider,
)
prometheusLogger.increment_deployment_cooled_down(
litellm_model_name=_model_name,
model_id=model_id,
api_base=_api_base,
api_provider=llm_provider,
exception_status=str(exception_status),
)
return
def _get_prometheus_logger_from_callbacks() -> Optional[PrometheusLogger]:
"""
Checks if prometheus is a initalized callback, if yes returns it
"""
from litellm.integrations.prometheus import PrometheusLogger
for _callback in litellm._async_success_callback:
if isinstance(_callback, PrometheusLogger):
return _callback
for global_callback in litellm.callbacks:
if isinstance(global_callback, PrometheusLogger):
return global_callback
return None

View File

@@ -0,0 +1,438 @@
"""
Router cooldown handlers
- _set_cooldown_deployments: puts a deployment in the cooldown list
- get_cooldown_deployments: returns the list of deployments in the cooldown list
- async_get_cooldown_deployments: ASYNC: returns the list of deployments in the cooldown list
"""
import asyncio
from typing import TYPE_CHECKING, Any, List, Optional, Union
import litellm
from litellm._logging import verbose_router_logger
from litellm.constants import (
DEFAULT_COOLDOWN_TIME_SECONDS,
DEFAULT_FAILURE_THRESHOLD_PERCENT,
SINGLE_DEPLOYMENT_TRAFFIC_FAILURE_THRESHOLD,
)
from litellm.router_utils.cooldown_callbacks import router_cooldown_event_callback
from .router_callbacks.track_deployment_metrics import (
get_deployment_failures_for_current_minute,
get_deployment_successes_for_current_minute,
)
if TYPE_CHECKING:
from opentelemetry.trace import Span as _Span
from litellm.router import Router as _Router
LitellmRouter = _Router
Span = Union[_Span, Any]
else:
LitellmRouter = Any
Span = Any
def _is_cooldown_required(
litellm_router_instance: LitellmRouter,
model_id: str,
exception_status: Union[str, int],
exception_str: Optional[str] = None,
) -> bool:
"""
A function to determine if a cooldown is required based on the exception status.
Parameters:
model_id (str) The id of the model in the model list
exception_status (Union[str, int]): The status of the exception.
Returns:
bool: True if a cooldown is required, False otherwise.
"""
try:
ignored_strings = ["APIConnectionError"]
if (
exception_str is not None
): # don't cooldown on litellm api connection errors errors
for ignored_string in ignored_strings:
if ignored_string in exception_str:
return False
if isinstance(exception_status, str):
exception_status = int(exception_status)
if exception_status >= 400 and exception_status < 500:
if exception_status == 429:
# Cool down 429 Rate Limit Errors
return True
elif exception_status == 401:
# Cool down 401 Auth Errors
return True
elif exception_status == 408:
return True
elif exception_status == 404:
return True
else:
# Do NOT cool down all other 4XX Errors
return False
else:
# should cool down for all other errors
return True
except Exception:
# Catch all - if any exceptions default to cooling down
return True
def _should_run_cooldown_logic(
litellm_router_instance: LitellmRouter,
deployment: Optional[str],
exception_status: Union[str, int],
original_exception: Any,
) -> bool:
"""
Helper that decides if cooldown logic should be run
Returns False if cooldown logic should not be run
Does not run cooldown logic when:
- router.disable_cooldowns is True
- deployment is None
- _is_cooldown_required() returns False
- deployment is in litellm_router_instance.provider_default_deployment_ids
- exception_status is not one that should be immediately retried (e.g. 401)
"""
if (
deployment is None
or litellm_router_instance.get_model_group(id=deployment) is None
):
verbose_router_logger.debug(
"Should Not Run Cooldown Logic: deployment id is none or model group can't be found."
)
return False
if litellm_router_instance.disable_cooldowns:
verbose_router_logger.debug(
"Should Not Run Cooldown Logic: disable_cooldowns is True"
)
return False
if deployment is None:
verbose_router_logger.debug("Should Not Run Cooldown Logic: deployment is None")
return False
if not _is_cooldown_required(
litellm_router_instance=litellm_router_instance,
model_id=deployment,
exception_status=exception_status,
exception_str=str(original_exception),
):
verbose_router_logger.debug(
"Should Not Run Cooldown Logic: _is_cooldown_required returned False"
)
return False
if deployment in litellm_router_instance.provider_default_deployment_ids:
verbose_router_logger.debug(
"Should Not Run Cooldown Logic: deployment is in provider_default_deployment_ids"
)
return False
return True
def _should_cooldown_deployment(
litellm_router_instance: LitellmRouter,
deployment: str,
exception_status: Union[str, int],
original_exception: Any,
) -> bool:
"""
Helper that decides if a deployment should be put in cooldown
Returns True if the deployment should be put in cooldown
Returns False if the deployment should not be put in cooldown
Deployment is put in cooldown when:
- v2 logic (Current):
cooldown if:
- got a 429 error from LLM API
- if %fails/%(successes + fails) > ALLOWED_FAILURE_RATE_PER_MINUTE
- got 401 Auth error, 404 NotFounder - checked by litellm._should_retry()
- v1 logic (Legacy): if allowed fails or allowed fail policy set, coolsdown if num fails in this minute > allowed fails
"""
## BASE CASE - single deployment
model_group = litellm_router_instance.get_model_group(id=deployment)
is_single_deployment_model_group = False
if model_group is not None and len(model_group) == 1:
is_single_deployment_model_group = True
if (
litellm_router_instance.allowed_fails_policy is None
and _is_allowed_fails_set_on_router(
litellm_router_instance=litellm_router_instance
)
is False
):
num_successes_this_minute = get_deployment_successes_for_current_minute(
litellm_router_instance=litellm_router_instance, deployment_id=deployment
)
num_fails_this_minute = get_deployment_failures_for_current_minute(
litellm_router_instance=litellm_router_instance, deployment_id=deployment
)
total_requests_this_minute = num_successes_this_minute + num_fails_this_minute
percent_fails = 0.0
if total_requests_this_minute > 0:
percent_fails = num_fails_this_minute / (
num_successes_this_minute + num_fails_this_minute
)
verbose_router_logger.debug(
"percent fails for deployment = %s, percent fails = %s, num successes = %s, num fails = %s",
deployment,
percent_fails,
num_successes_this_minute,
num_fails_this_minute,
)
exception_status_int = cast_exception_status_to_int(exception_status)
if exception_status_int == 429 and not is_single_deployment_model_group:
return True
elif (
percent_fails == 1.0
and total_requests_this_minute
>= SINGLE_DEPLOYMENT_TRAFFIC_FAILURE_THRESHOLD
):
# Cooldown if all requests failed and we have reasonable traffic
return True
elif (
percent_fails > DEFAULT_FAILURE_THRESHOLD_PERCENT
and not is_single_deployment_model_group # by default we should avoid cooldowns on single deployment model groups
):
return True
elif (
litellm._should_retry(
status_code=cast_exception_status_to_int(exception_status)
)
is False
):
return True
return False
else:
return should_cooldown_based_on_allowed_fails_policy(
litellm_router_instance=litellm_router_instance,
deployment=deployment,
original_exception=original_exception,
)
return False
def _set_cooldown_deployments(
litellm_router_instance: LitellmRouter,
original_exception: Any,
exception_status: Union[str, int],
deployment: Optional[str] = None,
time_to_cooldown: Optional[float] = None,
) -> bool:
"""
Add a model to the list of models being cooled down for that minute, if it exceeds the allowed fails / minute
or
the exception is not one that should be immediately retried (e.g. 401)
Returns:
- True if the deployment should be put in cooldown
- False if the deployment should not be put in cooldown
"""
verbose_router_logger.debug("checks 'should_run_cooldown_logic'")
if (
_should_run_cooldown_logic(
litellm_router_instance, deployment, exception_status, original_exception
)
is False
or deployment is None
):
verbose_router_logger.debug("should_run_cooldown_logic returned False")
return False
exception_status_int = cast_exception_status_to_int(exception_status)
verbose_router_logger.debug(f"Attempting to add {deployment} to cooldown list")
cooldown_time = litellm_router_instance.cooldown_time or 1
if time_to_cooldown is not None:
cooldown_time = time_to_cooldown
if _should_cooldown_deployment(
litellm_router_instance, deployment, exception_status, original_exception
):
litellm_router_instance.cooldown_cache.add_deployment_to_cooldown(
model_id=deployment,
original_exception=original_exception,
exception_status=exception_status_int,
cooldown_time=cooldown_time,
)
# Trigger cooldown callback handler
asyncio.create_task(
router_cooldown_event_callback(
litellm_router_instance=litellm_router_instance,
deployment_id=deployment,
exception_status=exception_status,
cooldown_time=cooldown_time,
)
)
return True
return False
async def _async_get_cooldown_deployments(
litellm_router_instance: LitellmRouter,
parent_otel_span: Optional[Span],
) -> List[str]:
"""
Async implementation of '_get_cooldown_deployments'
"""
model_ids = litellm_router_instance.get_model_ids()
cooldown_models = (
await litellm_router_instance.cooldown_cache.async_get_active_cooldowns(
model_ids=model_ids,
parent_otel_span=parent_otel_span,
)
)
cached_value_deployment_ids = []
if (
cooldown_models is not None
and isinstance(cooldown_models, list)
and len(cooldown_models) > 0
and isinstance(cooldown_models[0], tuple)
):
cached_value_deployment_ids = [cv[0] for cv in cooldown_models]
verbose_router_logger.debug(f"retrieve cooldown models: {cooldown_models}")
return cached_value_deployment_ids
async def _async_get_cooldown_deployments_with_debug_info(
litellm_router_instance: LitellmRouter,
parent_otel_span: Optional[Span],
) -> List[tuple]:
"""
Async implementation of '_get_cooldown_deployments'
"""
model_ids = litellm_router_instance.get_model_ids()
cooldown_models = (
await litellm_router_instance.cooldown_cache.async_get_active_cooldowns(
model_ids=model_ids, parent_otel_span=parent_otel_span
)
)
verbose_router_logger.debug(f"retrieve cooldown models: {cooldown_models}")
return cooldown_models
def _get_cooldown_deployments(
litellm_router_instance: LitellmRouter, parent_otel_span: Optional[Span]
) -> List[str]:
"""
Get the list of models being cooled down for this minute
"""
# get the current cooldown list for that minute
# ----------------------
# Return cooldown models
# ----------------------
model_ids = litellm_router_instance.get_model_ids()
cooldown_models = litellm_router_instance.cooldown_cache.get_active_cooldowns(
model_ids=model_ids, parent_otel_span=parent_otel_span
)
cached_value_deployment_ids = []
if (
cooldown_models is not None
and isinstance(cooldown_models, list)
and len(cooldown_models) > 0
and isinstance(cooldown_models[0], tuple)
):
cached_value_deployment_ids = [cv[0] for cv in cooldown_models]
return cached_value_deployment_ids
def should_cooldown_based_on_allowed_fails_policy(
litellm_router_instance: LitellmRouter,
deployment: str,
original_exception: Any,
) -> bool:
"""
Check if fails are within the allowed limit and update the number of fails.
Returns:
- True if fails exceed the allowed limit (should cooldown)
- False if fails are within the allowed limit (should not cooldown)
"""
allowed_fails = (
litellm_router_instance.get_allowed_fails_from_policy(
exception=original_exception,
)
or litellm_router_instance.allowed_fails
)
cooldown_time = (
litellm_router_instance.cooldown_time or DEFAULT_COOLDOWN_TIME_SECONDS
)
current_fails = litellm_router_instance.failed_calls.get_cache(key=deployment) or 0
updated_fails = current_fails + 1
if updated_fails > allowed_fails:
return True
else:
litellm_router_instance.failed_calls.set_cache(
key=deployment, value=updated_fails, ttl=cooldown_time
)
return False
def _is_allowed_fails_set_on_router(
litellm_router_instance: LitellmRouter,
) -> bool:
"""
Check if Router.allowed_fails is set or is Non-default Value
Returns:
- True if Router.allowed_fails is set or is Non-default Value
- False if Router.allowed_fails is None or is Default Value
"""
if litellm_router_instance.allowed_fails is None:
return False
if litellm_router_instance.allowed_fails != litellm.allowed_fails:
return True
return False
def cast_exception_status_to_int(exception_status: Union[str, int]) -> int:
if isinstance(exception_status, str):
try:
exception_status = int(exception_status)
except Exception:
verbose_router_logger.debug(
f"Unable to cast exception status to int {exception_status}. Defaulting to status=500."
)
exception_status = 500
return exception_status

View File

@@ -0,0 +1,303 @@
from enum import Enum
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
import litellm
from litellm._logging import verbose_router_logger
from litellm.integrations.custom_logger import CustomLogger
from litellm.router_utils.add_retry_fallback_headers import (
add_fallback_headers_to_response,
)
from litellm.types.router import LiteLLMParamsTypedDict
if TYPE_CHECKING:
from litellm.router import Router as _Router
LitellmRouter = _Router
else:
LitellmRouter = Any
def _check_stripped_model_group(model_group: str, fallback_key: str) -> bool:
"""
Handles wildcard routing scenario
where fallbacks set like:
[{"gpt-3.5-turbo": ["claude-3-haiku"]}]
but model_group is like:
"openai/gpt-3.5-turbo"
Returns:
- True if the stripped model group == fallback_key
"""
for provider in litellm.provider_list:
if isinstance(provider, Enum):
_provider = provider.value
else:
_provider = provider
if model_group.startswith(f"{_provider}/"):
stripped_model_group = model_group.replace(f"{_provider}/", "")
if stripped_model_group == fallback_key:
return True
return False
def get_fallback_model_group(
fallbacks: List[Any], model_group: str
) -> Tuple[Optional[List[str]], Optional[int]]:
"""
Returns:
- fallback_model_group: List[str] of fallback model groups. example: ["gpt-4", "gpt-3.5-turbo"]
- generic_fallback_idx: int of the index of the generic fallback in the fallbacks list.
Checks:
- exact match
- stripped model group match
- generic fallback
"""
generic_fallback_idx: Optional[int] = None
stripped_model_fallback: Optional[List[str]] = None
fallback_model_group: Optional[List[str]] = None
## check for specific model group-specific fallbacks
for idx, item in enumerate(fallbacks):
if isinstance(item, dict):
if list(item.keys())[0] == model_group: # check exact match
fallback_model_group = item[model_group]
break
elif _check_stripped_model_group(
model_group=model_group, fallback_key=list(item.keys())[0]
): # check generic fallback
stripped_model_fallback = item[list(item.keys())[0]]
elif list(item.keys())[0] == "*": # check generic fallback
generic_fallback_idx = idx
elif isinstance(item, str):
fallback_model_group = [fallbacks.pop(idx)] # returns single-item list
## if none, check for generic fallback
if fallback_model_group is None:
if stripped_model_fallback is not None:
fallback_model_group = stripped_model_fallback
elif generic_fallback_idx is not None:
fallback_model_group = fallbacks[generic_fallback_idx]["*"]
return fallback_model_group, generic_fallback_idx
async def run_async_fallback(
*args: Tuple[Any],
litellm_router: LitellmRouter,
fallback_model_group: List[str],
original_model_group: str,
original_exception: Exception,
max_fallbacks: int,
fallback_depth: int,
**kwargs,
) -> Any:
"""
Loops through all the fallback model groups and calls kwargs["original_function"] with the arguments and keyword arguments provided.
If the call is successful, it logs the success and returns the response.
If the call fails, it logs the failure and continues to the next fallback model group.
If all fallback model groups fail, it raises the most recent exception.
Args:
litellm_router: The litellm router instance.
*args: Positional arguments.
fallback_model_group: List[str] of fallback model groups. example: ["gpt-4", "gpt-3.5-turbo"]
original_model_group: The original model group. example: "gpt-3.5-turbo"
original_exception: The original exception.
**kwargs: Keyword arguments.
Returns:
The response from the successful fallback model group.
Raises:
The most recent exception if all fallback model groups fail.
"""
### BASE CASE ### MAX FALLBACK DEPTH REACHED
if fallback_depth >= max_fallbacks:
raise original_exception
error_from_fallbacks = original_exception
for mg in fallback_model_group:
if mg == original_model_group:
continue
try:
# LOGGING
kwargs = litellm_router.log_retry(kwargs=kwargs, e=original_exception)
verbose_router_logger.info(f"Falling back to model_group = {mg}")
if isinstance(mg, str):
kwargs["model"] = mg
elif isinstance(mg, dict):
kwargs.update(mg)
kwargs.setdefault("metadata", {}).update(
{"model_group": kwargs.get("model", None)}
) # update model_group used, if fallbacks are done
fallback_depth = fallback_depth + 1
kwargs["fallback_depth"] = fallback_depth
kwargs["max_fallbacks"] = max_fallbacks
response = await litellm_router.async_function_with_fallbacks(
*args, **kwargs
)
verbose_router_logger.info("Successful fallback b/w models.")
response = add_fallback_headers_to_response(
response=response,
attempted_fallbacks=fallback_depth,
)
# callback for successfull_fallback_event():
await log_success_fallback_event(
original_model_group=original_model_group,
kwargs=kwargs,
original_exception=original_exception,
)
return response
except Exception as e:
error_from_fallbacks = e
await log_failure_fallback_event(
original_model_group=original_model_group,
kwargs=kwargs,
original_exception=original_exception,
)
raise error_from_fallbacks
async def log_success_fallback_event(
original_model_group: str, kwargs: dict, original_exception: Exception
):
"""
Log a successful fallback event to all registered callbacks.
This function iterates through all callbacks, initializing _known_custom_logger_compatible_callbacks if needed,
and calls the log_success_fallback_event method on CustomLogger instances.
Args:
original_model_group (str): The original model group before fallback.
kwargs (dict): kwargs for the request
Note:
Errors during logging are caught and reported but do not interrupt the process.
"""
from litellm.litellm_core_utils.litellm_logging import (
_init_custom_logger_compatible_class,
)
for _callback in litellm.callbacks:
if isinstance(_callback, CustomLogger) or (
_callback in litellm._known_custom_logger_compatible_callbacks
):
try:
_callback_custom_logger: Optional[CustomLogger] = None
if _callback in litellm._known_custom_logger_compatible_callbacks:
_callback_custom_logger = _init_custom_logger_compatible_class(
logging_integration=_callback, # type: ignore
llm_router=None,
internal_usage_cache=None,
)
elif isinstance(_callback, CustomLogger):
_callback_custom_logger = _callback
else:
verbose_router_logger.exception(
f"{_callback} logger not found / initialized properly"
)
continue
if _callback_custom_logger is None:
verbose_router_logger.exception(
f"{_callback} logger not found / initialized properly, callback is None"
)
continue
await _callback_custom_logger.log_success_fallback_event(
original_model_group=original_model_group,
kwargs=kwargs,
original_exception=original_exception,
)
except Exception as e:
verbose_router_logger.error(
f"Error in log_success_fallback_event: {str(e)}"
)
async def log_failure_fallback_event(
original_model_group: str, kwargs: dict, original_exception: Exception
):
"""
Log a failed fallback event to all registered callbacks.
This function iterates through all callbacks, initializing _known_custom_logger_compatible_callbacks if needed,
and calls the log_failure_fallback_event method on CustomLogger instances.
Args:
original_model_group (str): The original model group before fallback.
kwargs (dict): kwargs for the request
Note:
Errors during logging are caught and reported but do not interrupt the process.
"""
from litellm.litellm_core_utils.litellm_logging import (
_init_custom_logger_compatible_class,
)
for _callback in litellm.callbacks:
if isinstance(_callback, CustomLogger) or (
_callback in litellm._known_custom_logger_compatible_callbacks
):
try:
_callback_custom_logger: Optional[CustomLogger] = None
if _callback in litellm._known_custom_logger_compatible_callbacks:
_callback_custom_logger = _init_custom_logger_compatible_class(
logging_integration=_callback, # type: ignore
llm_router=None,
internal_usage_cache=None,
)
elif isinstance(_callback, CustomLogger):
_callback_custom_logger = _callback
else:
verbose_router_logger.exception(
f"{_callback} logger not found / initialized properly"
)
continue
if _callback_custom_logger is None:
verbose_router_logger.exception(
f"{_callback} logger not found / initialized properly"
)
continue
await _callback_custom_logger.log_failure_fallback_event(
original_model_group=original_model_group,
kwargs=kwargs,
original_exception=original_exception,
)
except Exception as e:
verbose_router_logger.error(
f"Error in log_failure_fallback_event: {str(e)}"
)
def _check_non_standard_fallback_format(fallbacks: Optional[List[Any]]) -> bool:
"""
Checks if the fallbacks list is a list of strings or a list of dictionaries.
If
- List[str]: e.g. ["claude-3-haiku", "openai/o-1"]
- List[Dict[<LiteLLMParamsTypedDict>, Any]]: e.g. [{"model": "claude-3-haiku", "messages": [{"role": "user", "content": "Hey, how's it going?"}]}]
If [{"gpt-3.5-turbo": ["claude-3-haiku"]}] then standard format.
"""
if fallbacks is None or not isinstance(fallbacks, list) or len(fallbacks) == 0:
return False
if all(isinstance(item, str) for item in fallbacks):
return True
elif all(isinstance(item, dict) for item in fallbacks):
for key in LiteLLMParamsTypedDict.__annotations__.keys():
if key in fallbacks[0].keys():
return True
return False
def run_non_standard_fallback_format(
fallbacks: Union[List[str], List[Dict[str, Any]]], model_group: str
):
pass

View File

@@ -0,0 +1,71 @@
"""
Get num retries for an exception.
- Account for retry policy by exception type.
"""
from typing import Dict, Optional, Union
from litellm.exceptions import (
AuthenticationError,
BadRequestError,
ContentPolicyViolationError,
RateLimitError,
Timeout,
)
from litellm.types.router import RetryPolicy
def get_num_retries_from_retry_policy(
exception: Exception,
retry_policy: Optional[Union[RetryPolicy, dict]] = None,
model_group: Optional[str] = None,
model_group_retry_policy: Optional[Dict[str, RetryPolicy]] = None,
):
"""
BadRequestErrorRetries: Optional[int] = None
AuthenticationErrorRetries: Optional[int] = None
TimeoutErrorRetries: Optional[int] = None
RateLimitErrorRetries: Optional[int] = None
ContentPolicyViolationErrorRetries: Optional[int] = None
"""
# if we can find the exception then in the retry policy -> return the number of retries
if (
model_group_retry_policy is not None
and model_group is not None
and model_group in model_group_retry_policy
):
retry_policy = model_group_retry_policy.get(model_group, None) # type: ignore
if retry_policy is None:
return None
if isinstance(retry_policy, dict):
retry_policy = RetryPolicy(**retry_policy)
if (
isinstance(exception, BadRequestError)
and retry_policy.BadRequestErrorRetries is not None
):
return retry_policy.BadRequestErrorRetries
if (
isinstance(exception, AuthenticationError)
and retry_policy.AuthenticationErrorRetries is not None
):
return retry_policy.AuthenticationErrorRetries
if isinstance(exception, Timeout) and retry_policy.TimeoutErrorRetries is not None:
return retry_policy.TimeoutErrorRetries
if (
isinstance(exception, RateLimitError)
and retry_policy.RateLimitErrorRetries is not None
):
return retry_policy.RateLimitErrorRetries
if (
isinstance(exception, ContentPolicyViolationError)
and retry_policy.ContentPolicyViolationErrorRetries is not None
):
return retry_policy.ContentPolicyViolationErrorRetries
def reset_retry_policy() -> RetryPolicy:
return RetryPolicy()

View File

@@ -0,0 +1,90 @@
from typing import TYPE_CHECKING, Any, Optional, Union
from litellm._logging import verbose_router_logger
from litellm.constants import MAX_EXCEPTION_MESSAGE_LENGTH
from litellm.router_utils.cooldown_handlers import (
_async_get_cooldown_deployments_with_debug_info,
)
from litellm.types.integrations.slack_alerting import AlertType
from litellm.types.router import RouterRateLimitError
if TYPE_CHECKING:
from opentelemetry.trace import Span as _Span
from litellm.router import Router as _Router
LitellmRouter = _Router
Span = Union[_Span, Any]
else:
LitellmRouter = Any
Span = Any
async def send_llm_exception_alert(
litellm_router_instance: LitellmRouter,
request_kwargs: dict,
error_traceback_str: str,
original_exception,
):
"""
Only runs if router.slack_alerting_logger is set
Sends a Slack / MS Teams alert for the LLM API call failure. Only if router.slack_alerting_logger is set.
Parameters:
litellm_router_instance (_Router): The LitellmRouter instance.
original_exception (Any): The original exception that occurred.
Returns:
None
"""
if litellm_router_instance is None:
return
if not hasattr(litellm_router_instance, "slack_alerting_logger"):
return
if litellm_router_instance.slack_alerting_logger is None:
return
if "proxy_server_request" in request_kwargs:
# Do not send any alert if it's a request from litellm proxy server request
# the proxy is already instrumented to send LLM API call failures
return
litellm_debug_info = getattr(original_exception, "litellm_debug_info", None)
exception_str = str(original_exception)
if litellm_debug_info is not None:
exception_str += litellm_debug_info
exception_str += f"\n\n{error_traceback_str[:MAX_EXCEPTION_MESSAGE_LENGTH]}"
await litellm_router_instance.slack_alerting_logger.send_alert(
message=f"LLM API call failed: `{exception_str}`",
level="High",
alert_type=AlertType.llm_exceptions,
alerting_metadata={},
)
async def async_raise_no_deployment_exception(
litellm_router_instance: LitellmRouter, model: str, parent_otel_span: Optional[Span]
):
"""
Raises a RouterRateLimitError if no deployment is found for the given model.
"""
verbose_router_logger.info(
f"get_available_deployment for model: {model}, No deployment available"
)
model_ids = litellm_router_instance.get_model_ids(model_name=model)
_cooldown_time = litellm_router_instance.cooldown_cache.get_min_cooldown(
model_ids=model_ids, parent_otel_span=parent_otel_span
)
_cooldown_list = await _async_get_cooldown_deployments_with_debug_info(
litellm_router_instance=litellm_router_instance,
parent_otel_span=parent_otel_span,
)
return RouterRateLimitError(
model=model,
cooldown_time=_cooldown_time,
enable_pre_call_checks=litellm_router_instance.enable_pre_call_checks,
cooldown_list=_cooldown_list,
)

View File

@@ -0,0 +1,264 @@
"""
Class to handle llm wildcard routing and regex pattern matching
"""
import copy
import re
from re import Match
from typing import Dict, List, Optional, Tuple
from litellm import get_llm_provider
from litellm._logging import verbose_router_logger
class PatternUtils:
@staticmethod
def calculate_pattern_specificity(pattern: str) -> Tuple[int, int]:
"""
Calculate pattern specificity based on length and complexity.
Args:
pattern: Regex pattern to analyze
Returns:
Tuple of (length, complexity) for sorting
"""
complexity_chars = ["*", "+", "?", "\\", "^", "$", "|", "(", ")"]
ret_val = (
len(pattern), # Longer patterns more specific
sum(
pattern.count(char) for char in complexity_chars
), # More regex complexity
)
return ret_val
@staticmethod
def sorted_patterns(
patterns: Dict[str, List[Dict]]
) -> List[Tuple[str, List[Dict]]]:
"""
Cached property for patterns sorted by specificity.
Returns:
Sorted list of pattern-deployment tuples
"""
return sorted(
patterns.items(),
key=lambda x: PatternUtils.calculate_pattern_specificity(x[0]),
reverse=True,
)
class PatternMatchRouter:
"""
Class to handle llm wildcard routing and regex pattern matching
doc: https://docs.litellm.ai/docs/proxy/configs#provider-specific-wildcard-routing
This class will store a mapping for regex pattern: List[Deployments]
"""
def __init__(self):
self.patterns: Dict[str, List] = {}
def add_pattern(self, pattern: str, llm_deployment: Dict):
"""
Add a regex pattern and the corresponding llm deployments to the patterns
Args:
pattern: str
llm_deployment: str or List[str]
"""
# Convert the pattern to a regex
regex = self._pattern_to_regex(pattern)
if regex not in self.patterns:
self.patterns[regex] = []
self.patterns[regex].append(llm_deployment)
def _pattern_to_regex(self, pattern: str) -> str:
"""
Convert a wildcard pattern to a regex pattern
example:
pattern: openai/*
regex: openai/.*
pattern: openai/fo::*::static::*
regex: openai/fo::.*::static::.*
Args:
pattern: str
Returns:
str: regex pattern
"""
# # Replace '*' with '.*' for regex matching
# regex = pattern.replace("*", ".*")
# # Escape other special characters
# regex = re.escape(regex).replace(r"\.\*", ".*")
# return f"^{regex}$"
return re.escape(pattern).replace(r"\*", "(.*)")
def _return_pattern_matched_deployments(
self, matched_pattern: Match, deployments: List[Dict]
) -> List[Dict]:
new_deployments = []
for deployment in deployments:
new_deployment = copy.deepcopy(deployment)
new_deployment["litellm_params"][
"model"
] = PatternMatchRouter.set_deployment_model_name(
matched_pattern=matched_pattern,
litellm_deployment_litellm_model=deployment["litellm_params"]["model"],
)
new_deployments.append(new_deployment)
return new_deployments
def route(
self, request: Optional[str], filtered_model_names: Optional[List[str]] = None
) -> Optional[List[Dict]]:
"""
Route a requested model to the corresponding llm deployments based on the regex pattern
loop through all the patterns and find the matching pattern
if a pattern is found, return the corresponding llm deployments
if no pattern is found, return None
Args:
request: str - the received model name from the user (can be a wildcard route). If none, No deployments will be returned.
filtered_model_names: Optional[List[str]] - if provided, only return deployments that match the filtered_model_names
Returns:
Optional[List[Deployment]]: llm deployments
"""
try:
if request is None:
return None
sorted_patterns = PatternUtils.sorted_patterns(self.patterns)
regex_filtered_model_names = (
[self._pattern_to_regex(m) for m in filtered_model_names]
if filtered_model_names is not None
else []
)
for pattern, llm_deployments in sorted_patterns:
if (
filtered_model_names is not None
and pattern not in regex_filtered_model_names
):
continue
pattern_match = re.match(pattern, request)
if pattern_match:
return self._return_pattern_matched_deployments(
matched_pattern=pattern_match, deployments=llm_deployments
)
except Exception as e:
verbose_router_logger.debug(f"Error in PatternMatchRouter.route: {str(e)}")
return None # No matching pattern found
@staticmethod
def set_deployment_model_name(
matched_pattern: Match,
litellm_deployment_litellm_model: str,
) -> str:
"""
Set the model name for the matched pattern llm deployment
E.g.:
Case 1:
model_name: llmengine/* (can be any regex pattern or wildcard pattern)
litellm_params:
model: openai/*
if model_name = "llmengine/foo" -> model = "openai/foo"
Case 2:
model_name: llmengine/fo::*::static::*
litellm_params:
model: openai/fo::*::static::*
if model_name = "llmengine/foo::bar::static::baz" -> model = "openai/foo::bar::static::baz"
Case 3:
model_name: *meta.llama3*
litellm_params:
model: bedrock/meta.llama3*
if model_name = "hello-world-meta.llama3-70b" -> model = "bedrock/meta.llama3-70b"
"""
## BASE CASE: if the deployment model name does not contain a wildcard, return the deployment model name
if "*" not in litellm_deployment_litellm_model:
return litellm_deployment_litellm_model
wildcard_count = litellm_deployment_litellm_model.count("*")
# Extract all dynamic segments from the request
dynamic_segments = matched_pattern.groups()
if len(dynamic_segments) > wildcard_count:
return (
matched_pattern.string
) # default to the user input, if unable to map based on wildcards.
# Replace the corresponding wildcards in the litellm model pattern with extracted segments
for segment in dynamic_segments:
litellm_deployment_litellm_model = litellm_deployment_litellm_model.replace(
"*", segment, 1
)
return litellm_deployment_litellm_model
def get_pattern(
self, model: str, custom_llm_provider: Optional[str] = None
) -> Optional[List[Dict]]:
"""
Check if a pattern exists for the given model and custom llm provider
Args:
model: str
custom_llm_provider: Optional[str]
Returns:
bool: True if pattern exists, False otherwise
"""
if custom_llm_provider is None:
try:
(
_,
custom_llm_provider,
_,
_,
) = get_llm_provider(model=model)
except Exception:
# get_llm_provider raises exception when provider is unknown
pass
return self.route(model) or self.route(f"{custom_llm_provider}/{model}")
def get_deployments_by_pattern(
self, model: str, custom_llm_provider: Optional[str] = None
) -> List[Dict]:
"""
Get the deployments by pattern
Args:
model: str
custom_llm_provider: Optional[str]
Returns:
List[Dict]: llm deployments matching the pattern
"""
pattern_match = self.get_pattern(model, custom_llm_provider)
if pattern_match:
return pattern_match
return []
# Example usage:
# router = PatternRouter()
# router.add_pattern('openai/*', [Deployment(), Deployment()])
# router.add_pattern('openai/fo::*::static::*', Deployment())
# print(router.route('openai/gpt-4')) # Output: [Deployment(), Deployment()]
# print(router.route('openai/fo::hi::static::hi')) # Output: [Deployment()]
# print(router.route('something/else')) # Output: None

View File

@@ -0,0 +1,99 @@
"""
Check if prompt caching is valid for a given deployment
Route to previously cached model id, if valid
"""
from typing import List, Optional, cast
from litellm import verbose_logger
from litellm.caching.dual_cache import DualCache
from litellm.integrations.custom_logger import CustomLogger, Span
from litellm.types.llms.openai import AllMessageValues
from litellm.types.utils import CallTypes, StandardLoggingPayload
from litellm.utils import is_prompt_caching_valid_prompt
from ..prompt_caching_cache import PromptCachingCache
class PromptCachingDeploymentCheck(CustomLogger):
def __init__(self, cache: DualCache):
self.cache = cache
async def async_filter_deployments(
self,
model: str,
healthy_deployments: List,
messages: Optional[List[AllMessageValues]],
request_kwargs: Optional[dict] = None,
parent_otel_span: Optional[Span] = None,
) -> List[dict]:
if messages is not None and is_prompt_caching_valid_prompt(
messages=messages,
model=model,
): # prompt > 1024 tokens
prompt_cache = PromptCachingCache(
cache=self.cache,
)
model_id_dict = await prompt_cache.async_get_model_id(
messages=cast(List[AllMessageValues], messages),
tools=None,
)
if model_id_dict is not None:
model_id = model_id_dict["model_id"]
for deployment in healthy_deployments:
if deployment["model_info"]["id"] == model_id:
return [deployment]
return healthy_deployments
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
standard_logging_object: Optional[StandardLoggingPayload] = kwargs.get(
"standard_logging_object", None
)
if standard_logging_object is None:
return
call_type = standard_logging_object["call_type"]
if (
call_type != CallTypes.completion.value
and call_type != CallTypes.acompletion.value
): # only use prompt caching for completion calls
verbose_logger.debug(
"litellm.router_utils.pre_call_checks.prompt_caching_deployment_check: skipping adding model id to prompt caching cache, CALL TYPE IS NOT COMPLETION"
)
return
model = standard_logging_object["model"]
messages = standard_logging_object["messages"]
model_id = standard_logging_object["model_id"]
if messages is None or not isinstance(messages, list):
verbose_logger.debug(
"litellm.router_utils.pre_call_checks.prompt_caching_deployment_check: skipping adding model id to prompt caching cache, MESSAGES IS NOT A LIST"
)
return
if model_id is None:
verbose_logger.debug(
"litellm.router_utils.pre_call_checks.prompt_caching_deployment_check: skipping adding model id to prompt caching cache, MODEL ID IS NONE"
)
return
## PROMPT CACHING - cache model id, if prompt caching valid prompt + provider
if is_prompt_caching_valid_prompt(
model=model,
messages=cast(List[AllMessageValues], messages),
):
cache = PromptCachingCache(
cache=self.cache,
)
await cache.async_add_model_id(
model_id=model_id,
messages=messages,
tools=None, # [TODO]: add tools once standard_logging_object supports it
)
return

View File

@@ -0,0 +1,45 @@
"""
For Responses API, we need routing affinity when a user sends a previous_response_id.
eg. If proxy admins are load balancing between N gpt-4.1-turbo deployments, and a user sends a previous_response_id,
we want to route to the same gpt-4.1-turbo deployment.
This is different from the normal behavior of the router, which does not have routing affinity for previous_response_id.
If previous_response_id is provided, route to the deployment that returned the previous response
"""
from typing import List, Optional
from litellm.integrations.custom_logger import CustomLogger, Span
from litellm.responses.utils import ResponsesAPIRequestUtils
from litellm.types.llms.openai import AllMessageValues
class ResponsesApiDeploymentCheck(CustomLogger):
async def async_filter_deployments(
self,
model: str,
healthy_deployments: List,
messages: Optional[List[AllMessageValues]],
request_kwargs: Optional[dict] = None,
parent_otel_span: Optional[Span] = None,
) -> List[dict]:
request_kwargs = request_kwargs or {}
previous_response_id = request_kwargs.get("previous_response_id", None)
if previous_response_id is None:
return healthy_deployments
decoded_response = ResponsesAPIRequestUtils._decode_responses_api_response_id(
response_id=previous_response_id,
)
model_id = decoded_response.get("model_id")
if model_id is None:
return healthy_deployments
for deployment in healthy_deployments:
if deployment["model_info"]["id"] == model_id:
return [deployment]
return healthy_deployments

View File

@@ -0,0 +1,171 @@
"""
Wrapper around router cache. Meant to store model id when prompt caching supported prompt is called.
"""
import hashlib
import json
from typing import TYPE_CHECKING, Any, List, Optional, TypedDict, Union
from litellm.caching.caching import DualCache
from litellm.caching.in_memory_cache import InMemoryCache
from litellm.types.llms.openai import AllMessageValues, ChatCompletionToolParam
if TYPE_CHECKING:
from opentelemetry.trace import Span as _Span
from litellm.router import Router
litellm_router = Router
Span = Union[_Span, Any]
else:
Span = Any
litellm_router = Any
class PromptCachingCacheValue(TypedDict):
model_id: str
class PromptCachingCache:
def __init__(self, cache: DualCache):
self.cache = cache
self.in_memory_cache = InMemoryCache()
@staticmethod
def serialize_object(obj: Any) -> Any:
"""Helper function to serialize Pydantic objects, dictionaries, or fallback to string."""
if hasattr(obj, "dict"):
# If the object is a Pydantic model, use its `dict()` method
return obj.dict()
elif isinstance(obj, dict):
# If the object is a dictionary, serialize it with sorted keys
return json.dumps(
obj, sort_keys=True, separators=(",", ":")
) # Standardize serialization
elif isinstance(obj, list):
# Serialize lists by ensuring each element is handled properly
return [PromptCachingCache.serialize_object(item) for item in obj]
elif isinstance(obj, (int, float, bool)):
return obj # Keep primitive types as-is
return str(obj)
@staticmethod
def get_prompt_caching_cache_key(
messages: Optional[List[AllMessageValues]],
tools: Optional[List[ChatCompletionToolParam]],
) -> Optional[str]:
if messages is None and tools is None:
return None
# Use serialize_object for consistent and stable serialization
data_to_hash = {}
if messages is not None:
serialized_messages = PromptCachingCache.serialize_object(messages)
data_to_hash["messages"] = serialized_messages
if tools is not None:
serialized_tools = PromptCachingCache.serialize_object(tools)
data_to_hash["tools"] = serialized_tools
# Combine serialized data into a single string
data_to_hash_str = json.dumps(
data_to_hash,
sort_keys=True,
separators=(",", ":"),
)
# Create a hash of the serialized data for a stable cache key
hashed_data = hashlib.sha256(data_to_hash_str.encode()).hexdigest()
return f"deployment:{hashed_data}:prompt_caching"
def add_model_id(
self,
model_id: str,
messages: Optional[List[AllMessageValues]],
tools: Optional[List[ChatCompletionToolParam]],
) -> None:
if messages is None and tools is None:
return None
cache_key = PromptCachingCache.get_prompt_caching_cache_key(messages, tools)
self.cache.set_cache(
cache_key, PromptCachingCacheValue(model_id=model_id), ttl=300
)
return None
async def async_add_model_id(
self,
model_id: str,
messages: Optional[List[AllMessageValues]],
tools: Optional[List[ChatCompletionToolParam]],
) -> None:
if messages is None and tools is None:
return None
cache_key = PromptCachingCache.get_prompt_caching_cache_key(messages, tools)
await self.cache.async_set_cache(
cache_key,
PromptCachingCacheValue(model_id=model_id),
ttl=300, # store for 5 minutes
)
return None
async def async_get_model_id(
self,
messages: Optional[List[AllMessageValues]],
tools: Optional[List[ChatCompletionToolParam]],
) -> Optional[PromptCachingCacheValue]:
"""
if messages is not none
- check full messages
- check messages[:-1]
- check messages[:-2]
- check messages[:-3]
use self.cache.async_batch_get_cache(keys=potential_cache_keys])
"""
if messages is None and tools is None:
return None
# Generate potential cache keys by slicing messages
potential_cache_keys = []
if messages is not None:
full_cache_key = PromptCachingCache.get_prompt_caching_cache_key(
messages, tools
)
potential_cache_keys.append(full_cache_key)
# Check progressively shorter message slices
for i in range(1, min(4, len(messages))):
partial_messages = messages[:-i]
partial_cache_key = PromptCachingCache.get_prompt_caching_cache_key(
partial_messages, tools
)
potential_cache_keys.append(partial_cache_key)
# Perform batch cache lookup
cache_results = await self.cache.async_batch_get_cache(
keys=potential_cache_keys
)
if cache_results is None:
return None
# Return the first non-None cache result
for result in cache_results:
if result is not None:
return result
return None
def get_model_id(
self,
messages: Optional[List[AllMessageValues]],
tools: Optional[List[ChatCompletionToolParam]],
) -> Optional[PromptCachingCacheValue]:
if messages is None and tools is None:
return None
cache_key = PromptCachingCache.get_prompt_caching_cache_key(messages, tools)
return self.cache.get_cache(cache_key)

View File

@@ -0,0 +1,90 @@
"""
Helper functions to get/set num success and num failures per deployment
set_deployment_failures_for_current_minute
set_deployment_successes_for_current_minute
get_deployment_failures_for_current_minute
get_deployment_successes_for_current_minute
"""
from typing import TYPE_CHECKING, Any
if TYPE_CHECKING:
from litellm.router import Router as _Router
LitellmRouter = _Router
else:
LitellmRouter = Any
def increment_deployment_successes_for_current_minute(
litellm_router_instance: LitellmRouter,
deployment_id: str,
) -> str:
"""
In-Memory: Increments the number of successes for the current minute for a deployment_id
"""
key = f"{deployment_id}:successes"
litellm_router_instance.cache.increment_cache(
local_only=True,
key=key,
value=1,
ttl=60,
)
return key
def increment_deployment_failures_for_current_minute(
litellm_router_instance: LitellmRouter,
deployment_id: str,
):
"""
In-Memory: Increments the number of failures for the current minute for a deployment_id
"""
key = f"{deployment_id}:fails"
litellm_router_instance.cache.increment_cache(
local_only=True,
key=key,
value=1,
ttl=60,
)
def get_deployment_successes_for_current_minute(
litellm_router_instance: LitellmRouter,
deployment_id: str,
) -> int:
"""
Returns the number of successes for the current minute for a deployment_id
Returns 0 if no value found
"""
key = f"{deployment_id}:successes"
return (
litellm_router_instance.cache.get_cache(
local_only=True,
key=key,
)
or 0
)
def get_deployment_failures_for_current_minute(
litellm_router_instance: LitellmRouter,
deployment_id: str,
) -> int:
"""
Returns the number of fails for the current minute for a deployment_id
Returns 0 if no value found
"""
key = f"{deployment_id}:fails"
return (
litellm_router_instance.cache.get_cache(
local_only=True,
key=key,
)
or 0
)