structure saas with tools

This commit is contained in:
Davidson Gomes
2025-04-25 15:30:54 -03:00
commit 1aef473937
16434 changed files with 6584257 additions and 0 deletions

View File

@@ -0,0 +1,190 @@
"""
Base class across routing strategies to abstract commmon functions like batch incrementing redis
"""
import asyncio
import threading
from abc import ABC
from typing import List, Optional, Set, Union
from litellm._logging import verbose_router_logger
from litellm.caching.caching import DualCache
from litellm.caching.redis_cache import RedisPipelineIncrementOperation
from litellm.constants import DEFAULT_REDIS_SYNC_INTERVAL
class BaseRoutingStrategy(ABC):
def __init__(
self,
dual_cache: DualCache,
should_batch_redis_writes: bool,
default_sync_interval: Optional[Union[int, float]],
):
self.dual_cache = dual_cache
self.redis_increment_operation_queue: List[RedisPipelineIncrementOperation] = []
if should_batch_redis_writes:
try:
# Try to get existing event loop
loop = asyncio.get_event_loop()
if loop.is_running():
# If loop exists and is running, create task in existing loop
loop.create_task(
self.periodic_sync_in_memory_spend_with_redis(
default_sync_interval=default_sync_interval
)
)
else:
self._create_sync_thread(default_sync_interval)
except RuntimeError: # No event loop in current thread
self._create_sync_thread(default_sync_interval)
self.in_memory_keys_to_update: set[
str
] = set() # Set with max size of 1000 keys
async def _increment_value_in_current_window(
self, key: str, value: Union[int, float], ttl: int
):
"""
Increment spend within existing budget window
Runs once the budget start time exists in Redis Cache (on the 2nd and subsequent requests to the same provider)
- Increments the spend in memory cache (so spend instantly updated in memory)
- Queues the increment operation to Redis Pipeline (using batched pipeline to optimize performance. Using Redis for multi instance environment of LiteLLM)
"""
result = await self.dual_cache.in_memory_cache.async_increment(
key=key,
value=value,
ttl=ttl,
)
increment_op = RedisPipelineIncrementOperation(
key=key,
increment_value=value,
ttl=ttl,
)
self.redis_increment_operation_queue.append(increment_op)
self.add_to_in_memory_keys_to_update(key=key)
return result
async def periodic_sync_in_memory_spend_with_redis(
self, default_sync_interval: Optional[Union[int, float]]
):
"""
Handler that triggers sync_in_memory_spend_with_redis every DEFAULT_REDIS_SYNC_INTERVAL seconds
Required for multi-instance environment usage of provider budgets
"""
default_sync_interval = default_sync_interval or DEFAULT_REDIS_SYNC_INTERVAL
while True:
try:
await self._sync_in_memory_spend_with_redis()
await asyncio.sleep(
default_sync_interval
) # Wait for DEFAULT_REDIS_SYNC_INTERVAL seconds before next sync
except Exception as e:
verbose_router_logger.error(f"Error in periodic sync task: {str(e)}")
await asyncio.sleep(
default_sync_interval
) # Still wait DEFAULT_REDIS_SYNC_INTERVAL seconds on error before retrying
async def _push_in_memory_increments_to_redis(self):
"""
How this works:
- async_log_success_event collects all provider spend increments in `redis_increment_operation_queue`
- This function pushes all increments to Redis in a batched pipeline to optimize performance
Only runs if Redis is initialized
"""
try:
if not self.dual_cache.redis_cache:
return # Redis is not initialized
verbose_router_logger.debug(
"Pushing Redis Increment Pipeline for queue: %s",
self.redis_increment_operation_queue,
)
if len(self.redis_increment_operation_queue) > 0:
asyncio.create_task(
self.dual_cache.redis_cache.async_increment_pipeline(
increment_list=self.redis_increment_operation_queue,
)
)
self.redis_increment_operation_queue = []
except Exception as e:
verbose_router_logger.error(
f"Error syncing in-memory cache with Redis: {str(e)}"
)
self.redis_increment_operation_queue = []
def add_to_in_memory_keys_to_update(self, key: str):
self.in_memory_keys_to_update.add(key)
def get_in_memory_keys_to_update(self) -> Set[str]:
return self.in_memory_keys_to_update
def reset_in_memory_keys_to_update(self):
self.in_memory_keys_to_update = set()
async def _sync_in_memory_spend_with_redis(self):
"""
Ensures in-memory cache is updated with latest Redis values for all provider spends.
Why Do we need this?
- Optimization to hit sub 100ms latency. Performance was impacted when redis was used for read/write per request
- Use provider budgets in multi-instance environment, we use Redis to sync spend across all instances
What this does:
1. Push all provider spend increments to Redis
2. Fetch all current provider spend from Redis to update in-memory cache
"""
try:
# No need to sync if Redis cache is not initialized
if self.dual_cache.redis_cache is None:
return
# 1. Push all provider spend increments to Redis
await self._push_in_memory_increments_to_redis()
# 2. Fetch all current provider spend from Redis to update in-memory cache
cache_keys = self.get_in_memory_keys_to_update()
cache_keys_list = list(cache_keys)
# Batch fetch current spend values from Redis
redis_values = await self.dual_cache.redis_cache.async_batch_get_cache(
key_list=cache_keys_list
)
# Update in-memory cache with Redis values
if isinstance(redis_values, dict): # Check if redis_values is a dictionary
for key, value in redis_values.items():
if value is not None:
await self.dual_cache.in_memory_cache.async_set_cache(
key=key, value=float(value)
)
verbose_router_logger.debug(
f"Updated in-memory cache for {key}: {value}"
)
self.reset_in_memory_keys_to_update()
except Exception as e:
verbose_router_logger.exception(
f"Error syncing in-memory cache with Redis: {str(e)}"
)
def _create_sync_thread(self, default_sync_interval):
"""Helper method to create a new thread for periodic sync"""
thread = threading.Thread(
target=asyncio.run,
args=(
self.periodic_sync_in_memory_spend_with_redis(
default_sync_interval=default_sync_interval
),
),
daemon=True,
)
thread.start()

View File

@@ -0,0 +1,821 @@
"""
Provider budget limiting
Use this if you want to set $ budget limits for each provider.
Note: This is a filter, like tag-routing. Meaning it will accept healthy deployments and then filter out deployments that have exceeded their budget limit.
This means you can use this with weighted-pick, lowest-latency, simple-shuffle, routing etc
Example:
```
openai:
budget_limit: 0.000000000001
time_period: 1d
anthropic:
budget_limit: 100
time_period: 7d
```
"""
import asyncio
from datetime import datetime, timedelta, timezone
from typing import Any, Dict, List, Optional, Tuple, Union
import litellm
from litellm._logging import verbose_router_logger
from litellm.caching.caching import DualCache
from litellm.caching.redis_cache import RedisPipelineIncrementOperation
from litellm.integrations.custom_logger import CustomLogger, Span
from litellm.litellm_core_utils.duration_parser import duration_in_seconds
from litellm.router_strategy.tag_based_routing import _get_tags_from_request_kwargs
from litellm.router_utils.cooldown_callbacks import (
_get_prometheus_logger_from_callbacks,
)
from litellm.types.llms.openai import AllMessageValues
from litellm.types.router import DeploymentTypedDict, LiteLLM_Params, RouterErrors
from litellm.types.utils import BudgetConfig
from litellm.types.utils import BudgetConfig as GenericBudgetInfo
from litellm.types.utils import GenericBudgetConfigType, StandardLoggingPayload
DEFAULT_REDIS_SYNC_INTERVAL = 1
class RouterBudgetLimiting(CustomLogger):
def __init__(
self,
dual_cache: DualCache,
provider_budget_config: Optional[dict],
model_list: Optional[
Union[List[DeploymentTypedDict], List[Dict[str, Any]]]
] = None,
):
self.dual_cache = dual_cache
self.redis_increment_operation_queue: List[RedisPipelineIncrementOperation] = []
asyncio.create_task(self.periodic_sync_in_memory_spend_with_redis())
self.provider_budget_config: Optional[
GenericBudgetConfigType
] = provider_budget_config
self.deployment_budget_config: Optional[GenericBudgetConfigType] = None
self.tag_budget_config: Optional[GenericBudgetConfigType] = None
self._init_provider_budgets()
self._init_deployment_budgets(model_list=model_list)
self._init_tag_budgets()
# Add self to litellm callbacks if it's a list
if isinstance(litellm.callbacks, list):
litellm.logging_callback_manager.add_litellm_callback(self) # type: ignore
async def async_filter_deployments(
self,
model: str,
healthy_deployments: List,
messages: Optional[List[AllMessageValues]],
request_kwargs: Optional[dict] = None,
parent_otel_span: Optional[Span] = None, # type: ignore
) -> List[dict]:
"""
Filter out deployments that have exceeded their provider budget limit.
Example:
if deployment = openai/gpt-3.5-turbo
and openai spend > openai budget limit
then skip this deployment
"""
# If a single deployment is passed, convert it to a list
if isinstance(healthy_deployments, dict):
healthy_deployments = [healthy_deployments]
# Don't do any filtering if there are no healthy deployments
if len(healthy_deployments) == 0:
return healthy_deployments
potential_deployments: List[Dict] = []
(
cache_keys,
provider_configs,
deployment_configs,
) = await self._async_get_cache_keys_for_router_budget_limiting(
healthy_deployments=healthy_deployments,
request_kwargs=request_kwargs,
)
# Single cache read for all spend values
if len(cache_keys) > 0:
_current_spends = await self.dual_cache.async_batch_get_cache(
keys=cache_keys,
parent_otel_span=parent_otel_span,
)
current_spends: List = _current_spends or [0.0] * len(cache_keys)
# Map spends to their respective keys
spend_map: Dict[str, float] = {}
for idx, key in enumerate(cache_keys):
spend_map[key] = float(current_spends[idx] or 0.0)
(
potential_deployments,
deployment_above_budget_info,
) = self._filter_out_deployments_above_budget(
healthy_deployments=healthy_deployments,
provider_configs=provider_configs,
deployment_configs=deployment_configs,
spend_map=spend_map,
potential_deployments=potential_deployments,
request_tags=_get_tags_from_request_kwargs(
request_kwargs=request_kwargs
),
)
if len(potential_deployments) == 0:
raise ValueError(
f"{RouterErrors.no_deployments_with_provider_budget_routing.value}: {deployment_above_budget_info}"
)
return potential_deployments
else:
return healthy_deployments
def _filter_out_deployments_above_budget(
self,
potential_deployments: List[Dict[str, Any]],
healthy_deployments: List[Dict[str, Any]],
provider_configs: Dict[str, GenericBudgetInfo],
deployment_configs: Dict[str, GenericBudgetInfo],
spend_map: Dict[str, float],
request_tags: List[str],
) -> Tuple[List[Dict[str, Any]], str]:
"""
Filter out deployments that have exceeded their budget limit.
Follow budget checks are run here:
- Provider budget
- Deployment budget
- Request tags budget
Returns:
Tuple[List[Dict[str, Any]], str]:
- A tuple containing the filtered deployments
- A string containing debug information about deployments that exceeded their budget limit.
"""
# Filter deployments based on both provider and deployment budgets
deployment_above_budget_info: str = ""
for deployment in healthy_deployments:
is_within_budget = True
# Check provider budget
if self.provider_budget_config:
provider = self._get_llm_provider_for_deployment(deployment)
if provider in provider_configs:
config = provider_configs[provider]
if config.max_budget is None:
continue
current_spend = spend_map.get(
f"provider_spend:{provider}:{config.budget_duration}", 0.0
)
self._track_provider_remaining_budget_prometheus(
provider=provider,
spend=current_spend,
budget_limit=config.max_budget,
)
if config.max_budget and current_spend >= config.max_budget:
debug_msg = f"Exceeded budget for provider {provider}: {current_spend} >= {config.max_budget}"
deployment_above_budget_info += f"{debug_msg}\n"
is_within_budget = False
continue
# Check deployment budget
if self.deployment_budget_config and is_within_budget:
_model_name = deployment.get("model_name")
_litellm_params = deployment.get("litellm_params") or {}
_litellm_model_name = _litellm_params.get("model")
model_id = deployment.get("model_info", {}).get("id")
if model_id in deployment_configs:
config = deployment_configs[model_id]
current_spend = spend_map.get(
f"deployment_spend:{model_id}:{config.budget_duration}", 0.0
)
if config.max_budget and current_spend >= config.max_budget:
debug_msg = f"Exceeded budget for deployment model_name: {_model_name}, litellm_params.model: {_litellm_model_name}, model_id: {model_id}: {current_spend} >= {config.budget_duration}"
verbose_router_logger.debug(debug_msg)
deployment_above_budget_info += f"{debug_msg}\n"
is_within_budget = False
continue
# Check tag budget
if self.tag_budget_config and is_within_budget:
for _tag in request_tags:
_tag_budget_config = self._get_budget_config_for_tag(_tag)
if _tag_budget_config:
_tag_spend = spend_map.get(
f"tag_spend:{_tag}:{_tag_budget_config.budget_duration}",
0.0,
)
if (
_tag_budget_config.max_budget
and _tag_spend >= _tag_budget_config.max_budget
):
debug_msg = f"Exceeded budget for tag='{_tag}', tag_spend={_tag_spend}, tag_budget_limit={_tag_budget_config.max_budget}"
verbose_router_logger.debug(debug_msg)
deployment_above_budget_info += f"{debug_msg}\n"
is_within_budget = False
continue
if is_within_budget:
potential_deployments.append(deployment)
return potential_deployments, deployment_above_budget_info
async def _async_get_cache_keys_for_router_budget_limiting(
self,
healthy_deployments: List[Dict[str, Any]],
request_kwargs: Optional[Dict] = None,
) -> Tuple[List[str], Dict[str, GenericBudgetInfo], Dict[str, GenericBudgetInfo]]:
"""
Returns list of cache keys to fetch from router cache for budget limiting and provider and deployment configs
Returns:
Tuple[List[str], Dict[str, GenericBudgetInfo], Dict[str, GenericBudgetInfo]]:
- List of cache keys to fetch from router cache for budget limiting
- Dict of provider budget configs `provider_configs`
- Dict of deployment budget configs `deployment_configs`
"""
cache_keys: List[str] = []
provider_configs: Dict[str, GenericBudgetInfo] = {}
deployment_configs: Dict[str, GenericBudgetInfo] = {}
for deployment in healthy_deployments:
# Check provider budgets
if self.provider_budget_config:
provider = self._get_llm_provider_for_deployment(deployment)
if provider is not None:
budget_config = self._get_budget_config_for_provider(provider)
if (
budget_config is not None
and budget_config.budget_duration is not None
):
provider_configs[provider] = budget_config
cache_keys.append(
f"provider_spend:{provider}:{budget_config.budget_duration}"
)
# Check deployment budgets
if self.deployment_budget_config:
model_id = deployment.get("model_info", {}).get("id")
if model_id is not None:
budget_config = self._get_budget_config_for_deployment(model_id)
if budget_config is not None:
deployment_configs[model_id] = budget_config
cache_keys.append(
f"deployment_spend:{model_id}:{budget_config.budget_duration}"
)
# Check tag budgets
if self.tag_budget_config:
request_tags = _get_tags_from_request_kwargs(
request_kwargs=request_kwargs
)
for _tag in request_tags:
_tag_budget_config = self._get_budget_config_for_tag(_tag)
if _tag_budget_config:
cache_keys.append(
f"tag_spend:{_tag}:{_tag_budget_config.budget_duration}"
)
return cache_keys, provider_configs, deployment_configs
async def _get_or_set_budget_start_time(
self, start_time_key: str, current_time: float, ttl_seconds: int
) -> float:
"""
Checks if the key = `provider_budget_start_time:{provider}` exists in cache.
If it does, return the value.
If it does not, set the key to `current_time` and return the value.
"""
budget_start = await self.dual_cache.async_get_cache(start_time_key)
if budget_start is None:
await self.dual_cache.async_set_cache(
key=start_time_key, value=current_time, ttl=ttl_seconds
)
return current_time
return float(budget_start)
async def _handle_new_budget_window(
self,
spend_key: str,
start_time_key: str,
current_time: float,
response_cost: float,
ttl_seconds: int,
) -> float:
"""
Handle start of new budget window by resetting spend and start time
Enters this when:
- The budget does not exist in cache, so we need to set it
- The budget window has expired, so we need to reset everything
Does 2 things:
- stores key: `provider_spend:{provider}:1d`, value: response_cost
- stores key: `provider_budget_start_time:{provider}`, value: current_time.
This stores the start time of the new budget window
"""
await self.dual_cache.async_set_cache(
key=spend_key, value=response_cost, ttl=ttl_seconds
)
await self.dual_cache.async_set_cache(
key=start_time_key, value=current_time, ttl=ttl_seconds
)
return current_time
async def _increment_spend_in_current_window(
self, spend_key: str, response_cost: float, ttl: int
):
"""
Increment spend within existing budget window
Runs once the budget start time exists in Redis Cache (on the 2nd and subsequent requests to the same provider)
- Increments the spend in memory cache (so spend instantly updated in memory)
- Queues the increment operation to Redis Pipeline (using batched pipeline to optimize performance. Using Redis for multi instance environment of LiteLLM)
"""
await self.dual_cache.in_memory_cache.async_increment(
key=spend_key,
value=response_cost,
ttl=ttl,
)
increment_op = RedisPipelineIncrementOperation(
key=spend_key,
increment_value=response_cost,
ttl=ttl,
)
self.redis_increment_operation_queue.append(increment_op)
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
"""Original method now uses helper functions"""
verbose_router_logger.debug("in RouterBudgetLimiting.async_log_success_event")
standard_logging_payload: Optional[StandardLoggingPayload] = kwargs.get(
"standard_logging_object", None
)
if standard_logging_payload is None:
raise ValueError("standard_logging_payload is required")
response_cost: float = standard_logging_payload.get("response_cost", 0)
model_id: str = str(standard_logging_payload.get("model_id", ""))
custom_llm_provider: str = kwargs.get("litellm_params", {}).get(
"custom_llm_provider", None
)
if custom_llm_provider is None:
raise ValueError("custom_llm_provider is required")
budget_config = self._get_budget_config_for_provider(custom_llm_provider)
if budget_config:
# increment spend for provider
spend_key = (
f"provider_spend:{custom_llm_provider}:{budget_config.budget_duration}"
)
start_time_key = f"provider_budget_start_time:{custom_llm_provider}"
await self._increment_spend_for_key(
budget_config=budget_config,
spend_key=spend_key,
start_time_key=start_time_key,
response_cost=response_cost,
)
deployment_budget_config = self._get_budget_config_for_deployment(model_id)
if deployment_budget_config:
# increment spend for specific deployment id
deployment_spend_key = f"deployment_spend:{model_id}:{deployment_budget_config.budget_duration}"
deployment_start_time_key = f"deployment_budget_start_time:{model_id}"
await self._increment_spend_for_key(
budget_config=deployment_budget_config,
spend_key=deployment_spend_key,
start_time_key=deployment_start_time_key,
response_cost=response_cost,
)
request_tags = _get_tags_from_request_kwargs(kwargs)
if len(request_tags) > 0:
for _tag in request_tags:
_tag_budget_config = self._get_budget_config_for_tag(_tag)
if _tag_budget_config:
_tag_spend_key = (
f"tag_spend:{_tag}:{_tag_budget_config.budget_duration}"
)
_tag_start_time_key = f"tag_budget_start_time:{_tag}"
await self._increment_spend_for_key(
budget_config=_tag_budget_config,
spend_key=_tag_spend_key,
start_time_key=_tag_start_time_key,
response_cost=response_cost,
)
async def _increment_spend_for_key(
self,
budget_config: GenericBudgetInfo,
spend_key: str,
start_time_key: str,
response_cost: float,
):
if budget_config.budget_duration is None:
return
current_time = datetime.now(timezone.utc).timestamp()
ttl_seconds = duration_in_seconds(budget_config.budget_duration)
budget_start = await self._get_or_set_budget_start_time(
start_time_key=start_time_key,
current_time=current_time,
ttl_seconds=ttl_seconds,
)
if budget_start is None:
# First spend for this provider
budget_start = await self._handle_new_budget_window(
spend_key=spend_key,
start_time_key=start_time_key,
current_time=current_time,
response_cost=response_cost,
ttl_seconds=ttl_seconds,
)
elif (current_time - budget_start) > ttl_seconds:
# Budget window expired - reset everything
verbose_router_logger.debug("Budget window expired - resetting everything")
budget_start = await self._handle_new_budget_window(
spend_key=spend_key,
start_time_key=start_time_key,
current_time=current_time,
response_cost=response_cost,
ttl_seconds=ttl_seconds,
)
else:
# Within existing window - increment spend
remaining_time = ttl_seconds - (current_time - budget_start)
ttl_for_increment = int(remaining_time)
await self._increment_spend_in_current_window(
spend_key=spend_key, response_cost=response_cost, ttl=ttl_for_increment
)
verbose_router_logger.debug(
f"Incremented spend for {spend_key} by {response_cost}"
)
async def periodic_sync_in_memory_spend_with_redis(self):
"""
Handler that triggers sync_in_memory_spend_with_redis every DEFAULT_REDIS_SYNC_INTERVAL seconds
Required for multi-instance environment usage of provider budgets
"""
while True:
try:
await self._sync_in_memory_spend_with_redis()
await asyncio.sleep(
DEFAULT_REDIS_SYNC_INTERVAL
) # Wait for DEFAULT_REDIS_SYNC_INTERVAL seconds before next sync
except Exception as e:
verbose_router_logger.error(f"Error in periodic sync task: {str(e)}")
await asyncio.sleep(
DEFAULT_REDIS_SYNC_INTERVAL
) # Still wait DEFAULT_REDIS_SYNC_INTERVAL seconds on error before retrying
async def _push_in_memory_increments_to_redis(self):
"""
How this works:
- async_log_success_event collects all provider spend increments in `redis_increment_operation_queue`
- This function pushes all increments to Redis in a batched pipeline to optimize performance
Only runs if Redis is initialized
"""
try:
if not self.dual_cache.redis_cache:
return # Redis is not initialized
verbose_router_logger.debug(
"Pushing Redis Increment Pipeline for queue: %s",
self.redis_increment_operation_queue,
)
if len(self.redis_increment_operation_queue) > 0:
asyncio.create_task(
self.dual_cache.redis_cache.async_increment_pipeline(
increment_list=self.redis_increment_operation_queue,
)
)
self.redis_increment_operation_queue = []
except Exception as e:
verbose_router_logger.error(
f"Error syncing in-memory cache with Redis: {str(e)}"
)
async def _sync_in_memory_spend_with_redis(self):
"""
Ensures in-memory cache is updated with latest Redis values for all provider spends.
Why Do we need this?
- Optimization to hit sub 100ms latency. Performance was impacted when redis was used for read/write per request
- Use provider budgets in multi-instance environment, we use Redis to sync spend across all instances
What this does:
1. Push all provider spend increments to Redis
2. Fetch all current provider spend from Redis to update in-memory cache
"""
try:
# No need to sync if Redis cache is not initialized
if self.dual_cache.redis_cache is None:
return
# 1. Push all provider spend increments to Redis
await self._push_in_memory_increments_to_redis()
# 2. Fetch all current provider spend from Redis to update in-memory cache
cache_keys = []
if self.provider_budget_config is not None:
for provider, config in self.provider_budget_config.items():
if config is None:
continue
cache_keys.append(
f"provider_spend:{provider}:{config.budget_duration}"
)
if self.deployment_budget_config is not None:
for model_id, config in self.deployment_budget_config.items():
if config is None:
continue
cache_keys.append(
f"deployment_spend:{model_id}:{config.budget_duration}"
)
if self.tag_budget_config is not None:
for tag, config in self.tag_budget_config.items():
if config is None:
continue
cache_keys.append(f"tag_spend:{tag}:{config.budget_duration}")
# Batch fetch current spend values from Redis
redis_values = await self.dual_cache.redis_cache.async_batch_get_cache(
key_list=cache_keys
)
# Update in-memory cache with Redis values
if isinstance(redis_values, dict): # Check if redis_values is a dictionary
for key, value in redis_values.items():
if value is not None:
await self.dual_cache.in_memory_cache.async_set_cache(
key=key, value=float(value)
)
verbose_router_logger.debug(
f"Updated in-memory cache for {key}: {value}"
)
except Exception as e:
verbose_router_logger.error(
f"Error syncing in-memory cache with Redis: {str(e)}"
)
def _get_budget_config_for_deployment(
self,
model_id: str,
) -> Optional[GenericBudgetInfo]:
if self.deployment_budget_config is None:
return None
return self.deployment_budget_config.get(model_id, None)
def _get_budget_config_for_provider(
self, provider: str
) -> Optional[GenericBudgetInfo]:
if self.provider_budget_config is None:
return None
return self.provider_budget_config.get(provider, None)
def _get_budget_config_for_tag(self, tag: str) -> Optional[GenericBudgetInfo]:
if self.tag_budget_config is None:
return None
return self.tag_budget_config.get(tag, None)
def _get_llm_provider_for_deployment(self, deployment: Dict) -> Optional[str]:
try:
_litellm_params: LiteLLM_Params = LiteLLM_Params(
**deployment.get("litellm_params", {"model": ""})
)
_, custom_llm_provider, _, _ = litellm.get_llm_provider(
model=_litellm_params.model,
litellm_params=_litellm_params,
)
except Exception:
verbose_router_logger.error(
f"Error getting LLM provider for deployment: {deployment}"
)
return None
return custom_llm_provider
def _track_provider_remaining_budget_prometheus(
self, provider: str, spend: float, budget_limit: float
):
"""
Optional helper - emit provider remaining budget metric to Prometheus
This is helpful for debugging and monitoring provider budget limits.
"""
prometheus_logger = _get_prometheus_logger_from_callbacks()
if prometheus_logger:
prometheus_logger.track_provider_remaining_budget(
provider=provider,
spend=spend,
budget_limit=budget_limit,
)
async def _get_current_provider_spend(self, provider: str) -> Optional[float]:
"""
GET the current spend for a provider from cache
used for GET /provider/budgets endpoint in spend_management_endpoints.py
Args:
provider (str): The provider to get spend for (e.g., "openai", "anthropic")
Returns:
Optional[float]: The current spend for the provider, or None if not found
"""
budget_config = self._get_budget_config_for_provider(provider)
if budget_config is None:
return None
spend_key = f"provider_spend:{provider}:{budget_config.budget_duration}"
if self.dual_cache.redis_cache:
# use Redis as source of truth since that has spend across all instances
current_spend = await self.dual_cache.redis_cache.async_get_cache(spend_key)
else:
# use in-memory cache if Redis is not initialized
current_spend = await self.dual_cache.async_get_cache(spend_key)
return float(current_spend) if current_spend is not None else 0.0
async def _get_current_provider_budget_reset_at(
self, provider: str
) -> Optional[str]:
budget_config = self._get_budget_config_for_provider(provider)
if budget_config is None:
return None
spend_key = f"provider_spend:{provider}:{budget_config.budget_duration}"
if self.dual_cache.redis_cache:
ttl_seconds = await self.dual_cache.redis_cache.async_get_ttl(spend_key)
else:
ttl_seconds = await self.dual_cache.async_get_ttl(spend_key)
if ttl_seconds is None:
return None
return (datetime.now(timezone.utc) + timedelta(seconds=ttl_seconds)).isoformat()
async def _init_provider_budget_in_cache(
self, provider: str, budget_config: GenericBudgetInfo
):
"""
Initialize provider budget in cache by storing the following keys if they don't exist:
- provider_spend:{provider}:{budget_config.time_period} - stores the current spend
- provider_budget_start_time:{provider} - stores the start time of the budget window
"""
spend_key = f"provider_spend:{provider}:{budget_config.budget_duration}"
start_time_key = f"provider_budget_start_time:{provider}"
ttl_seconds: Optional[int] = None
if budget_config.budget_duration is not None:
ttl_seconds = duration_in_seconds(budget_config.budget_duration)
budget_start = await self.dual_cache.async_get_cache(start_time_key)
if budget_start is None:
budget_start = datetime.now(timezone.utc).timestamp()
await self.dual_cache.async_set_cache(
key=start_time_key, value=budget_start, ttl=ttl_seconds
)
_spend_key = await self.dual_cache.async_get_cache(spend_key)
if _spend_key is None:
await self.dual_cache.async_set_cache(
key=spend_key, value=0.0, ttl=ttl_seconds
)
@staticmethod
def should_init_router_budget_limiter(
provider_budget_config: Optional[dict],
model_list: Optional[
Union[List[DeploymentTypedDict], List[Dict[str, Any]]]
] = None,
):
"""
Returns `True` if the router budget routing settings are set and RouterBudgetLimiting should be initialized
Either:
- provider_budget_config is set
- budgets are set for deployments in the model_list
- tag_budget_config is set
"""
if provider_budget_config is not None:
return True
if litellm.tag_budget_config is not None:
return True
if model_list is None:
return False
for _model in model_list:
_litellm_params = _model.get("litellm_params", {})
if (
_litellm_params.get("max_budget")
or _litellm_params.get("budget_duration") is not None
):
return True
return False
def _init_provider_budgets(self):
if self.provider_budget_config is not None:
# cast elements of provider_budget_config to GenericBudgetInfo
for provider, config in self.provider_budget_config.items():
if config is None:
raise ValueError(
f"No budget config found for provider {provider}, provider_budget_config: {self.provider_budget_config}"
)
if not isinstance(config, GenericBudgetInfo):
self.provider_budget_config[provider] = GenericBudgetInfo(
budget_limit=config.get("budget_limit"),
time_period=config.get("time_period"),
)
asyncio.create_task(
self._init_provider_budget_in_cache(
provider=provider,
budget_config=self.provider_budget_config[provider],
)
)
verbose_router_logger.debug(
f"Initalized Provider budget config: {self.provider_budget_config}"
)
def _init_deployment_budgets(
self,
model_list: Optional[
Union[List[DeploymentTypedDict], List[Dict[str, Any]]]
] = None,
):
if model_list is None:
return
for _model in model_list:
_litellm_params = _model.get("litellm_params", {})
_model_info: Dict = _model.get("model_info") or {}
_model_id = _model_info.get("id")
_max_budget = _litellm_params.get("max_budget")
_budget_duration = _litellm_params.get("budget_duration")
verbose_router_logger.debug(
f"Init Deployment Budget: max_budget: {_max_budget}, budget_duration: {_budget_duration}, model_id: {_model_id}"
)
if (
_max_budget is not None
and _budget_duration is not None
and _model_id is not None
):
_budget_config = GenericBudgetInfo(
time_period=_budget_duration,
budget_limit=_max_budget,
)
if self.deployment_budget_config is None:
self.deployment_budget_config = {}
self.deployment_budget_config[_model_id] = _budget_config
verbose_router_logger.debug(
f"Initialized Deployment Budget Config: {self.deployment_budget_config}"
)
def _init_tag_budgets(self):
if litellm.tag_budget_config is None:
return
from litellm.proxy.proxy_server import CommonProxyErrors, premium_user
if premium_user is not True:
raise ValueError(
f"Tag budgets are an Enterprise only feature, {CommonProxyErrors.not_premium_user}"
)
if self.tag_budget_config is None:
self.tag_budget_config = {}
for _tag, _tag_budget_config in litellm.tag_budget_config.items():
if isinstance(_tag_budget_config, dict):
_tag_budget_config = BudgetConfig(**_tag_budget_config)
_generic_budget_config = GenericBudgetInfo(
time_period=_tag_budget_config.budget_duration,
budget_limit=_tag_budget_config.max_budget,
)
self.tag_budget_config[_tag] = _generic_budget_config
verbose_router_logger.debug(
f"Initialized Tag Budget Config: {self.tag_budget_config}"
)

View File

@@ -0,0 +1,252 @@
#### What this does ####
# identifies least busy deployment
# How is this achieved?
# - Before each call, have the router print the state of requests {"deployment": "requests_in_flight"}
# - use litellm.input_callbacks to log when a request is just about to be made to a model - {"deployment-id": traffic}
# - use litellm.success + failure callbacks to log when a request completed
# - in get_available_deployment, for a given model group name -> pick based on traffic
import random
from typing import Optional
from litellm.caching.caching import DualCache
from litellm.integrations.custom_logger import CustomLogger
class LeastBusyLoggingHandler(CustomLogger):
test_flag: bool = False
logged_success: int = 0
logged_failure: int = 0
def __init__(self, router_cache: DualCache, model_list: list):
self.router_cache = router_cache
self.mapping_deployment_to_id: dict = {}
self.model_list = model_list
def log_pre_api_call(self, model, messages, kwargs):
"""
Log when a model is being used.
Caching based on model group.
"""
try:
if kwargs["litellm_params"].get("metadata") is None:
pass
else:
model_group = kwargs["litellm_params"]["metadata"].get(
"model_group", None
)
id = kwargs["litellm_params"].get("model_info", {}).get("id", None)
if model_group is None or id is None:
return
elif isinstance(id, int):
id = str(id)
request_count_api_key = f"{model_group}_request_count"
# update cache
request_count_dict = (
self.router_cache.get_cache(key=request_count_api_key) or {}
)
request_count_dict[id] = request_count_dict.get(id, 0) + 1
self.router_cache.set_cache(
key=request_count_api_key, value=request_count_dict
)
except Exception:
pass
def log_success_event(self, kwargs, response_obj, start_time, end_time):
try:
if kwargs["litellm_params"].get("metadata") is None:
pass
else:
model_group = kwargs["litellm_params"]["metadata"].get(
"model_group", None
)
id = kwargs["litellm_params"].get("model_info", {}).get("id", None)
if model_group is None or id is None:
return
elif isinstance(id, int):
id = str(id)
request_count_api_key = f"{model_group}_request_count"
# decrement count in cache
request_count_dict = (
self.router_cache.get_cache(key=request_count_api_key) or {}
)
request_count_value: Optional[int] = request_count_dict.get(id, 0)
if request_count_value is None:
return
request_count_dict[id] = request_count_value - 1
self.router_cache.set_cache(
key=request_count_api_key, value=request_count_dict
)
### TESTING ###
if self.test_flag:
self.logged_success += 1
except Exception:
pass
def log_failure_event(self, kwargs, response_obj, start_time, end_time):
try:
if kwargs["litellm_params"].get("metadata") is None:
pass
else:
model_group = kwargs["litellm_params"]["metadata"].get(
"model_group", None
)
id = kwargs["litellm_params"].get("model_info", {}).get("id", None)
if model_group is None or id is None:
return
elif isinstance(id, int):
id = str(id)
request_count_api_key = f"{model_group}_request_count"
# decrement count in cache
request_count_dict = (
self.router_cache.get_cache(key=request_count_api_key) or {}
)
request_count_value: Optional[int] = request_count_dict.get(id, 0)
if request_count_value is None:
return
request_count_dict[id] = request_count_value - 1
self.router_cache.set_cache(
key=request_count_api_key, value=request_count_dict
)
### TESTING ###
if self.test_flag:
self.logged_failure += 1
except Exception:
pass
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
try:
if kwargs["litellm_params"].get("metadata") is None:
pass
else:
model_group = kwargs["litellm_params"]["metadata"].get(
"model_group", None
)
id = kwargs["litellm_params"].get("model_info", {}).get("id", None)
if model_group is None or id is None:
return
elif isinstance(id, int):
id = str(id)
request_count_api_key = f"{model_group}_request_count"
# decrement count in cache
request_count_dict = (
await self.router_cache.async_get_cache(key=request_count_api_key)
or {}
)
request_count_value: Optional[int] = request_count_dict.get(id, 0)
if request_count_value is None:
return
request_count_dict[id] = request_count_value - 1
await self.router_cache.async_set_cache(
key=request_count_api_key, value=request_count_dict
)
### TESTING ###
if self.test_flag:
self.logged_success += 1
except Exception:
pass
async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time):
try:
if kwargs["litellm_params"].get("metadata") is None:
pass
else:
model_group = kwargs["litellm_params"]["metadata"].get(
"model_group", None
)
id = kwargs["litellm_params"].get("model_info", {}).get("id", None)
if model_group is None or id is None:
return
elif isinstance(id, int):
id = str(id)
request_count_api_key = f"{model_group}_request_count"
# decrement count in cache
request_count_dict = (
await self.router_cache.async_get_cache(key=request_count_api_key)
or {}
)
request_count_value: Optional[int] = request_count_dict.get(id, 0)
if request_count_value is None:
return
request_count_dict[id] = request_count_value - 1
await self.router_cache.async_set_cache(
key=request_count_api_key, value=request_count_dict
)
### TESTING ###
if self.test_flag:
self.logged_failure += 1
except Exception:
pass
def _get_available_deployments(
self,
healthy_deployments: list,
all_deployments: dict,
):
"""
Helper to get deployments using least busy strategy
"""
for d in healthy_deployments:
## if healthy deployment not yet used
if d["model_info"]["id"] not in all_deployments:
all_deployments[d["model_info"]["id"]] = 0
# map deployment to id
# pick least busy deployment
min_traffic = float("inf")
min_deployment = None
for k, v in all_deployments.items():
if v < min_traffic:
min_traffic = v
min_deployment = k
if min_deployment is not None:
## check if min deployment is a string, if so, cast it to int
for m in healthy_deployments:
if m["model_info"]["id"] == min_deployment:
return m
min_deployment = random.choice(healthy_deployments)
else:
min_deployment = random.choice(healthy_deployments)
return min_deployment
def get_available_deployments(
self,
model_group: str,
healthy_deployments: list,
):
"""
Sync helper to get deployments using least busy strategy
"""
request_count_api_key = f"{model_group}_request_count"
all_deployments = self.router_cache.get_cache(key=request_count_api_key) or {}
return self._get_available_deployments(
healthy_deployments=healthy_deployments,
all_deployments=all_deployments,
)
async def async_get_available_deployments(
self, model_group: str, healthy_deployments: list
):
"""
Async helper to get deployments using least busy strategy
"""
request_count_api_key = f"{model_group}_request_count"
all_deployments = (
await self.router_cache.async_get_cache(key=request_count_api_key) or {}
)
return self._get_available_deployments(
healthy_deployments=healthy_deployments,
all_deployments=all_deployments,
)

View File

@@ -0,0 +1,333 @@
#### What this does ####
# picks based on response time (for streaming, this is time to first token)
from datetime import datetime, timedelta
from typing import Dict, List, Optional, Union
import litellm
from litellm import ModelResponse, token_counter, verbose_logger
from litellm._logging import verbose_router_logger
from litellm.caching.caching import DualCache
from litellm.integrations.custom_logger import CustomLogger
class LowestCostLoggingHandler(CustomLogger):
test_flag: bool = False
logged_success: int = 0
logged_failure: int = 0
def __init__(
self, router_cache: DualCache, model_list: list, routing_args: dict = {}
):
self.router_cache = router_cache
self.model_list = model_list
def log_success_event(self, kwargs, response_obj, start_time, end_time):
try:
"""
Update usage on success
"""
if kwargs["litellm_params"].get("metadata") is None:
pass
else:
model_group = kwargs["litellm_params"]["metadata"].get(
"model_group", None
)
id = kwargs["litellm_params"].get("model_info", {}).get("id", None)
if model_group is None or id is None:
return
elif isinstance(id, int):
id = str(id)
# ------------
# Setup values
# ------------
"""
{
{model_group}_map: {
id: {
f"{date:hour:minute}" : {"tpm": 34, "rpm": 3}
}
}
}
"""
current_date = datetime.now().strftime("%Y-%m-%d")
current_hour = datetime.now().strftime("%H")
current_minute = datetime.now().strftime("%M")
precise_minute = f"{current_date}-{current_hour}-{current_minute}"
cost_key = f"{model_group}_map"
response_ms: timedelta = end_time - start_time
total_tokens = 0
if isinstance(response_obj, ModelResponse):
_usage = getattr(response_obj, "usage", None)
if _usage is not None and isinstance(_usage, litellm.Usage):
completion_tokens = _usage.completion_tokens
total_tokens = _usage.total_tokens
float(response_ms.total_seconds() / completion_tokens)
# ------------
# Update usage
# ------------
request_count_dict = self.router_cache.get_cache(key=cost_key) or {}
# check local result first
if id not in request_count_dict:
request_count_dict[id] = {}
if precise_minute not in request_count_dict[id]:
request_count_dict[id][precise_minute] = {}
## TPM
request_count_dict[id][precise_minute]["tpm"] = (
request_count_dict[id][precise_minute].get("tpm", 0) + total_tokens
)
## RPM
request_count_dict[id][precise_minute]["rpm"] = (
request_count_dict[id][precise_minute].get("rpm", 0) + 1
)
self.router_cache.set_cache(key=cost_key, value=request_count_dict)
### TESTING ###
if self.test_flag:
self.logged_success += 1
except Exception as e:
verbose_logger.exception(
"litellm.router_strategy.lowest_cost.py::log_success_event(): Exception occured - {}".format(
str(e)
)
)
pass
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
try:
"""
Update cost usage on success
"""
if kwargs["litellm_params"].get("metadata") is None:
pass
else:
model_group = kwargs["litellm_params"]["metadata"].get(
"model_group", None
)
id = kwargs["litellm_params"].get("model_info", {}).get("id", None)
if model_group is None or id is None:
return
elif isinstance(id, int):
id = str(id)
# ------------
# Setup values
# ------------
"""
{
{model_group}_map: {
id: {
"cost": [..]
f"{date:hour:minute}" : {"tpm": 34, "rpm": 3}
}
}
}
"""
cost_key = f"{model_group}_map"
current_date = datetime.now().strftime("%Y-%m-%d")
current_hour = datetime.now().strftime("%H")
current_minute = datetime.now().strftime("%M")
precise_minute = f"{current_date}-{current_hour}-{current_minute}"
response_ms: timedelta = end_time - start_time
total_tokens = 0
if isinstance(response_obj, ModelResponse):
_usage = getattr(response_obj, "usage", None)
if _usage is not None and isinstance(_usage, litellm.Usage):
completion_tokens = _usage.completion_tokens
total_tokens = _usage.total_tokens
float(response_ms.total_seconds() / completion_tokens)
# ------------
# Update usage
# ------------
request_count_dict = (
await self.router_cache.async_get_cache(key=cost_key) or {}
)
if id not in request_count_dict:
request_count_dict[id] = {}
if precise_minute not in request_count_dict[id]:
request_count_dict[id][precise_minute] = {}
## TPM
request_count_dict[id][precise_minute]["tpm"] = (
request_count_dict[id][precise_minute].get("tpm", 0) + total_tokens
)
## RPM
request_count_dict[id][precise_minute]["rpm"] = (
request_count_dict[id][precise_minute].get("rpm", 0) + 1
)
await self.router_cache.async_set_cache(
key=cost_key, value=request_count_dict
) # reset map within window
### TESTING ###
if self.test_flag:
self.logged_success += 1
except Exception as e:
verbose_logger.exception(
"litellm.proxy.hooks.prompt_injection_detection.py::async_pre_call_hook(): Exception occured - {}".format(
str(e)
)
)
pass
async def async_get_available_deployments( # noqa: PLR0915
self,
model_group: str,
healthy_deployments: list,
messages: Optional[List[Dict[str, str]]] = None,
input: Optional[Union[str, List]] = None,
request_kwargs: Optional[Dict] = None,
):
"""
Returns a deployment with the lowest cost
"""
cost_key = f"{model_group}_map"
request_count_dict = await self.router_cache.async_get_cache(key=cost_key) or {}
# -----------------------
# Find lowest used model
# ----------------------
float("inf")
current_date = datetime.now().strftime("%Y-%m-%d")
current_hour = datetime.now().strftime("%H")
current_minute = datetime.now().strftime("%M")
precise_minute = f"{current_date}-{current_hour}-{current_minute}"
if request_count_dict is None: # base case
return
all_deployments = request_count_dict
for d in healthy_deployments:
## if healthy deployment not yet used
if d["model_info"]["id"] not in all_deployments:
all_deployments[d["model_info"]["id"]] = {
precise_minute: {"tpm": 0, "rpm": 0},
}
try:
input_tokens = token_counter(messages=messages, text=input)
except Exception:
input_tokens = 0
# randomly sample from all_deployments, incase all deployments have latency=0.0
_items = all_deployments.items()
### GET AVAILABLE DEPLOYMENTS ### filter out any deployments > tpm/rpm limits
potential_deployments = []
_cost_per_deployment = {}
for item, item_map in all_deployments.items():
## get the item from model list
_deployment = None
for m in healthy_deployments:
if item == m["model_info"]["id"]:
_deployment = m
if _deployment is None:
continue # skip to next one
_deployment_tpm = (
_deployment.get("tpm", None)
or _deployment.get("litellm_params", {}).get("tpm", None)
or _deployment.get("model_info", {}).get("tpm", None)
or float("inf")
)
_deployment_rpm = (
_deployment.get("rpm", None)
or _deployment.get("litellm_params", {}).get("rpm", None)
or _deployment.get("model_info", {}).get("rpm", None)
or float("inf")
)
item_litellm_model_name = _deployment.get("litellm_params", {}).get("model")
item_litellm_model_cost_map = litellm.model_cost.get(
item_litellm_model_name, {}
)
# check if user provided input_cost_per_token and output_cost_per_token in litellm_params
item_input_cost = None
item_output_cost = None
if _deployment.get("litellm_params", {}).get("input_cost_per_token", None):
item_input_cost = _deployment.get("litellm_params", {}).get(
"input_cost_per_token"
)
if _deployment.get("litellm_params", {}).get("output_cost_per_token", None):
item_output_cost = _deployment.get("litellm_params", {}).get(
"output_cost_per_token"
)
if item_input_cost is None:
item_input_cost = item_litellm_model_cost_map.get(
"input_cost_per_token", 5.0
)
if item_output_cost is None:
item_output_cost = item_litellm_model_cost_map.get(
"output_cost_per_token", 5.0
)
# if litellm["model"] is not in model_cost map -> use item_cost = $10
item_cost = item_input_cost + item_output_cost
item_rpm = item_map.get(precise_minute, {}).get("rpm", 0)
item_tpm = item_map.get(precise_minute, {}).get("tpm", 0)
verbose_router_logger.debug(
f"item_cost: {item_cost}, item_tpm: {item_tpm}, item_rpm: {item_rpm}, model_id: {_deployment.get('model_info', {}).get('id')}"
)
# -------------- #
# Debugging Logic
# -------------- #
# We use _cost_per_deployment to log to langfuse, slack - this is not used to make a decision on routing
# this helps a user to debug why the router picked a specfic deployment #
_deployment_api_base = _deployment.get("litellm_params", {}).get(
"api_base", ""
)
if _deployment_api_base is not None:
_cost_per_deployment[_deployment_api_base] = item_cost
# -------------- #
# End of Debugging Logic
# -------------- #
if (
item_tpm + input_tokens > _deployment_tpm
or item_rpm + 1 > _deployment_rpm
): # if user passed in tpm / rpm in the model_list
continue
else:
potential_deployments.append((_deployment, item_cost))
if len(potential_deployments) == 0:
return None
potential_deployments = sorted(potential_deployments, key=lambda x: x[1])
selected_deployment = potential_deployments[0][0]
return selected_deployment

View File

@@ -0,0 +1,590 @@
#### What this does ####
# picks based on response time (for streaming, this is time to first token)
import random
from datetime import datetime, timedelta
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
import litellm
from litellm import ModelResponse, token_counter, verbose_logger
from litellm.caching.caching import DualCache
from litellm.integrations.custom_logger import CustomLogger
from litellm.litellm_core_utils.core_helpers import _get_parent_otel_span_from_kwargs
from litellm.types.utils import LiteLLMPydanticObjectBase
if TYPE_CHECKING:
from opentelemetry.trace import Span as _Span
Span = Union[_Span, Any]
else:
Span = Any
class RoutingArgs(LiteLLMPydanticObjectBase):
ttl: float = 1 * 60 * 60 # 1 hour
lowest_latency_buffer: float = 0
max_latency_list_size: int = 10
class LowestLatencyLoggingHandler(CustomLogger):
test_flag: bool = False
logged_success: int = 0
logged_failure: int = 0
def __init__(
self, router_cache: DualCache, model_list: list, routing_args: dict = {}
):
self.router_cache = router_cache
self.model_list = model_list
self.routing_args = RoutingArgs(**routing_args)
def log_success_event( # noqa: PLR0915
self, kwargs, response_obj, start_time, end_time
):
try:
"""
Update latency usage on success
"""
if kwargs["litellm_params"].get("metadata") is None:
pass
else:
model_group = kwargs["litellm_params"]["metadata"].get(
"model_group", None
)
id = kwargs["litellm_params"].get("model_info", {}).get("id", None)
if model_group is None or id is None:
return
elif isinstance(id, int):
id = str(id)
# ------------
# Setup values
# ------------
"""
{
{model_group}_map: {
id: {
"latency": [..]
f"{date:hour:minute}" : {"tpm": 34, "rpm": 3}
}
}
}
"""
latency_key = f"{model_group}_map"
current_date = datetime.now().strftime("%Y-%m-%d")
current_hour = datetime.now().strftime("%H")
current_minute = datetime.now().strftime("%M")
precise_minute = f"{current_date}-{current_hour}-{current_minute}"
response_ms: timedelta = end_time - start_time
time_to_first_token_response_time: Optional[timedelta] = None
if kwargs.get("stream", None) is not None and kwargs["stream"] is True:
# only log ttft for streaming request
time_to_first_token_response_time = (
kwargs.get("completion_start_time", end_time) - start_time
)
final_value: Union[float, timedelta] = response_ms
time_to_first_token: Optional[float] = None
total_tokens = 0
if isinstance(response_obj, ModelResponse):
_usage = getattr(response_obj, "usage", None)
if _usage is not None:
completion_tokens = _usage.completion_tokens
total_tokens = _usage.total_tokens
final_value = float(
response_ms.total_seconds() / completion_tokens
)
if time_to_first_token_response_time is not None:
time_to_first_token = float(
time_to_first_token_response_time.total_seconds()
/ completion_tokens
)
# ------------
# Update usage
# ------------
parent_otel_span = _get_parent_otel_span_from_kwargs(kwargs)
request_count_dict = (
self.router_cache.get_cache(
key=latency_key, parent_otel_span=parent_otel_span
)
or {}
)
if id not in request_count_dict:
request_count_dict[id] = {}
## Latency
if (
len(request_count_dict[id].get("latency", []))
< self.routing_args.max_latency_list_size
):
request_count_dict[id].setdefault("latency", []).append(final_value)
else:
request_count_dict[id]["latency"] = request_count_dict[id][
"latency"
][: self.routing_args.max_latency_list_size - 1] + [final_value]
## Time to first token
if time_to_first_token is not None:
if (
len(request_count_dict[id].get("time_to_first_token", []))
< self.routing_args.max_latency_list_size
):
request_count_dict[id].setdefault(
"time_to_first_token", []
).append(time_to_first_token)
else:
request_count_dict[id][
"time_to_first_token"
] = request_count_dict[id]["time_to_first_token"][
: self.routing_args.max_latency_list_size - 1
] + [
time_to_first_token
]
if precise_minute not in request_count_dict[id]:
request_count_dict[id][precise_minute] = {}
## TPM
request_count_dict[id][precise_minute]["tpm"] = (
request_count_dict[id][precise_minute].get("tpm", 0) + total_tokens
)
## RPM
request_count_dict[id][precise_minute]["rpm"] = (
request_count_dict[id][precise_minute].get("rpm", 0) + 1
)
self.router_cache.set_cache(
key=latency_key, value=request_count_dict, ttl=self.routing_args.ttl
) # reset map within window
### TESTING ###
if self.test_flag:
self.logged_success += 1
except Exception as e:
verbose_logger.exception(
"litellm.proxy.hooks.prompt_injection_detection.py::async_pre_call_hook(): Exception occured - {}".format(
str(e)
)
)
pass
async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time):
"""
Check if Timeout Error, if timeout set deployment latency -> 100
"""
try:
_exception = kwargs.get("exception", None)
if isinstance(_exception, litellm.Timeout):
if kwargs["litellm_params"].get("metadata") is None:
pass
else:
model_group = kwargs["litellm_params"]["metadata"].get(
"model_group", None
)
id = kwargs["litellm_params"].get("model_info", {}).get("id", None)
if model_group is None or id is None:
return
elif isinstance(id, int):
id = str(id)
# ------------
# Setup values
# ------------
"""
{
{model_group}_map: {
id: {
"latency": [..]
f"{date:hour:minute}" : {"tpm": 34, "rpm": 3}
}
}
}
"""
latency_key = f"{model_group}_map"
request_count_dict = (
await self.router_cache.async_get_cache(key=latency_key) or {}
)
if id not in request_count_dict:
request_count_dict[id] = {}
## Latency - give 1000s penalty for failing
if (
len(request_count_dict[id].get("latency", []))
< self.routing_args.max_latency_list_size
):
request_count_dict[id].setdefault("latency", []).append(1000.0)
else:
request_count_dict[id]["latency"] = request_count_dict[id][
"latency"
][: self.routing_args.max_latency_list_size - 1] + [1000.0]
await self.router_cache.async_set_cache(
key=latency_key,
value=request_count_dict,
ttl=self.routing_args.ttl,
) # reset map within window
else:
# do nothing if it's not a timeout error
return
except Exception as e:
verbose_logger.exception(
"litellm.proxy.hooks.prompt_injection_detection.py::async_pre_call_hook(): Exception occured - {}".format(
str(e)
)
)
pass
async def async_log_success_event( # noqa: PLR0915
self, kwargs, response_obj, start_time, end_time
):
try:
"""
Update latency usage on success
"""
if kwargs["litellm_params"].get("metadata") is None:
pass
else:
model_group = kwargs["litellm_params"]["metadata"].get(
"model_group", None
)
id = kwargs["litellm_params"].get("model_info", {}).get("id", None)
if model_group is None or id is None:
return
elif isinstance(id, int):
id = str(id)
# ------------
# Setup values
# ------------
"""
{
{model_group}_map: {
id: {
"latency": [..]
"time_to_first_token": [..]
f"{date:hour:minute}" : {"tpm": 34, "rpm": 3}
}
}
}
"""
latency_key = f"{model_group}_map"
current_date = datetime.now().strftime("%Y-%m-%d")
current_hour = datetime.now().strftime("%H")
current_minute = datetime.now().strftime("%M")
precise_minute = f"{current_date}-{current_hour}-{current_minute}"
response_ms: timedelta = end_time - start_time
time_to_first_token_response_time: Optional[timedelta] = None
if kwargs.get("stream", None) is not None and kwargs["stream"] is True:
# only log ttft for streaming request
time_to_first_token_response_time = (
kwargs.get("completion_start_time", end_time) - start_time
)
final_value: Union[float, timedelta] = response_ms
total_tokens = 0
time_to_first_token: Optional[float] = None
if isinstance(response_obj, ModelResponse):
_usage = getattr(response_obj, "usage", None)
if _usage is not None:
completion_tokens = _usage.completion_tokens
total_tokens = _usage.total_tokens
final_value = float(
response_ms.total_seconds() / completion_tokens
)
if time_to_first_token_response_time is not None:
time_to_first_token = float(
time_to_first_token_response_time.total_seconds()
/ completion_tokens
)
# ------------
# Update usage
# ------------
parent_otel_span = _get_parent_otel_span_from_kwargs(kwargs)
request_count_dict = (
await self.router_cache.async_get_cache(
key=latency_key,
parent_otel_span=parent_otel_span,
local_only=True,
)
or {}
)
if id not in request_count_dict:
request_count_dict[id] = {}
## Latency
if (
len(request_count_dict[id].get("latency", []))
< self.routing_args.max_latency_list_size
):
request_count_dict[id].setdefault("latency", []).append(final_value)
else:
request_count_dict[id]["latency"] = request_count_dict[id][
"latency"
][: self.routing_args.max_latency_list_size - 1] + [final_value]
## Time to first token
if time_to_first_token is not None:
if (
len(request_count_dict[id].get("time_to_first_token", []))
< self.routing_args.max_latency_list_size
):
request_count_dict[id].setdefault(
"time_to_first_token", []
).append(time_to_first_token)
else:
request_count_dict[id][
"time_to_first_token"
] = request_count_dict[id]["time_to_first_token"][
: self.routing_args.max_latency_list_size - 1
] + [
time_to_first_token
]
if precise_minute not in request_count_dict[id]:
request_count_dict[id][precise_minute] = {}
## TPM
request_count_dict[id][precise_minute]["tpm"] = (
request_count_dict[id][precise_minute].get("tpm", 0) + total_tokens
)
## RPM
request_count_dict[id][precise_minute]["rpm"] = (
request_count_dict[id][precise_minute].get("rpm", 0) + 1
)
await self.router_cache.async_set_cache(
key=latency_key, value=request_count_dict, ttl=self.routing_args.ttl
) # reset map within window
### TESTING ###
if self.test_flag:
self.logged_success += 1
except Exception as e:
verbose_logger.exception(
"litellm.router_strategy.lowest_latency.py::async_log_success_event(): Exception occured - {}".format(
str(e)
)
)
pass
def _get_available_deployments( # noqa: PLR0915
self,
model_group: str,
healthy_deployments: list,
messages: Optional[List[Dict[str, str]]] = None,
input: Optional[Union[str, List]] = None,
request_kwargs: Optional[Dict] = None,
request_count_dict: Optional[Dict] = None,
):
"""Common logic for both sync and async get_available_deployments"""
# -----------------------
# Find lowest used model
# ----------------------
_latency_per_deployment = {}
lowest_latency = float("inf")
current_date = datetime.now().strftime("%Y-%m-%d")
current_hour = datetime.now().strftime("%H")
current_minute = datetime.now().strftime("%M")
precise_minute = f"{current_date}-{current_hour}-{current_minute}"
deployment = None
if request_count_dict is None: # base case
return
all_deployments = request_count_dict
for d in healthy_deployments:
## if healthy deployment not yet used
if d["model_info"]["id"] not in all_deployments:
all_deployments[d["model_info"]["id"]] = {
"latency": [0],
precise_minute: {"tpm": 0, "rpm": 0},
}
try:
input_tokens = token_counter(messages=messages, text=input)
except Exception:
input_tokens = 0
# randomly sample from all_deployments, incase all deployments have latency=0.0
_items = all_deployments.items()
_all_deployments = random.sample(list(_items), len(_items))
all_deployments = dict(_all_deployments)
### GET AVAILABLE DEPLOYMENTS ### filter out any deployments > tpm/rpm limits
potential_deployments = []
for item, item_map in all_deployments.items():
## get the item from model list
_deployment = None
for m in healthy_deployments:
if item == m["model_info"]["id"]:
_deployment = m
if _deployment is None:
continue # skip to next one
_deployment_tpm = (
_deployment.get("tpm", None)
or _deployment.get("litellm_params", {}).get("tpm", None)
or _deployment.get("model_info", {}).get("tpm", None)
or float("inf")
)
_deployment_rpm = (
_deployment.get("rpm", None)
or _deployment.get("litellm_params", {}).get("rpm", None)
or _deployment.get("model_info", {}).get("rpm", None)
or float("inf")
)
item_latency = item_map.get("latency", [])
item_ttft_latency = item_map.get("time_to_first_token", [])
item_rpm = item_map.get(precise_minute, {}).get("rpm", 0)
item_tpm = item_map.get(precise_minute, {}).get("tpm", 0)
# get average latency or average ttft (depending on streaming/non-streaming)
total: float = 0.0
if (
request_kwargs is not None
and request_kwargs.get("stream", None) is not None
and request_kwargs["stream"] is True
and len(item_ttft_latency) > 0
):
for _call_latency in item_ttft_latency:
if isinstance(_call_latency, float):
total += _call_latency
else:
for _call_latency in item_latency:
if isinstance(_call_latency, float):
total += _call_latency
item_latency = total / len(item_latency)
# -------------- #
# Debugging Logic
# -------------- #
# We use _latency_per_deployment to log to langfuse, slack - this is not used to make a decision on routing
# this helps a user to debug why the router picked a specfic deployment #
_deployment_api_base = _deployment.get("litellm_params", {}).get(
"api_base", ""
)
if _deployment_api_base is not None:
_latency_per_deployment[_deployment_api_base] = item_latency
# -------------- #
# End of Debugging Logic
# -------------- #
if (
item_tpm + input_tokens > _deployment_tpm
or item_rpm + 1 > _deployment_rpm
): # if user passed in tpm / rpm in the model_list
continue
else:
potential_deployments.append((_deployment, item_latency))
if len(potential_deployments) == 0:
return None
# Sort potential deployments by latency
sorted_deployments = sorted(potential_deployments, key=lambda x: x[1])
# Find lowest latency deployment
lowest_latency = sorted_deployments[0][1]
# Find deployments within buffer of lowest latency
buffer = self.routing_args.lowest_latency_buffer * lowest_latency
valid_deployments = [
x for x in sorted_deployments if x[1] <= lowest_latency + buffer
]
# Pick a random deployment from valid deployments
random_valid_deployment = random.choice(valid_deployments)
deployment = random_valid_deployment[0]
if request_kwargs is not None and "metadata" in request_kwargs:
request_kwargs["metadata"][
"_latency_per_deployment"
] = _latency_per_deployment
return deployment
async def async_get_available_deployments(
self,
model_group: str,
healthy_deployments: list,
messages: Optional[List[Dict[str, str]]] = None,
input: Optional[Union[str, List]] = None,
request_kwargs: Optional[Dict] = None,
):
# get list of potential deployments
latency_key = f"{model_group}_map"
parent_otel_span: Optional[Span] = _get_parent_otel_span_from_kwargs(
request_kwargs
)
request_count_dict = (
await self.router_cache.async_get_cache(
key=latency_key, parent_otel_span=parent_otel_span
)
or {}
)
return self._get_available_deployments(
model_group,
healthy_deployments,
messages,
input,
request_kwargs,
request_count_dict,
)
def get_available_deployments(
self,
model_group: str,
healthy_deployments: list,
messages: Optional[List[Dict[str, str]]] = None,
input: Optional[Union[str, List]] = None,
request_kwargs: Optional[Dict] = None,
):
"""
Returns a deployment with the lowest latency
"""
# get list of potential deployments
latency_key = f"{model_group}_map"
parent_otel_span: Optional[Span] = _get_parent_otel_span_from_kwargs(
request_kwargs
)
request_count_dict = (
self.router_cache.get_cache(
key=latency_key, parent_otel_span=parent_otel_span
)
or {}
)
return self._get_available_deployments(
model_group,
healthy_deployments,
messages,
input,
request_kwargs,
request_count_dict,
)

View File

@@ -0,0 +1,243 @@
#### What this does ####
# identifies lowest tpm deployment
import traceback
from datetime import datetime
from typing import Dict, List, Optional, Union
from litellm import token_counter
from litellm._logging import verbose_router_logger
from litellm.caching.caching import DualCache
from litellm.integrations.custom_logger import CustomLogger
from litellm.types.utils import LiteLLMPydanticObjectBase
from litellm.utils import print_verbose
class RoutingArgs(LiteLLMPydanticObjectBase):
ttl: int = 1 * 60 # 1min (RPM/TPM expire key)
class LowestTPMLoggingHandler(CustomLogger):
test_flag: bool = False
logged_success: int = 0
logged_failure: int = 0
default_cache_time_seconds: int = 1 * 60 * 60 # 1 hour
def __init__(
self, router_cache: DualCache, model_list: list, routing_args: dict = {}
):
self.router_cache = router_cache
self.model_list = model_list
self.routing_args = RoutingArgs(**routing_args)
def log_success_event(self, kwargs, response_obj, start_time, end_time):
try:
"""
Update TPM/RPM usage on success
"""
if kwargs["litellm_params"].get("metadata") is None:
pass
else:
model_group = kwargs["litellm_params"]["metadata"].get(
"model_group", None
)
id = kwargs["litellm_params"].get("model_info", {}).get("id", None)
if model_group is None or id is None:
return
elif isinstance(id, int):
id = str(id)
total_tokens = response_obj["usage"]["total_tokens"]
# ------------
# Setup values
# ------------
current_minute = datetime.now().strftime("%H-%M")
tpm_key = f"{model_group}:tpm:{current_minute}"
rpm_key = f"{model_group}:rpm:{current_minute}"
# ------------
# Update usage
# ------------
## TPM
request_count_dict = self.router_cache.get_cache(key=tpm_key) or {}
request_count_dict[id] = request_count_dict.get(id, 0) + total_tokens
self.router_cache.set_cache(
key=tpm_key, value=request_count_dict, ttl=self.routing_args.ttl
)
## RPM
request_count_dict = self.router_cache.get_cache(key=rpm_key) or {}
request_count_dict[id] = request_count_dict.get(id, 0) + 1
self.router_cache.set_cache(
key=rpm_key, value=request_count_dict, ttl=self.routing_args.ttl
)
### TESTING ###
if self.test_flag:
self.logged_success += 1
except Exception as e:
verbose_router_logger.error(
"litellm.router_strategy.lowest_tpm_rpm.py::async_log_success_event(): Exception occured - {}".format(
str(e)
)
)
verbose_router_logger.debug(traceback.format_exc())
pass
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
try:
"""
Update TPM/RPM usage on success
"""
if kwargs["litellm_params"].get("metadata") is None:
pass
else:
model_group = kwargs["litellm_params"]["metadata"].get(
"model_group", None
)
id = kwargs["litellm_params"].get("model_info", {}).get("id", None)
if model_group is None or id is None:
return
elif isinstance(id, int):
id = str(id)
total_tokens = response_obj["usage"]["total_tokens"]
# ------------
# Setup values
# ------------
current_minute = datetime.now().strftime("%H-%M")
tpm_key = f"{model_group}:tpm:{current_minute}"
rpm_key = f"{model_group}:rpm:{current_minute}"
# ------------
# Update usage
# ------------
# update cache
## TPM
request_count_dict = (
await self.router_cache.async_get_cache(key=tpm_key) or {}
)
request_count_dict[id] = request_count_dict.get(id, 0) + total_tokens
await self.router_cache.async_set_cache(
key=tpm_key, value=request_count_dict, ttl=self.routing_args.ttl
)
## RPM
request_count_dict = (
await self.router_cache.async_get_cache(key=rpm_key) or {}
)
request_count_dict[id] = request_count_dict.get(id, 0) + 1
await self.router_cache.async_set_cache(
key=rpm_key, value=request_count_dict, ttl=self.routing_args.ttl
)
### TESTING ###
if self.test_flag:
self.logged_success += 1
except Exception as e:
verbose_router_logger.error(
"litellm.router_strategy.lowest_tpm_rpm.py::async_log_success_event(): Exception occured - {}".format(
str(e)
)
)
verbose_router_logger.debug(traceback.format_exc())
pass
def get_available_deployments( # noqa: PLR0915
self,
model_group: str,
healthy_deployments: list,
messages: Optional[List[Dict[str, str]]] = None,
input: Optional[Union[str, List]] = None,
):
"""
Returns a deployment with the lowest TPM/RPM usage.
"""
# get list of potential deployments
verbose_router_logger.debug(
f"get_available_deployments - Usage Based. model_group: {model_group}, healthy_deployments: {healthy_deployments}"
)
current_minute = datetime.now().strftime("%H-%M")
tpm_key = f"{model_group}:tpm:{current_minute}"
rpm_key = f"{model_group}:rpm:{current_minute}"
tpm_dict = self.router_cache.get_cache(key=tpm_key)
rpm_dict = self.router_cache.get_cache(key=rpm_key)
verbose_router_logger.debug(
f"tpm_key={tpm_key}, tpm_dict: {tpm_dict}, rpm_dict: {rpm_dict}"
)
try:
input_tokens = token_counter(messages=messages, text=input)
except Exception:
input_tokens = 0
verbose_router_logger.debug(f"input_tokens={input_tokens}")
# -----------------------
# Find lowest used model
# ----------------------
lowest_tpm = float("inf")
if tpm_dict is None: # base case - none of the deployments have been used
# initialize a tpm dict with {model_id: 0}
tpm_dict = {}
for deployment in healthy_deployments:
tpm_dict[deployment["model_info"]["id"]] = 0
else:
for d in healthy_deployments:
## if healthy deployment not yet used
if d["model_info"]["id"] not in tpm_dict:
tpm_dict[d["model_info"]["id"]] = 0
all_deployments = tpm_dict
deployment = None
for item, item_tpm in all_deployments.items():
## get the item from model list
_deployment = None
for m in healthy_deployments:
if item == m["model_info"]["id"]:
_deployment = m
if _deployment is None:
continue # skip to next one
_deployment_tpm = None
if _deployment_tpm is None:
_deployment_tpm = _deployment.get("tpm")
if _deployment_tpm is None:
_deployment_tpm = _deployment.get("litellm_params", {}).get("tpm")
if _deployment_tpm is None:
_deployment_tpm = _deployment.get("model_info", {}).get("tpm")
if _deployment_tpm is None:
_deployment_tpm = float("inf")
_deployment_rpm = None
if _deployment_rpm is None:
_deployment_rpm = _deployment.get("rpm")
if _deployment_rpm is None:
_deployment_rpm = _deployment.get("litellm_params", {}).get("rpm")
if _deployment_rpm is None:
_deployment_rpm = _deployment.get("model_info", {}).get("rpm")
if _deployment_rpm is None:
_deployment_rpm = float("inf")
if item_tpm + input_tokens > _deployment_tpm:
continue
elif (rpm_dict is not None and item in rpm_dict) and (
rpm_dict[item] + 1 >= _deployment_rpm
):
continue
elif item_tpm < lowest_tpm:
lowest_tpm = item_tpm
deployment = _deployment
print_verbose("returning picked lowest tpm/rpm deployment.")
return deployment

View File

@@ -0,0 +1,670 @@
#### What this does ####
# identifies lowest tpm deployment
import random
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
import httpx
import litellm
from litellm import token_counter
from litellm._logging import verbose_logger, verbose_router_logger
from litellm.caching.caching import DualCache
from litellm.integrations.custom_logger import CustomLogger
from litellm.litellm_core_utils.core_helpers import _get_parent_otel_span_from_kwargs
from litellm.types.router import RouterErrors
from litellm.types.utils import LiteLLMPydanticObjectBase, StandardLoggingPayload
from litellm.utils import get_utc_datetime, print_verbose
from .base_routing_strategy import BaseRoutingStrategy
if TYPE_CHECKING:
from opentelemetry.trace import Span as _Span
Span = Union[_Span, Any]
else:
Span = Any
class RoutingArgs(LiteLLMPydanticObjectBase):
ttl: int = 1 * 60 # 1min (RPM/TPM expire key)
class LowestTPMLoggingHandler_v2(BaseRoutingStrategy, CustomLogger):
"""
Updated version of TPM/RPM Logging.
Meant to work across instances.
Caches individual models, not model_groups
Uses batch get (redis.mget)
Increments tpm/rpm limit using redis.incr
"""
test_flag: bool = False
logged_success: int = 0
logged_failure: int = 0
default_cache_time_seconds: int = 1 * 60 * 60 # 1 hour
def __init__(
self, router_cache: DualCache, model_list: list, routing_args: dict = {}
):
self.router_cache = router_cache
self.model_list = model_list
self.routing_args = RoutingArgs(**routing_args)
BaseRoutingStrategy.__init__(
self,
dual_cache=router_cache,
should_batch_redis_writes=True,
default_sync_interval=0.1,
)
def pre_call_check(self, deployment: Dict) -> Optional[Dict]:
"""
Pre-call check + update model rpm
Returns - deployment
Raises - RateLimitError if deployment over defined RPM limit
"""
try:
# ------------
# Setup values
# ------------
dt = get_utc_datetime()
current_minute = dt.strftime("%H-%M")
model_id = deployment.get("model_info", {}).get("id")
deployment_name = deployment.get("litellm_params", {}).get("model")
rpm_key = f"{model_id}:{deployment_name}:rpm:{current_minute}"
local_result = self.router_cache.get_cache(
key=rpm_key, local_only=True
) # check local result first
deployment_rpm = None
if deployment_rpm is None:
deployment_rpm = deployment.get("rpm")
if deployment_rpm is None:
deployment_rpm = deployment.get("litellm_params", {}).get("rpm")
if deployment_rpm is None:
deployment_rpm = deployment.get("model_info", {}).get("rpm")
if deployment_rpm is None:
deployment_rpm = float("inf")
if local_result is not None and local_result >= deployment_rpm:
raise litellm.RateLimitError(
message="Deployment over defined rpm limit={}. current usage={}".format(
deployment_rpm, local_result
),
llm_provider="",
model=deployment.get("litellm_params", {}).get("model"),
response=httpx.Response(
status_code=429,
content="{} rpm limit={}. current usage={}. id={}, model_group={}. Get the model info by calling 'router.get_model_info(id)".format(
RouterErrors.user_defined_ratelimit_error.value,
deployment_rpm,
local_result,
model_id,
deployment.get("model_name", ""),
),
request=httpx.Request(method="tpm_rpm_limits", url="https://github.com/BerriAI/litellm"), # type: ignore
),
)
else:
# if local result below limit, check redis ## prevent unnecessary redis checks
result = self.router_cache.increment_cache(
key=rpm_key, value=1, ttl=self.routing_args.ttl
)
if result is not None and result > deployment_rpm:
raise litellm.RateLimitError(
message="Deployment over defined rpm limit={}. current usage={}".format(
deployment_rpm, result
),
llm_provider="",
model=deployment.get("litellm_params", {}).get("model"),
response=httpx.Response(
status_code=429,
content="{} rpm limit={}. current usage={}".format(
RouterErrors.user_defined_ratelimit_error.value,
deployment_rpm,
result,
),
request=httpx.Request(method="tpm_rpm_limits", url="https://github.com/BerriAI/litellm"), # type: ignore
),
)
return deployment
except Exception as e:
if isinstance(e, litellm.RateLimitError):
raise e
return deployment # don't fail calls if eg. redis fails to connect
async def async_pre_call_check(
self, deployment: Dict, parent_otel_span: Optional[Span]
) -> Optional[Dict]:
"""
Pre-call check + update model rpm
- Used inside semaphore
- raise rate limit error if deployment over limit
Why? solves concurrency issue - https://github.com/BerriAI/litellm/issues/2994
Returns - deployment
Raises - RateLimitError if deployment over defined RPM limit
"""
try:
# ------------
# Setup values
# ------------
dt = get_utc_datetime()
current_minute = dt.strftime("%H-%M")
model_id = deployment.get("model_info", {}).get("id")
deployment_name = deployment.get("litellm_params", {}).get("model")
rpm_key = f"{model_id}:{deployment_name}:rpm:{current_minute}"
local_result = await self.router_cache.async_get_cache(
key=rpm_key, local_only=True
) # check local result first
deployment_rpm = None
if deployment_rpm is None:
deployment_rpm = deployment.get("rpm")
if deployment_rpm is None:
deployment_rpm = deployment.get("litellm_params", {}).get("rpm")
if deployment_rpm is None:
deployment_rpm = deployment.get("model_info", {}).get("rpm")
if deployment_rpm is None:
deployment_rpm = float("inf")
if local_result is not None and local_result >= deployment_rpm:
raise litellm.RateLimitError(
message="Deployment over defined rpm limit={}. current usage={}".format(
deployment_rpm, local_result
),
llm_provider="",
model=deployment.get("litellm_params", {}).get("model"),
response=httpx.Response(
status_code=429,
content="{} rpm limit={}. current usage={}".format(
RouterErrors.user_defined_ratelimit_error.value,
deployment_rpm,
local_result,
),
headers={"retry-after": str(60)}, # type: ignore
request=httpx.Request(method="tpm_rpm_limits", url="https://github.com/BerriAI/litellm"), # type: ignore
),
num_retries=deployment.get("num_retries"),
)
else:
# if local result below limit, check redis ## prevent unnecessary redis checks
result = await self._increment_value_in_current_window(
key=rpm_key, value=1, ttl=self.routing_args.ttl
)
if result is not None and result > deployment_rpm:
raise litellm.RateLimitError(
message="Deployment over defined rpm limit={}. current usage={}".format(
deployment_rpm, result
),
llm_provider="",
model=deployment.get("litellm_params", {}).get("model"),
response=httpx.Response(
status_code=429,
content="{} rpm limit={}. current usage={}".format(
RouterErrors.user_defined_ratelimit_error.value,
deployment_rpm,
result,
),
headers={"retry-after": str(60)}, # type: ignore
request=httpx.Request(method="tpm_rpm_limits", url="https://github.com/BerriAI/litellm"), # type: ignore
),
num_retries=deployment.get("num_retries"),
)
return deployment
except Exception as e:
if isinstance(e, litellm.RateLimitError):
raise e
return deployment # don't fail calls if eg. redis fails to connect
def log_success_event(self, kwargs, response_obj, start_time, end_time):
try:
"""
Update TPM/RPM usage on success
"""
standard_logging_object: Optional[StandardLoggingPayload] = kwargs.get(
"standard_logging_object"
)
if standard_logging_object is None:
raise ValueError("standard_logging_object not passed in.")
model_group = standard_logging_object.get("model_group")
model = standard_logging_object["hidden_params"].get("litellm_model_name")
id = standard_logging_object.get("model_id")
if model_group is None or id is None or model is None:
return
elif isinstance(id, int):
id = str(id)
total_tokens = standard_logging_object.get("total_tokens")
# ------------
# Setup values
# ------------
dt = get_utc_datetime()
current_minute = dt.strftime(
"%H-%M"
) # use the same timezone regardless of system clock
tpm_key = f"{id}:{model}:tpm:{current_minute}"
# ------------
# Update usage
# ------------
# update cache
## TPM
self.router_cache.increment_cache(
key=tpm_key, value=total_tokens, ttl=self.routing_args.ttl
)
### TESTING ###
if self.test_flag:
self.logged_success += 1
except Exception as e:
verbose_logger.exception(
"litellm.proxy.hooks.lowest_tpm_rpm_v2.py::log_success_event(): Exception occured - {}".format(
str(e)
)
)
pass
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
try:
"""
Update TPM usage on success
"""
standard_logging_object: Optional[StandardLoggingPayload] = kwargs.get(
"standard_logging_object"
)
if standard_logging_object is None:
raise ValueError("standard_logging_object not passed in.")
model_group = standard_logging_object.get("model_group")
model = standard_logging_object["hidden_params"]["litellm_model_name"]
id = standard_logging_object.get("model_id")
if model_group is None or id is None:
return
elif isinstance(id, int):
id = str(id)
total_tokens = standard_logging_object.get("total_tokens")
# ------------
# Setup values
# ------------
dt = get_utc_datetime()
current_minute = dt.strftime(
"%H-%M"
) # use the same timezone regardless of system clock
tpm_key = f"{id}:{model}:tpm:{current_minute}"
# ------------
# Update usage
# ------------
# update cache
parent_otel_span = _get_parent_otel_span_from_kwargs(kwargs)
## TPM
await self.router_cache.async_increment_cache(
key=tpm_key,
value=total_tokens,
ttl=self.routing_args.ttl,
parent_otel_span=parent_otel_span,
)
### TESTING ###
if self.test_flag:
self.logged_success += 1
except Exception as e:
verbose_logger.exception(
"litellm.proxy.hooks.lowest_tpm_rpm_v2.py::async_log_success_event(): Exception occured - {}".format(
str(e)
)
)
pass
def _return_potential_deployments(
self,
healthy_deployments: List[Dict],
all_deployments: Dict,
input_tokens: int,
rpm_dict: Dict,
):
lowest_tpm = float("inf")
potential_deployments = [] # if multiple deployments have the same low value
for item, item_tpm in all_deployments.items():
## get the item from model list
_deployment = None
item = item.split(":")[0]
for m in healthy_deployments:
if item == m["model_info"]["id"]:
_deployment = m
if _deployment is None:
continue # skip to next one
elif item_tpm is None:
continue # skip if unhealthy deployment
_deployment_tpm = None
if _deployment_tpm is None:
_deployment_tpm = _deployment.get("tpm")
if _deployment_tpm is None:
_deployment_tpm = _deployment.get("litellm_params", {}).get("tpm")
if _deployment_tpm is None:
_deployment_tpm = _deployment.get("model_info", {}).get("tpm")
if _deployment_tpm is None:
_deployment_tpm = float("inf")
_deployment_rpm = None
if _deployment_rpm is None:
_deployment_rpm = _deployment.get("rpm")
if _deployment_rpm is None:
_deployment_rpm = _deployment.get("litellm_params", {}).get("rpm")
if _deployment_rpm is None:
_deployment_rpm = _deployment.get("model_info", {}).get("rpm")
if _deployment_rpm is None:
_deployment_rpm = float("inf")
if item_tpm + input_tokens > _deployment_tpm:
continue
elif (
(rpm_dict is not None and item in rpm_dict)
and rpm_dict[item] is not None
and (rpm_dict[item] + 1 >= _deployment_rpm)
):
continue
elif item_tpm == lowest_tpm:
potential_deployments.append(_deployment)
elif item_tpm < lowest_tpm:
lowest_tpm = item_tpm
potential_deployments = [_deployment]
return potential_deployments
def _common_checks_available_deployment( # noqa: PLR0915
self,
model_group: str,
healthy_deployments: list,
tpm_keys: list,
tpm_values: Optional[list],
rpm_keys: list,
rpm_values: Optional[list],
messages: Optional[List[Dict[str, str]]] = None,
input: Optional[Union[str, List]] = None,
) -> Optional[dict]:
"""
Common checks for get available deployment, across sync + async implementations
"""
if tpm_values is None or rpm_values is None:
return None
tpm_dict = {} # {model_id: 1, ..}
for idx, key in enumerate(tpm_keys):
tpm_dict[tpm_keys[idx].split(":")[0]] = tpm_values[idx]
rpm_dict = {} # {model_id: 1, ..}
for idx, key in enumerate(rpm_keys):
rpm_dict[rpm_keys[idx].split(":")[0]] = rpm_values[idx]
try:
input_tokens = token_counter(messages=messages, text=input)
except Exception:
input_tokens = 0
verbose_router_logger.debug(f"input_tokens={input_tokens}")
# -----------------------
# Find lowest used model
# ----------------------
if tpm_dict is None: # base case - none of the deployments have been used
# initialize a tpm dict with {model_id: 0}
tpm_dict = {}
for deployment in healthy_deployments:
tpm_dict[deployment["model_info"]["id"]] = 0
else:
for d in healthy_deployments:
## if healthy deployment not yet used
tpm_key = d["model_info"]["id"]
if tpm_key not in tpm_dict or tpm_dict[tpm_key] is None:
tpm_dict[tpm_key] = 0
all_deployments = tpm_dict
potential_deployments = self._return_potential_deployments(
healthy_deployments=healthy_deployments,
all_deployments=all_deployments,
input_tokens=input_tokens,
rpm_dict=rpm_dict,
)
print_verbose("returning picked lowest tpm/rpm deployment.")
if len(potential_deployments) > 0:
return random.choice(potential_deployments)
else:
return None
async def async_get_available_deployments(
self,
model_group: str,
healthy_deployments: list,
messages: Optional[List[Dict[str, str]]] = None,
input: Optional[Union[str, List]] = None,
):
"""
Async implementation of get deployments.
Reduces time to retrieve the tpm/rpm values from cache
"""
# get list of potential deployments
verbose_router_logger.debug(
f"get_available_deployments - Usage Based. model_group: {model_group}, healthy_deployments: {healthy_deployments}"
)
dt = get_utc_datetime()
current_minute = dt.strftime("%H-%M")
tpm_keys = []
rpm_keys = []
for m in healthy_deployments:
if isinstance(m, dict):
id = m.get("model_info", {}).get(
"id"
) # a deployment should always have an 'id'. this is set in router.py
deployment_name = m.get("litellm_params", {}).get("model")
tpm_key = "{}:{}:tpm:{}".format(id, deployment_name, current_minute)
rpm_key = "{}:{}:rpm:{}".format(id, deployment_name, current_minute)
tpm_keys.append(tpm_key)
rpm_keys.append(rpm_key)
combined_tpm_rpm_keys = tpm_keys + rpm_keys
combined_tpm_rpm_values = await self.router_cache.async_batch_get_cache(
keys=combined_tpm_rpm_keys
) # [1, 2, None, ..]
if combined_tpm_rpm_values is not None:
tpm_values = combined_tpm_rpm_values[: len(tpm_keys)]
rpm_values = combined_tpm_rpm_values[len(tpm_keys) :]
else:
tpm_values = None
rpm_values = None
deployment = self._common_checks_available_deployment(
model_group=model_group,
healthy_deployments=healthy_deployments,
tpm_keys=tpm_keys,
tpm_values=tpm_values,
rpm_keys=rpm_keys,
rpm_values=rpm_values,
messages=messages,
input=input,
)
try:
assert deployment is not None
return deployment
except Exception:
### GET THE DICT OF TPM / RPM + LIMITS PER DEPLOYMENT ###
deployment_dict = {}
for index, _deployment in enumerate(healthy_deployments):
if isinstance(_deployment, dict):
id = _deployment.get("model_info", {}).get("id")
### GET DEPLOYMENT TPM LIMIT ###
_deployment_tpm = None
if _deployment_tpm is None:
_deployment_tpm = _deployment.get("tpm", None)
if _deployment_tpm is None:
_deployment_tpm = _deployment.get("litellm_params", {}).get(
"tpm", None
)
if _deployment_tpm is None:
_deployment_tpm = _deployment.get("model_info", {}).get(
"tpm", None
)
if _deployment_tpm is None:
_deployment_tpm = float("inf")
### GET CURRENT TPM ###
current_tpm = tpm_values[index] if tpm_values else 0
### GET DEPLOYMENT TPM LIMIT ###
_deployment_rpm = None
if _deployment_rpm is None:
_deployment_rpm = _deployment.get("rpm", None)
if _deployment_rpm is None:
_deployment_rpm = _deployment.get("litellm_params", {}).get(
"rpm", None
)
if _deployment_rpm is None:
_deployment_rpm = _deployment.get("model_info", {}).get(
"rpm", None
)
if _deployment_rpm is None:
_deployment_rpm = float("inf")
### GET CURRENT RPM ###
current_rpm = rpm_values[index] if rpm_values else 0
deployment_dict[id] = {
"current_tpm": current_tpm,
"tpm_limit": _deployment_tpm,
"current_rpm": current_rpm,
"rpm_limit": _deployment_rpm,
}
raise litellm.RateLimitError(
message=f"{RouterErrors.no_deployments_available.value}. Passed model={model_group}. Deployments={deployment_dict}",
llm_provider="",
model=model_group,
response=httpx.Response(
status_code=429,
content="",
headers={"retry-after": str(60)}, # type: ignore
request=httpx.Request(method="tpm_rpm_limits", url="https://github.com/BerriAI/litellm"), # type: ignore
),
)
def get_available_deployments(
self,
model_group: str,
healthy_deployments: list,
messages: Optional[List[Dict[str, str]]] = None,
input: Optional[Union[str, List]] = None,
parent_otel_span: Optional[Span] = None,
):
"""
Returns a deployment with the lowest TPM/RPM usage.
"""
# get list of potential deployments
verbose_router_logger.debug(
f"get_available_deployments - Usage Based. model_group: {model_group}, healthy_deployments: {healthy_deployments}"
)
dt = get_utc_datetime()
current_minute = dt.strftime("%H-%M")
tpm_keys = []
rpm_keys = []
for m in healthy_deployments:
if isinstance(m, dict):
id = m.get("model_info", {}).get(
"id"
) # a deployment should always have an 'id'. this is set in router.py
deployment_name = m.get("litellm_params", {}).get("model")
tpm_key = "{}:{}:tpm:{}".format(id, deployment_name, current_minute)
rpm_key = "{}:{}:rpm:{}".format(id, deployment_name, current_minute)
tpm_keys.append(tpm_key)
rpm_keys.append(rpm_key)
tpm_values = self.router_cache.batch_get_cache(
keys=tpm_keys, parent_otel_span=parent_otel_span
) # [1, 2, None, ..]
rpm_values = self.router_cache.batch_get_cache(
keys=rpm_keys, parent_otel_span=parent_otel_span
) # [1, 2, None, ..]
deployment = self._common_checks_available_deployment(
model_group=model_group,
healthy_deployments=healthy_deployments,
tpm_keys=tpm_keys,
tpm_values=tpm_values,
rpm_keys=rpm_keys,
rpm_values=rpm_values,
messages=messages,
input=input,
)
try:
assert deployment is not None
return deployment
except Exception:
### GET THE DICT OF TPM / RPM + LIMITS PER DEPLOYMENT ###
deployment_dict = {}
for index, _deployment in enumerate(healthy_deployments):
if isinstance(_deployment, dict):
id = _deployment.get("model_info", {}).get("id")
### GET DEPLOYMENT TPM LIMIT ###
_deployment_tpm = None
if _deployment_tpm is None:
_deployment_tpm = _deployment.get("tpm", None)
if _deployment_tpm is None:
_deployment_tpm = _deployment.get("litellm_params", {}).get(
"tpm", None
)
if _deployment_tpm is None:
_deployment_tpm = _deployment.get("model_info", {}).get(
"tpm", None
)
if _deployment_tpm is None:
_deployment_tpm = float("inf")
### GET CURRENT TPM ###
current_tpm = tpm_values[index] if tpm_values else 0
### GET DEPLOYMENT TPM LIMIT ###
_deployment_rpm = None
if _deployment_rpm is None:
_deployment_rpm = _deployment.get("rpm", None)
if _deployment_rpm is None:
_deployment_rpm = _deployment.get("litellm_params", {}).get(
"rpm", None
)
if _deployment_rpm is None:
_deployment_rpm = _deployment.get("model_info", {}).get(
"rpm", None
)
if _deployment_rpm is None:
_deployment_rpm = float("inf")
### GET CURRENT RPM ###
current_rpm = rpm_values[index] if rpm_values else 0
deployment_dict[id] = {
"current_tpm": current_tpm,
"tpm_limit": _deployment_tpm,
"current_rpm": current_rpm,
"rpm_limit": _deployment_rpm,
}
raise ValueError(
f"{RouterErrors.no_deployments_available.value}. Passed model={model_group}. Deployments={deployment_dict}"
)

View File

@@ -0,0 +1,96 @@
"""
Returns a random deployment from the list of healthy deployments.
If weights are provided, it will return a deployment based on the weights.
"""
import random
from typing import TYPE_CHECKING, Any, Dict, List, Union
from litellm._logging import verbose_router_logger
if TYPE_CHECKING:
from litellm.router import Router as _Router
LitellmRouter = _Router
else:
LitellmRouter = Any
def simple_shuffle(
llm_router_instance: LitellmRouter,
healthy_deployments: Union[List[Any], Dict[Any, Any]],
model: str,
) -> Dict:
"""
Returns a random deployment from the list of healthy deployments.
If weights are provided, it will return a deployment based on the weights.
If users pass `rpm` or `tpm`, we do a random weighted pick - based on `rpm`/`tpm`.
Args:
llm_router_instance: LitellmRouter instance
healthy_deployments: List of healthy deployments
model: Model name
Returns:
Dict: A single healthy deployment
"""
############## Check if 'weight' param set for a weighted pick #################
weight = healthy_deployments[0].get("litellm_params").get("weight", None)
if weight is not None:
# use weight-random pick if rpms provided
weights = [m["litellm_params"].get("weight", 0) for m in healthy_deployments]
verbose_router_logger.debug(f"\nweight {weights}")
total_weight = sum(weights)
weights = [weight / total_weight for weight in weights]
verbose_router_logger.debug(f"\n weights {weights}")
# Perform weighted random pick
selected_index = random.choices(range(len(weights)), weights=weights)[0]
verbose_router_logger.debug(f"\n selected index, {selected_index}")
deployment = healthy_deployments[selected_index]
verbose_router_logger.info(
f"get_available_deployment for model: {model}, Selected deployment: {llm_router_instance.print_deployment(deployment) or deployment[0]} for model: {model}"
)
return deployment or deployment[0]
############## Check if we can do a RPM/TPM based weighted pick #################
rpm = healthy_deployments[0].get("litellm_params").get("rpm", None)
if rpm is not None:
# use weight-random pick if rpms provided
rpms = [m["litellm_params"].get("rpm", 0) for m in healthy_deployments]
verbose_router_logger.debug(f"\nrpms {rpms}")
total_rpm = sum(rpms)
weights = [rpm / total_rpm for rpm in rpms]
verbose_router_logger.debug(f"\n weights {weights}")
# Perform weighted random pick
selected_index = random.choices(range(len(rpms)), weights=weights)[0]
verbose_router_logger.debug(f"\n selected index, {selected_index}")
deployment = healthy_deployments[selected_index]
verbose_router_logger.info(
f"get_available_deployment for model: {model}, Selected deployment: {llm_router_instance.print_deployment(deployment) or deployment[0]} for model: {model}"
)
return deployment or deployment[0]
############## Check if we can do a RPM/TPM based weighted pick #################
tpm = healthy_deployments[0].get("litellm_params").get("tpm", None)
if tpm is not None:
# use weight-random pick if rpms provided
tpms = [m["litellm_params"].get("tpm", 0) for m in healthy_deployments]
verbose_router_logger.debug(f"\ntpms {tpms}")
total_tpm = sum(tpms)
weights = [tpm / total_tpm for tpm in tpms]
verbose_router_logger.debug(f"\n weights {weights}")
# Perform weighted random pick
selected_index = random.choices(range(len(tpms)), weights=weights)[0]
verbose_router_logger.debug(f"\n selected index, {selected_index}")
deployment = healthy_deployments[selected_index]
verbose_router_logger.info(
f"get_available_deployment for model: {model}, Selected deployment: {llm_router_instance.print_deployment(deployment) or deployment[0]} for model: {model}"
)
return deployment or deployment[0]
############## No RPM/TPM passed, we do a random pick #################
item = random.choice(healthy_deployments)
return item or item[0]

View File

@@ -0,0 +1,146 @@
"""
Use this to route requests between Teams
- If tags in request is a subset of tags in deployment, return deployment
- if deployments are set with default tags, return all default deployment
- If no default_deployments are set, return all deployments
"""
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
from litellm._logging import verbose_logger
from litellm.types.router import RouterErrors
if TYPE_CHECKING:
from litellm.router import Router as _Router
LitellmRouter = _Router
else:
LitellmRouter = Any
def is_valid_deployment_tag(
deployment_tags: List[str], request_tags: List[str]
) -> bool:
"""
Check if a tag is valid
"""
if any(tag in deployment_tags for tag in request_tags):
verbose_logger.debug(
"adding deployment with tags: %s, request tags: %s",
deployment_tags,
request_tags,
)
return True
elif "default" in deployment_tags:
verbose_logger.debug(
"adding default deployment with tags: %s, request tags: %s",
deployment_tags,
request_tags,
)
return True
return False
async def get_deployments_for_tag(
llm_router_instance: LitellmRouter,
model: str, # used to raise the correct error
healthy_deployments: Union[List[Any], Dict[Any, Any]],
request_kwargs: Optional[Dict[Any, Any]] = None,
):
"""
Returns a list of deployments that match the requested model and tags in the request.
Executes tag based filtering based on the tags in request metadata and the tags on the deployments
"""
if llm_router_instance.enable_tag_filtering is not True:
return healthy_deployments
if request_kwargs is None:
verbose_logger.debug(
"get_deployments_for_tag: request_kwargs is None returning healthy_deployments: %s",
healthy_deployments,
)
return healthy_deployments
if healthy_deployments is None:
verbose_logger.debug(
"get_deployments_for_tag: healthy_deployments is None returning healthy_deployments"
)
return healthy_deployments
verbose_logger.debug("request metadata: %s", request_kwargs.get("metadata"))
if "metadata" in request_kwargs:
metadata = request_kwargs["metadata"]
request_tags = metadata.get("tags")
new_healthy_deployments = []
if request_tags:
verbose_logger.debug(
"get_deployments_for_tag routing: router_keys: %s", request_tags
)
# example this can be router_keys=["free", "custom"]
# get all deployments that have a superset of these router keys
for deployment in healthy_deployments:
deployment_litellm_params = deployment.get("litellm_params")
deployment_tags = deployment_litellm_params.get("tags")
verbose_logger.debug(
"deployment: %s, deployment_router_keys: %s",
deployment,
deployment_tags,
)
if deployment_tags is None:
continue
if is_valid_deployment_tag(deployment_tags, request_tags):
new_healthy_deployments.append(deployment)
if len(new_healthy_deployments) == 0:
raise ValueError(
f"{RouterErrors.no_deployments_with_tag_routing.value}. Passed model={model} and tags={request_tags}"
)
return new_healthy_deployments
# for Untagged requests use default deployments if set
_default_deployments_with_tags = []
for deployment in healthy_deployments:
if "default" in deployment.get("litellm_params", {}).get("tags", []):
_default_deployments_with_tags.append(deployment)
if len(_default_deployments_with_tags) > 0:
return _default_deployments_with_tags
# if no default deployment is found, return healthy_deployments
verbose_logger.debug(
"no tier found in metadata, returning healthy_deployments: %s",
healthy_deployments,
)
return healthy_deployments
def _get_tags_from_request_kwargs(
request_kwargs: Optional[Dict[Any, Any]] = None
) -> List[str]:
"""
Helper to get tags from request kwargs
Args:
request_kwargs: The request kwargs to get tags from
Returns:
List[str]: The tags from the request kwargs
"""
if request_kwargs is None:
return []
if "metadata" in request_kwargs:
metadata = request_kwargs["metadata"]
return metadata.get("tags", [])
elif "litellm_params" in request_kwargs:
litellm_params = request_kwargs["litellm_params"]
_metadata = litellm_params.get("metadata", {})
return _metadata.get("tags", [])
return []