structure saas with tools
This commit is contained in:
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,190 @@
|
||||
"""
|
||||
Base class across routing strategies to abstract commmon functions like batch incrementing redis
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import threading
|
||||
from abc import ABC
|
||||
from typing import List, Optional, Set, Union
|
||||
|
||||
from litellm._logging import verbose_router_logger
|
||||
from litellm.caching.caching import DualCache
|
||||
from litellm.caching.redis_cache import RedisPipelineIncrementOperation
|
||||
from litellm.constants import DEFAULT_REDIS_SYNC_INTERVAL
|
||||
|
||||
|
||||
class BaseRoutingStrategy(ABC):
|
||||
def __init__(
|
||||
self,
|
||||
dual_cache: DualCache,
|
||||
should_batch_redis_writes: bool,
|
||||
default_sync_interval: Optional[Union[int, float]],
|
||||
):
|
||||
self.dual_cache = dual_cache
|
||||
self.redis_increment_operation_queue: List[RedisPipelineIncrementOperation] = []
|
||||
if should_batch_redis_writes:
|
||||
try:
|
||||
# Try to get existing event loop
|
||||
loop = asyncio.get_event_loop()
|
||||
if loop.is_running():
|
||||
# If loop exists and is running, create task in existing loop
|
||||
loop.create_task(
|
||||
self.periodic_sync_in_memory_spend_with_redis(
|
||||
default_sync_interval=default_sync_interval
|
||||
)
|
||||
)
|
||||
else:
|
||||
self._create_sync_thread(default_sync_interval)
|
||||
except RuntimeError: # No event loop in current thread
|
||||
self._create_sync_thread(default_sync_interval)
|
||||
|
||||
self.in_memory_keys_to_update: set[
|
||||
str
|
||||
] = set() # Set with max size of 1000 keys
|
||||
|
||||
async def _increment_value_in_current_window(
|
||||
self, key: str, value: Union[int, float], ttl: int
|
||||
):
|
||||
"""
|
||||
Increment spend within existing budget window
|
||||
|
||||
Runs once the budget start time exists in Redis Cache (on the 2nd and subsequent requests to the same provider)
|
||||
|
||||
- Increments the spend in memory cache (so spend instantly updated in memory)
|
||||
- Queues the increment operation to Redis Pipeline (using batched pipeline to optimize performance. Using Redis for multi instance environment of LiteLLM)
|
||||
"""
|
||||
result = await self.dual_cache.in_memory_cache.async_increment(
|
||||
key=key,
|
||||
value=value,
|
||||
ttl=ttl,
|
||||
)
|
||||
increment_op = RedisPipelineIncrementOperation(
|
||||
key=key,
|
||||
increment_value=value,
|
||||
ttl=ttl,
|
||||
)
|
||||
self.redis_increment_operation_queue.append(increment_op)
|
||||
self.add_to_in_memory_keys_to_update(key=key)
|
||||
return result
|
||||
|
||||
async def periodic_sync_in_memory_spend_with_redis(
|
||||
self, default_sync_interval: Optional[Union[int, float]]
|
||||
):
|
||||
"""
|
||||
Handler that triggers sync_in_memory_spend_with_redis every DEFAULT_REDIS_SYNC_INTERVAL seconds
|
||||
|
||||
Required for multi-instance environment usage of provider budgets
|
||||
"""
|
||||
default_sync_interval = default_sync_interval or DEFAULT_REDIS_SYNC_INTERVAL
|
||||
while True:
|
||||
try:
|
||||
await self._sync_in_memory_spend_with_redis()
|
||||
await asyncio.sleep(
|
||||
default_sync_interval
|
||||
) # Wait for DEFAULT_REDIS_SYNC_INTERVAL seconds before next sync
|
||||
except Exception as e:
|
||||
verbose_router_logger.error(f"Error in periodic sync task: {str(e)}")
|
||||
await asyncio.sleep(
|
||||
default_sync_interval
|
||||
) # Still wait DEFAULT_REDIS_SYNC_INTERVAL seconds on error before retrying
|
||||
|
||||
async def _push_in_memory_increments_to_redis(self):
|
||||
"""
|
||||
How this works:
|
||||
- async_log_success_event collects all provider spend increments in `redis_increment_operation_queue`
|
||||
- This function pushes all increments to Redis in a batched pipeline to optimize performance
|
||||
|
||||
Only runs if Redis is initialized
|
||||
"""
|
||||
try:
|
||||
if not self.dual_cache.redis_cache:
|
||||
return # Redis is not initialized
|
||||
|
||||
verbose_router_logger.debug(
|
||||
"Pushing Redis Increment Pipeline for queue: %s",
|
||||
self.redis_increment_operation_queue,
|
||||
)
|
||||
if len(self.redis_increment_operation_queue) > 0:
|
||||
asyncio.create_task(
|
||||
self.dual_cache.redis_cache.async_increment_pipeline(
|
||||
increment_list=self.redis_increment_operation_queue,
|
||||
)
|
||||
)
|
||||
|
||||
self.redis_increment_operation_queue = []
|
||||
|
||||
except Exception as e:
|
||||
verbose_router_logger.error(
|
||||
f"Error syncing in-memory cache with Redis: {str(e)}"
|
||||
)
|
||||
self.redis_increment_operation_queue = []
|
||||
|
||||
def add_to_in_memory_keys_to_update(self, key: str):
|
||||
self.in_memory_keys_to_update.add(key)
|
||||
|
||||
def get_in_memory_keys_to_update(self) -> Set[str]:
|
||||
return self.in_memory_keys_to_update
|
||||
|
||||
def reset_in_memory_keys_to_update(self):
|
||||
self.in_memory_keys_to_update = set()
|
||||
|
||||
async def _sync_in_memory_spend_with_redis(self):
|
||||
"""
|
||||
Ensures in-memory cache is updated with latest Redis values for all provider spends.
|
||||
|
||||
Why Do we need this?
|
||||
- Optimization to hit sub 100ms latency. Performance was impacted when redis was used for read/write per request
|
||||
- Use provider budgets in multi-instance environment, we use Redis to sync spend across all instances
|
||||
|
||||
What this does:
|
||||
1. Push all provider spend increments to Redis
|
||||
2. Fetch all current provider spend from Redis to update in-memory cache
|
||||
"""
|
||||
|
||||
try:
|
||||
# No need to sync if Redis cache is not initialized
|
||||
if self.dual_cache.redis_cache is None:
|
||||
return
|
||||
|
||||
# 1. Push all provider spend increments to Redis
|
||||
await self._push_in_memory_increments_to_redis()
|
||||
|
||||
# 2. Fetch all current provider spend from Redis to update in-memory cache
|
||||
cache_keys = self.get_in_memory_keys_to_update()
|
||||
|
||||
cache_keys_list = list(cache_keys)
|
||||
|
||||
# Batch fetch current spend values from Redis
|
||||
redis_values = await self.dual_cache.redis_cache.async_batch_get_cache(
|
||||
key_list=cache_keys_list
|
||||
)
|
||||
|
||||
# Update in-memory cache with Redis values
|
||||
if isinstance(redis_values, dict): # Check if redis_values is a dictionary
|
||||
for key, value in redis_values.items():
|
||||
if value is not None:
|
||||
await self.dual_cache.in_memory_cache.async_set_cache(
|
||||
key=key, value=float(value)
|
||||
)
|
||||
verbose_router_logger.debug(
|
||||
f"Updated in-memory cache for {key}: {value}"
|
||||
)
|
||||
|
||||
self.reset_in_memory_keys_to_update()
|
||||
except Exception as e:
|
||||
verbose_router_logger.exception(
|
||||
f"Error syncing in-memory cache with Redis: {str(e)}"
|
||||
)
|
||||
|
||||
def _create_sync_thread(self, default_sync_interval):
|
||||
"""Helper method to create a new thread for periodic sync"""
|
||||
thread = threading.Thread(
|
||||
target=asyncio.run,
|
||||
args=(
|
||||
self.periodic_sync_in_memory_spend_with_redis(
|
||||
default_sync_interval=default_sync_interval
|
||||
),
|
||||
),
|
||||
daemon=True,
|
||||
)
|
||||
thread.start()
|
||||
@@ -0,0 +1,821 @@
|
||||
"""
|
||||
Provider budget limiting
|
||||
|
||||
Use this if you want to set $ budget limits for each provider.
|
||||
|
||||
Note: This is a filter, like tag-routing. Meaning it will accept healthy deployments and then filter out deployments that have exceeded their budget limit.
|
||||
|
||||
This means you can use this with weighted-pick, lowest-latency, simple-shuffle, routing etc
|
||||
|
||||
Example:
|
||||
```
|
||||
openai:
|
||||
budget_limit: 0.000000000001
|
||||
time_period: 1d
|
||||
anthropic:
|
||||
budget_limit: 100
|
||||
time_period: 7d
|
||||
```
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import litellm
|
||||
from litellm._logging import verbose_router_logger
|
||||
from litellm.caching.caching import DualCache
|
||||
from litellm.caching.redis_cache import RedisPipelineIncrementOperation
|
||||
from litellm.integrations.custom_logger import CustomLogger, Span
|
||||
from litellm.litellm_core_utils.duration_parser import duration_in_seconds
|
||||
from litellm.router_strategy.tag_based_routing import _get_tags_from_request_kwargs
|
||||
from litellm.router_utils.cooldown_callbacks import (
|
||||
_get_prometheus_logger_from_callbacks,
|
||||
)
|
||||
from litellm.types.llms.openai import AllMessageValues
|
||||
from litellm.types.router import DeploymentTypedDict, LiteLLM_Params, RouterErrors
|
||||
from litellm.types.utils import BudgetConfig
|
||||
from litellm.types.utils import BudgetConfig as GenericBudgetInfo
|
||||
from litellm.types.utils import GenericBudgetConfigType, StandardLoggingPayload
|
||||
|
||||
DEFAULT_REDIS_SYNC_INTERVAL = 1
|
||||
|
||||
|
||||
class RouterBudgetLimiting(CustomLogger):
|
||||
def __init__(
|
||||
self,
|
||||
dual_cache: DualCache,
|
||||
provider_budget_config: Optional[dict],
|
||||
model_list: Optional[
|
||||
Union[List[DeploymentTypedDict], List[Dict[str, Any]]]
|
||||
] = None,
|
||||
):
|
||||
self.dual_cache = dual_cache
|
||||
self.redis_increment_operation_queue: List[RedisPipelineIncrementOperation] = []
|
||||
asyncio.create_task(self.periodic_sync_in_memory_spend_with_redis())
|
||||
self.provider_budget_config: Optional[
|
||||
GenericBudgetConfigType
|
||||
] = provider_budget_config
|
||||
self.deployment_budget_config: Optional[GenericBudgetConfigType] = None
|
||||
self.tag_budget_config: Optional[GenericBudgetConfigType] = None
|
||||
self._init_provider_budgets()
|
||||
self._init_deployment_budgets(model_list=model_list)
|
||||
self._init_tag_budgets()
|
||||
|
||||
# Add self to litellm callbacks if it's a list
|
||||
if isinstance(litellm.callbacks, list):
|
||||
litellm.logging_callback_manager.add_litellm_callback(self) # type: ignore
|
||||
|
||||
async def async_filter_deployments(
|
||||
self,
|
||||
model: str,
|
||||
healthy_deployments: List,
|
||||
messages: Optional[List[AllMessageValues]],
|
||||
request_kwargs: Optional[dict] = None,
|
||||
parent_otel_span: Optional[Span] = None, # type: ignore
|
||||
) -> List[dict]:
|
||||
"""
|
||||
Filter out deployments that have exceeded their provider budget limit.
|
||||
|
||||
|
||||
Example:
|
||||
if deployment = openai/gpt-3.5-turbo
|
||||
and openai spend > openai budget limit
|
||||
then skip this deployment
|
||||
"""
|
||||
|
||||
# If a single deployment is passed, convert it to a list
|
||||
if isinstance(healthy_deployments, dict):
|
||||
healthy_deployments = [healthy_deployments]
|
||||
|
||||
# Don't do any filtering if there are no healthy deployments
|
||||
if len(healthy_deployments) == 0:
|
||||
return healthy_deployments
|
||||
|
||||
potential_deployments: List[Dict] = []
|
||||
|
||||
(
|
||||
cache_keys,
|
||||
provider_configs,
|
||||
deployment_configs,
|
||||
) = await self._async_get_cache_keys_for_router_budget_limiting(
|
||||
healthy_deployments=healthy_deployments,
|
||||
request_kwargs=request_kwargs,
|
||||
)
|
||||
|
||||
# Single cache read for all spend values
|
||||
if len(cache_keys) > 0:
|
||||
_current_spends = await self.dual_cache.async_batch_get_cache(
|
||||
keys=cache_keys,
|
||||
parent_otel_span=parent_otel_span,
|
||||
)
|
||||
current_spends: List = _current_spends or [0.0] * len(cache_keys)
|
||||
|
||||
# Map spends to their respective keys
|
||||
spend_map: Dict[str, float] = {}
|
||||
for idx, key in enumerate(cache_keys):
|
||||
spend_map[key] = float(current_spends[idx] or 0.0)
|
||||
|
||||
(
|
||||
potential_deployments,
|
||||
deployment_above_budget_info,
|
||||
) = self._filter_out_deployments_above_budget(
|
||||
healthy_deployments=healthy_deployments,
|
||||
provider_configs=provider_configs,
|
||||
deployment_configs=deployment_configs,
|
||||
spend_map=spend_map,
|
||||
potential_deployments=potential_deployments,
|
||||
request_tags=_get_tags_from_request_kwargs(
|
||||
request_kwargs=request_kwargs
|
||||
),
|
||||
)
|
||||
|
||||
if len(potential_deployments) == 0:
|
||||
raise ValueError(
|
||||
f"{RouterErrors.no_deployments_with_provider_budget_routing.value}: {deployment_above_budget_info}"
|
||||
)
|
||||
|
||||
return potential_deployments
|
||||
else:
|
||||
return healthy_deployments
|
||||
|
||||
def _filter_out_deployments_above_budget(
|
||||
self,
|
||||
potential_deployments: List[Dict[str, Any]],
|
||||
healthy_deployments: List[Dict[str, Any]],
|
||||
provider_configs: Dict[str, GenericBudgetInfo],
|
||||
deployment_configs: Dict[str, GenericBudgetInfo],
|
||||
spend_map: Dict[str, float],
|
||||
request_tags: List[str],
|
||||
) -> Tuple[List[Dict[str, Any]], str]:
|
||||
"""
|
||||
Filter out deployments that have exceeded their budget limit.
|
||||
Follow budget checks are run here:
|
||||
- Provider budget
|
||||
- Deployment budget
|
||||
- Request tags budget
|
||||
Returns:
|
||||
Tuple[List[Dict[str, Any]], str]:
|
||||
- A tuple containing the filtered deployments
|
||||
- A string containing debug information about deployments that exceeded their budget limit.
|
||||
"""
|
||||
# Filter deployments based on both provider and deployment budgets
|
||||
deployment_above_budget_info: str = ""
|
||||
for deployment in healthy_deployments:
|
||||
is_within_budget = True
|
||||
|
||||
# Check provider budget
|
||||
if self.provider_budget_config:
|
||||
provider = self._get_llm_provider_for_deployment(deployment)
|
||||
if provider in provider_configs:
|
||||
config = provider_configs[provider]
|
||||
if config.max_budget is None:
|
||||
continue
|
||||
current_spend = spend_map.get(
|
||||
f"provider_spend:{provider}:{config.budget_duration}", 0.0
|
||||
)
|
||||
self._track_provider_remaining_budget_prometheus(
|
||||
provider=provider,
|
||||
spend=current_spend,
|
||||
budget_limit=config.max_budget,
|
||||
)
|
||||
|
||||
if config.max_budget and current_spend >= config.max_budget:
|
||||
debug_msg = f"Exceeded budget for provider {provider}: {current_spend} >= {config.max_budget}"
|
||||
deployment_above_budget_info += f"{debug_msg}\n"
|
||||
is_within_budget = False
|
||||
continue
|
||||
|
||||
# Check deployment budget
|
||||
if self.deployment_budget_config and is_within_budget:
|
||||
_model_name = deployment.get("model_name")
|
||||
_litellm_params = deployment.get("litellm_params") or {}
|
||||
_litellm_model_name = _litellm_params.get("model")
|
||||
model_id = deployment.get("model_info", {}).get("id")
|
||||
if model_id in deployment_configs:
|
||||
config = deployment_configs[model_id]
|
||||
current_spend = spend_map.get(
|
||||
f"deployment_spend:{model_id}:{config.budget_duration}", 0.0
|
||||
)
|
||||
if config.max_budget and current_spend >= config.max_budget:
|
||||
debug_msg = f"Exceeded budget for deployment model_name: {_model_name}, litellm_params.model: {_litellm_model_name}, model_id: {model_id}: {current_spend} >= {config.budget_duration}"
|
||||
verbose_router_logger.debug(debug_msg)
|
||||
deployment_above_budget_info += f"{debug_msg}\n"
|
||||
is_within_budget = False
|
||||
continue
|
||||
# Check tag budget
|
||||
if self.tag_budget_config and is_within_budget:
|
||||
for _tag in request_tags:
|
||||
_tag_budget_config = self._get_budget_config_for_tag(_tag)
|
||||
if _tag_budget_config:
|
||||
_tag_spend = spend_map.get(
|
||||
f"tag_spend:{_tag}:{_tag_budget_config.budget_duration}",
|
||||
0.0,
|
||||
)
|
||||
if (
|
||||
_tag_budget_config.max_budget
|
||||
and _tag_spend >= _tag_budget_config.max_budget
|
||||
):
|
||||
debug_msg = f"Exceeded budget for tag='{_tag}', tag_spend={_tag_spend}, tag_budget_limit={_tag_budget_config.max_budget}"
|
||||
verbose_router_logger.debug(debug_msg)
|
||||
deployment_above_budget_info += f"{debug_msg}\n"
|
||||
is_within_budget = False
|
||||
continue
|
||||
if is_within_budget:
|
||||
potential_deployments.append(deployment)
|
||||
|
||||
return potential_deployments, deployment_above_budget_info
|
||||
|
||||
async def _async_get_cache_keys_for_router_budget_limiting(
|
||||
self,
|
||||
healthy_deployments: List[Dict[str, Any]],
|
||||
request_kwargs: Optional[Dict] = None,
|
||||
) -> Tuple[List[str], Dict[str, GenericBudgetInfo], Dict[str, GenericBudgetInfo]]:
|
||||
"""
|
||||
Returns list of cache keys to fetch from router cache for budget limiting and provider and deployment configs
|
||||
|
||||
Returns:
|
||||
Tuple[List[str], Dict[str, GenericBudgetInfo], Dict[str, GenericBudgetInfo]]:
|
||||
- List of cache keys to fetch from router cache for budget limiting
|
||||
- Dict of provider budget configs `provider_configs`
|
||||
- Dict of deployment budget configs `deployment_configs`
|
||||
"""
|
||||
cache_keys: List[str] = []
|
||||
provider_configs: Dict[str, GenericBudgetInfo] = {}
|
||||
deployment_configs: Dict[str, GenericBudgetInfo] = {}
|
||||
|
||||
for deployment in healthy_deployments:
|
||||
# Check provider budgets
|
||||
if self.provider_budget_config:
|
||||
provider = self._get_llm_provider_for_deployment(deployment)
|
||||
if provider is not None:
|
||||
budget_config = self._get_budget_config_for_provider(provider)
|
||||
if (
|
||||
budget_config is not None
|
||||
and budget_config.budget_duration is not None
|
||||
):
|
||||
provider_configs[provider] = budget_config
|
||||
cache_keys.append(
|
||||
f"provider_spend:{provider}:{budget_config.budget_duration}"
|
||||
)
|
||||
|
||||
# Check deployment budgets
|
||||
if self.deployment_budget_config:
|
||||
model_id = deployment.get("model_info", {}).get("id")
|
||||
if model_id is not None:
|
||||
budget_config = self._get_budget_config_for_deployment(model_id)
|
||||
if budget_config is not None:
|
||||
deployment_configs[model_id] = budget_config
|
||||
cache_keys.append(
|
||||
f"deployment_spend:{model_id}:{budget_config.budget_duration}"
|
||||
)
|
||||
# Check tag budgets
|
||||
if self.tag_budget_config:
|
||||
request_tags = _get_tags_from_request_kwargs(
|
||||
request_kwargs=request_kwargs
|
||||
)
|
||||
for _tag in request_tags:
|
||||
_tag_budget_config = self._get_budget_config_for_tag(_tag)
|
||||
if _tag_budget_config:
|
||||
cache_keys.append(
|
||||
f"tag_spend:{_tag}:{_tag_budget_config.budget_duration}"
|
||||
)
|
||||
return cache_keys, provider_configs, deployment_configs
|
||||
|
||||
async def _get_or_set_budget_start_time(
|
||||
self, start_time_key: str, current_time: float, ttl_seconds: int
|
||||
) -> float:
|
||||
"""
|
||||
Checks if the key = `provider_budget_start_time:{provider}` exists in cache.
|
||||
|
||||
If it does, return the value.
|
||||
If it does not, set the key to `current_time` and return the value.
|
||||
"""
|
||||
budget_start = await self.dual_cache.async_get_cache(start_time_key)
|
||||
if budget_start is None:
|
||||
await self.dual_cache.async_set_cache(
|
||||
key=start_time_key, value=current_time, ttl=ttl_seconds
|
||||
)
|
||||
return current_time
|
||||
return float(budget_start)
|
||||
|
||||
async def _handle_new_budget_window(
|
||||
self,
|
||||
spend_key: str,
|
||||
start_time_key: str,
|
||||
current_time: float,
|
||||
response_cost: float,
|
||||
ttl_seconds: int,
|
||||
) -> float:
|
||||
"""
|
||||
Handle start of new budget window by resetting spend and start time
|
||||
|
||||
Enters this when:
|
||||
- The budget does not exist in cache, so we need to set it
|
||||
- The budget window has expired, so we need to reset everything
|
||||
|
||||
Does 2 things:
|
||||
- stores key: `provider_spend:{provider}:1d`, value: response_cost
|
||||
- stores key: `provider_budget_start_time:{provider}`, value: current_time.
|
||||
This stores the start time of the new budget window
|
||||
"""
|
||||
await self.dual_cache.async_set_cache(
|
||||
key=spend_key, value=response_cost, ttl=ttl_seconds
|
||||
)
|
||||
await self.dual_cache.async_set_cache(
|
||||
key=start_time_key, value=current_time, ttl=ttl_seconds
|
||||
)
|
||||
return current_time
|
||||
|
||||
async def _increment_spend_in_current_window(
|
||||
self, spend_key: str, response_cost: float, ttl: int
|
||||
):
|
||||
"""
|
||||
Increment spend within existing budget window
|
||||
|
||||
Runs once the budget start time exists in Redis Cache (on the 2nd and subsequent requests to the same provider)
|
||||
|
||||
- Increments the spend in memory cache (so spend instantly updated in memory)
|
||||
- Queues the increment operation to Redis Pipeline (using batched pipeline to optimize performance. Using Redis for multi instance environment of LiteLLM)
|
||||
"""
|
||||
await self.dual_cache.in_memory_cache.async_increment(
|
||||
key=spend_key,
|
||||
value=response_cost,
|
||||
ttl=ttl,
|
||||
)
|
||||
increment_op = RedisPipelineIncrementOperation(
|
||||
key=spend_key,
|
||||
increment_value=response_cost,
|
||||
ttl=ttl,
|
||||
)
|
||||
self.redis_increment_operation_queue.append(increment_op)
|
||||
|
||||
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
|
||||
"""Original method now uses helper functions"""
|
||||
verbose_router_logger.debug("in RouterBudgetLimiting.async_log_success_event")
|
||||
standard_logging_payload: Optional[StandardLoggingPayload] = kwargs.get(
|
||||
"standard_logging_object", None
|
||||
)
|
||||
if standard_logging_payload is None:
|
||||
raise ValueError("standard_logging_payload is required")
|
||||
|
||||
response_cost: float = standard_logging_payload.get("response_cost", 0)
|
||||
model_id: str = str(standard_logging_payload.get("model_id", ""))
|
||||
custom_llm_provider: str = kwargs.get("litellm_params", {}).get(
|
||||
"custom_llm_provider", None
|
||||
)
|
||||
if custom_llm_provider is None:
|
||||
raise ValueError("custom_llm_provider is required")
|
||||
|
||||
budget_config = self._get_budget_config_for_provider(custom_llm_provider)
|
||||
if budget_config:
|
||||
# increment spend for provider
|
||||
spend_key = (
|
||||
f"provider_spend:{custom_llm_provider}:{budget_config.budget_duration}"
|
||||
)
|
||||
start_time_key = f"provider_budget_start_time:{custom_llm_provider}"
|
||||
await self._increment_spend_for_key(
|
||||
budget_config=budget_config,
|
||||
spend_key=spend_key,
|
||||
start_time_key=start_time_key,
|
||||
response_cost=response_cost,
|
||||
)
|
||||
|
||||
deployment_budget_config = self._get_budget_config_for_deployment(model_id)
|
||||
if deployment_budget_config:
|
||||
# increment spend for specific deployment id
|
||||
deployment_spend_key = f"deployment_spend:{model_id}:{deployment_budget_config.budget_duration}"
|
||||
deployment_start_time_key = f"deployment_budget_start_time:{model_id}"
|
||||
await self._increment_spend_for_key(
|
||||
budget_config=deployment_budget_config,
|
||||
spend_key=deployment_spend_key,
|
||||
start_time_key=deployment_start_time_key,
|
||||
response_cost=response_cost,
|
||||
)
|
||||
|
||||
request_tags = _get_tags_from_request_kwargs(kwargs)
|
||||
if len(request_tags) > 0:
|
||||
for _tag in request_tags:
|
||||
_tag_budget_config = self._get_budget_config_for_tag(_tag)
|
||||
if _tag_budget_config:
|
||||
_tag_spend_key = (
|
||||
f"tag_spend:{_tag}:{_tag_budget_config.budget_duration}"
|
||||
)
|
||||
_tag_start_time_key = f"tag_budget_start_time:{_tag}"
|
||||
await self._increment_spend_for_key(
|
||||
budget_config=_tag_budget_config,
|
||||
spend_key=_tag_spend_key,
|
||||
start_time_key=_tag_start_time_key,
|
||||
response_cost=response_cost,
|
||||
)
|
||||
|
||||
async def _increment_spend_for_key(
|
||||
self,
|
||||
budget_config: GenericBudgetInfo,
|
||||
spend_key: str,
|
||||
start_time_key: str,
|
||||
response_cost: float,
|
||||
):
|
||||
if budget_config.budget_duration is None:
|
||||
return
|
||||
|
||||
current_time = datetime.now(timezone.utc).timestamp()
|
||||
ttl_seconds = duration_in_seconds(budget_config.budget_duration)
|
||||
|
||||
budget_start = await self._get_or_set_budget_start_time(
|
||||
start_time_key=start_time_key,
|
||||
current_time=current_time,
|
||||
ttl_seconds=ttl_seconds,
|
||||
)
|
||||
|
||||
if budget_start is None:
|
||||
# First spend for this provider
|
||||
budget_start = await self._handle_new_budget_window(
|
||||
spend_key=spend_key,
|
||||
start_time_key=start_time_key,
|
||||
current_time=current_time,
|
||||
response_cost=response_cost,
|
||||
ttl_seconds=ttl_seconds,
|
||||
)
|
||||
elif (current_time - budget_start) > ttl_seconds:
|
||||
# Budget window expired - reset everything
|
||||
verbose_router_logger.debug("Budget window expired - resetting everything")
|
||||
budget_start = await self._handle_new_budget_window(
|
||||
spend_key=spend_key,
|
||||
start_time_key=start_time_key,
|
||||
current_time=current_time,
|
||||
response_cost=response_cost,
|
||||
ttl_seconds=ttl_seconds,
|
||||
)
|
||||
else:
|
||||
# Within existing window - increment spend
|
||||
remaining_time = ttl_seconds - (current_time - budget_start)
|
||||
ttl_for_increment = int(remaining_time)
|
||||
|
||||
await self._increment_spend_in_current_window(
|
||||
spend_key=spend_key, response_cost=response_cost, ttl=ttl_for_increment
|
||||
)
|
||||
|
||||
verbose_router_logger.debug(
|
||||
f"Incremented spend for {spend_key} by {response_cost}"
|
||||
)
|
||||
|
||||
async def periodic_sync_in_memory_spend_with_redis(self):
|
||||
"""
|
||||
Handler that triggers sync_in_memory_spend_with_redis every DEFAULT_REDIS_SYNC_INTERVAL seconds
|
||||
|
||||
Required for multi-instance environment usage of provider budgets
|
||||
"""
|
||||
while True:
|
||||
try:
|
||||
await self._sync_in_memory_spend_with_redis()
|
||||
await asyncio.sleep(
|
||||
DEFAULT_REDIS_SYNC_INTERVAL
|
||||
) # Wait for DEFAULT_REDIS_SYNC_INTERVAL seconds before next sync
|
||||
except Exception as e:
|
||||
verbose_router_logger.error(f"Error in periodic sync task: {str(e)}")
|
||||
await asyncio.sleep(
|
||||
DEFAULT_REDIS_SYNC_INTERVAL
|
||||
) # Still wait DEFAULT_REDIS_SYNC_INTERVAL seconds on error before retrying
|
||||
|
||||
async def _push_in_memory_increments_to_redis(self):
|
||||
"""
|
||||
How this works:
|
||||
- async_log_success_event collects all provider spend increments in `redis_increment_operation_queue`
|
||||
- This function pushes all increments to Redis in a batched pipeline to optimize performance
|
||||
|
||||
Only runs if Redis is initialized
|
||||
"""
|
||||
try:
|
||||
if not self.dual_cache.redis_cache:
|
||||
return # Redis is not initialized
|
||||
|
||||
verbose_router_logger.debug(
|
||||
"Pushing Redis Increment Pipeline for queue: %s",
|
||||
self.redis_increment_operation_queue,
|
||||
)
|
||||
if len(self.redis_increment_operation_queue) > 0:
|
||||
asyncio.create_task(
|
||||
self.dual_cache.redis_cache.async_increment_pipeline(
|
||||
increment_list=self.redis_increment_operation_queue,
|
||||
)
|
||||
)
|
||||
|
||||
self.redis_increment_operation_queue = []
|
||||
|
||||
except Exception as e:
|
||||
verbose_router_logger.error(
|
||||
f"Error syncing in-memory cache with Redis: {str(e)}"
|
||||
)
|
||||
|
||||
async def _sync_in_memory_spend_with_redis(self):
|
||||
"""
|
||||
Ensures in-memory cache is updated with latest Redis values for all provider spends.
|
||||
|
||||
Why Do we need this?
|
||||
- Optimization to hit sub 100ms latency. Performance was impacted when redis was used for read/write per request
|
||||
- Use provider budgets in multi-instance environment, we use Redis to sync spend across all instances
|
||||
|
||||
What this does:
|
||||
1. Push all provider spend increments to Redis
|
||||
2. Fetch all current provider spend from Redis to update in-memory cache
|
||||
"""
|
||||
|
||||
try:
|
||||
# No need to sync if Redis cache is not initialized
|
||||
if self.dual_cache.redis_cache is None:
|
||||
return
|
||||
|
||||
# 1. Push all provider spend increments to Redis
|
||||
await self._push_in_memory_increments_to_redis()
|
||||
|
||||
# 2. Fetch all current provider spend from Redis to update in-memory cache
|
||||
cache_keys = []
|
||||
|
||||
if self.provider_budget_config is not None:
|
||||
for provider, config in self.provider_budget_config.items():
|
||||
if config is None:
|
||||
continue
|
||||
cache_keys.append(
|
||||
f"provider_spend:{provider}:{config.budget_duration}"
|
||||
)
|
||||
|
||||
if self.deployment_budget_config is not None:
|
||||
for model_id, config in self.deployment_budget_config.items():
|
||||
if config is None:
|
||||
continue
|
||||
cache_keys.append(
|
||||
f"deployment_spend:{model_id}:{config.budget_duration}"
|
||||
)
|
||||
|
||||
if self.tag_budget_config is not None:
|
||||
for tag, config in self.tag_budget_config.items():
|
||||
if config is None:
|
||||
continue
|
||||
cache_keys.append(f"tag_spend:{tag}:{config.budget_duration}")
|
||||
|
||||
# Batch fetch current spend values from Redis
|
||||
redis_values = await self.dual_cache.redis_cache.async_batch_get_cache(
|
||||
key_list=cache_keys
|
||||
)
|
||||
|
||||
# Update in-memory cache with Redis values
|
||||
if isinstance(redis_values, dict): # Check if redis_values is a dictionary
|
||||
for key, value in redis_values.items():
|
||||
if value is not None:
|
||||
await self.dual_cache.in_memory_cache.async_set_cache(
|
||||
key=key, value=float(value)
|
||||
)
|
||||
verbose_router_logger.debug(
|
||||
f"Updated in-memory cache for {key}: {value}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
verbose_router_logger.error(
|
||||
f"Error syncing in-memory cache with Redis: {str(e)}"
|
||||
)
|
||||
|
||||
def _get_budget_config_for_deployment(
|
||||
self,
|
||||
model_id: str,
|
||||
) -> Optional[GenericBudgetInfo]:
|
||||
if self.deployment_budget_config is None:
|
||||
return None
|
||||
return self.deployment_budget_config.get(model_id, None)
|
||||
|
||||
def _get_budget_config_for_provider(
|
||||
self, provider: str
|
||||
) -> Optional[GenericBudgetInfo]:
|
||||
if self.provider_budget_config is None:
|
||||
return None
|
||||
return self.provider_budget_config.get(provider, None)
|
||||
|
||||
def _get_budget_config_for_tag(self, tag: str) -> Optional[GenericBudgetInfo]:
|
||||
if self.tag_budget_config is None:
|
||||
return None
|
||||
return self.tag_budget_config.get(tag, None)
|
||||
|
||||
def _get_llm_provider_for_deployment(self, deployment: Dict) -> Optional[str]:
|
||||
try:
|
||||
_litellm_params: LiteLLM_Params = LiteLLM_Params(
|
||||
**deployment.get("litellm_params", {"model": ""})
|
||||
)
|
||||
_, custom_llm_provider, _, _ = litellm.get_llm_provider(
|
||||
model=_litellm_params.model,
|
||||
litellm_params=_litellm_params,
|
||||
)
|
||||
except Exception:
|
||||
verbose_router_logger.error(
|
||||
f"Error getting LLM provider for deployment: {deployment}"
|
||||
)
|
||||
return None
|
||||
return custom_llm_provider
|
||||
|
||||
def _track_provider_remaining_budget_prometheus(
|
||||
self, provider: str, spend: float, budget_limit: float
|
||||
):
|
||||
"""
|
||||
Optional helper - emit provider remaining budget metric to Prometheus
|
||||
|
||||
This is helpful for debugging and monitoring provider budget limits.
|
||||
"""
|
||||
|
||||
prometheus_logger = _get_prometheus_logger_from_callbacks()
|
||||
if prometheus_logger:
|
||||
prometheus_logger.track_provider_remaining_budget(
|
||||
provider=provider,
|
||||
spend=spend,
|
||||
budget_limit=budget_limit,
|
||||
)
|
||||
|
||||
async def _get_current_provider_spend(self, provider: str) -> Optional[float]:
|
||||
"""
|
||||
GET the current spend for a provider from cache
|
||||
|
||||
used for GET /provider/budgets endpoint in spend_management_endpoints.py
|
||||
|
||||
Args:
|
||||
provider (str): The provider to get spend for (e.g., "openai", "anthropic")
|
||||
|
||||
Returns:
|
||||
Optional[float]: The current spend for the provider, or None if not found
|
||||
"""
|
||||
budget_config = self._get_budget_config_for_provider(provider)
|
||||
if budget_config is None:
|
||||
return None
|
||||
|
||||
spend_key = f"provider_spend:{provider}:{budget_config.budget_duration}"
|
||||
|
||||
if self.dual_cache.redis_cache:
|
||||
# use Redis as source of truth since that has spend across all instances
|
||||
current_spend = await self.dual_cache.redis_cache.async_get_cache(spend_key)
|
||||
else:
|
||||
# use in-memory cache if Redis is not initialized
|
||||
current_spend = await self.dual_cache.async_get_cache(spend_key)
|
||||
return float(current_spend) if current_spend is not None else 0.0
|
||||
|
||||
async def _get_current_provider_budget_reset_at(
|
||||
self, provider: str
|
||||
) -> Optional[str]:
|
||||
budget_config = self._get_budget_config_for_provider(provider)
|
||||
if budget_config is None:
|
||||
return None
|
||||
|
||||
spend_key = f"provider_spend:{provider}:{budget_config.budget_duration}"
|
||||
if self.dual_cache.redis_cache:
|
||||
ttl_seconds = await self.dual_cache.redis_cache.async_get_ttl(spend_key)
|
||||
else:
|
||||
ttl_seconds = await self.dual_cache.async_get_ttl(spend_key)
|
||||
|
||||
if ttl_seconds is None:
|
||||
return None
|
||||
|
||||
return (datetime.now(timezone.utc) + timedelta(seconds=ttl_seconds)).isoformat()
|
||||
|
||||
async def _init_provider_budget_in_cache(
|
||||
self, provider: str, budget_config: GenericBudgetInfo
|
||||
):
|
||||
"""
|
||||
Initialize provider budget in cache by storing the following keys if they don't exist:
|
||||
- provider_spend:{provider}:{budget_config.time_period} - stores the current spend
|
||||
- provider_budget_start_time:{provider} - stores the start time of the budget window
|
||||
|
||||
"""
|
||||
|
||||
spend_key = f"provider_spend:{provider}:{budget_config.budget_duration}"
|
||||
start_time_key = f"provider_budget_start_time:{provider}"
|
||||
ttl_seconds: Optional[int] = None
|
||||
if budget_config.budget_duration is not None:
|
||||
ttl_seconds = duration_in_seconds(budget_config.budget_duration)
|
||||
|
||||
budget_start = await self.dual_cache.async_get_cache(start_time_key)
|
||||
if budget_start is None:
|
||||
budget_start = datetime.now(timezone.utc).timestamp()
|
||||
await self.dual_cache.async_set_cache(
|
||||
key=start_time_key, value=budget_start, ttl=ttl_seconds
|
||||
)
|
||||
|
||||
_spend_key = await self.dual_cache.async_get_cache(spend_key)
|
||||
if _spend_key is None:
|
||||
await self.dual_cache.async_set_cache(
|
||||
key=spend_key, value=0.0, ttl=ttl_seconds
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def should_init_router_budget_limiter(
|
||||
provider_budget_config: Optional[dict],
|
||||
model_list: Optional[
|
||||
Union[List[DeploymentTypedDict], List[Dict[str, Any]]]
|
||||
] = None,
|
||||
):
|
||||
"""
|
||||
Returns `True` if the router budget routing settings are set and RouterBudgetLimiting should be initialized
|
||||
|
||||
Either:
|
||||
- provider_budget_config is set
|
||||
- budgets are set for deployments in the model_list
|
||||
- tag_budget_config is set
|
||||
"""
|
||||
if provider_budget_config is not None:
|
||||
return True
|
||||
|
||||
if litellm.tag_budget_config is not None:
|
||||
return True
|
||||
|
||||
if model_list is None:
|
||||
return False
|
||||
|
||||
for _model in model_list:
|
||||
_litellm_params = _model.get("litellm_params", {})
|
||||
if (
|
||||
_litellm_params.get("max_budget")
|
||||
or _litellm_params.get("budget_duration") is not None
|
||||
):
|
||||
return True
|
||||
return False
|
||||
|
||||
def _init_provider_budgets(self):
|
||||
if self.provider_budget_config is not None:
|
||||
# cast elements of provider_budget_config to GenericBudgetInfo
|
||||
for provider, config in self.provider_budget_config.items():
|
||||
if config is None:
|
||||
raise ValueError(
|
||||
f"No budget config found for provider {provider}, provider_budget_config: {self.provider_budget_config}"
|
||||
)
|
||||
|
||||
if not isinstance(config, GenericBudgetInfo):
|
||||
self.provider_budget_config[provider] = GenericBudgetInfo(
|
||||
budget_limit=config.get("budget_limit"),
|
||||
time_period=config.get("time_period"),
|
||||
)
|
||||
asyncio.create_task(
|
||||
self._init_provider_budget_in_cache(
|
||||
provider=provider,
|
||||
budget_config=self.provider_budget_config[provider],
|
||||
)
|
||||
)
|
||||
|
||||
verbose_router_logger.debug(
|
||||
f"Initalized Provider budget config: {self.provider_budget_config}"
|
||||
)
|
||||
|
||||
def _init_deployment_budgets(
|
||||
self,
|
||||
model_list: Optional[
|
||||
Union[List[DeploymentTypedDict], List[Dict[str, Any]]]
|
||||
] = None,
|
||||
):
|
||||
if model_list is None:
|
||||
return
|
||||
for _model in model_list:
|
||||
_litellm_params = _model.get("litellm_params", {})
|
||||
_model_info: Dict = _model.get("model_info") or {}
|
||||
_model_id = _model_info.get("id")
|
||||
_max_budget = _litellm_params.get("max_budget")
|
||||
_budget_duration = _litellm_params.get("budget_duration")
|
||||
|
||||
verbose_router_logger.debug(
|
||||
f"Init Deployment Budget: max_budget: {_max_budget}, budget_duration: {_budget_duration}, model_id: {_model_id}"
|
||||
)
|
||||
if (
|
||||
_max_budget is not None
|
||||
and _budget_duration is not None
|
||||
and _model_id is not None
|
||||
):
|
||||
_budget_config = GenericBudgetInfo(
|
||||
time_period=_budget_duration,
|
||||
budget_limit=_max_budget,
|
||||
)
|
||||
if self.deployment_budget_config is None:
|
||||
self.deployment_budget_config = {}
|
||||
self.deployment_budget_config[_model_id] = _budget_config
|
||||
|
||||
verbose_router_logger.debug(
|
||||
f"Initialized Deployment Budget Config: {self.deployment_budget_config}"
|
||||
)
|
||||
|
||||
def _init_tag_budgets(self):
|
||||
if litellm.tag_budget_config is None:
|
||||
return
|
||||
from litellm.proxy.proxy_server import CommonProxyErrors, premium_user
|
||||
|
||||
if premium_user is not True:
|
||||
raise ValueError(
|
||||
f"Tag budgets are an Enterprise only feature, {CommonProxyErrors.not_premium_user}"
|
||||
)
|
||||
|
||||
if self.tag_budget_config is None:
|
||||
self.tag_budget_config = {}
|
||||
|
||||
for _tag, _tag_budget_config in litellm.tag_budget_config.items():
|
||||
if isinstance(_tag_budget_config, dict):
|
||||
_tag_budget_config = BudgetConfig(**_tag_budget_config)
|
||||
_generic_budget_config = GenericBudgetInfo(
|
||||
time_period=_tag_budget_config.budget_duration,
|
||||
budget_limit=_tag_budget_config.max_budget,
|
||||
)
|
||||
self.tag_budget_config[_tag] = _generic_budget_config
|
||||
|
||||
verbose_router_logger.debug(
|
||||
f"Initialized Tag Budget Config: {self.tag_budget_config}"
|
||||
)
|
||||
@@ -0,0 +1,252 @@
|
||||
#### What this does ####
|
||||
# identifies least busy deployment
|
||||
# How is this achieved?
|
||||
# - Before each call, have the router print the state of requests {"deployment": "requests_in_flight"}
|
||||
# - use litellm.input_callbacks to log when a request is just about to be made to a model - {"deployment-id": traffic}
|
||||
# - use litellm.success + failure callbacks to log when a request completed
|
||||
# - in get_available_deployment, for a given model group name -> pick based on traffic
|
||||
|
||||
import random
|
||||
from typing import Optional
|
||||
|
||||
from litellm.caching.caching import DualCache
|
||||
from litellm.integrations.custom_logger import CustomLogger
|
||||
|
||||
|
||||
class LeastBusyLoggingHandler(CustomLogger):
|
||||
test_flag: bool = False
|
||||
logged_success: int = 0
|
||||
logged_failure: int = 0
|
||||
|
||||
def __init__(self, router_cache: DualCache, model_list: list):
|
||||
self.router_cache = router_cache
|
||||
self.mapping_deployment_to_id: dict = {}
|
||||
self.model_list = model_list
|
||||
|
||||
def log_pre_api_call(self, model, messages, kwargs):
|
||||
"""
|
||||
Log when a model is being used.
|
||||
|
||||
Caching based on model group.
|
||||
"""
|
||||
try:
|
||||
if kwargs["litellm_params"].get("metadata") is None:
|
||||
pass
|
||||
else:
|
||||
model_group = kwargs["litellm_params"]["metadata"].get(
|
||||
"model_group", None
|
||||
)
|
||||
id = kwargs["litellm_params"].get("model_info", {}).get("id", None)
|
||||
if model_group is None or id is None:
|
||||
return
|
||||
elif isinstance(id, int):
|
||||
id = str(id)
|
||||
|
||||
request_count_api_key = f"{model_group}_request_count"
|
||||
# update cache
|
||||
request_count_dict = (
|
||||
self.router_cache.get_cache(key=request_count_api_key) or {}
|
||||
)
|
||||
request_count_dict[id] = request_count_dict.get(id, 0) + 1
|
||||
|
||||
self.router_cache.set_cache(
|
||||
key=request_count_api_key, value=request_count_dict
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def log_success_event(self, kwargs, response_obj, start_time, end_time):
|
||||
try:
|
||||
if kwargs["litellm_params"].get("metadata") is None:
|
||||
pass
|
||||
else:
|
||||
model_group = kwargs["litellm_params"]["metadata"].get(
|
||||
"model_group", None
|
||||
)
|
||||
|
||||
id = kwargs["litellm_params"].get("model_info", {}).get("id", None)
|
||||
if model_group is None or id is None:
|
||||
return
|
||||
elif isinstance(id, int):
|
||||
id = str(id)
|
||||
|
||||
request_count_api_key = f"{model_group}_request_count"
|
||||
# decrement count in cache
|
||||
request_count_dict = (
|
||||
self.router_cache.get_cache(key=request_count_api_key) or {}
|
||||
)
|
||||
request_count_value: Optional[int] = request_count_dict.get(id, 0)
|
||||
if request_count_value is None:
|
||||
return
|
||||
request_count_dict[id] = request_count_value - 1
|
||||
self.router_cache.set_cache(
|
||||
key=request_count_api_key, value=request_count_dict
|
||||
)
|
||||
|
||||
### TESTING ###
|
||||
if self.test_flag:
|
||||
self.logged_success += 1
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def log_failure_event(self, kwargs, response_obj, start_time, end_time):
|
||||
try:
|
||||
if kwargs["litellm_params"].get("metadata") is None:
|
||||
pass
|
||||
else:
|
||||
model_group = kwargs["litellm_params"]["metadata"].get(
|
||||
"model_group", None
|
||||
)
|
||||
id = kwargs["litellm_params"].get("model_info", {}).get("id", None)
|
||||
if model_group is None or id is None:
|
||||
return
|
||||
elif isinstance(id, int):
|
||||
id = str(id)
|
||||
|
||||
request_count_api_key = f"{model_group}_request_count"
|
||||
# decrement count in cache
|
||||
request_count_dict = (
|
||||
self.router_cache.get_cache(key=request_count_api_key) or {}
|
||||
)
|
||||
request_count_value: Optional[int] = request_count_dict.get(id, 0)
|
||||
if request_count_value is None:
|
||||
return
|
||||
request_count_dict[id] = request_count_value - 1
|
||||
self.router_cache.set_cache(
|
||||
key=request_count_api_key, value=request_count_dict
|
||||
)
|
||||
|
||||
### TESTING ###
|
||||
if self.test_flag:
|
||||
self.logged_failure += 1
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
|
||||
try:
|
||||
if kwargs["litellm_params"].get("metadata") is None:
|
||||
pass
|
||||
else:
|
||||
model_group = kwargs["litellm_params"]["metadata"].get(
|
||||
"model_group", None
|
||||
)
|
||||
|
||||
id = kwargs["litellm_params"].get("model_info", {}).get("id", None)
|
||||
if model_group is None or id is None:
|
||||
return
|
||||
elif isinstance(id, int):
|
||||
id = str(id)
|
||||
|
||||
request_count_api_key = f"{model_group}_request_count"
|
||||
# decrement count in cache
|
||||
request_count_dict = (
|
||||
await self.router_cache.async_get_cache(key=request_count_api_key)
|
||||
or {}
|
||||
)
|
||||
request_count_value: Optional[int] = request_count_dict.get(id, 0)
|
||||
if request_count_value is None:
|
||||
return
|
||||
request_count_dict[id] = request_count_value - 1
|
||||
await self.router_cache.async_set_cache(
|
||||
key=request_count_api_key, value=request_count_dict
|
||||
)
|
||||
|
||||
### TESTING ###
|
||||
if self.test_flag:
|
||||
self.logged_success += 1
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time):
|
||||
try:
|
||||
if kwargs["litellm_params"].get("metadata") is None:
|
||||
pass
|
||||
else:
|
||||
model_group = kwargs["litellm_params"]["metadata"].get(
|
||||
"model_group", None
|
||||
)
|
||||
id = kwargs["litellm_params"].get("model_info", {}).get("id", None)
|
||||
if model_group is None or id is None:
|
||||
return
|
||||
elif isinstance(id, int):
|
||||
id = str(id)
|
||||
|
||||
request_count_api_key = f"{model_group}_request_count"
|
||||
# decrement count in cache
|
||||
request_count_dict = (
|
||||
await self.router_cache.async_get_cache(key=request_count_api_key)
|
||||
or {}
|
||||
)
|
||||
request_count_value: Optional[int] = request_count_dict.get(id, 0)
|
||||
if request_count_value is None:
|
||||
return
|
||||
request_count_dict[id] = request_count_value - 1
|
||||
await self.router_cache.async_set_cache(
|
||||
key=request_count_api_key, value=request_count_dict
|
||||
)
|
||||
|
||||
### TESTING ###
|
||||
if self.test_flag:
|
||||
self.logged_failure += 1
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def _get_available_deployments(
|
||||
self,
|
||||
healthy_deployments: list,
|
||||
all_deployments: dict,
|
||||
):
|
||||
"""
|
||||
Helper to get deployments using least busy strategy
|
||||
"""
|
||||
for d in healthy_deployments:
|
||||
## if healthy deployment not yet used
|
||||
if d["model_info"]["id"] not in all_deployments:
|
||||
all_deployments[d["model_info"]["id"]] = 0
|
||||
# map deployment to id
|
||||
# pick least busy deployment
|
||||
min_traffic = float("inf")
|
||||
min_deployment = None
|
||||
for k, v in all_deployments.items():
|
||||
if v < min_traffic:
|
||||
min_traffic = v
|
||||
min_deployment = k
|
||||
if min_deployment is not None:
|
||||
## check if min deployment is a string, if so, cast it to int
|
||||
for m in healthy_deployments:
|
||||
if m["model_info"]["id"] == min_deployment:
|
||||
return m
|
||||
min_deployment = random.choice(healthy_deployments)
|
||||
else:
|
||||
min_deployment = random.choice(healthy_deployments)
|
||||
return min_deployment
|
||||
|
||||
def get_available_deployments(
|
||||
self,
|
||||
model_group: str,
|
||||
healthy_deployments: list,
|
||||
):
|
||||
"""
|
||||
Sync helper to get deployments using least busy strategy
|
||||
"""
|
||||
request_count_api_key = f"{model_group}_request_count"
|
||||
all_deployments = self.router_cache.get_cache(key=request_count_api_key) or {}
|
||||
return self._get_available_deployments(
|
||||
healthy_deployments=healthy_deployments,
|
||||
all_deployments=all_deployments,
|
||||
)
|
||||
|
||||
async def async_get_available_deployments(
|
||||
self, model_group: str, healthy_deployments: list
|
||||
):
|
||||
"""
|
||||
Async helper to get deployments using least busy strategy
|
||||
"""
|
||||
request_count_api_key = f"{model_group}_request_count"
|
||||
all_deployments = (
|
||||
await self.router_cache.async_get_cache(key=request_count_api_key) or {}
|
||||
)
|
||||
return self._get_available_deployments(
|
||||
healthy_deployments=healthy_deployments,
|
||||
all_deployments=all_deployments,
|
||||
)
|
||||
@@ -0,0 +1,333 @@
|
||||
#### What this does ####
|
||||
# picks based on response time (for streaming, this is time to first token)
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
import litellm
|
||||
from litellm import ModelResponse, token_counter, verbose_logger
|
||||
from litellm._logging import verbose_router_logger
|
||||
from litellm.caching.caching import DualCache
|
||||
from litellm.integrations.custom_logger import CustomLogger
|
||||
|
||||
|
||||
class LowestCostLoggingHandler(CustomLogger):
|
||||
test_flag: bool = False
|
||||
logged_success: int = 0
|
||||
logged_failure: int = 0
|
||||
|
||||
def __init__(
|
||||
self, router_cache: DualCache, model_list: list, routing_args: dict = {}
|
||||
):
|
||||
self.router_cache = router_cache
|
||||
self.model_list = model_list
|
||||
|
||||
def log_success_event(self, kwargs, response_obj, start_time, end_time):
|
||||
try:
|
||||
"""
|
||||
Update usage on success
|
||||
"""
|
||||
if kwargs["litellm_params"].get("metadata") is None:
|
||||
pass
|
||||
else:
|
||||
model_group = kwargs["litellm_params"]["metadata"].get(
|
||||
"model_group", None
|
||||
)
|
||||
|
||||
id = kwargs["litellm_params"].get("model_info", {}).get("id", None)
|
||||
if model_group is None or id is None:
|
||||
return
|
||||
elif isinstance(id, int):
|
||||
id = str(id)
|
||||
|
||||
# ------------
|
||||
# Setup values
|
||||
# ------------
|
||||
"""
|
||||
{
|
||||
{model_group}_map: {
|
||||
id: {
|
||||
f"{date:hour:minute}" : {"tpm": 34, "rpm": 3}
|
||||
}
|
||||
}
|
||||
}
|
||||
"""
|
||||
current_date = datetime.now().strftime("%Y-%m-%d")
|
||||
current_hour = datetime.now().strftime("%H")
|
||||
current_minute = datetime.now().strftime("%M")
|
||||
precise_minute = f"{current_date}-{current_hour}-{current_minute}"
|
||||
cost_key = f"{model_group}_map"
|
||||
|
||||
response_ms: timedelta = end_time - start_time
|
||||
|
||||
total_tokens = 0
|
||||
|
||||
if isinstance(response_obj, ModelResponse):
|
||||
_usage = getattr(response_obj, "usage", None)
|
||||
if _usage is not None and isinstance(_usage, litellm.Usage):
|
||||
completion_tokens = _usage.completion_tokens
|
||||
total_tokens = _usage.total_tokens
|
||||
float(response_ms.total_seconds() / completion_tokens)
|
||||
|
||||
# ------------
|
||||
# Update usage
|
||||
# ------------
|
||||
|
||||
request_count_dict = self.router_cache.get_cache(key=cost_key) or {}
|
||||
|
||||
# check local result first
|
||||
|
||||
if id not in request_count_dict:
|
||||
request_count_dict[id] = {}
|
||||
|
||||
if precise_minute not in request_count_dict[id]:
|
||||
request_count_dict[id][precise_minute] = {}
|
||||
|
||||
## TPM
|
||||
request_count_dict[id][precise_minute]["tpm"] = (
|
||||
request_count_dict[id][precise_minute].get("tpm", 0) + total_tokens
|
||||
)
|
||||
|
||||
## RPM
|
||||
request_count_dict[id][precise_minute]["rpm"] = (
|
||||
request_count_dict[id][precise_minute].get("rpm", 0) + 1
|
||||
)
|
||||
|
||||
self.router_cache.set_cache(key=cost_key, value=request_count_dict)
|
||||
|
||||
### TESTING ###
|
||||
if self.test_flag:
|
||||
self.logged_success += 1
|
||||
except Exception as e:
|
||||
verbose_logger.exception(
|
||||
"litellm.router_strategy.lowest_cost.py::log_success_event(): Exception occured - {}".format(
|
||||
str(e)
|
||||
)
|
||||
)
|
||||
pass
|
||||
|
||||
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
|
||||
try:
|
||||
"""
|
||||
Update cost usage on success
|
||||
"""
|
||||
if kwargs["litellm_params"].get("metadata") is None:
|
||||
pass
|
||||
else:
|
||||
model_group = kwargs["litellm_params"]["metadata"].get(
|
||||
"model_group", None
|
||||
)
|
||||
|
||||
id = kwargs["litellm_params"].get("model_info", {}).get("id", None)
|
||||
if model_group is None or id is None:
|
||||
return
|
||||
elif isinstance(id, int):
|
||||
id = str(id)
|
||||
|
||||
# ------------
|
||||
# Setup values
|
||||
# ------------
|
||||
"""
|
||||
{
|
||||
{model_group}_map: {
|
||||
id: {
|
||||
"cost": [..]
|
||||
f"{date:hour:minute}" : {"tpm": 34, "rpm": 3}
|
||||
}
|
||||
}
|
||||
}
|
||||
"""
|
||||
cost_key = f"{model_group}_map"
|
||||
|
||||
current_date = datetime.now().strftime("%Y-%m-%d")
|
||||
current_hour = datetime.now().strftime("%H")
|
||||
current_minute = datetime.now().strftime("%M")
|
||||
precise_minute = f"{current_date}-{current_hour}-{current_minute}"
|
||||
|
||||
response_ms: timedelta = end_time - start_time
|
||||
|
||||
total_tokens = 0
|
||||
|
||||
if isinstance(response_obj, ModelResponse):
|
||||
_usage = getattr(response_obj, "usage", None)
|
||||
if _usage is not None and isinstance(_usage, litellm.Usage):
|
||||
completion_tokens = _usage.completion_tokens
|
||||
total_tokens = _usage.total_tokens
|
||||
|
||||
float(response_ms.total_seconds() / completion_tokens)
|
||||
|
||||
# ------------
|
||||
# Update usage
|
||||
# ------------
|
||||
|
||||
request_count_dict = (
|
||||
await self.router_cache.async_get_cache(key=cost_key) or {}
|
||||
)
|
||||
|
||||
if id not in request_count_dict:
|
||||
request_count_dict[id] = {}
|
||||
if precise_minute not in request_count_dict[id]:
|
||||
request_count_dict[id][precise_minute] = {}
|
||||
|
||||
## TPM
|
||||
request_count_dict[id][precise_minute]["tpm"] = (
|
||||
request_count_dict[id][precise_minute].get("tpm", 0) + total_tokens
|
||||
)
|
||||
|
||||
## RPM
|
||||
request_count_dict[id][precise_minute]["rpm"] = (
|
||||
request_count_dict[id][precise_minute].get("rpm", 0) + 1
|
||||
)
|
||||
|
||||
await self.router_cache.async_set_cache(
|
||||
key=cost_key, value=request_count_dict
|
||||
) # reset map within window
|
||||
|
||||
### TESTING ###
|
||||
if self.test_flag:
|
||||
self.logged_success += 1
|
||||
except Exception as e:
|
||||
verbose_logger.exception(
|
||||
"litellm.proxy.hooks.prompt_injection_detection.py::async_pre_call_hook(): Exception occured - {}".format(
|
||||
str(e)
|
||||
)
|
||||
)
|
||||
pass
|
||||
|
||||
async def async_get_available_deployments( # noqa: PLR0915
|
||||
self,
|
||||
model_group: str,
|
||||
healthy_deployments: list,
|
||||
messages: Optional[List[Dict[str, str]]] = None,
|
||||
input: Optional[Union[str, List]] = None,
|
||||
request_kwargs: Optional[Dict] = None,
|
||||
):
|
||||
"""
|
||||
Returns a deployment with the lowest cost
|
||||
"""
|
||||
cost_key = f"{model_group}_map"
|
||||
|
||||
request_count_dict = await self.router_cache.async_get_cache(key=cost_key) or {}
|
||||
|
||||
# -----------------------
|
||||
# Find lowest used model
|
||||
# ----------------------
|
||||
float("inf")
|
||||
|
||||
current_date = datetime.now().strftime("%Y-%m-%d")
|
||||
current_hour = datetime.now().strftime("%H")
|
||||
current_minute = datetime.now().strftime("%M")
|
||||
precise_minute = f"{current_date}-{current_hour}-{current_minute}"
|
||||
|
||||
if request_count_dict is None: # base case
|
||||
return
|
||||
|
||||
all_deployments = request_count_dict
|
||||
for d in healthy_deployments:
|
||||
## if healthy deployment not yet used
|
||||
if d["model_info"]["id"] not in all_deployments:
|
||||
all_deployments[d["model_info"]["id"]] = {
|
||||
precise_minute: {"tpm": 0, "rpm": 0},
|
||||
}
|
||||
|
||||
try:
|
||||
input_tokens = token_counter(messages=messages, text=input)
|
||||
except Exception:
|
||||
input_tokens = 0
|
||||
|
||||
# randomly sample from all_deployments, incase all deployments have latency=0.0
|
||||
_items = all_deployments.items()
|
||||
|
||||
### GET AVAILABLE DEPLOYMENTS ### filter out any deployments > tpm/rpm limits
|
||||
potential_deployments = []
|
||||
_cost_per_deployment = {}
|
||||
for item, item_map in all_deployments.items():
|
||||
## get the item from model list
|
||||
_deployment = None
|
||||
for m in healthy_deployments:
|
||||
if item == m["model_info"]["id"]:
|
||||
_deployment = m
|
||||
|
||||
if _deployment is None:
|
||||
continue # skip to next one
|
||||
|
||||
_deployment_tpm = (
|
||||
_deployment.get("tpm", None)
|
||||
or _deployment.get("litellm_params", {}).get("tpm", None)
|
||||
or _deployment.get("model_info", {}).get("tpm", None)
|
||||
or float("inf")
|
||||
)
|
||||
|
||||
_deployment_rpm = (
|
||||
_deployment.get("rpm", None)
|
||||
or _deployment.get("litellm_params", {}).get("rpm", None)
|
||||
or _deployment.get("model_info", {}).get("rpm", None)
|
||||
or float("inf")
|
||||
)
|
||||
item_litellm_model_name = _deployment.get("litellm_params", {}).get("model")
|
||||
item_litellm_model_cost_map = litellm.model_cost.get(
|
||||
item_litellm_model_name, {}
|
||||
)
|
||||
|
||||
# check if user provided input_cost_per_token and output_cost_per_token in litellm_params
|
||||
item_input_cost = None
|
||||
item_output_cost = None
|
||||
if _deployment.get("litellm_params", {}).get("input_cost_per_token", None):
|
||||
item_input_cost = _deployment.get("litellm_params", {}).get(
|
||||
"input_cost_per_token"
|
||||
)
|
||||
|
||||
if _deployment.get("litellm_params", {}).get("output_cost_per_token", None):
|
||||
item_output_cost = _deployment.get("litellm_params", {}).get(
|
||||
"output_cost_per_token"
|
||||
)
|
||||
|
||||
if item_input_cost is None:
|
||||
item_input_cost = item_litellm_model_cost_map.get(
|
||||
"input_cost_per_token", 5.0
|
||||
)
|
||||
|
||||
if item_output_cost is None:
|
||||
item_output_cost = item_litellm_model_cost_map.get(
|
||||
"output_cost_per_token", 5.0
|
||||
)
|
||||
|
||||
# if litellm["model"] is not in model_cost map -> use item_cost = $10
|
||||
|
||||
item_cost = item_input_cost + item_output_cost
|
||||
|
||||
item_rpm = item_map.get(precise_minute, {}).get("rpm", 0)
|
||||
item_tpm = item_map.get(precise_minute, {}).get("tpm", 0)
|
||||
|
||||
verbose_router_logger.debug(
|
||||
f"item_cost: {item_cost}, item_tpm: {item_tpm}, item_rpm: {item_rpm}, model_id: {_deployment.get('model_info', {}).get('id')}"
|
||||
)
|
||||
|
||||
# -------------- #
|
||||
# Debugging Logic
|
||||
# -------------- #
|
||||
# We use _cost_per_deployment to log to langfuse, slack - this is not used to make a decision on routing
|
||||
# this helps a user to debug why the router picked a specfic deployment #
|
||||
_deployment_api_base = _deployment.get("litellm_params", {}).get(
|
||||
"api_base", ""
|
||||
)
|
||||
if _deployment_api_base is not None:
|
||||
_cost_per_deployment[_deployment_api_base] = item_cost
|
||||
# -------------- #
|
||||
# End of Debugging Logic
|
||||
# -------------- #
|
||||
|
||||
if (
|
||||
item_tpm + input_tokens > _deployment_tpm
|
||||
or item_rpm + 1 > _deployment_rpm
|
||||
): # if user passed in tpm / rpm in the model_list
|
||||
continue
|
||||
else:
|
||||
potential_deployments.append((_deployment, item_cost))
|
||||
|
||||
if len(potential_deployments) == 0:
|
||||
return None
|
||||
|
||||
potential_deployments = sorted(potential_deployments, key=lambda x: x[1])
|
||||
|
||||
selected_deployment = potential_deployments[0][0]
|
||||
return selected_deployment
|
||||
@@ -0,0 +1,590 @@
|
||||
#### What this does ####
|
||||
# picks based on response time (for streaming, this is time to first token)
|
||||
import random
|
||||
from datetime import datetime, timedelta
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
|
||||
|
||||
import litellm
|
||||
from litellm import ModelResponse, token_counter, verbose_logger
|
||||
from litellm.caching.caching import DualCache
|
||||
from litellm.integrations.custom_logger import CustomLogger
|
||||
from litellm.litellm_core_utils.core_helpers import _get_parent_otel_span_from_kwargs
|
||||
from litellm.types.utils import LiteLLMPydanticObjectBase
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from opentelemetry.trace import Span as _Span
|
||||
|
||||
Span = Union[_Span, Any]
|
||||
else:
|
||||
Span = Any
|
||||
|
||||
|
||||
class RoutingArgs(LiteLLMPydanticObjectBase):
|
||||
ttl: float = 1 * 60 * 60 # 1 hour
|
||||
lowest_latency_buffer: float = 0
|
||||
max_latency_list_size: int = 10
|
||||
|
||||
|
||||
class LowestLatencyLoggingHandler(CustomLogger):
|
||||
test_flag: bool = False
|
||||
logged_success: int = 0
|
||||
logged_failure: int = 0
|
||||
|
||||
def __init__(
|
||||
self, router_cache: DualCache, model_list: list, routing_args: dict = {}
|
||||
):
|
||||
self.router_cache = router_cache
|
||||
self.model_list = model_list
|
||||
self.routing_args = RoutingArgs(**routing_args)
|
||||
|
||||
def log_success_event( # noqa: PLR0915
|
||||
self, kwargs, response_obj, start_time, end_time
|
||||
):
|
||||
try:
|
||||
"""
|
||||
Update latency usage on success
|
||||
"""
|
||||
if kwargs["litellm_params"].get("metadata") is None:
|
||||
pass
|
||||
else:
|
||||
model_group = kwargs["litellm_params"]["metadata"].get(
|
||||
"model_group", None
|
||||
)
|
||||
|
||||
id = kwargs["litellm_params"].get("model_info", {}).get("id", None)
|
||||
if model_group is None or id is None:
|
||||
return
|
||||
elif isinstance(id, int):
|
||||
id = str(id)
|
||||
|
||||
# ------------
|
||||
# Setup values
|
||||
# ------------
|
||||
"""
|
||||
{
|
||||
{model_group}_map: {
|
||||
id: {
|
||||
"latency": [..]
|
||||
f"{date:hour:minute}" : {"tpm": 34, "rpm": 3}
|
||||
}
|
||||
}
|
||||
}
|
||||
"""
|
||||
latency_key = f"{model_group}_map"
|
||||
|
||||
current_date = datetime.now().strftime("%Y-%m-%d")
|
||||
current_hour = datetime.now().strftime("%H")
|
||||
current_minute = datetime.now().strftime("%M")
|
||||
precise_minute = f"{current_date}-{current_hour}-{current_minute}"
|
||||
|
||||
response_ms: timedelta = end_time - start_time
|
||||
time_to_first_token_response_time: Optional[timedelta] = None
|
||||
|
||||
if kwargs.get("stream", None) is not None and kwargs["stream"] is True:
|
||||
# only log ttft for streaming request
|
||||
time_to_first_token_response_time = (
|
||||
kwargs.get("completion_start_time", end_time) - start_time
|
||||
)
|
||||
|
||||
final_value: Union[float, timedelta] = response_ms
|
||||
time_to_first_token: Optional[float] = None
|
||||
total_tokens = 0
|
||||
|
||||
if isinstance(response_obj, ModelResponse):
|
||||
_usage = getattr(response_obj, "usage", None)
|
||||
if _usage is not None:
|
||||
completion_tokens = _usage.completion_tokens
|
||||
total_tokens = _usage.total_tokens
|
||||
final_value = float(
|
||||
response_ms.total_seconds() / completion_tokens
|
||||
)
|
||||
|
||||
if time_to_first_token_response_time is not None:
|
||||
time_to_first_token = float(
|
||||
time_to_first_token_response_time.total_seconds()
|
||||
/ completion_tokens
|
||||
)
|
||||
|
||||
# ------------
|
||||
# Update usage
|
||||
# ------------
|
||||
parent_otel_span = _get_parent_otel_span_from_kwargs(kwargs)
|
||||
request_count_dict = (
|
||||
self.router_cache.get_cache(
|
||||
key=latency_key, parent_otel_span=parent_otel_span
|
||||
)
|
||||
or {}
|
||||
)
|
||||
|
||||
if id not in request_count_dict:
|
||||
request_count_dict[id] = {}
|
||||
|
||||
## Latency
|
||||
if (
|
||||
len(request_count_dict[id].get("latency", []))
|
||||
< self.routing_args.max_latency_list_size
|
||||
):
|
||||
request_count_dict[id].setdefault("latency", []).append(final_value)
|
||||
else:
|
||||
request_count_dict[id]["latency"] = request_count_dict[id][
|
||||
"latency"
|
||||
][: self.routing_args.max_latency_list_size - 1] + [final_value]
|
||||
|
||||
## Time to first token
|
||||
if time_to_first_token is not None:
|
||||
if (
|
||||
len(request_count_dict[id].get("time_to_first_token", []))
|
||||
< self.routing_args.max_latency_list_size
|
||||
):
|
||||
request_count_dict[id].setdefault(
|
||||
"time_to_first_token", []
|
||||
).append(time_to_first_token)
|
||||
else:
|
||||
request_count_dict[id][
|
||||
"time_to_first_token"
|
||||
] = request_count_dict[id]["time_to_first_token"][
|
||||
: self.routing_args.max_latency_list_size - 1
|
||||
] + [
|
||||
time_to_first_token
|
||||
]
|
||||
|
||||
if precise_minute not in request_count_dict[id]:
|
||||
request_count_dict[id][precise_minute] = {}
|
||||
|
||||
## TPM
|
||||
request_count_dict[id][precise_minute]["tpm"] = (
|
||||
request_count_dict[id][precise_minute].get("tpm", 0) + total_tokens
|
||||
)
|
||||
|
||||
## RPM
|
||||
request_count_dict[id][precise_minute]["rpm"] = (
|
||||
request_count_dict[id][precise_minute].get("rpm", 0) + 1
|
||||
)
|
||||
|
||||
self.router_cache.set_cache(
|
||||
key=latency_key, value=request_count_dict, ttl=self.routing_args.ttl
|
||||
) # reset map within window
|
||||
|
||||
### TESTING ###
|
||||
if self.test_flag:
|
||||
self.logged_success += 1
|
||||
except Exception as e:
|
||||
verbose_logger.exception(
|
||||
"litellm.proxy.hooks.prompt_injection_detection.py::async_pre_call_hook(): Exception occured - {}".format(
|
||||
str(e)
|
||||
)
|
||||
)
|
||||
pass
|
||||
|
||||
async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time):
|
||||
"""
|
||||
Check if Timeout Error, if timeout set deployment latency -> 100
|
||||
"""
|
||||
try:
|
||||
_exception = kwargs.get("exception", None)
|
||||
if isinstance(_exception, litellm.Timeout):
|
||||
if kwargs["litellm_params"].get("metadata") is None:
|
||||
pass
|
||||
else:
|
||||
model_group = kwargs["litellm_params"]["metadata"].get(
|
||||
"model_group", None
|
||||
)
|
||||
|
||||
id = kwargs["litellm_params"].get("model_info", {}).get("id", None)
|
||||
if model_group is None or id is None:
|
||||
return
|
||||
elif isinstance(id, int):
|
||||
id = str(id)
|
||||
|
||||
# ------------
|
||||
# Setup values
|
||||
# ------------
|
||||
"""
|
||||
{
|
||||
{model_group}_map: {
|
||||
id: {
|
||||
"latency": [..]
|
||||
f"{date:hour:minute}" : {"tpm": 34, "rpm": 3}
|
||||
}
|
||||
}
|
||||
}
|
||||
"""
|
||||
latency_key = f"{model_group}_map"
|
||||
request_count_dict = (
|
||||
await self.router_cache.async_get_cache(key=latency_key) or {}
|
||||
)
|
||||
|
||||
if id not in request_count_dict:
|
||||
request_count_dict[id] = {}
|
||||
|
||||
## Latency - give 1000s penalty for failing
|
||||
if (
|
||||
len(request_count_dict[id].get("latency", []))
|
||||
< self.routing_args.max_latency_list_size
|
||||
):
|
||||
request_count_dict[id].setdefault("latency", []).append(1000.0)
|
||||
else:
|
||||
request_count_dict[id]["latency"] = request_count_dict[id][
|
||||
"latency"
|
||||
][: self.routing_args.max_latency_list_size - 1] + [1000.0]
|
||||
|
||||
await self.router_cache.async_set_cache(
|
||||
key=latency_key,
|
||||
value=request_count_dict,
|
||||
ttl=self.routing_args.ttl,
|
||||
) # reset map within window
|
||||
else:
|
||||
# do nothing if it's not a timeout error
|
||||
return
|
||||
except Exception as e:
|
||||
verbose_logger.exception(
|
||||
"litellm.proxy.hooks.prompt_injection_detection.py::async_pre_call_hook(): Exception occured - {}".format(
|
||||
str(e)
|
||||
)
|
||||
)
|
||||
pass
|
||||
|
||||
async def async_log_success_event( # noqa: PLR0915
|
||||
self, kwargs, response_obj, start_time, end_time
|
||||
):
|
||||
try:
|
||||
"""
|
||||
Update latency usage on success
|
||||
"""
|
||||
if kwargs["litellm_params"].get("metadata") is None:
|
||||
pass
|
||||
else:
|
||||
model_group = kwargs["litellm_params"]["metadata"].get(
|
||||
"model_group", None
|
||||
)
|
||||
|
||||
id = kwargs["litellm_params"].get("model_info", {}).get("id", None)
|
||||
if model_group is None or id is None:
|
||||
return
|
||||
elif isinstance(id, int):
|
||||
id = str(id)
|
||||
|
||||
# ------------
|
||||
# Setup values
|
||||
# ------------
|
||||
"""
|
||||
{
|
||||
{model_group}_map: {
|
||||
id: {
|
||||
"latency": [..]
|
||||
"time_to_first_token": [..]
|
||||
f"{date:hour:minute}" : {"tpm": 34, "rpm": 3}
|
||||
}
|
||||
}
|
||||
}
|
||||
"""
|
||||
latency_key = f"{model_group}_map"
|
||||
|
||||
current_date = datetime.now().strftime("%Y-%m-%d")
|
||||
current_hour = datetime.now().strftime("%H")
|
||||
current_minute = datetime.now().strftime("%M")
|
||||
precise_minute = f"{current_date}-{current_hour}-{current_minute}"
|
||||
|
||||
response_ms: timedelta = end_time - start_time
|
||||
time_to_first_token_response_time: Optional[timedelta] = None
|
||||
if kwargs.get("stream", None) is not None and kwargs["stream"] is True:
|
||||
# only log ttft for streaming request
|
||||
time_to_first_token_response_time = (
|
||||
kwargs.get("completion_start_time", end_time) - start_time
|
||||
)
|
||||
|
||||
final_value: Union[float, timedelta] = response_ms
|
||||
total_tokens = 0
|
||||
time_to_first_token: Optional[float] = None
|
||||
|
||||
if isinstance(response_obj, ModelResponse):
|
||||
_usage = getattr(response_obj, "usage", None)
|
||||
if _usage is not None:
|
||||
completion_tokens = _usage.completion_tokens
|
||||
total_tokens = _usage.total_tokens
|
||||
final_value = float(
|
||||
response_ms.total_seconds() / completion_tokens
|
||||
)
|
||||
|
||||
if time_to_first_token_response_time is not None:
|
||||
time_to_first_token = float(
|
||||
time_to_first_token_response_time.total_seconds()
|
||||
/ completion_tokens
|
||||
)
|
||||
# ------------
|
||||
# Update usage
|
||||
# ------------
|
||||
parent_otel_span = _get_parent_otel_span_from_kwargs(kwargs)
|
||||
request_count_dict = (
|
||||
await self.router_cache.async_get_cache(
|
||||
key=latency_key,
|
||||
parent_otel_span=parent_otel_span,
|
||||
local_only=True,
|
||||
)
|
||||
or {}
|
||||
)
|
||||
|
||||
if id not in request_count_dict:
|
||||
request_count_dict[id] = {}
|
||||
|
||||
## Latency
|
||||
if (
|
||||
len(request_count_dict[id].get("latency", []))
|
||||
< self.routing_args.max_latency_list_size
|
||||
):
|
||||
request_count_dict[id].setdefault("latency", []).append(final_value)
|
||||
else:
|
||||
request_count_dict[id]["latency"] = request_count_dict[id][
|
||||
"latency"
|
||||
][: self.routing_args.max_latency_list_size - 1] + [final_value]
|
||||
|
||||
## Time to first token
|
||||
if time_to_first_token is not None:
|
||||
if (
|
||||
len(request_count_dict[id].get("time_to_first_token", []))
|
||||
< self.routing_args.max_latency_list_size
|
||||
):
|
||||
request_count_dict[id].setdefault(
|
||||
"time_to_first_token", []
|
||||
).append(time_to_first_token)
|
||||
else:
|
||||
request_count_dict[id][
|
||||
"time_to_first_token"
|
||||
] = request_count_dict[id]["time_to_first_token"][
|
||||
: self.routing_args.max_latency_list_size - 1
|
||||
] + [
|
||||
time_to_first_token
|
||||
]
|
||||
|
||||
if precise_minute not in request_count_dict[id]:
|
||||
request_count_dict[id][precise_minute] = {}
|
||||
|
||||
## TPM
|
||||
request_count_dict[id][precise_minute]["tpm"] = (
|
||||
request_count_dict[id][precise_minute].get("tpm", 0) + total_tokens
|
||||
)
|
||||
|
||||
## RPM
|
||||
request_count_dict[id][precise_minute]["rpm"] = (
|
||||
request_count_dict[id][precise_minute].get("rpm", 0) + 1
|
||||
)
|
||||
|
||||
await self.router_cache.async_set_cache(
|
||||
key=latency_key, value=request_count_dict, ttl=self.routing_args.ttl
|
||||
) # reset map within window
|
||||
|
||||
### TESTING ###
|
||||
if self.test_flag:
|
||||
self.logged_success += 1
|
||||
except Exception as e:
|
||||
verbose_logger.exception(
|
||||
"litellm.router_strategy.lowest_latency.py::async_log_success_event(): Exception occured - {}".format(
|
||||
str(e)
|
||||
)
|
||||
)
|
||||
pass
|
||||
|
||||
def _get_available_deployments( # noqa: PLR0915
|
||||
self,
|
||||
model_group: str,
|
||||
healthy_deployments: list,
|
||||
messages: Optional[List[Dict[str, str]]] = None,
|
||||
input: Optional[Union[str, List]] = None,
|
||||
request_kwargs: Optional[Dict] = None,
|
||||
request_count_dict: Optional[Dict] = None,
|
||||
):
|
||||
"""Common logic for both sync and async get_available_deployments"""
|
||||
|
||||
# -----------------------
|
||||
# Find lowest used model
|
||||
# ----------------------
|
||||
_latency_per_deployment = {}
|
||||
lowest_latency = float("inf")
|
||||
|
||||
current_date = datetime.now().strftime("%Y-%m-%d")
|
||||
current_hour = datetime.now().strftime("%H")
|
||||
current_minute = datetime.now().strftime("%M")
|
||||
precise_minute = f"{current_date}-{current_hour}-{current_minute}"
|
||||
|
||||
deployment = None
|
||||
|
||||
if request_count_dict is None: # base case
|
||||
return
|
||||
|
||||
all_deployments = request_count_dict
|
||||
for d in healthy_deployments:
|
||||
## if healthy deployment not yet used
|
||||
if d["model_info"]["id"] not in all_deployments:
|
||||
all_deployments[d["model_info"]["id"]] = {
|
||||
"latency": [0],
|
||||
precise_minute: {"tpm": 0, "rpm": 0},
|
||||
}
|
||||
|
||||
try:
|
||||
input_tokens = token_counter(messages=messages, text=input)
|
||||
except Exception:
|
||||
input_tokens = 0
|
||||
|
||||
# randomly sample from all_deployments, incase all deployments have latency=0.0
|
||||
_items = all_deployments.items()
|
||||
|
||||
_all_deployments = random.sample(list(_items), len(_items))
|
||||
all_deployments = dict(_all_deployments)
|
||||
### GET AVAILABLE DEPLOYMENTS ### filter out any deployments > tpm/rpm limits
|
||||
|
||||
potential_deployments = []
|
||||
for item, item_map in all_deployments.items():
|
||||
## get the item from model list
|
||||
_deployment = None
|
||||
for m in healthy_deployments:
|
||||
if item == m["model_info"]["id"]:
|
||||
_deployment = m
|
||||
|
||||
if _deployment is None:
|
||||
continue # skip to next one
|
||||
|
||||
_deployment_tpm = (
|
||||
_deployment.get("tpm", None)
|
||||
or _deployment.get("litellm_params", {}).get("tpm", None)
|
||||
or _deployment.get("model_info", {}).get("tpm", None)
|
||||
or float("inf")
|
||||
)
|
||||
|
||||
_deployment_rpm = (
|
||||
_deployment.get("rpm", None)
|
||||
or _deployment.get("litellm_params", {}).get("rpm", None)
|
||||
or _deployment.get("model_info", {}).get("rpm", None)
|
||||
or float("inf")
|
||||
)
|
||||
item_latency = item_map.get("latency", [])
|
||||
item_ttft_latency = item_map.get("time_to_first_token", [])
|
||||
item_rpm = item_map.get(precise_minute, {}).get("rpm", 0)
|
||||
item_tpm = item_map.get(precise_minute, {}).get("tpm", 0)
|
||||
|
||||
# get average latency or average ttft (depending on streaming/non-streaming)
|
||||
total: float = 0.0
|
||||
if (
|
||||
request_kwargs is not None
|
||||
and request_kwargs.get("stream", None) is not None
|
||||
and request_kwargs["stream"] is True
|
||||
and len(item_ttft_latency) > 0
|
||||
):
|
||||
for _call_latency in item_ttft_latency:
|
||||
if isinstance(_call_latency, float):
|
||||
total += _call_latency
|
||||
else:
|
||||
for _call_latency in item_latency:
|
||||
if isinstance(_call_latency, float):
|
||||
total += _call_latency
|
||||
item_latency = total / len(item_latency)
|
||||
|
||||
# -------------- #
|
||||
# Debugging Logic
|
||||
# -------------- #
|
||||
# We use _latency_per_deployment to log to langfuse, slack - this is not used to make a decision on routing
|
||||
# this helps a user to debug why the router picked a specfic deployment #
|
||||
_deployment_api_base = _deployment.get("litellm_params", {}).get(
|
||||
"api_base", ""
|
||||
)
|
||||
if _deployment_api_base is not None:
|
||||
_latency_per_deployment[_deployment_api_base] = item_latency
|
||||
# -------------- #
|
||||
# End of Debugging Logic
|
||||
# -------------- #
|
||||
|
||||
if (
|
||||
item_tpm + input_tokens > _deployment_tpm
|
||||
or item_rpm + 1 > _deployment_rpm
|
||||
): # if user passed in tpm / rpm in the model_list
|
||||
continue
|
||||
else:
|
||||
potential_deployments.append((_deployment, item_latency))
|
||||
|
||||
if len(potential_deployments) == 0:
|
||||
return None
|
||||
|
||||
# Sort potential deployments by latency
|
||||
sorted_deployments = sorted(potential_deployments, key=lambda x: x[1])
|
||||
|
||||
# Find lowest latency deployment
|
||||
lowest_latency = sorted_deployments[0][1]
|
||||
|
||||
# Find deployments within buffer of lowest latency
|
||||
buffer = self.routing_args.lowest_latency_buffer * lowest_latency
|
||||
|
||||
valid_deployments = [
|
||||
x for x in sorted_deployments if x[1] <= lowest_latency + buffer
|
||||
]
|
||||
|
||||
# Pick a random deployment from valid deployments
|
||||
random_valid_deployment = random.choice(valid_deployments)
|
||||
deployment = random_valid_deployment[0]
|
||||
|
||||
if request_kwargs is not None and "metadata" in request_kwargs:
|
||||
request_kwargs["metadata"][
|
||||
"_latency_per_deployment"
|
||||
] = _latency_per_deployment
|
||||
return deployment
|
||||
|
||||
async def async_get_available_deployments(
|
||||
self,
|
||||
model_group: str,
|
||||
healthy_deployments: list,
|
||||
messages: Optional[List[Dict[str, str]]] = None,
|
||||
input: Optional[Union[str, List]] = None,
|
||||
request_kwargs: Optional[Dict] = None,
|
||||
):
|
||||
# get list of potential deployments
|
||||
latency_key = f"{model_group}_map"
|
||||
|
||||
parent_otel_span: Optional[Span] = _get_parent_otel_span_from_kwargs(
|
||||
request_kwargs
|
||||
)
|
||||
request_count_dict = (
|
||||
await self.router_cache.async_get_cache(
|
||||
key=latency_key, parent_otel_span=parent_otel_span
|
||||
)
|
||||
or {}
|
||||
)
|
||||
|
||||
return self._get_available_deployments(
|
||||
model_group,
|
||||
healthy_deployments,
|
||||
messages,
|
||||
input,
|
||||
request_kwargs,
|
||||
request_count_dict,
|
||||
)
|
||||
|
||||
def get_available_deployments(
|
||||
self,
|
||||
model_group: str,
|
||||
healthy_deployments: list,
|
||||
messages: Optional[List[Dict[str, str]]] = None,
|
||||
input: Optional[Union[str, List]] = None,
|
||||
request_kwargs: Optional[Dict] = None,
|
||||
):
|
||||
"""
|
||||
Returns a deployment with the lowest latency
|
||||
"""
|
||||
# get list of potential deployments
|
||||
latency_key = f"{model_group}_map"
|
||||
|
||||
parent_otel_span: Optional[Span] = _get_parent_otel_span_from_kwargs(
|
||||
request_kwargs
|
||||
)
|
||||
request_count_dict = (
|
||||
self.router_cache.get_cache(
|
||||
key=latency_key, parent_otel_span=parent_otel_span
|
||||
)
|
||||
or {}
|
||||
)
|
||||
|
||||
return self._get_available_deployments(
|
||||
model_group,
|
||||
healthy_deployments,
|
||||
messages,
|
||||
input,
|
||||
request_kwargs,
|
||||
request_count_dict,
|
||||
)
|
||||
@@ -0,0 +1,243 @@
|
||||
#### What this does ####
|
||||
# identifies lowest tpm deployment
|
||||
import traceback
|
||||
from datetime import datetime
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
from litellm import token_counter
|
||||
from litellm._logging import verbose_router_logger
|
||||
from litellm.caching.caching import DualCache
|
||||
from litellm.integrations.custom_logger import CustomLogger
|
||||
from litellm.types.utils import LiteLLMPydanticObjectBase
|
||||
from litellm.utils import print_verbose
|
||||
|
||||
|
||||
class RoutingArgs(LiteLLMPydanticObjectBase):
|
||||
ttl: int = 1 * 60 # 1min (RPM/TPM expire key)
|
||||
|
||||
|
||||
class LowestTPMLoggingHandler(CustomLogger):
|
||||
test_flag: bool = False
|
||||
logged_success: int = 0
|
||||
logged_failure: int = 0
|
||||
default_cache_time_seconds: int = 1 * 60 * 60 # 1 hour
|
||||
|
||||
def __init__(
|
||||
self, router_cache: DualCache, model_list: list, routing_args: dict = {}
|
||||
):
|
||||
self.router_cache = router_cache
|
||||
self.model_list = model_list
|
||||
self.routing_args = RoutingArgs(**routing_args)
|
||||
|
||||
def log_success_event(self, kwargs, response_obj, start_time, end_time):
|
||||
try:
|
||||
"""
|
||||
Update TPM/RPM usage on success
|
||||
"""
|
||||
if kwargs["litellm_params"].get("metadata") is None:
|
||||
pass
|
||||
else:
|
||||
model_group = kwargs["litellm_params"]["metadata"].get(
|
||||
"model_group", None
|
||||
)
|
||||
|
||||
id = kwargs["litellm_params"].get("model_info", {}).get("id", None)
|
||||
if model_group is None or id is None:
|
||||
return
|
||||
elif isinstance(id, int):
|
||||
id = str(id)
|
||||
|
||||
total_tokens = response_obj["usage"]["total_tokens"]
|
||||
|
||||
# ------------
|
||||
# Setup values
|
||||
# ------------
|
||||
current_minute = datetime.now().strftime("%H-%M")
|
||||
tpm_key = f"{model_group}:tpm:{current_minute}"
|
||||
rpm_key = f"{model_group}:rpm:{current_minute}"
|
||||
|
||||
# ------------
|
||||
# Update usage
|
||||
# ------------
|
||||
|
||||
## TPM
|
||||
request_count_dict = self.router_cache.get_cache(key=tpm_key) or {}
|
||||
request_count_dict[id] = request_count_dict.get(id, 0) + total_tokens
|
||||
|
||||
self.router_cache.set_cache(
|
||||
key=tpm_key, value=request_count_dict, ttl=self.routing_args.ttl
|
||||
)
|
||||
|
||||
## RPM
|
||||
request_count_dict = self.router_cache.get_cache(key=rpm_key) or {}
|
||||
request_count_dict[id] = request_count_dict.get(id, 0) + 1
|
||||
|
||||
self.router_cache.set_cache(
|
||||
key=rpm_key, value=request_count_dict, ttl=self.routing_args.ttl
|
||||
)
|
||||
|
||||
### TESTING ###
|
||||
if self.test_flag:
|
||||
self.logged_success += 1
|
||||
except Exception as e:
|
||||
verbose_router_logger.error(
|
||||
"litellm.router_strategy.lowest_tpm_rpm.py::async_log_success_event(): Exception occured - {}".format(
|
||||
str(e)
|
||||
)
|
||||
)
|
||||
verbose_router_logger.debug(traceback.format_exc())
|
||||
pass
|
||||
|
||||
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
|
||||
try:
|
||||
"""
|
||||
Update TPM/RPM usage on success
|
||||
"""
|
||||
if kwargs["litellm_params"].get("metadata") is None:
|
||||
pass
|
||||
else:
|
||||
model_group = kwargs["litellm_params"]["metadata"].get(
|
||||
"model_group", None
|
||||
)
|
||||
|
||||
id = kwargs["litellm_params"].get("model_info", {}).get("id", None)
|
||||
if model_group is None or id is None:
|
||||
return
|
||||
elif isinstance(id, int):
|
||||
id = str(id)
|
||||
|
||||
total_tokens = response_obj["usage"]["total_tokens"]
|
||||
|
||||
# ------------
|
||||
# Setup values
|
||||
# ------------
|
||||
current_minute = datetime.now().strftime("%H-%M")
|
||||
tpm_key = f"{model_group}:tpm:{current_minute}"
|
||||
rpm_key = f"{model_group}:rpm:{current_minute}"
|
||||
|
||||
# ------------
|
||||
# Update usage
|
||||
# ------------
|
||||
# update cache
|
||||
|
||||
## TPM
|
||||
request_count_dict = (
|
||||
await self.router_cache.async_get_cache(key=tpm_key) or {}
|
||||
)
|
||||
request_count_dict[id] = request_count_dict.get(id, 0) + total_tokens
|
||||
|
||||
await self.router_cache.async_set_cache(
|
||||
key=tpm_key, value=request_count_dict, ttl=self.routing_args.ttl
|
||||
)
|
||||
|
||||
## RPM
|
||||
request_count_dict = (
|
||||
await self.router_cache.async_get_cache(key=rpm_key) or {}
|
||||
)
|
||||
request_count_dict[id] = request_count_dict.get(id, 0) + 1
|
||||
|
||||
await self.router_cache.async_set_cache(
|
||||
key=rpm_key, value=request_count_dict, ttl=self.routing_args.ttl
|
||||
)
|
||||
|
||||
### TESTING ###
|
||||
if self.test_flag:
|
||||
self.logged_success += 1
|
||||
except Exception as e:
|
||||
verbose_router_logger.error(
|
||||
"litellm.router_strategy.lowest_tpm_rpm.py::async_log_success_event(): Exception occured - {}".format(
|
||||
str(e)
|
||||
)
|
||||
)
|
||||
verbose_router_logger.debug(traceback.format_exc())
|
||||
pass
|
||||
|
||||
def get_available_deployments( # noqa: PLR0915
|
||||
self,
|
||||
model_group: str,
|
||||
healthy_deployments: list,
|
||||
messages: Optional[List[Dict[str, str]]] = None,
|
||||
input: Optional[Union[str, List]] = None,
|
||||
):
|
||||
"""
|
||||
Returns a deployment with the lowest TPM/RPM usage.
|
||||
"""
|
||||
# get list of potential deployments
|
||||
verbose_router_logger.debug(
|
||||
f"get_available_deployments - Usage Based. model_group: {model_group}, healthy_deployments: {healthy_deployments}"
|
||||
)
|
||||
current_minute = datetime.now().strftime("%H-%M")
|
||||
tpm_key = f"{model_group}:tpm:{current_minute}"
|
||||
rpm_key = f"{model_group}:rpm:{current_minute}"
|
||||
|
||||
tpm_dict = self.router_cache.get_cache(key=tpm_key)
|
||||
rpm_dict = self.router_cache.get_cache(key=rpm_key)
|
||||
|
||||
verbose_router_logger.debug(
|
||||
f"tpm_key={tpm_key}, tpm_dict: {tpm_dict}, rpm_dict: {rpm_dict}"
|
||||
)
|
||||
try:
|
||||
input_tokens = token_counter(messages=messages, text=input)
|
||||
except Exception:
|
||||
input_tokens = 0
|
||||
verbose_router_logger.debug(f"input_tokens={input_tokens}")
|
||||
# -----------------------
|
||||
# Find lowest used model
|
||||
# ----------------------
|
||||
lowest_tpm = float("inf")
|
||||
|
||||
if tpm_dict is None: # base case - none of the deployments have been used
|
||||
# initialize a tpm dict with {model_id: 0}
|
||||
tpm_dict = {}
|
||||
for deployment in healthy_deployments:
|
||||
tpm_dict[deployment["model_info"]["id"]] = 0
|
||||
else:
|
||||
for d in healthy_deployments:
|
||||
## if healthy deployment not yet used
|
||||
if d["model_info"]["id"] not in tpm_dict:
|
||||
tpm_dict[d["model_info"]["id"]] = 0
|
||||
|
||||
all_deployments = tpm_dict
|
||||
|
||||
deployment = None
|
||||
for item, item_tpm in all_deployments.items():
|
||||
## get the item from model list
|
||||
_deployment = None
|
||||
for m in healthy_deployments:
|
||||
if item == m["model_info"]["id"]:
|
||||
_deployment = m
|
||||
|
||||
if _deployment is None:
|
||||
continue # skip to next one
|
||||
|
||||
_deployment_tpm = None
|
||||
if _deployment_tpm is None:
|
||||
_deployment_tpm = _deployment.get("tpm")
|
||||
if _deployment_tpm is None:
|
||||
_deployment_tpm = _deployment.get("litellm_params", {}).get("tpm")
|
||||
if _deployment_tpm is None:
|
||||
_deployment_tpm = _deployment.get("model_info", {}).get("tpm")
|
||||
if _deployment_tpm is None:
|
||||
_deployment_tpm = float("inf")
|
||||
|
||||
_deployment_rpm = None
|
||||
if _deployment_rpm is None:
|
||||
_deployment_rpm = _deployment.get("rpm")
|
||||
if _deployment_rpm is None:
|
||||
_deployment_rpm = _deployment.get("litellm_params", {}).get("rpm")
|
||||
if _deployment_rpm is None:
|
||||
_deployment_rpm = _deployment.get("model_info", {}).get("rpm")
|
||||
if _deployment_rpm is None:
|
||||
_deployment_rpm = float("inf")
|
||||
|
||||
if item_tpm + input_tokens > _deployment_tpm:
|
||||
continue
|
||||
elif (rpm_dict is not None and item in rpm_dict) and (
|
||||
rpm_dict[item] + 1 >= _deployment_rpm
|
||||
):
|
||||
continue
|
||||
elif item_tpm < lowest_tpm:
|
||||
lowest_tpm = item_tpm
|
||||
deployment = _deployment
|
||||
print_verbose("returning picked lowest tpm/rpm deployment.")
|
||||
return deployment
|
||||
@@ -0,0 +1,670 @@
|
||||
#### What this does ####
|
||||
# identifies lowest tpm deployment
|
||||
import random
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
|
||||
|
||||
import httpx
|
||||
|
||||
import litellm
|
||||
from litellm import token_counter
|
||||
from litellm._logging import verbose_logger, verbose_router_logger
|
||||
from litellm.caching.caching import DualCache
|
||||
from litellm.integrations.custom_logger import CustomLogger
|
||||
from litellm.litellm_core_utils.core_helpers import _get_parent_otel_span_from_kwargs
|
||||
from litellm.types.router import RouterErrors
|
||||
from litellm.types.utils import LiteLLMPydanticObjectBase, StandardLoggingPayload
|
||||
from litellm.utils import get_utc_datetime, print_verbose
|
||||
|
||||
from .base_routing_strategy import BaseRoutingStrategy
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from opentelemetry.trace import Span as _Span
|
||||
|
||||
Span = Union[_Span, Any]
|
||||
else:
|
||||
Span = Any
|
||||
|
||||
|
||||
class RoutingArgs(LiteLLMPydanticObjectBase):
|
||||
ttl: int = 1 * 60 # 1min (RPM/TPM expire key)
|
||||
|
||||
|
||||
class LowestTPMLoggingHandler_v2(BaseRoutingStrategy, CustomLogger):
|
||||
"""
|
||||
Updated version of TPM/RPM Logging.
|
||||
|
||||
Meant to work across instances.
|
||||
|
||||
Caches individual models, not model_groups
|
||||
|
||||
Uses batch get (redis.mget)
|
||||
|
||||
Increments tpm/rpm limit using redis.incr
|
||||
"""
|
||||
|
||||
test_flag: bool = False
|
||||
logged_success: int = 0
|
||||
logged_failure: int = 0
|
||||
default_cache_time_seconds: int = 1 * 60 * 60 # 1 hour
|
||||
|
||||
def __init__(
|
||||
self, router_cache: DualCache, model_list: list, routing_args: dict = {}
|
||||
):
|
||||
self.router_cache = router_cache
|
||||
self.model_list = model_list
|
||||
self.routing_args = RoutingArgs(**routing_args)
|
||||
BaseRoutingStrategy.__init__(
|
||||
self,
|
||||
dual_cache=router_cache,
|
||||
should_batch_redis_writes=True,
|
||||
default_sync_interval=0.1,
|
||||
)
|
||||
|
||||
def pre_call_check(self, deployment: Dict) -> Optional[Dict]:
|
||||
"""
|
||||
Pre-call check + update model rpm
|
||||
|
||||
Returns - deployment
|
||||
|
||||
Raises - RateLimitError if deployment over defined RPM limit
|
||||
"""
|
||||
try:
|
||||
# ------------
|
||||
# Setup values
|
||||
# ------------
|
||||
|
||||
dt = get_utc_datetime()
|
||||
current_minute = dt.strftime("%H-%M")
|
||||
model_id = deployment.get("model_info", {}).get("id")
|
||||
deployment_name = deployment.get("litellm_params", {}).get("model")
|
||||
rpm_key = f"{model_id}:{deployment_name}:rpm:{current_minute}"
|
||||
|
||||
local_result = self.router_cache.get_cache(
|
||||
key=rpm_key, local_only=True
|
||||
) # check local result first
|
||||
|
||||
deployment_rpm = None
|
||||
if deployment_rpm is None:
|
||||
deployment_rpm = deployment.get("rpm")
|
||||
if deployment_rpm is None:
|
||||
deployment_rpm = deployment.get("litellm_params", {}).get("rpm")
|
||||
if deployment_rpm is None:
|
||||
deployment_rpm = deployment.get("model_info", {}).get("rpm")
|
||||
if deployment_rpm is None:
|
||||
deployment_rpm = float("inf")
|
||||
|
||||
if local_result is not None and local_result >= deployment_rpm:
|
||||
raise litellm.RateLimitError(
|
||||
message="Deployment over defined rpm limit={}. current usage={}".format(
|
||||
deployment_rpm, local_result
|
||||
),
|
||||
llm_provider="",
|
||||
model=deployment.get("litellm_params", {}).get("model"),
|
||||
response=httpx.Response(
|
||||
status_code=429,
|
||||
content="{} rpm limit={}. current usage={}. id={}, model_group={}. Get the model info by calling 'router.get_model_info(id)".format(
|
||||
RouterErrors.user_defined_ratelimit_error.value,
|
||||
deployment_rpm,
|
||||
local_result,
|
||||
model_id,
|
||||
deployment.get("model_name", ""),
|
||||
),
|
||||
request=httpx.Request(method="tpm_rpm_limits", url="https://github.com/BerriAI/litellm"), # type: ignore
|
||||
),
|
||||
)
|
||||
else:
|
||||
# if local result below limit, check redis ## prevent unnecessary redis checks
|
||||
|
||||
result = self.router_cache.increment_cache(
|
||||
key=rpm_key, value=1, ttl=self.routing_args.ttl
|
||||
)
|
||||
if result is not None and result > deployment_rpm:
|
||||
raise litellm.RateLimitError(
|
||||
message="Deployment over defined rpm limit={}. current usage={}".format(
|
||||
deployment_rpm, result
|
||||
),
|
||||
llm_provider="",
|
||||
model=deployment.get("litellm_params", {}).get("model"),
|
||||
response=httpx.Response(
|
||||
status_code=429,
|
||||
content="{} rpm limit={}. current usage={}".format(
|
||||
RouterErrors.user_defined_ratelimit_error.value,
|
||||
deployment_rpm,
|
||||
result,
|
||||
),
|
||||
request=httpx.Request(method="tpm_rpm_limits", url="https://github.com/BerriAI/litellm"), # type: ignore
|
||||
),
|
||||
)
|
||||
return deployment
|
||||
except Exception as e:
|
||||
if isinstance(e, litellm.RateLimitError):
|
||||
raise e
|
||||
return deployment # don't fail calls if eg. redis fails to connect
|
||||
|
||||
async def async_pre_call_check(
|
||||
self, deployment: Dict, parent_otel_span: Optional[Span]
|
||||
) -> Optional[Dict]:
|
||||
"""
|
||||
Pre-call check + update model rpm
|
||||
- Used inside semaphore
|
||||
- raise rate limit error if deployment over limit
|
||||
|
||||
Why? solves concurrency issue - https://github.com/BerriAI/litellm/issues/2994
|
||||
|
||||
Returns - deployment
|
||||
|
||||
Raises - RateLimitError if deployment over defined RPM limit
|
||||
"""
|
||||
try:
|
||||
# ------------
|
||||
# Setup values
|
||||
# ------------
|
||||
dt = get_utc_datetime()
|
||||
current_minute = dt.strftime("%H-%M")
|
||||
model_id = deployment.get("model_info", {}).get("id")
|
||||
deployment_name = deployment.get("litellm_params", {}).get("model")
|
||||
|
||||
rpm_key = f"{model_id}:{deployment_name}:rpm:{current_minute}"
|
||||
local_result = await self.router_cache.async_get_cache(
|
||||
key=rpm_key, local_only=True
|
||||
) # check local result first
|
||||
|
||||
deployment_rpm = None
|
||||
if deployment_rpm is None:
|
||||
deployment_rpm = deployment.get("rpm")
|
||||
if deployment_rpm is None:
|
||||
deployment_rpm = deployment.get("litellm_params", {}).get("rpm")
|
||||
if deployment_rpm is None:
|
||||
deployment_rpm = deployment.get("model_info", {}).get("rpm")
|
||||
if deployment_rpm is None:
|
||||
deployment_rpm = float("inf")
|
||||
if local_result is not None and local_result >= deployment_rpm:
|
||||
raise litellm.RateLimitError(
|
||||
message="Deployment over defined rpm limit={}. current usage={}".format(
|
||||
deployment_rpm, local_result
|
||||
),
|
||||
llm_provider="",
|
||||
model=deployment.get("litellm_params", {}).get("model"),
|
||||
response=httpx.Response(
|
||||
status_code=429,
|
||||
content="{} rpm limit={}. current usage={}".format(
|
||||
RouterErrors.user_defined_ratelimit_error.value,
|
||||
deployment_rpm,
|
||||
local_result,
|
||||
),
|
||||
headers={"retry-after": str(60)}, # type: ignore
|
||||
request=httpx.Request(method="tpm_rpm_limits", url="https://github.com/BerriAI/litellm"), # type: ignore
|
||||
),
|
||||
num_retries=deployment.get("num_retries"),
|
||||
)
|
||||
else:
|
||||
# if local result below limit, check redis ## prevent unnecessary redis checks
|
||||
result = await self._increment_value_in_current_window(
|
||||
key=rpm_key, value=1, ttl=self.routing_args.ttl
|
||||
)
|
||||
if result is not None and result > deployment_rpm:
|
||||
raise litellm.RateLimitError(
|
||||
message="Deployment over defined rpm limit={}. current usage={}".format(
|
||||
deployment_rpm, result
|
||||
),
|
||||
llm_provider="",
|
||||
model=deployment.get("litellm_params", {}).get("model"),
|
||||
response=httpx.Response(
|
||||
status_code=429,
|
||||
content="{} rpm limit={}. current usage={}".format(
|
||||
RouterErrors.user_defined_ratelimit_error.value,
|
||||
deployment_rpm,
|
||||
result,
|
||||
),
|
||||
headers={"retry-after": str(60)}, # type: ignore
|
||||
request=httpx.Request(method="tpm_rpm_limits", url="https://github.com/BerriAI/litellm"), # type: ignore
|
||||
),
|
||||
num_retries=deployment.get("num_retries"),
|
||||
)
|
||||
return deployment
|
||||
except Exception as e:
|
||||
if isinstance(e, litellm.RateLimitError):
|
||||
raise e
|
||||
return deployment # don't fail calls if eg. redis fails to connect
|
||||
|
||||
def log_success_event(self, kwargs, response_obj, start_time, end_time):
|
||||
try:
|
||||
"""
|
||||
Update TPM/RPM usage on success
|
||||
"""
|
||||
standard_logging_object: Optional[StandardLoggingPayload] = kwargs.get(
|
||||
"standard_logging_object"
|
||||
)
|
||||
if standard_logging_object is None:
|
||||
raise ValueError("standard_logging_object not passed in.")
|
||||
model_group = standard_logging_object.get("model_group")
|
||||
model = standard_logging_object["hidden_params"].get("litellm_model_name")
|
||||
id = standard_logging_object.get("model_id")
|
||||
if model_group is None or id is None or model is None:
|
||||
return
|
||||
elif isinstance(id, int):
|
||||
id = str(id)
|
||||
|
||||
total_tokens = standard_logging_object.get("total_tokens")
|
||||
|
||||
# ------------
|
||||
# Setup values
|
||||
# ------------
|
||||
dt = get_utc_datetime()
|
||||
current_minute = dt.strftime(
|
||||
"%H-%M"
|
||||
) # use the same timezone regardless of system clock
|
||||
|
||||
tpm_key = f"{id}:{model}:tpm:{current_minute}"
|
||||
# ------------
|
||||
# Update usage
|
||||
# ------------
|
||||
# update cache
|
||||
|
||||
## TPM
|
||||
self.router_cache.increment_cache(
|
||||
key=tpm_key, value=total_tokens, ttl=self.routing_args.ttl
|
||||
)
|
||||
### TESTING ###
|
||||
if self.test_flag:
|
||||
self.logged_success += 1
|
||||
except Exception as e:
|
||||
verbose_logger.exception(
|
||||
"litellm.proxy.hooks.lowest_tpm_rpm_v2.py::log_success_event(): Exception occured - {}".format(
|
||||
str(e)
|
||||
)
|
||||
)
|
||||
pass
|
||||
|
||||
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
|
||||
try:
|
||||
"""
|
||||
Update TPM usage on success
|
||||
"""
|
||||
standard_logging_object: Optional[StandardLoggingPayload] = kwargs.get(
|
||||
"standard_logging_object"
|
||||
)
|
||||
if standard_logging_object is None:
|
||||
raise ValueError("standard_logging_object not passed in.")
|
||||
model_group = standard_logging_object.get("model_group")
|
||||
model = standard_logging_object["hidden_params"]["litellm_model_name"]
|
||||
id = standard_logging_object.get("model_id")
|
||||
if model_group is None or id is None:
|
||||
return
|
||||
elif isinstance(id, int):
|
||||
id = str(id)
|
||||
total_tokens = standard_logging_object.get("total_tokens")
|
||||
# ------------
|
||||
# Setup values
|
||||
# ------------
|
||||
dt = get_utc_datetime()
|
||||
current_minute = dt.strftime(
|
||||
"%H-%M"
|
||||
) # use the same timezone regardless of system clock
|
||||
|
||||
tpm_key = f"{id}:{model}:tpm:{current_minute}"
|
||||
# ------------
|
||||
# Update usage
|
||||
# ------------
|
||||
# update cache
|
||||
parent_otel_span = _get_parent_otel_span_from_kwargs(kwargs)
|
||||
## TPM
|
||||
await self.router_cache.async_increment_cache(
|
||||
key=tpm_key,
|
||||
value=total_tokens,
|
||||
ttl=self.routing_args.ttl,
|
||||
parent_otel_span=parent_otel_span,
|
||||
)
|
||||
|
||||
### TESTING ###
|
||||
if self.test_flag:
|
||||
self.logged_success += 1
|
||||
except Exception as e:
|
||||
verbose_logger.exception(
|
||||
"litellm.proxy.hooks.lowest_tpm_rpm_v2.py::async_log_success_event(): Exception occured - {}".format(
|
||||
str(e)
|
||||
)
|
||||
)
|
||||
pass
|
||||
|
||||
def _return_potential_deployments(
|
||||
self,
|
||||
healthy_deployments: List[Dict],
|
||||
all_deployments: Dict,
|
||||
input_tokens: int,
|
||||
rpm_dict: Dict,
|
||||
):
|
||||
lowest_tpm = float("inf")
|
||||
potential_deployments = [] # if multiple deployments have the same low value
|
||||
for item, item_tpm in all_deployments.items():
|
||||
## get the item from model list
|
||||
_deployment = None
|
||||
item = item.split(":")[0]
|
||||
for m in healthy_deployments:
|
||||
if item == m["model_info"]["id"]:
|
||||
_deployment = m
|
||||
if _deployment is None:
|
||||
continue # skip to next one
|
||||
elif item_tpm is None:
|
||||
continue # skip if unhealthy deployment
|
||||
|
||||
_deployment_tpm = None
|
||||
if _deployment_tpm is None:
|
||||
_deployment_tpm = _deployment.get("tpm")
|
||||
if _deployment_tpm is None:
|
||||
_deployment_tpm = _deployment.get("litellm_params", {}).get("tpm")
|
||||
if _deployment_tpm is None:
|
||||
_deployment_tpm = _deployment.get("model_info", {}).get("tpm")
|
||||
if _deployment_tpm is None:
|
||||
_deployment_tpm = float("inf")
|
||||
|
||||
_deployment_rpm = None
|
||||
if _deployment_rpm is None:
|
||||
_deployment_rpm = _deployment.get("rpm")
|
||||
if _deployment_rpm is None:
|
||||
_deployment_rpm = _deployment.get("litellm_params", {}).get("rpm")
|
||||
if _deployment_rpm is None:
|
||||
_deployment_rpm = _deployment.get("model_info", {}).get("rpm")
|
||||
if _deployment_rpm is None:
|
||||
_deployment_rpm = float("inf")
|
||||
if item_tpm + input_tokens > _deployment_tpm:
|
||||
continue
|
||||
elif (
|
||||
(rpm_dict is not None and item in rpm_dict)
|
||||
and rpm_dict[item] is not None
|
||||
and (rpm_dict[item] + 1 >= _deployment_rpm)
|
||||
):
|
||||
continue
|
||||
elif item_tpm == lowest_tpm:
|
||||
potential_deployments.append(_deployment)
|
||||
elif item_tpm < lowest_tpm:
|
||||
lowest_tpm = item_tpm
|
||||
potential_deployments = [_deployment]
|
||||
return potential_deployments
|
||||
|
||||
def _common_checks_available_deployment( # noqa: PLR0915
|
||||
self,
|
||||
model_group: str,
|
||||
healthy_deployments: list,
|
||||
tpm_keys: list,
|
||||
tpm_values: Optional[list],
|
||||
rpm_keys: list,
|
||||
rpm_values: Optional[list],
|
||||
messages: Optional[List[Dict[str, str]]] = None,
|
||||
input: Optional[Union[str, List]] = None,
|
||||
) -> Optional[dict]:
|
||||
"""
|
||||
Common checks for get available deployment, across sync + async implementations
|
||||
"""
|
||||
|
||||
if tpm_values is None or rpm_values is None:
|
||||
return None
|
||||
|
||||
tpm_dict = {} # {model_id: 1, ..}
|
||||
for idx, key in enumerate(tpm_keys):
|
||||
tpm_dict[tpm_keys[idx].split(":")[0]] = tpm_values[idx]
|
||||
|
||||
rpm_dict = {} # {model_id: 1, ..}
|
||||
for idx, key in enumerate(rpm_keys):
|
||||
rpm_dict[rpm_keys[idx].split(":")[0]] = rpm_values[idx]
|
||||
|
||||
try:
|
||||
input_tokens = token_counter(messages=messages, text=input)
|
||||
except Exception:
|
||||
input_tokens = 0
|
||||
verbose_router_logger.debug(f"input_tokens={input_tokens}")
|
||||
# -----------------------
|
||||
# Find lowest used model
|
||||
# ----------------------
|
||||
|
||||
if tpm_dict is None: # base case - none of the deployments have been used
|
||||
# initialize a tpm dict with {model_id: 0}
|
||||
tpm_dict = {}
|
||||
for deployment in healthy_deployments:
|
||||
tpm_dict[deployment["model_info"]["id"]] = 0
|
||||
else:
|
||||
for d in healthy_deployments:
|
||||
## if healthy deployment not yet used
|
||||
tpm_key = d["model_info"]["id"]
|
||||
if tpm_key not in tpm_dict or tpm_dict[tpm_key] is None:
|
||||
tpm_dict[tpm_key] = 0
|
||||
|
||||
all_deployments = tpm_dict
|
||||
potential_deployments = self._return_potential_deployments(
|
||||
healthy_deployments=healthy_deployments,
|
||||
all_deployments=all_deployments,
|
||||
input_tokens=input_tokens,
|
||||
rpm_dict=rpm_dict,
|
||||
)
|
||||
print_verbose("returning picked lowest tpm/rpm deployment.")
|
||||
|
||||
if len(potential_deployments) > 0:
|
||||
return random.choice(potential_deployments)
|
||||
else:
|
||||
return None
|
||||
|
||||
async def async_get_available_deployments(
|
||||
self,
|
||||
model_group: str,
|
||||
healthy_deployments: list,
|
||||
messages: Optional[List[Dict[str, str]]] = None,
|
||||
input: Optional[Union[str, List]] = None,
|
||||
):
|
||||
"""
|
||||
Async implementation of get deployments.
|
||||
|
||||
Reduces time to retrieve the tpm/rpm values from cache
|
||||
"""
|
||||
# get list of potential deployments
|
||||
verbose_router_logger.debug(
|
||||
f"get_available_deployments - Usage Based. model_group: {model_group}, healthy_deployments: {healthy_deployments}"
|
||||
)
|
||||
|
||||
dt = get_utc_datetime()
|
||||
current_minute = dt.strftime("%H-%M")
|
||||
|
||||
tpm_keys = []
|
||||
rpm_keys = []
|
||||
for m in healthy_deployments:
|
||||
if isinstance(m, dict):
|
||||
id = m.get("model_info", {}).get(
|
||||
"id"
|
||||
) # a deployment should always have an 'id'. this is set in router.py
|
||||
deployment_name = m.get("litellm_params", {}).get("model")
|
||||
tpm_key = "{}:{}:tpm:{}".format(id, deployment_name, current_minute)
|
||||
rpm_key = "{}:{}:rpm:{}".format(id, deployment_name, current_minute)
|
||||
|
||||
tpm_keys.append(tpm_key)
|
||||
rpm_keys.append(rpm_key)
|
||||
|
||||
combined_tpm_rpm_keys = tpm_keys + rpm_keys
|
||||
|
||||
combined_tpm_rpm_values = await self.router_cache.async_batch_get_cache(
|
||||
keys=combined_tpm_rpm_keys
|
||||
) # [1, 2, None, ..]
|
||||
|
||||
if combined_tpm_rpm_values is not None:
|
||||
tpm_values = combined_tpm_rpm_values[: len(tpm_keys)]
|
||||
rpm_values = combined_tpm_rpm_values[len(tpm_keys) :]
|
||||
else:
|
||||
tpm_values = None
|
||||
rpm_values = None
|
||||
|
||||
deployment = self._common_checks_available_deployment(
|
||||
model_group=model_group,
|
||||
healthy_deployments=healthy_deployments,
|
||||
tpm_keys=tpm_keys,
|
||||
tpm_values=tpm_values,
|
||||
rpm_keys=rpm_keys,
|
||||
rpm_values=rpm_values,
|
||||
messages=messages,
|
||||
input=input,
|
||||
)
|
||||
|
||||
try:
|
||||
assert deployment is not None
|
||||
return deployment
|
||||
except Exception:
|
||||
### GET THE DICT OF TPM / RPM + LIMITS PER DEPLOYMENT ###
|
||||
deployment_dict = {}
|
||||
for index, _deployment in enumerate(healthy_deployments):
|
||||
if isinstance(_deployment, dict):
|
||||
id = _deployment.get("model_info", {}).get("id")
|
||||
### GET DEPLOYMENT TPM LIMIT ###
|
||||
_deployment_tpm = None
|
||||
if _deployment_tpm is None:
|
||||
_deployment_tpm = _deployment.get("tpm", None)
|
||||
if _deployment_tpm is None:
|
||||
_deployment_tpm = _deployment.get("litellm_params", {}).get(
|
||||
"tpm", None
|
||||
)
|
||||
if _deployment_tpm is None:
|
||||
_deployment_tpm = _deployment.get("model_info", {}).get(
|
||||
"tpm", None
|
||||
)
|
||||
if _deployment_tpm is None:
|
||||
_deployment_tpm = float("inf")
|
||||
|
||||
### GET CURRENT TPM ###
|
||||
current_tpm = tpm_values[index] if tpm_values else 0
|
||||
|
||||
### GET DEPLOYMENT TPM LIMIT ###
|
||||
_deployment_rpm = None
|
||||
if _deployment_rpm is None:
|
||||
_deployment_rpm = _deployment.get("rpm", None)
|
||||
if _deployment_rpm is None:
|
||||
_deployment_rpm = _deployment.get("litellm_params", {}).get(
|
||||
"rpm", None
|
||||
)
|
||||
if _deployment_rpm is None:
|
||||
_deployment_rpm = _deployment.get("model_info", {}).get(
|
||||
"rpm", None
|
||||
)
|
||||
if _deployment_rpm is None:
|
||||
_deployment_rpm = float("inf")
|
||||
|
||||
### GET CURRENT RPM ###
|
||||
current_rpm = rpm_values[index] if rpm_values else 0
|
||||
|
||||
deployment_dict[id] = {
|
||||
"current_tpm": current_tpm,
|
||||
"tpm_limit": _deployment_tpm,
|
||||
"current_rpm": current_rpm,
|
||||
"rpm_limit": _deployment_rpm,
|
||||
}
|
||||
raise litellm.RateLimitError(
|
||||
message=f"{RouterErrors.no_deployments_available.value}. Passed model={model_group}. Deployments={deployment_dict}",
|
||||
llm_provider="",
|
||||
model=model_group,
|
||||
response=httpx.Response(
|
||||
status_code=429,
|
||||
content="",
|
||||
headers={"retry-after": str(60)}, # type: ignore
|
||||
request=httpx.Request(method="tpm_rpm_limits", url="https://github.com/BerriAI/litellm"), # type: ignore
|
||||
),
|
||||
)
|
||||
|
||||
def get_available_deployments(
|
||||
self,
|
||||
model_group: str,
|
||||
healthy_deployments: list,
|
||||
messages: Optional[List[Dict[str, str]]] = None,
|
||||
input: Optional[Union[str, List]] = None,
|
||||
parent_otel_span: Optional[Span] = None,
|
||||
):
|
||||
"""
|
||||
Returns a deployment with the lowest TPM/RPM usage.
|
||||
"""
|
||||
# get list of potential deployments
|
||||
verbose_router_logger.debug(
|
||||
f"get_available_deployments - Usage Based. model_group: {model_group}, healthy_deployments: {healthy_deployments}"
|
||||
)
|
||||
|
||||
dt = get_utc_datetime()
|
||||
current_minute = dt.strftime("%H-%M")
|
||||
tpm_keys = []
|
||||
rpm_keys = []
|
||||
for m in healthy_deployments:
|
||||
if isinstance(m, dict):
|
||||
id = m.get("model_info", {}).get(
|
||||
"id"
|
||||
) # a deployment should always have an 'id'. this is set in router.py
|
||||
deployment_name = m.get("litellm_params", {}).get("model")
|
||||
tpm_key = "{}:{}:tpm:{}".format(id, deployment_name, current_minute)
|
||||
rpm_key = "{}:{}:rpm:{}".format(id, deployment_name, current_minute)
|
||||
|
||||
tpm_keys.append(tpm_key)
|
||||
rpm_keys.append(rpm_key)
|
||||
|
||||
tpm_values = self.router_cache.batch_get_cache(
|
||||
keys=tpm_keys, parent_otel_span=parent_otel_span
|
||||
) # [1, 2, None, ..]
|
||||
rpm_values = self.router_cache.batch_get_cache(
|
||||
keys=rpm_keys, parent_otel_span=parent_otel_span
|
||||
) # [1, 2, None, ..]
|
||||
|
||||
deployment = self._common_checks_available_deployment(
|
||||
model_group=model_group,
|
||||
healthy_deployments=healthy_deployments,
|
||||
tpm_keys=tpm_keys,
|
||||
tpm_values=tpm_values,
|
||||
rpm_keys=rpm_keys,
|
||||
rpm_values=rpm_values,
|
||||
messages=messages,
|
||||
input=input,
|
||||
)
|
||||
|
||||
try:
|
||||
assert deployment is not None
|
||||
return deployment
|
||||
except Exception:
|
||||
### GET THE DICT OF TPM / RPM + LIMITS PER DEPLOYMENT ###
|
||||
deployment_dict = {}
|
||||
for index, _deployment in enumerate(healthy_deployments):
|
||||
if isinstance(_deployment, dict):
|
||||
id = _deployment.get("model_info", {}).get("id")
|
||||
### GET DEPLOYMENT TPM LIMIT ###
|
||||
_deployment_tpm = None
|
||||
if _deployment_tpm is None:
|
||||
_deployment_tpm = _deployment.get("tpm", None)
|
||||
if _deployment_tpm is None:
|
||||
_deployment_tpm = _deployment.get("litellm_params", {}).get(
|
||||
"tpm", None
|
||||
)
|
||||
if _deployment_tpm is None:
|
||||
_deployment_tpm = _deployment.get("model_info", {}).get(
|
||||
"tpm", None
|
||||
)
|
||||
if _deployment_tpm is None:
|
||||
_deployment_tpm = float("inf")
|
||||
|
||||
### GET CURRENT TPM ###
|
||||
current_tpm = tpm_values[index] if tpm_values else 0
|
||||
|
||||
### GET DEPLOYMENT TPM LIMIT ###
|
||||
_deployment_rpm = None
|
||||
if _deployment_rpm is None:
|
||||
_deployment_rpm = _deployment.get("rpm", None)
|
||||
if _deployment_rpm is None:
|
||||
_deployment_rpm = _deployment.get("litellm_params", {}).get(
|
||||
"rpm", None
|
||||
)
|
||||
if _deployment_rpm is None:
|
||||
_deployment_rpm = _deployment.get("model_info", {}).get(
|
||||
"rpm", None
|
||||
)
|
||||
if _deployment_rpm is None:
|
||||
_deployment_rpm = float("inf")
|
||||
|
||||
### GET CURRENT RPM ###
|
||||
current_rpm = rpm_values[index] if rpm_values else 0
|
||||
|
||||
deployment_dict[id] = {
|
||||
"current_tpm": current_tpm,
|
||||
"tpm_limit": _deployment_tpm,
|
||||
"current_rpm": current_rpm,
|
||||
"rpm_limit": _deployment_rpm,
|
||||
}
|
||||
raise ValueError(
|
||||
f"{RouterErrors.no_deployments_available.value}. Passed model={model_group}. Deployments={deployment_dict}"
|
||||
)
|
||||
@@ -0,0 +1,96 @@
|
||||
"""
|
||||
Returns a random deployment from the list of healthy deployments.
|
||||
|
||||
If weights are provided, it will return a deployment based on the weights.
|
||||
|
||||
"""
|
||||
|
||||
import random
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Union
|
||||
|
||||
from litellm._logging import verbose_router_logger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.router import Router as _Router
|
||||
|
||||
LitellmRouter = _Router
|
||||
else:
|
||||
LitellmRouter = Any
|
||||
|
||||
|
||||
def simple_shuffle(
|
||||
llm_router_instance: LitellmRouter,
|
||||
healthy_deployments: Union[List[Any], Dict[Any, Any]],
|
||||
model: str,
|
||||
) -> Dict:
|
||||
"""
|
||||
Returns a random deployment from the list of healthy deployments.
|
||||
|
||||
If weights are provided, it will return a deployment based on the weights.
|
||||
|
||||
If users pass `rpm` or `tpm`, we do a random weighted pick - based on `rpm`/`tpm`.
|
||||
|
||||
Args:
|
||||
llm_router_instance: LitellmRouter instance
|
||||
healthy_deployments: List of healthy deployments
|
||||
model: Model name
|
||||
|
||||
Returns:
|
||||
Dict: A single healthy deployment
|
||||
"""
|
||||
|
||||
############## Check if 'weight' param set for a weighted pick #################
|
||||
weight = healthy_deployments[0].get("litellm_params").get("weight", None)
|
||||
if weight is not None:
|
||||
# use weight-random pick if rpms provided
|
||||
weights = [m["litellm_params"].get("weight", 0) for m in healthy_deployments]
|
||||
verbose_router_logger.debug(f"\nweight {weights}")
|
||||
total_weight = sum(weights)
|
||||
weights = [weight / total_weight for weight in weights]
|
||||
verbose_router_logger.debug(f"\n weights {weights}")
|
||||
# Perform weighted random pick
|
||||
selected_index = random.choices(range(len(weights)), weights=weights)[0]
|
||||
verbose_router_logger.debug(f"\n selected index, {selected_index}")
|
||||
deployment = healthy_deployments[selected_index]
|
||||
verbose_router_logger.info(
|
||||
f"get_available_deployment for model: {model}, Selected deployment: {llm_router_instance.print_deployment(deployment) or deployment[0]} for model: {model}"
|
||||
)
|
||||
return deployment or deployment[0]
|
||||
############## Check if we can do a RPM/TPM based weighted pick #################
|
||||
rpm = healthy_deployments[0].get("litellm_params").get("rpm", None)
|
||||
if rpm is not None:
|
||||
# use weight-random pick if rpms provided
|
||||
rpms = [m["litellm_params"].get("rpm", 0) for m in healthy_deployments]
|
||||
verbose_router_logger.debug(f"\nrpms {rpms}")
|
||||
total_rpm = sum(rpms)
|
||||
weights = [rpm / total_rpm for rpm in rpms]
|
||||
verbose_router_logger.debug(f"\n weights {weights}")
|
||||
# Perform weighted random pick
|
||||
selected_index = random.choices(range(len(rpms)), weights=weights)[0]
|
||||
verbose_router_logger.debug(f"\n selected index, {selected_index}")
|
||||
deployment = healthy_deployments[selected_index]
|
||||
verbose_router_logger.info(
|
||||
f"get_available_deployment for model: {model}, Selected deployment: {llm_router_instance.print_deployment(deployment) or deployment[0]} for model: {model}"
|
||||
)
|
||||
return deployment or deployment[0]
|
||||
############## Check if we can do a RPM/TPM based weighted pick #################
|
||||
tpm = healthy_deployments[0].get("litellm_params").get("tpm", None)
|
||||
if tpm is not None:
|
||||
# use weight-random pick if rpms provided
|
||||
tpms = [m["litellm_params"].get("tpm", 0) for m in healthy_deployments]
|
||||
verbose_router_logger.debug(f"\ntpms {tpms}")
|
||||
total_tpm = sum(tpms)
|
||||
weights = [tpm / total_tpm for tpm in tpms]
|
||||
verbose_router_logger.debug(f"\n weights {weights}")
|
||||
# Perform weighted random pick
|
||||
selected_index = random.choices(range(len(tpms)), weights=weights)[0]
|
||||
verbose_router_logger.debug(f"\n selected index, {selected_index}")
|
||||
deployment = healthy_deployments[selected_index]
|
||||
verbose_router_logger.info(
|
||||
f"get_available_deployment for model: {model}, Selected deployment: {llm_router_instance.print_deployment(deployment) or deployment[0]} for model: {model}"
|
||||
)
|
||||
return deployment or deployment[0]
|
||||
|
||||
############## No RPM/TPM passed, we do a random pick #################
|
||||
item = random.choice(healthy_deployments)
|
||||
return item or item[0]
|
||||
@@ -0,0 +1,146 @@
|
||||
"""
|
||||
Use this to route requests between Teams
|
||||
|
||||
- If tags in request is a subset of tags in deployment, return deployment
|
||||
- if deployments are set with default tags, return all default deployment
|
||||
- If no default_deployments are set, return all deployments
|
||||
"""
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
|
||||
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.types.router import RouterErrors
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.router import Router as _Router
|
||||
|
||||
LitellmRouter = _Router
|
||||
else:
|
||||
LitellmRouter = Any
|
||||
|
||||
|
||||
def is_valid_deployment_tag(
|
||||
deployment_tags: List[str], request_tags: List[str]
|
||||
) -> bool:
|
||||
"""
|
||||
Check if a tag is valid
|
||||
"""
|
||||
|
||||
if any(tag in deployment_tags for tag in request_tags):
|
||||
verbose_logger.debug(
|
||||
"adding deployment with tags: %s, request tags: %s",
|
||||
deployment_tags,
|
||||
request_tags,
|
||||
)
|
||||
return True
|
||||
elif "default" in deployment_tags:
|
||||
verbose_logger.debug(
|
||||
"adding default deployment with tags: %s, request tags: %s",
|
||||
deployment_tags,
|
||||
request_tags,
|
||||
)
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
async def get_deployments_for_tag(
|
||||
llm_router_instance: LitellmRouter,
|
||||
model: str, # used to raise the correct error
|
||||
healthy_deployments: Union[List[Any], Dict[Any, Any]],
|
||||
request_kwargs: Optional[Dict[Any, Any]] = None,
|
||||
):
|
||||
"""
|
||||
Returns a list of deployments that match the requested model and tags in the request.
|
||||
|
||||
Executes tag based filtering based on the tags in request metadata and the tags on the deployments
|
||||
"""
|
||||
if llm_router_instance.enable_tag_filtering is not True:
|
||||
return healthy_deployments
|
||||
|
||||
if request_kwargs is None:
|
||||
verbose_logger.debug(
|
||||
"get_deployments_for_tag: request_kwargs is None returning healthy_deployments: %s",
|
||||
healthy_deployments,
|
||||
)
|
||||
return healthy_deployments
|
||||
|
||||
if healthy_deployments is None:
|
||||
verbose_logger.debug(
|
||||
"get_deployments_for_tag: healthy_deployments is None returning healthy_deployments"
|
||||
)
|
||||
return healthy_deployments
|
||||
|
||||
verbose_logger.debug("request metadata: %s", request_kwargs.get("metadata"))
|
||||
if "metadata" in request_kwargs:
|
||||
metadata = request_kwargs["metadata"]
|
||||
request_tags = metadata.get("tags")
|
||||
|
||||
new_healthy_deployments = []
|
||||
if request_tags:
|
||||
verbose_logger.debug(
|
||||
"get_deployments_for_tag routing: router_keys: %s", request_tags
|
||||
)
|
||||
# example this can be router_keys=["free", "custom"]
|
||||
# get all deployments that have a superset of these router keys
|
||||
for deployment in healthy_deployments:
|
||||
deployment_litellm_params = deployment.get("litellm_params")
|
||||
deployment_tags = deployment_litellm_params.get("tags")
|
||||
|
||||
verbose_logger.debug(
|
||||
"deployment: %s, deployment_router_keys: %s",
|
||||
deployment,
|
||||
deployment_tags,
|
||||
)
|
||||
|
||||
if deployment_tags is None:
|
||||
continue
|
||||
|
||||
if is_valid_deployment_tag(deployment_tags, request_tags):
|
||||
new_healthy_deployments.append(deployment)
|
||||
|
||||
if len(new_healthy_deployments) == 0:
|
||||
raise ValueError(
|
||||
f"{RouterErrors.no_deployments_with_tag_routing.value}. Passed model={model} and tags={request_tags}"
|
||||
)
|
||||
|
||||
return new_healthy_deployments
|
||||
|
||||
# for Untagged requests use default deployments if set
|
||||
_default_deployments_with_tags = []
|
||||
for deployment in healthy_deployments:
|
||||
if "default" in deployment.get("litellm_params", {}).get("tags", []):
|
||||
_default_deployments_with_tags.append(deployment)
|
||||
|
||||
if len(_default_deployments_with_tags) > 0:
|
||||
return _default_deployments_with_tags
|
||||
|
||||
# if no default deployment is found, return healthy_deployments
|
||||
verbose_logger.debug(
|
||||
"no tier found in metadata, returning healthy_deployments: %s",
|
||||
healthy_deployments,
|
||||
)
|
||||
return healthy_deployments
|
||||
|
||||
|
||||
def _get_tags_from_request_kwargs(
|
||||
request_kwargs: Optional[Dict[Any, Any]] = None
|
||||
) -> List[str]:
|
||||
"""
|
||||
Helper to get tags from request kwargs
|
||||
|
||||
Args:
|
||||
request_kwargs: The request kwargs to get tags from
|
||||
|
||||
Returns:
|
||||
List[str]: The tags from the request kwargs
|
||||
"""
|
||||
if request_kwargs is None:
|
||||
return []
|
||||
if "metadata" in request_kwargs:
|
||||
metadata = request_kwargs["metadata"]
|
||||
return metadata.get("tags", [])
|
||||
elif "litellm_params" in request_kwargs:
|
||||
litellm_params = request_kwargs["litellm_params"]
|
||||
_metadata = litellm_params.get("metadata", {})
|
||||
return _metadata.get("tags", [])
|
||||
return []
|
||||
Reference in New Issue
Block a user