structure saas with tools
This commit is contained in:
@@ -0,0 +1,36 @@
|
||||
from typing import Literal, Union
|
||||
|
||||
from . import *
|
||||
from .cache_control_check import _PROXY_CacheControlCheck
|
||||
from .managed_files import _PROXY_LiteLLMManagedFiles
|
||||
from .max_budget_limiter import _PROXY_MaxBudgetLimiter
|
||||
from .parallel_request_limiter import _PROXY_MaxParallelRequestsHandler
|
||||
|
||||
# List of all available hooks that can be enabled
|
||||
PROXY_HOOKS = {
|
||||
"max_budget_limiter": _PROXY_MaxBudgetLimiter,
|
||||
"managed_files": _PROXY_LiteLLMManagedFiles,
|
||||
"parallel_request_limiter": _PROXY_MaxParallelRequestsHandler,
|
||||
"cache_control_check": _PROXY_CacheControlCheck,
|
||||
}
|
||||
|
||||
|
||||
def get_proxy_hook(
|
||||
hook_name: Union[
|
||||
Literal[
|
||||
"max_budget_limiter",
|
||||
"managed_files",
|
||||
"parallel_request_limiter",
|
||||
"cache_control_check",
|
||||
],
|
||||
str,
|
||||
]
|
||||
):
|
||||
"""
|
||||
Factory method to get a proxy hook instance by name
|
||||
"""
|
||||
if hook_name not in PROXY_HOOKS:
|
||||
raise ValueError(
|
||||
f"Unknown hook: {hook_name}. Available hooks: {list(PROXY_HOOKS.keys())}"
|
||||
)
|
||||
return PROXY_HOOKS[hook_name]
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,156 @@
|
||||
import traceback
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import HTTPException
|
||||
|
||||
import litellm
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.caching.caching import DualCache
|
||||
from litellm.integrations.custom_logger import CustomLogger
|
||||
from litellm.proxy._types import UserAPIKeyAuth
|
||||
|
||||
|
||||
class _PROXY_AzureContentSafety(
|
||||
CustomLogger
|
||||
): # https://docs.litellm.ai/docs/observability/custom_callback#callback-class
|
||||
# Class variables or attributes
|
||||
|
||||
def __init__(self, endpoint, api_key, thresholds=None):
|
||||
try:
|
||||
from azure.ai.contentsafety.aio import ContentSafetyClient
|
||||
from azure.ai.contentsafety.models import (
|
||||
AnalyzeTextOptions,
|
||||
AnalyzeTextOutputType,
|
||||
TextCategory,
|
||||
)
|
||||
from azure.core.credentials import AzureKeyCredential
|
||||
from azure.core.exceptions import HttpResponseError
|
||||
except Exception as e:
|
||||
raise Exception(
|
||||
f"\033[91mAzure Content-Safety not installed, try running 'pip install azure-ai-contentsafety' to fix this error: {e}\n{traceback.format_exc()}\033[0m"
|
||||
)
|
||||
self.endpoint = endpoint
|
||||
self.api_key = api_key
|
||||
self.text_category = TextCategory
|
||||
self.analyze_text_options = AnalyzeTextOptions
|
||||
self.analyze_text_output_type = AnalyzeTextOutputType
|
||||
self.azure_http_error = HttpResponseError
|
||||
|
||||
self.thresholds = self._configure_thresholds(thresholds)
|
||||
|
||||
self.client = ContentSafetyClient(
|
||||
self.endpoint, AzureKeyCredential(self.api_key)
|
||||
)
|
||||
|
||||
def _configure_thresholds(self, thresholds=None):
|
||||
default_thresholds = {
|
||||
self.text_category.HATE: 4,
|
||||
self.text_category.SELF_HARM: 4,
|
||||
self.text_category.SEXUAL: 4,
|
||||
self.text_category.VIOLENCE: 4,
|
||||
}
|
||||
|
||||
if thresholds is None:
|
||||
return default_thresholds
|
||||
|
||||
for key, default in default_thresholds.items():
|
||||
if key not in thresholds:
|
||||
thresholds[key] = default
|
||||
|
||||
return thresholds
|
||||
|
||||
def _compute_result(self, response):
|
||||
result = {}
|
||||
|
||||
category_severity = {
|
||||
item.category: item.severity for item in response.categories_analysis
|
||||
}
|
||||
for category in self.text_category:
|
||||
severity = category_severity.get(category)
|
||||
if severity is not None:
|
||||
result[category] = {
|
||||
"filtered": severity >= self.thresholds[category],
|
||||
"severity": severity,
|
||||
}
|
||||
|
||||
return result
|
||||
|
||||
async def test_violation(self, content: str, source: Optional[str] = None):
|
||||
verbose_proxy_logger.debug("Testing Azure Content-Safety for: %s", content)
|
||||
|
||||
# Construct a request
|
||||
request = self.analyze_text_options(
|
||||
text=content,
|
||||
output_type=self.analyze_text_output_type.EIGHT_SEVERITY_LEVELS,
|
||||
)
|
||||
|
||||
# Analyze text
|
||||
try:
|
||||
response = await self.client.analyze_text(request)
|
||||
except self.azure_http_error:
|
||||
verbose_proxy_logger.debug(
|
||||
"Error in Azure Content-Safety: %s", traceback.format_exc()
|
||||
)
|
||||
verbose_proxy_logger.debug(traceback.format_exc())
|
||||
raise
|
||||
|
||||
result = self._compute_result(response)
|
||||
verbose_proxy_logger.debug("Azure Content-Safety Result: %s", result)
|
||||
|
||||
for key, value in result.items():
|
||||
if value["filtered"]:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={
|
||||
"error": "Violated content safety policy",
|
||||
"source": source,
|
||||
"category": key,
|
||||
"severity": value["severity"],
|
||||
},
|
||||
)
|
||||
|
||||
async def async_pre_call_hook(
|
||||
self,
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
cache: DualCache,
|
||||
data: dict,
|
||||
call_type: str, # "completion", "embeddings", "image_generation", "moderation"
|
||||
):
|
||||
verbose_proxy_logger.debug("Inside Azure Content-Safety Pre-Call Hook")
|
||||
try:
|
||||
if call_type == "completion" and "messages" in data:
|
||||
for m in data["messages"]:
|
||||
if "content" in m and isinstance(m["content"], str):
|
||||
await self.test_violation(content=m["content"], source="input")
|
||||
|
||||
except HTTPException as e:
|
||||
raise e
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.error(
|
||||
"litellm.proxy.hooks.azure_content_safety.py::async_pre_call_hook(): Exception occured - {}".format(
|
||||
str(e)
|
||||
)
|
||||
)
|
||||
verbose_proxy_logger.debug(traceback.format_exc())
|
||||
|
||||
async def async_post_call_success_hook(
|
||||
self,
|
||||
data: dict,
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
response,
|
||||
):
|
||||
verbose_proxy_logger.debug("Inside Azure Content-Safety Post-Call Hook")
|
||||
if isinstance(response, litellm.ModelResponse) and isinstance(
|
||||
response.choices[0], litellm.utils.Choices
|
||||
):
|
||||
await self.test_violation(
|
||||
content=response.choices[0].message.content or "", source="output"
|
||||
)
|
||||
|
||||
# async def async_post_call_streaming_hook(
|
||||
# self,
|
||||
# user_api_key_dict: UserAPIKeyAuth,
|
||||
# response: str,
|
||||
# ):
|
||||
# verbose_proxy_logger.debug("Inside Azure Content-Safety Call-Stream Hook")
|
||||
# await self.test_violation(content=response, source="output")
|
||||
@@ -0,0 +1,149 @@
|
||||
# What this does?
|
||||
## Gets a key's redis cache, and store it in memory for 1 minute.
|
||||
## This reduces the number of REDIS GET requests made during high-traffic by the proxy.
|
||||
### [BETA] this is in Beta. And might change.
|
||||
|
||||
import traceback
|
||||
from typing import Literal, Optional
|
||||
|
||||
from fastapi import HTTPException
|
||||
|
||||
import litellm
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.caching.caching import DualCache, InMemoryCache, RedisCache
|
||||
from litellm.integrations.custom_logger import CustomLogger
|
||||
from litellm.proxy._types import UserAPIKeyAuth
|
||||
|
||||
|
||||
class _PROXY_BatchRedisRequests(CustomLogger):
|
||||
# Class variables or attributes
|
||||
in_memory_cache: Optional[InMemoryCache] = None
|
||||
|
||||
def __init__(self):
|
||||
if litellm.cache is not None:
|
||||
litellm.cache.async_get_cache = (
|
||||
self.async_get_cache
|
||||
) # map the litellm 'get_cache' function to our custom function
|
||||
|
||||
def print_verbose(
|
||||
self, print_statement, debug_level: Literal["INFO", "DEBUG"] = "DEBUG"
|
||||
):
|
||||
if debug_level == "DEBUG":
|
||||
verbose_proxy_logger.debug(print_statement)
|
||||
elif debug_level == "INFO":
|
||||
verbose_proxy_logger.debug(print_statement)
|
||||
if litellm.set_verbose is True:
|
||||
print(print_statement) # noqa
|
||||
|
||||
async def async_pre_call_hook(
|
||||
self,
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
cache: DualCache,
|
||||
data: dict,
|
||||
call_type: str,
|
||||
):
|
||||
try:
|
||||
"""
|
||||
Get the user key
|
||||
|
||||
Check if a key starting with `litellm:<api_key>:<call_type:` exists in-memory
|
||||
|
||||
If no, then get relevant cache from redis
|
||||
"""
|
||||
api_key = user_api_key_dict.api_key
|
||||
|
||||
cache_key_name = f"litellm:{api_key}:{call_type}"
|
||||
self.in_memory_cache = cache.in_memory_cache
|
||||
|
||||
key_value_dict = {}
|
||||
in_memory_cache_exists = False
|
||||
for key in cache.in_memory_cache.cache_dict.keys():
|
||||
if isinstance(key, str) and key.startswith(cache_key_name):
|
||||
in_memory_cache_exists = True
|
||||
|
||||
if in_memory_cache_exists is False and litellm.cache is not None:
|
||||
"""
|
||||
- Check if `litellm.Cache` is redis
|
||||
- Get the relevant values
|
||||
"""
|
||||
if litellm.cache.type is not None and isinstance(
|
||||
litellm.cache.cache, RedisCache
|
||||
):
|
||||
# Initialize an empty list to store the keys
|
||||
keys = []
|
||||
self.print_verbose(f"cache_key_name: {cache_key_name}")
|
||||
# Use the SCAN iterator to fetch keys matching the pattern
|
||||
keys = await litellm.cache.cache.async_scan_iter(
|
||||
pattern=cache_key_name, count=100
|
||||
)
|
||||
# If you need the truly "last" based on time or another criteria,
|
||||
# ensure your key naming or storage strategy allows this determination
|
||||
# Here you would sort or filter the keys as needed based on your strategy
|
||||
self.print_verbose(f"redis keys: {keys}")
|
||||
if len(keys) > 0:
|
||||
key_value_dict = (
|
||||
await litellm.cache.cache.async_batch_get_cache(
|
||||
key_list=keys
|
||||
)
|
||||
)
|
||||
|
||||
## Add to cache
|
||||
if len(key_value_dict.items()) > 0:
|
||||
await cache.in_memory_cache.async_set_cache_pipeline(
|
||||
cache_list=list(key_value_dict.items()), ttl=60
|
||||
)
|
||||
## Set cache namespace if it's a miss
|
||||
data["metadata"]["redis_namespace"] = cache_key_name
|
||||
except HTTPException as e:
|
||||
raise e
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.error(
|
||||
"litellm.proxy.hooks.batch_redis_get.py::async_pre_call_hook(): Exception occured - {}".format(
|
||||
str(e)
|
||||
)
|
||||
)
|
||||
verbose_proxy_logger.debug(traceback.format_exc())
|
||||
|
||||
async def async_get_cache(self, *args, **kwargs):
|
||||
"""
|
||||
- Check if the cache key is in-memory
|
||||
|
||||
- Else:
|
||||
- add missing cache key from REDIS
|
||||
- update in-memory cache
|
||||
- return redis cache request
|
||||
"""
|
||||
try: # never block execution
|
||||
cache_key: Optional[str] = None
|
||||
if "cache_key" in kwargs:
|
||||
cache_key = kwargs["cache_key"]
|
||||
elif litellm.cache is not None:
|
||||
cache_key = litellm.cache.get_cache_key(
|
||||
*args, **kwargs
|
||||
) # returns "<cache_key_name>:<hash>" - we pass redis_namespace in async_pre_call_hook. Done to avoid rewriting the async_set_cache logic
|
||||
|
||||
if (
|
||||
cache_key is not None
|
||||
and self.in_memory_cache is not None
|
||||
and litellm.cache is not None
|
||||
):
|
||||
cache_control_args = kwargs.get("cache", {})
|
||||
max_age = cache_control_args.get(
|
||||
"s-max-age", cache_control_args.get("s-maxage", float("inf"))
|
||||
)
|
||||
cached_result = self.in_memory_cache.get_cache(
|
||||
cache_key, *args, **kwargs
|
||||
)
|
||||
if cached_result is None:
|
||||
cached_result = await litellm.cache.cache.async_get_cache(
|
||||
cache_key, *args, **kwargs
|
||||
)
|
||||
if cached_result is not None:
|
||||
await self.in_memory_cache.async_set_cache(
|
||||
cache_key, cached_result, ttl=60
|
||||
)
|
||||
return litellm.cache._get_cache_logic(
|
||||
cached_result=cached_result, max_age=max_age
|
||||
)
|
||||
except Exception:
|
||||
return None
|
||||
@@ -0,0 +1,58 @@
|
||||
# What this does?
|
||||
## Checks if key is allowed to use the cache controls passed in to the completion() call
|
||||
|
||||
|
||||
from fastapi import HTTPException
|
||||
|
||||
from litellm import verbose_logger
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.caching.caching import DualCache
|
||||
from litellm.integrations.custom_logger import CustomLogger
|
||||
from litellm.proxy._types import UserAPIKeyAuth
|
||||
|
||||
|
||||
class _PROXY_CacheControlCheck(CustomLogger):
|
||||
# Class variables or attributes
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
async def async_pre_call_hook(
|
||||
self,
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
cache: DualCache,
|
||||
data: dict,
|
||||
call_type: str,
|
||||
):
|
||||
try:
|
||||
verbose_proxy_logger.debug("Inside Cache Control Check Pre-Call Hook")
|
||||
allowed_cache_controls = user_api_key_dict.allowed_cache_controls
|
||||
|
||||
if data.get("cache", None) is None:
|
||||
return
|
||||
|
||||
cache_args = data.get("cache", None)
|
||||
if isinstance(cache_args, dict):
|
||||
for k, v in cache_args.items():
|
||||
if (
|
||||
(allowed_cache_controls is not None)
|
||||
and (isinstance(allowed_cache_controls, list))
|
||||
and (
|
||||
len(allowed_cache_controls) > 0
|
||||
) # assume empty list to be nullable - https://github.com/prisma/prisma/issues/847#issuecomment-546895663
|
||||
and k not in allowed_cache_controls
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail=f"Not allowed to set {k} as a cache control. Contact admin to change permissions.",
|
||||
)
|
||||
else: # invalid cache
|
||||
return
|
||||
|
||||
except HTTPException as e:
|
||||
raise e
|
||||
except Exception as e:
|
||||
verbose_logger.exception(
|
||||
"litellm.proxy.hooks.cache_control_check.py::async_pre_call_hook(): Exception occured - {}".format(
|
||||
str(e)
|
||||
)
|
||||
)
|
||||
@@ -0,0 +1,306 @@
|
||||
# What is this?
|
||||
## Allocates dynamic tpm/rpm quota for a project based on current traffic
|
||||
## Tracks num active projects per minute
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
from typing import List, Literal, Optional, Tuple, Union
|
||||
|
||||
from fastapi import HTTPException
|
||||
|
||||
import litellm
|
||||
from litellm import ModelResponse, Router
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.caching.caching import DualCache
|
||||
from litellm.integrations.custom_logger import CustomLogger
|
||||
from litellm.proxy._types import UserAPIKeyAuth
|
||||
from litellm.types.router import ModelGroupInfo
|
||||
from litellm.utils import get_utc_datetime
|
||||
|
||||
|
||||
class DynamicRateLimiterCache:
|
||||
"""
|
||||
Thin wrapper on DualCache for this file.
|
||||
|
||||
Track number of active projects calling a model.
|
||||
"""
|
||||
|
||||
def __init__(self, cache: DualCache) -> None:
|
||||
self.cache = cache
|
||||
self.ttl = 60 # 1 min ttl
|
||||
|
||||
async def async_get_cache(self, model: str) -> Optional[int]:
|
||||
dt = get_utc_datetime()
|
||||
current_minute = dt.strftime("%H-%M")
|
||||
key_name = "{}:{}".format(current_minute, model)
|
||||
_response = await self.cache.async_get_cache(key=key_name)
|
||||
response: Optional[int] = None
|
||||
if _response is not None:
|
||||
response = len(_response)
|
||||
return response
|
||||
|
||||
async def async_set_cache_sadd(self, model: str, value: List):
|
||||
"""
|
||||
Add value to set.
|
||||
|
||||
Parameters:
|
||||
- model: str, the name of the model group
|
||||
- value: str, the team id
|
||||
|
||||
Returns:
|
||||
- None
|
||||
|
||||
Raises:
|
||||
- Exception, if unable to connect to cache client (if redis caching enabled)
|
||||
"""
|
||||
try:
|
||||
dt = get_utc_datetime()
|
||||
current_minute = dt.strftime("%H-%M")
|
||||
|
||||
key_name = "{}:{}".format(current_minute, model)
|
||||
await self.cache.async_set_cache_sadd(
|
||||
key=key_name, value=value, ttl=self.ttl
|
||||
)
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.exception(
|
||||
"litellm.proxy.hooks.dynamic_rate_limiter.py::async_set_cache_sadd(): Exception occured - {}".format(
|
||||
str(e)
|
||||
)
|
||||
)
|
||||
raise e
|
||||
|
||||
|
||||
class _PROXY_DynamicRateLimitHandler(CustomLogger):
|
||||
# Class variables or attributes
|
||||
def __init__(self, internal_usage_cache: DualCache):
|
||||
self.internal_usage_cache = DynamicRateLimiterCache(cache=internal_usage_cache)
|
||||
|
||||
def update_variables(self, llm_router: Router):
|
||||
self.llm_router = llm_router
|
||||
|
||||
async def check_available_usage(
|
||||
self, model: str, priority: Optional[str] = None
|
||||
) -> Tuple[
|
||||
Optional[int], Optional[int], Optional[int], Optional[int], Optional[int]
|
||||
]:
|
||||
"""
|
||||
For a given model, get its available tpm
|
||||
|
||||
Params:
|
||||
- model: str, the name of the model in the router model_list
|
||||
- priority: Optional[str], the priority for the request.
|
||||
|
||||
Returns
|
||||
- Tuple[available_tpm, available_tpm, model_tpm, model_rpm, active_projects]
|
||||
- available_tpm: int or null - always 0 or positive.
|
||||
- available_tpm: int or null - always 0 or positive.
|
||||
- remaining_model_tpm: int or null. If available tpm is int, then this will be too.
|
||||
- remaining_model_rpm: int or null. If available rpm is int, then this will be too.
|
||||
- active_projects: int or null
|
||||
"""
|
||||
try:
|
||||
weight: float = 1
|
||||
if (
|
||||
litellm.priority_reservation is None
|
||||
or priority not in litellm.priority_reservation
|
||||
):
|
||||
verbose_proxy_logger.error(
|
||||
"Priority Reservation not set. priority={}, but litellm.priority_reservation is {}.".format(
|
||||
priority, litellm.priority_reservation
|
||||
)
|
||||
)
|
||||
elif priority is not None and litellm.priority_reservation is not None:
|
||||
if os.getenv("LITELLM_LICENSE", None) is None:
|
||||
verbose_proxy_logger.error(
|
||||
"PREMIUM FEATURE: Reserving tpm/rpm by priority is a premium feature. Please add a 'LITELLM_LICENSE' to your .env to enable this.\nGet a license: https://docs.litellm.ai/docs/proxy/enterprise."
|
||||
)
|
||||
else:
|
||||
weight = litellm.priority_reservation[priority]
|
||||
|
||||
active_projects = await self.internal_usage_cache.async_get_cache(
|
||||
model=model
|
||||
)
|
||||
(
|
||||
current_model_tpm,
|
||||
current_model_rpm,
|
||||
) = await self.llm_router.get_model_group_usage(model_group=model)
|
||||
model_group_info: Optional[
|
||||
ModelGroupInfo
|
||||
] = self.llm_router.get_model_group_info(model_group=model)
|
||||
total_model_tpm: Optional[int] = None
|
||||
total_model_rpm: Optional[int] = None
|
||||
if model_group_info is not None:
|
||||
if model_group_info.tpm is not None:
|
||||
total_model_tpm = model_group_info.tpm
|
||||
if model_group_info.rpm is not None:
|
||||
total_model_rpm = model_group_info.rpm
|
||||
|
||||
remaining_model_tpm: Optional[int] = None
|
||||
if total_model_tpm is not None and current_model_tpm is not None:
|
||||
remaining_model_tpm = total_model_tpm - current_model_tpm
|
||||
elif total_model_tpm is not None:
|
||||
remaining_model_tpm = total_model_tpm
|
||||
|
||||
remaining_model_rpm: Optional[int] = None
|
||||
if total_model_rpm is not None and current_model_rpm is not None:
|
||||
remaining_model_rpm = total_model_rpm - current_model_rpm
|
||||
elif total_model_rpm is not None:
|
||||
remaining_model_rpm = total_model_rpm
|
||||
|
||||
available_tpm: Optional[int] = None
|
||||
|
||||
if remaining_model_tpm is not None:
|
||||
if active_projects is not None:
|
||||
available_tpm = int(remaining_model_tpm * weight / active_projects)
|
||||
else:
|
||||
available_tpm = int(remaining_model_tpm * weight)
|
||||
|
||||
if available_tpm is not None and available_tpm < 0:
|
||||
available_tpm = 0
|
||||
|
||||
available_rpm: Optional[int] = None
|
||||
|
||||
if remaining_model_rpm is not None:
|
||||
if active_projects is not None:
|
||||
available_rpm = int(remaining_model_rpm * weight / active_projects)
|
||||
else:
|
||||
available_rpm = int(remaining_model_rpm * weight)
|
||||
|
||||
if available_rpm is not None and available_rpm < 0:
|
||||
available_rpm = 0
|
||||
return (
|
||||
available_tpm,
|
||||
available_rpm,
|
||||
remaining_model_tpm,
|
||||
remaining_model_rpm,
|
||||
active_projects,
|
||||
)
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.exception(
|
||||
"litellm.proxy.hooks.dynamic_rate_limiter.py::check_available_usage: Exception occurred - {}".format(
|
||||
str(e)
|
||||
)
|
||||
)
|
||||
return None, None, None, None, None
|
||||
|
||||
async def async_pre_call_hook(
|
||||
self,
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
cache: DualCache,
|
||||
data: dict,
|
||||
call_type: Literal[
|
||||
"completion",
|
||||
"text_completion",
|
||||
"embeddings",
|
||||
"image_generation",
|
||||
"moderation",
|
||||
"audio_transcription",
|
||||
"pass_through_endpoint",
|
||||
"rerank",
|
||||
],
|
||||
) -> Optional[
|
||||
Union[Exception, str, dict]
|
||||
]: # raise exception if invalid, return a str for the user to receive - if rejected, or return a modified dictionary for passing into litellm
|
||||
"""
|
||||
- For a model group
|
||||
- Check if tpm/rpm available
|
||||
- Raise RateLimitError if no tpm/rpm available
|
||||
"""
|
||||
if "model" in data:
|
||||
key_priority: Optional[str] = user_api_key_dict.metadata.get(
|
||||
"priority", None
|
||||
)
|
||||
(
|
||||
available_tpm,
|
||||
available_rpm,
|
||||
model_tpm,
|
||||
model_rpm,
|
||||
active_projects,
|
||||
) = await self.check_available_usage(
|
||||
model=data["model"], priority=key_priority
|
||||
)
|
||||
### CHECK TPM ###
|
||||
if available_tpm is not None and available_tpm == 0:
|
||||
raise HTTPException(
|
||||
status_code=429,
|
||||
detail={
|
||||
"error": "Key={} over available TPM={}. Model TPM={}, Active keys={}".format(
|
||||
user_api_key_dict.api_key,
|
||||
available_tpm,
|
||||
model_tpm,
|
||||
active_projects,
|
||||
)
|
||||
},
|
||||
)
|
||||
### CHECK RPM ###
|
||||
elif available_rpm is not None and available_rpm == 0:
|
||||
raise HTTPException(
|
||||
status_code=429,
|
||||
detail={
|
||||
"error": "Key={} over available RPM={}. Model RPM={}, Active keys={}".format(
|
||||
user_api_key_dict.api_key,
|
||||
available_rpm,
|
||||
model_rpm,
|
||||
active_projects,
|
||||
)
|
||||
},
|
||||
)
|
||||
elif available_rpm is not None or available_tpm is not None:
|
||||
## UPDATE CACHE WITH ACTIVE PROJECT
|
||||
asyncio.create_task(
|
||||
self.internal_usage_cache.async_set_cache_sadd( # this is a set
|
||||
model=data["model"], # type: ignore
|
||||
value=[user_api_key_dict.token or "default_key"],
|
||||
)
|
||||
)
|
||||
return None
|
||||
|
||||
async def async_post_call_success_hook(
|
||||
self, data: dict, user_api_key_dict: UserAPIKeyAuth, response
|
||||
):
|
||||
try:
|
||||
if isinstance(response, ModelResponse):
|
||||
model_info = self.llm_router.get_model_info(
|
||||
id=response._hidden_params["model_id"]
|
||||
)
|
||||
assert (
|
||||
model_info is not None
|
||||
), "Model info for model with id={} is None".format(
|
||||
response._hidden_params["model_id"]
|
||||
)
|
||||
key_priority: Optional[str] = user_api_key_dict.metadata.get(
|
||||
"priority", None
|
||||
)
|
||||
(
|
||||
available_tpm,
|
||||
available_rpm,
|
||||
model_tpm,
|
||||
model_rpm,
|
||||
active_projects,
|
||||
) = await self.check_available_usage(
|
||||
model=model_info["model_name"], priority=key_priority
|
||||
)
|
||||
response._hidden_params[
|
||||
"additional_headers"
|
||||
] = { # Add additional response headers - easier debugging
|
||||
"x-litellm-model_group": model_info["model_name"],
|
||||
"x-ratelimit-remaining-litellm-project-tokens": available_tpm,
|
||||
"x-ratelimit-remaining-litellm-project-requests": available_rpm,
|
||||
"x-ratelimit-remaining-model-tokens": model_tpm,
|
||||
"x-ratelimit-remaining-model-requests": model_rpm,
|
||||
"x-ratelimit-current-active-projects": active_projects,
|
||||
}
|
||||
|
||||
return response
|
||||
return await super().async_post_call_success_hook(
|
||||
data=data,
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
response=response,
|
||||
)
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.exception(
|
||||
"litellm.proxy.hooks.dynamic_rate_limiter.py::async_post_call_success_hook(): Exception occured - {}".format(
|
||||
str(e)
|
||||
)
|
||||
)
|
||||
return response
|
||||
@@ -0,0 +1,28 @@
|
||||
[
|
||||
{
|
||||
"name": "Zip code Recognizer",
|
||||
"supported_language": "en",
|
||||
"patterns": [
|
||||
{
|
||||
"name": "zip code (weak)",
|
||||
"regex": "(\\b\\d{5}(?:\\-\\d{4})?\\b)",
|
||||
"score": 0.01
|
||||
}
|
||||
],
|
||||
"context": ["zip", "code"],
|
||||
"supported_entity": "ZIP"
|
||||
},
|
||||
{
|
||||
"name": "Swiss AHV Number Recognizer",
|
||||
"supported_language": "en",
|
||||
"patterns": [
|
||||
{
|
||||
"name": "AHV number (strong)",
|
||||
"regex": "(756\\.\\d{4}\\.\\d{4}\\.\\d{2})|(756\\d{10})",
|
||||
"score": 0.95
|
||||
}
|
||||
],
|
||||
"context": ["AHV", "social security", "Swiss"],
|
||||
"supported_entity": "AHV_NUMBER"
|
||||
}
|
||||
]
|
||||
@@ -0,0 +1,323 @@
|
||||
import asyncio
|
||||
import json
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, List, Optional
|
||||
|
||||
from fastapi import status
|
||||
|
||||
import litellm
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.proxy._types import (
|
||||
GenerateKeyRequest,
|
||||
GenerateKeyResponse,
|
||||
KeyRequest,
|
||||
LiteLLM_AuditLogs,
|
||||
LiteLLM_VerificationToken,
|
||||
LitellmTableNames,
|
||||
ProxyErrorTypes,
|
||||
ProxyException,
|
||||
RegenerateKeyRequest,
|
||||
UpdateKeyRequest,
|
||||
UserAPIKeyAuth,
|
||||
WebhookEvent,
|
||||
)
|
||||
|
||||
# NOTE: This is the prefix for all virtual keys stored in AWS Secrets Manager
|
||||
LITELLM_PREFIX_STORED_VIRTUAL_KEYS = "litellm/"
|
||||
|
||||
|
||||
class KeyManagementEventHooks:
|
||||
@staticmethod
|
||||
async def async_key_generated_hook(
|
||||
data: GenerateKeyRequest,
|
||||
response: GenerateKeyResponse,
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
litellm_changed_by: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
Hook that runs after a successful /key/generate request
|
||||
|
||||
Handles the following:
|
||||
- Sending Email with Key Details
|
||||
- Storing Audit Logs for key generation
|
||||
- Storing Generated Key in DB
|
||||
"""
|
||||
from litellm.proxy.management_helpers.audit_logs import (
|
||||
create_audit_log_for_update,
|
||||
)
|
||||
from litellm.proxy.proxy_server import litellm_proxy_admin_name
|
||||
|
||||
if data.send_invite_email is True:
|
||||
await KeyManagementEventHooks._send_key_created_email(
|
||||
response.model_dump(exclude_none=True)
|
||||
)
|
||||
|
||||
# Enterprise Feature - Audit Logging. Enable with litellm.store_audit_logs = True
|
||||
if litellm.store_audit_logs is True:
|
||||
_updated_values = response.model_dump_json(exclude_none=True)
|
||||
asyncio.create_task(
|
||||
create_audit_log_for_update(
|
||||
request_data=LiteLLM_AuditLogs(
|
||||
id=str(uuid.uuid4()),
|
||||
updated_at=datetime.now(timezone.utc),
|
||||
changed_by=litellm_changed_by
|
||||
or user_api_key_dict.user_id
|
||||
or litellm_proxy_admin_name,
|
||||
changed_by_api_key=user_api_key_dict.api_key,
|
||||
table_name=LitellmTableNames.KEY_TABLE_NAME,
|
||||
object_id=response.token_id or "",
|
||||
action="created",
|
||||
updated_values=_updated_values,
|
||||
before_value=None,
|
||||
)
|
||||
)
|
||||
)
|
||||
# store the generated key in the secret manager
|
||||
await KeyManagementEventHooks._store_virtual_key_in_secret_manager(
|
||||
secret_name=data.key_alias or f"virtual-key-{response.token_id}",
|
||||
secret_token=response.key,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def async_key_updated_hook(
|
||||
data: UpdateKeyRequest,
|
||||
existing_key_row: Any,
|
||||
response: Any,
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
litellm_changed_by: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
Post /key/update processing hook
|
||||
|
||||
Handles the following:
|
||||
- Storing Audit Logs for key update
|
||||
"""
|
||||
from litellm.proxy.management_helpers.audit_logs import (
|
||||
create_audit_log_for_update,
|
||||
)
|
||||
from litellm.proxy.proxy_server import litellm_proxy_admin_name
|
||||
|
||||
# Enterprise Feature - Audit Logging. Enable with litellm.store_audit_logs = True
|
||||
if litellm.store_audit_logs is True:
|
||||
_updated_values = json.dumps(data.json(exclude_none=True), default=str)
|
||||
|
||||
_before_value = existing_key_row.json(exclude_none=True)
|
||||
_before_value = json.dumps(_before_value, default=str)
|
||||
|
||||
asyncio.create_task(
|
||||
create_audit_log_for_update(
|
||||
request_data=LiteLLM_AuditLogs(
|
||||
id=str(uuid.uuid4()),
|
||||
updated_at=datetime.now(timezone.utc),
|
||||
changed_by=litellm_changed_by
|
||||
or user_api_key_dict.user_id
|
||||
or litellm_proxy_admin_name,
|
||||
changed_by_api_key=user_api_key_dict.api_key,
|
||||
table_name=LitellmTableNames.KEY_TABLE_NAME,
|
||||
object_id=data.key,
|
||||
action="updated",
|
||||
updated_values=_updated_values,
|
||||
before_value=_before_value,
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def async_key_rotated_hook(
|
||||
data: Optional[RegenerateKeyRequest],
|
||||
existing_key_row: Any,
|
||||
response: GenerateKeyResponse,
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
litellm_changed_by: Optional[str] = None,
|
||||
):
|
||||
# store the generated key in the secret manager
|
||||
if data is not None and response.token_id is not None:
|
||||
initial_secret_name = (
|
||||
existing_key_row.key_alias or f"virtual-key-{existing_key_row.token}"
|
||||
)
|
||||
await KeyManagementEventHooks._rotate_virtual_key_in_secret_manager(
|
||||
current_secret_name=initial_secret_name,
|
||||
new_secret_name=data.key_alias or f"virtual-key-{response.token_id}",
|
||||
new_secret_value=response.key,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def async_key_deleted_hook(
|
||||
data: KeyRequest,
|
||||
keys_being_deleted: List[LiteLLM_VerificationToken],
|
||||
response: dict,
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
litellm_changed_by: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
Post /key/delete processing hook
|
||||
|
||||
Handles the following:
|
||||
- Storing Audit Logs for key deletion
|
||||
"""
|
||||
from litellm.proxy.management_helpers.audit_logs import (
|
||||
create_audit_log_for_update,
|
||||
)
|
||||
from litellm.proxy.proxy_server import litellm_proxy_admin_name, prisma_client
|
||||
|
||||
# Enterprise Feature - Audit Logging. Enable with litellm.store_audit_logs = True
|
||||
# we do this after the first for loop, since first for loop is for validation. we only want this inserted after validation passes
|
||||
if litellm.store_audit_logs is True and data.keys is not None:
|
||||
# make an audit log for each team deleted
|
||||
for key in data.keys:
|
||||
key_row = await prisma_client.get_data( # type: ignore
|
||||
token=key, table_name="key", query_type="find_unique"
|
||||
)
|
||||
|
||||
if key_row is None:
|
||||
raise ProxyException(
|
||||
message=f"Key {key} not found",
|
||||
type=ProxyErrorTypes.bad_request_error,
|
||||
param="key",
|
||||
code=status.HTTP_404_NOT_FOUND,
|
||||
)
|
||||
|
||||
key_row = key_row.json(exclude_none=True)
|
||||
_key_row = json.dumps(key_row, default=str)
|
||||
|
||||
asyncio.create_task(
|
||||
create_audit_log_for_update(
|
||||
request_data=LiteLLM_AuditLogs(
|
||||
id=str(uuid.uuid4()),
|
||||
updated_at=datetime.now(timezone.utc),
|
||||
changed_by=litellm_changed_by
|
||||
or user_api_key_dict.user_id
|
||||
or litellm_proxy_admin_name,
|
||||
changed_by_api_key=user_api_key_dict.api_key,
|
||||
table_name=LitellmTableNames.KEY_TABLE_NAME,
|
||||
object_id=key,
|
||||
action="deleted",
|
||||
updated_values="{}",
|
||||
before_value=_key_row,
|
||||
)
|
||||
)
|
||||
)
|
||||
# delete the keys from the secret manager
|
||||
await KeyManagementEventHooks._delete_virtual_keys_from_secret_manager(
|
||||
keys_being_deleted=keys_being_deleted
|
||||
)
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
async def _store_virtual_key_in_secret_manager(secret_name: str, secret_token: str):
|
||||
"""
|
||||
Store a virtual key in the secret manager
|
||||
|
||||
Args:
|
||||
secret_name: Name of the virtual key
|
||||
secret_token: Value of the virtual key (example: sk-1234)
|
||||
"""
|
||||
if litellm._key_management_settings is not None:
|
||||
if litellm._key_management_settings.store_virtual_keys is True:
|
||||
from litellm.secret_managers.base_secret_manager import (
|
||||
BaseSecretManager,
|
||||
)
|
||||
|
||||
# store the key in the secret manager
|
||||
if isinstance(litellm.secret_manager_client, BaseSecretManager):
|
||||
await litellm.secret_manager_client.async_write_secret(
|
||||
secret_name=KeyManagementEventHooks._get_secret_name(
|
||||
secret_name
|
||||
),
|
||||
secret_value=secret_token,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def _rotate_virtual_key_in_secret_manager(
|
||||
current_secret_name: str, new_secret_name: str, new_secret_value: str
|
||||
):
|
||||
"""
|
||||
Update a virtual key in the secret manager
|
||||
|
||||
Args:
|
||||
secret_name: Name of the virtual key
|
||||
secret_token: Value of the virtual key (example: sk-1234)
|
||||
"""
|
||||
if litellm._key_management_settings is not None:
|
||||
if litellm._key_management_settings.store_virtual_keys is True:
|
||||
from litellm.secret_managers.base_secret_manager import (
|
||||
BaseSecretManager,
|
||||
)
|
||||
|
||||
# store the key in the secret manager
|
||||
if isinstance(litellm.secret_manager_client, BaseSecretManager):
|
||||
await litellm.secret_manager_client.async_rotate_secret(
|
||||
current_secret_name=KeyManagementEventHooks._get_secret_name(
|
||||
current_secret_name
|
||||
),
|
||||
new_secret_name=KeyManagementEventHooks._get_secret_name(
|
||||
new_secret_name
|
||||
),
|
||||
new_secret_value=new_secret_value,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _get_secret_name(secret_name: str) -> str:
|
||||
if litellm._key_management_settings.prefix_for_stored_virtual_keys.endswith(
|
||||
"/"
|
||||
):
|
||||
return f"{litellm._key_management_settings.prefix_for_stored_virtual_keys}{secret_name}"
|
||||
else:
|
||||
return f"{litellm._key_management_settings.prefix_for_stored_virtual_keys}/{secret_name}"
|
||||
|
||||
@staticmethod
|
||||
async def _delete_virtual_keys_from_secret_manager(
|
||||
keys_being_deleted: List[LiteLLM_VerificationToken],
|
||||
):
|
||||
"""
|
||||
Deletes virtual keys from the secret manager
|
||||
|
||||
Args:
|
||||
keys_being_deleted: List of keys being deleted, this is passed down from the /key/delete operation
|
||||
"""
|
||||
if litellm._key_management_settings is not None:
|
||||
if litellm._key_management_settings.store_virtual_keys is True:
|
||||
from litellm.secret_managers.base_secret_manager import (
|
||||
BaseSecretManager,
|
||||
)
|
||||
|
||||
if isinstance(litellm.secret_manager_client, BaseSecretManager):
|
||||
for key in keys_being_deleted:
|
||||
if key.key_alias is not None:
|
||||
await litellm.secret_manager_client.async_delete_secret(
|
||||
secret_name=KeyManagementEventHooks._get_secret_name(
|
||||
key.key_alias
|
||||
)
|
||||
)
|
||||
else:
|
||||
verbose_proxy_logger.warning(
|
||||
f"KeyManagementEventHooks._delete_virtual_key_from_secret_manager: Key alias not found for key {key.token}. Skipping deletion from secret manager."
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def _send_key_created_email(response: dict):
|
||||
from litellm.proxy.proxy_server import general_settings, proxy_logging_obj
|
||||
|
||||
if "email" not in general_settings.get("alerting", []):
|
||||
raise ValueError(
|
||||
"Email alerting not setup on config.yaml. Please set `alerting=['email']. \nDocs: https://docs.litellm.ai/docs/proxy/email`"
|
||||
)
|
||||
event = WebhookEvent(
|
||||
event="key_created",
|
||||
event_group="key",
|
||||
event_message="API Key Created",
|
||||
token=response.get("token", ""),
|
||||
spend=response.get("spend", 0.0),
|
||||
max_budget=response.get("max_budget", 0.0),
|
||||
user_id=response.get("user_id", None),
|
||||
team_id=response.get("team_id", "Default Team"),
|
||||
key_alias=response.get("key_alias", None),
|
||||
)
|
||||
|
||||
# If user configured email alerting - send an Email letting their end-user know the key was created
|
||||
asyncio.create_task(
|
||||
proxy_logging_obj.slack_alerting_instance.send_key_created_or_user_invited_email(
|
||||
webhook_event=event,
|
||||
)
|
||||
)
|
||||
@@ -0,0 +1,429 @@
|
||||
# What is this?
|
||||
## This hook is used to check for LiteLLM managed files in the request body, and replace them with model-specific file id
|
||||
|
||||
import base64
|
||||
import json
|
||||
import uuid
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Union, cast
|
||||
|
||||
from litellm import Router, verbose_logger
|
||||
from litellm.caching.caching import DualCache
|
||||
from litellm.integrations.custom_logger import CustomLogger
|
||||
from litellm.litellm_core_utils.prompt_templates.common_utils import extract_file_data
|
||||
from litellm.proxy._types import CallTypes, LiteLLM_ManagedFileTable, UserAPIKeyAuth
|
||||
from litellm.types.llms.openai import (
|
||||
AllMessageValues,
|
||||
ChatCompletionFileObject,
|
||||
CreateFileRequest,
|
||||
OpenAIFileObject,
|
||||
OpenAIFilesPurpose,
|
||||
)
|
||||
from litellm.types.utils import SpecialEnums
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from opentelemetry.trace import Span as _Span
|
||||
|
||||
from litellm.proxy.utils import InternalUsageCache as _InternalUsageCache
|
||||
from litellm.proxy.utils import PrismaClient as _PrismaClient
|
||||
|
||||
Span = Union[_Span, Any]
|
||||
InternalUsageCache = _InternalUsageCache
|
||||
PrismaClient = _PrismaClient
|
||||
else:
|
||||
Span = Any
|
||||
InternalUsageCache = Any
|
||||
PrismaClient = Any
|
||||
|
||||
|
||||
class BaseFileEndpoints(ABC):
|
||||
@abstractmethod
|
||||
async def afile_retrieve(
|
||||
self,
|
||||
file_id: str,
|
||||
litellm_parent_otel_span: Optional[Span],
|
||||
) -> OpenAIFileObject:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def afile_list(
|
||||
self, custom_llm_provider: str, **data: dict
|
||||
) -> List[OpenAIFileObject]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def afile_delete(
|
||||
self, custom_llm_provider: str, file_id: str, **data: dict
|
||||
) -> OpenAIFileObject:
|
||||
pass
|
||||
|
||||
|
||||
class _PROXY_LiteLLMManagedFiles(CustomLogger):
|
||||
# Class variables or attributes
|
||||
def __init__(
|
||||
self, internal_usage_cache: InternalUsageCache, prisma_client: PrismaClient
|
||||
):
|
||||
self.internal_usage_cache = internal_usage_cache
|
||||
self.prisma_client = prisma_client
|
||||
|
||||
async def store_unified_file_id(
|
||||
self,
|
||||
file_id: str,
|
||||
file_object: OpenAIFileObject,
|
||||
litellm_parent_otel_span: Optional[Span],
|
||||
model_mappings: Dict[str, str],
|
||||
) -> None:
|
||||
verbose_logger.info(
|
||||
f"Storing LiteLLM Managed File object with id={file_id} in cache"
|
||||
)
|
||||
litellm_managed_file_object = LiteLLM_ManagedFileTable(
|
||||
unified_file_id=file_id,
|
||||
file_object=file_object,
|
||||
model_mappings=model_mappings,
|
||||
)
|
||||
await self.internal_usage_cache.async_set_cache(
|
||||
key=file_id,
|
||||
value=litellm_managed_file_object.model_dump(),
|
||||
litellm_parent_otel_span=litellm_parent_otel_span,
|
||||
)
|
||||
|
||||
await self.prisma_client.db.litellm_managedfiletable.create(
|
||||
data={
|
||||
"unified_file_id": file_id,
|
||||
"file_object": file_object.model_dump_json(),
|
||||
"model_mappings": json.dumps(model_mappings),
|
||||
}
|
||||
)
|
||||
|
||||
async def get_unified_file_id(
|
||||
self, file_id: str, litellm_parent_otel_span: Optional[Span] = None
|
||||
) -> Optional[LiteLLM_ManagedFileTable]:
|
||||
## CHECK CACHE
|
||||
result = cast(
|
||||
Optional[dict],
|
||||
await self.internal_usage_cache.async_get_cache(
|
||||
key=file_id,
|
||||
litellm_parent_otel_span=litellm_parent_otel_span,
|
||||
),
|
||||
)
|
||||
|
||||
if result:
|
||||
return LiteLLM_ManagedFileTable(**result)
|
||||
|
||||
## CHECK DB
|
||||
db_object = await self.prisma_client.db.litellm_managedfiletable.find_first(
|
||||
where={"unified_file_id": file_id}
|
||||
)
|
||||
|
||||
if db_object:
|
||||
return LiteLLM_ManagedFileTable(**db_object.model_dump())
|
||||
return None
|
||||
|
||||
async def delete_unified_file_id(
|
||||
self, file_id: str, litellm_parent_otel_span: Optional[Span] = None
|
||||
) -> OpenAIFileObject:
|
||||
## get old value
|
||||
initial_value = await self.prisma_client.db.litellm_managedfiletable.find_first(
|
||||
where={"unified_file_id": file_id}
|
||||
)
|
||||
if initial_value is None:
|
||||
raise Exception(f"LiteLLM Managed File object with id={file_id} not found")
|
||||
## delete old value
|
||||
await self.internal_usage_cache.async_set_cache(
|
||||
key=file_id,
|
||||
value=None,
|
||||
litellm_parent_otel_span=litellm_parent_otel_span,
|
||||
)
|
||||
await self.prisma_client.db.litellm_managedfiletable.delete(
|
||||
where={"unified_file_id": file_id}
|
||||
)
|
||||
return initial_value.file_object
|
||||
|
||||
async def async_pre_call_hook(
|
||||
self,
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
cache: DualCache,
|
||||
data: Dict,
|
||||
call_type: Literal[
|
||||
"completion",
|
||||
"text_completion",
|
||||
"embeddings",
|
||||
"image_generation",
|
||||
"moderation",
|
||||
"audio_transcription",
|
||||
"pass_through_endpoint",
|
||||
"rerank",
|
||||
],
|
||||
) -> Union[Exception, str, Dict, None]:
|
||||
"""
|
||||
- Detect litellm_proxy/ file_id
|
||||
- add dictionary of mappings of litellm_proxy/ file_id -> provider_file_id => {litellm_proxy/file_id: {"model_id": id, "file_id": provider_file_id}}
|
||||
"""
|
||||
if call_type == CallTypes.completion.value:
|
||||
messages = data.get("messages")
|
||||
if messages:
|
||||
file_ids = self.get_file_ids_from_messages(messages)
|
||||
if file_ids:
|
||||
model_file_id_mapping = await self.get_model_file_id_mapping(
|
||||
file_ids, user_api_key_dict.parent_otel_span
|
||||
)
|
||||
|
||||
data["model_file_id_mapping"] = model_file_id_mapping
|
||||
|
||||
return data
|
||||
|
||||
def get_file_ids_from_messages(self, messages: List[AllMessageValues]) -> List[str]:
|
||||
"""
|
||||
Gets file ids from messages
|
||||
"""
|
||||
file_ids = []
|
||||
for message in messages:
|
||||
if message.get("role") == "user":
|
||||
content = message.get("content")
|
||||
if content:
|
||||
if isinstance(content, str):
|
||||
continue
|
||||
for c in content:
|
||||
if c["type"] == "file":
|
||||
file_object = cast(ChatCompletionFileObject, c)
|
||||
file_object_file_field = file_object["file"]
|
||||
file_id = file_object_file_field.get("file_id")
|
||||
if file_id:
|
||||
file_ids.append(file_id)
|
||||
return file_ids
|
||||
|
||||
@staticmethod
|
||||
def _convert_b64_uid_to_unified_uid(b64_uid: str) -> str:
|
||||
is_base64_unified_file_id = (
|
||||
_PROXY_LiteLLMManagedFiles._is_base64_encoded_unified_file_id(b64_uid)
|
||||
)
|
||||
if is_base64_unified_file_id:
|
||||
return is_base64_unified_file_id
|
||||
else:
|
||||
return b64_uid
|
||||
|
||||
@staticmethod
|
||||
def _is_base64_encoded_unified_file_id(b64_uid: str) -> Union[str, Literal[False]]:
|
||||
# Add padding back if needed
|
||||
padded = b64_uid + "=" * (-len(b64_uid) % 4)
|
||||
# Decode from base64
|
||||
try:
|
||||
decoded = base64.urlsafe_b64decode(padded).decode()
|
||||
if decoded.startswith(SpecialEnums.LITELM_MANAGED_FILE_ID_PREFIX.value):
|
||||
return decoded
|
||||
else:
|
||||
return False
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
def convert_b64_uid_to_unified_uid(self, b64_uid: str) -> str:
|
||||
is_base64_unified_file_id = self._is_base64_encoded_unified_file_id(b64_uid)
|
||||
if is_base64_unified_file_id:
|
||||
return is_base64_unified_file_id
|
||||
else:
|
||||
return b64_uid
|
||||
|
||||
async def get_model_file_id_mapping(
|
||||
self, file_ids: List[str], litellm_parent_otel_span: Span
|
||||
) -> dict:
|
||||
"""
|
||||
Get model-specific file IDs for a list of proxy file IDs.
|
||||
Returns a dictionary mapping litellm_proxy/ file_id -> model_id -> model_file_id
|
||||
|
||||
1. Get all the litellm_proxy/ file_ids from the messages
|
||||
2. For each file_id, search for cache keys matching the pattern file_id:*
|
||||
3. Return a dictionary of mappings of litellm_proxy/ file_id -> model_id -> model_file_id
|
||||
|
||||
Example:
|
||||
{
|
||||
"litellm_proxy/file_id": {
|
||||
"model_id": "model_file_id"
|
||||
}
|
||||
}
|
||||
"""
|
||||
|
||||
file_id_mapping: Dict[str, Dict[str, str]] = {}
|
||||
litellm_managed_file_ids = []
|
||||
|
||||
for file_id in file_ids:
|
||||
## CHECK IF FILE ID IS MANAGED BY LITELM
|
||||
is_base64_unified_file_id = self._is_base64_encoded_unified_file_id(file_id)
|
||||
|
||||
if is_base64_unified_file_id:
|
||||
litellm_managed_file_ids.append(file_id)
|
||||
|
||||
if litellm_managed_file_ids:
|
||||
# Get all cache keys matching the pattern file_id:*
|
||||
for file_id in litellm_managed_file_ids:
|
||||
# Search for any cache key starting with this file_id
|
||||
unified_file_object = await self.get_unified_file_id(
|
||||
file_id, litellm_parent_otel_span
|
||||
)
|
||||
if unified_file_object:
|
||||
file_id_mapping[file_id] = unified_file_object.model_mappings
|
||||
|
||||
return file_id_mapping
|
||||
|
||||
async def create_file_for_each_model(
|
||||
self,
|
||||
llm_router: Optional[Router],
|
||||
_create_file_request: CreateFileRequest,
|
||||
target_model_names_list: List[str],
|
||||
litellm_parent_otel_span: Span,
|
||||
) -> List[OpenAIFileObject]:
|
||||
if llm_router is None:
|
||||
raise Exception("LLM Router not initialized. Ensure models added to proxy.")
|
||||
responses = []
|
||||
for model in target_model_names_list:
|
||||
individual_response = await llm_router.acreate_file(
|
||||
model=model, **_create_file_request
|
||||
)
|
||||
responses.append(individual_response)
|
||||
|
||||
return responses
|
||||
|
||||
async def acreate_file(
|
||||
self,
|
||||
create_file_request: CreateFileRequest,
|
||||
llm_router: Router,
|
||||
target_model_names_list: List[str],
|
||||
litellm_parent_otel_span: Span,
|
||||
) -> OpenAIFileObject:
|
||||
responses = await self.create_file_for_each_model(
|
||||
llm_router=llm_router,
|
||||
_create_file_request=create_file_request,
|
||||
target_model_names_list=target_model_names_list,
|
||||
litellm_parent_otel_span=litellm_parent_otel_span,
|
||||
)
|
||||
response = await _PROXY_LiteLLMManagedFiles.return_unified_file_id(
|
||||
file_objects=responses,
|
||||
create_file_request=create_file_request,
|
||||
internal_usage_cache=self.internal_usage_cache,
|
||||
litellm_parent_otel_span=litellm_parent_otel_span,
|
||||
)
|
||||
|
||||
## STORE MODEL MAPPINGS IN DB
|
||||
model_mappings: Dict[str, str] = {}
|
||||
for file_object in responses:
|
||||
model_id = file_object._hidden_params.get("model_id")
|
||||
if model_id is None:
|
||||
verbose_logger.warning(
|
||||
f"Skipping file_object: {file_object} because model_id in hidden_params={file_object._hidden_params} is None"
|
||||
)
|
||||
continue
|
||||
file_id = file_object.id
|
||||
model_mappings[model_id] = file_id
|
||||
|
||||
await self.store_unified_file_id(
|
||||
file_id=response.id,
|
||||
file_object=response,
|
||||
litellm_parent_otel_span=litellm_parent_otel_span,
|
||||
model_mappings=model_mappings,
|
||||
)
|
||||
return response
|
||||
|
||||
@staticmethod
|
||||
async def return_unified_file_id(
|
||||
file_objects: List[OpenAIFileObject],
|
||||
create_file_request: CreateFileRequest,
|
||||
internal_usage_cache: InternalUsageCache,
|
||||
litellm_parent_otel_span: Span,
|
||||
) -> OpenAIFileObject:
|
||||
## GET THE FILE TYPE FROM THE CREATE FILE REQUEST
|
||||
file_data = extract_file_data(create_file_request["file"])
|
||||
|
||||
file_type = file_data["content_type"]
|
||||
|
||||
unified_file_id = SpecialEnums.LITELLM_MANAGED_FILE_COMPLETE_STR.value.format(
|
||||
file_type, str(uuid.uuid4())
|
||||
)
|
||||
|
||||
# Convert to URL-safe base64 and strip padding
|
||||
base64_unified_file_id = (
|
||||
base64.urlsafe_b64encode(unified_file_id.encode()).decode().rstrip("=")
|
||||
)
|
||||
|
||||
## CREATE RESPONSE OBJECT
|
||||
|
||||
response = OpenAIFileObject(
|
||||
id=base64_unified_file_id,
|
||||
object="file",
|
||||
purpose=create_file_request["purpose"],
|
||||
created_at=file_objects[0].created_at,
|
||||
bytes=file_objects[0].bytes,
|
||||
filename=file_objects[0].filename,
|
||||
status="uploaded",
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
async def afile_retrieve(
|
||||
self, file_id: str, litellm_parent_otel_span: Optional[Span]
|
||||
) -> OpenAIFileObject:
|
||||
stored_file_object = await self.get_unified_file_id(
|
||||
file_id, litellm_parent_otel_span
|
||||
)
|
||||
if stored_file_object:
|
||||
return stored_file_object.file_object
|
||||
else:
|
||||
raise Exception(f"LiteLLM Managed File object with id={file_id} not found")
|
||||
|
||||
async def afile_list(
|
||||
self,
|
||||
purpose: Optional[OpenAIFilesPurpose],
|
||||
litellm_parent_otel_span: Optional[Span],
|
||||
**data: Dict,
|
||||
) -> List[OpenAIFileObject]:
|
||||
return []
|
||||
|
||||
async def afile_delete(
|
||||
self,
|
||||
file_id: str,
|
||||
litellm_parent_otel_span: Optional[Span],
|
||||
llm_router: Router,
|
||||
**data: Dict,
|
||||
) -> OpenAIFileObject:
|
||||
file_id = self.convert_b64_uid_to_unified_uid(file_id)
|
||||
model_file_id_mapping = await self.get_model_file_id_mapping(
|
||||
[file_id], litellm_parent_otel_span
|
||||
)
|
||||
specific_model_file_id_mapping = model_file_id_mapping.get(file_id)
|
||||
if specific_model_file_id_mapping:
|
||||
for model_id, file_id in specific_model_file_id_mapping.items():
|
||||
await llm_router.afile_delete(model=model_id, file_id=file_id, **data) # type: ignore
|
||||
|
||||
stored_file_object = await self.delete_unified_file_id(
|
||||
file_id, litellm_parent_otel_span
|
||||
)
|
||||
if stored_file_object:
|
||||
return stored_file_object
|
||||
else:
|
||||
raise Exception(f"LiteLLM Managed File object with id={file_id} not found")
|
||||
|
||||
async def afile_content(
|
||||
self,
|
||||
file_id: str,
|
||||
litellm_parent_otel_span: Optional[Span],
|
||||
llm_router: Router,
|
||||
**data: Dict,
|
||||
) -> str:
|
||||
"""
|
||||
Get the content of a file from first model that has it
|
||||
"""
|
||||
model_file_id_mapping = await self.get_model_file_id_mapping(
|
||||
[file_id], litellm_parent_otel_span
|
||||
)
|
||||
specific_model_file_id_mapping = model_file_id_mapping.get(file_id)
|
||||
|
||||
if specific_model_file_id_mapping:
|
||||
exception_dict = {}
|
||||
for model_id, file_id in specific_model_file_id_mapping.items():
|
||||
try:
|
||||
return await llm_router.afile_content(model=model_id, file_id=file_id, **data) # type: ignore
|
||||
except Exception as e:
|
||||
exception_dict[model_id] = str(e)
|
||||
raise Exception(
|
||||
f"LiteLLM Managed File object with id={file_id} not found. Checked model id's: {specific_model_file_id_mapping.keys()}. Errors: {exception_dict}"
|
||||
)
|
||||
else:
|
||||
raise Exception(f"LiteLLM Managed File object with id={file_id} not found")
|
||||
@@ -0,0 +1,49 @@
|
||||
from fastapi import HTTPException
|
||||
|
||||
from litellm import verbose_logger
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.caching.caching import DualCache
|
||||
from litellm.integrations.custom_logger import CustomLogger
|
||||
from litellm.proxy._types import UserAPIKeyAuth
|
||||
|
||||
|
||||
class _PROXY_MaxBudgetLimiter(CustomLogger):
|
||||
# Class variables or attributes
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
async def async_pre_call_hook(
|
||||
self,
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
cache: DualCache,
|
||||
data: dict,
|
||||
call_type: str,
|
||||
):
|
||||
try:
|
||||
verbose_proxy_logger.debug("Inside Max Budget Limiter Pre-Call Hook")
|
||||
cache_key = f"{user_api_key_dict.user_id}_user_api_key_user_id"
|
||||
user_row = await cache.async_get_cache(
|
||||
cache_key, parent_otel_span=user_api_key_dict.parent_otel_span
|
||||
)
|
||||
if user_row is None: # value not yet cached
|
||||
return
|
||||
max_budget = user_row["max_budget"]
|
||||
curr_spend = user_row["spend"]
|
||||
|
||||
if max_budget is None:
|
||||
return
|
||||
|
||||
if curr_spend is None:
|
||||
return
|
||||
|
||||
# CHECK IF REQUEST ALLOWED
|
||||
if curr_spend >= max_budget:
|
||||
raise HTTPException(status_code=429, detail="Max budget limit reached.")
|
||||
except HTTPException as e:
|
||||
raise e
|
||||
except Exception as e:
|
||||
verbose_logger.exception(
|
||||
"litellm.proxy.hooks.max_budget_limiter.py::async_pre_call_hook(): Exception occured - {}".format(
|
||||
str(e)
|
||||
)
|
||||
)
|
||||
@@ -0,0 +1,192 @@
|
||||
import json
|
||||
from typing import List, Optional
|
||||
|
||||
import litellm
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.caching.caching import DualCache
|
||||
from litellm.integrations.custom_logger import Span
|
||||
from litellm.proxy._types import UserAPIKeyAuth
|
||||
from litellm.router_strategy.budget_limiter import RouterBudgetLimiting
|
||||
from litellm.types.llms.openai import AllMessageValues
|
||||
from litellm.types.utils import (
|
||||
BudgetConfig,
|
||||
GenericBudgetConfigType,
|
||||
StandardLoggingPayload,
|
||||
)
|
||||
|
||||
VIRTUAL_KEY_SPEND_CACHE_KEY_PREFIX = "virtual_key_spend"
|
||||
|
||||
|
||||
class _PROXY_VirtualKeyModelMaxBudgetLimiter(RouterBudgetLimiting):
|
||||
"""
|
||||
Handles budgets for model + virtual key
|
||||
|
||||
Example: key=sk-1234567890, model=gpt-4o, max_budget=100, time_period=1d
|
||||
"""
|
||||
|
||||
def __init__(self, dual_cache: DualCache):
|
||||
self.dual_cache = dual_cache
|
||||
self.redis_increment_operation_queue = []
|
||||
|
||||
async def is_key_within_model_budget(
|
||||
self,
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
model: str,
|
||||
) -> bool:
|
||||
"""
|
||||
Check if the user_api_key_dict is within the model budget
|
||||
|
||||
Raises:
|
||||
BudgetExceededError: If the user_api_key_dict has exceeded the model budget
|
||||
"""
|
||||
_model_max_budget = user_api_key_dict.model_max_budget
|
||||
internal_model_max_budget: GenericBudgetConfigType = {}
|
||||
|
||||
for _model, _budget_info in _model_max_budget.items():
|
||||
internal_model_max_budget[_model] = BudgetConfig(**_budget_info)
|
||||
|
||||
verbose_proxy_logger.debug(
|
||||
"internal_model_max_budget %s",
|
||||
json.dumps(internal_model_max_budget, indent=4, default=str),
|
||||
)
|
||||
|
||||
# check if current model is in internal_model_max_budget
|
||||
_current_model_budget_info = self._get_request_model_budget_config(
|
||||
model=model, internal_model_max_budget=internal_model_max_budget
|
||||
)
|
||||
if _current_model_budget_info is None:
|
||||
verbose_proxy_logger.debug(
|
||||
f"Model {model} not found in internal_model_max_budget"
|
||||
)
|
||||
return True
|
||||
|
||||
# check if current model is within budget
|
||||
if (
|
||||
_current_model_budget_info.max_budget
|
||||
and _current_model_budget_info.max_budget > 0
|
||||
):
|
||||
_current_spend = await self._get_virtual_key_spend_for_model(
|
||||
user_api_key_hash=user_api_key_dict.token,
|
||||
model=model,
|
||||
key_budget_config=_current_model_budget_info,
|
||||
)
|
||||
if (
|
||||
_current_spend is not None
|
||||
and _current_model_budget_info.max_budget is not None
|
||||
and _current_spend > _current_model_budget_info.max_budget
|
||||
):
|
||||
raise litellm.BudgetExceededError(
|
||||
message=f"LiteLLM Virtual Key: {user_api_key_dict.token}, key_alias: {user_api_key_dict.key_alias}, exceeded budget for model={model}",
|
||||
current_cost=_current_spend,
|
||||
max_budget=_current_model_budget_info.max_budget,
|
||||
)
|
||||
|
||||
return True
|
||||
|
||||
async def _get_virtual_key_spend_for_model(
|
||||
self,
|
||||
user_api_key_hash: Optional[str],
|
||||
model: str,
|
||||
key_budget_config: BudgetConfig,
|
||||
) -> Optional[float]:
|
||||
"""
|
||||
Get the current spend for a virtual key for a model
|
||||
|
||||
Lookup model in this order:
|
||||
1. model: directly look up `model`
|
||||
2. If 1, does not exist, check if passed as {custom_llm_provider}/model
|
||||
"""
|
||||
|
||||
# 1. model: directly look up `model`
|
||||
virtual_key_model_spend_cache_key = f"{VIRTUAL_KEY_SPEND_CACHE_KEY_PREFIX}:{user_api_key_hash}:{model}:{key_budget_config.budget_duration}"
|
||||
_current_spend = await self.dual_cache.async_get_cache(
|
||||
key=virtual_key_model_spend_cache_key,
|
||||
)
|
||||
|
||||
if _current_spend is None:
|
||||
# 2. If 1, does not exist, check if passed as {custom_llm_provider}/model
|
||||
# if "/" in model, remove first part before "/" - eg. openai/o1-preview -> o1-preview
|
||||
virtual_key_model_spend_cache_key = f"{VIRTUAL_KEY_SPEND_CACHE_KEY_PREFIX}:{user_api_key_hash}:{self._get_model_without_custom_llm_provider(model)}:{key_budget_config.budget_duration}"
|
||||
_current_spend = await self.dual_cache.async_get_cache(
|
||||
key=virtual_key_model_spend_cache_key,
|
||||
)
|
||||
return _current_spend
|
||||
|
||||
def _get_request_model_budget_config(
|
||||
self, model: str, internal_model_max_budget: GenericBudgetConfigType
|
||||
) -> Optional[BudgetConfig]:
|
||||
"""
|
||||
Get the budget config for the request model
|
||||
|
||||
1. Check if `model` is in `internal_model_max_budget`
|
||||
2. If not, check if `model` without custom llm provider is in `internal_model_max_budget`
|
||||
"""
|
||||
return internal_model_max_budget.get(
|
||||
model, None
|
||||
) or internal_model_max_budget.get(
|
||||
self._get_model_without_custom_llm_provider(model), None
|
||||
)
|
||||
|
||||
def _get_model_without_custom_llm_provider(self, model: str) -> str:
|
||||
if "/" in model:
|
||||
return model.split("/")[-1]
|
||||
return model
|
||||
|
||||
async def async_filter_deployments(
|
||||
self,
|
||||
model: str,
|
||||
healthy_deployments: List,
|
||||
messages: Optional[List[AllMessageValues]],
|
||||
request_kwargs: Optional[dict] = None,
|
||||
parent_otel_span: Optional[Span] = None, # type: ignore
|
||||
) -> List[dict]:
|
||||
return healthy_deployments
|
||||
|
||||
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
|
||||
"""
|
||||
Track spend for virtual key + model in DualCache
|
||||
|
||||
Example: key=sk-1234567890, model=gpt-4o, max_budget=100, time_period=1d
|
||||
"""
|
||||
verbose_proxy_logger.debug("in RouterBudgetLimiting.async_log_success_event")
|
||||
standard_logging_payload: Optional[StandardLoggingPayload] = kwargs.get(
|
||||
"standard_logging_object", None
|
||||
)
|
||||
if standard_logging_payload is None:
|
||||
raise ValueError("standard_logging_payload is required")
|
||||
|
||||
_litellm_params: dict = kwargs.get("litellm_params", {}) or {}
|
||||
_metadata: dict = _litellm_params.get("metadata", {}) or {}
|
||||
user_api_key_model_max_budget: Optional[dict] = _metadata.get(
|
||||
"user_api_key_model_max_budget", None
|
||||
)
|
||||
if (
|
||||
user_api_key_model_max_budget is None
|
||||
or len(user_api_key_model_max_budget) == 0
|
||||
):
|
||||
verbose_proxy_logger.debug(
|
||||
"Not running _PROXY_VirtualKeyModelMaxBudgetLimiter.async_log_success_event because user_api_key_model_max_budget is None or empty. `user_api_key_model_max_budget`=%s",
|
||||
user_api_key_model_max_budget,
|
||||
)
|
||||
return
|
||||
response_cost: float = standard_logging_payload.get("response_cost", 0)
|
||||
model = standard_logging_payload.get("model")
|
||||
|
||||
virtual_key = standard_logging_payload.get("metadata").get("user_api_key_hash")
|
||||
model = standard_logging_payload.get("model")
|
||||
if virtual_key is not None:
|
||||
budget_config = BudgetConfig(time_period="1d", budget_limit=0.1)
|
||||
virtual_spend_key = f"{VIRTUAL_KEY_SPEND_CACHE_KEY_PREFIX}:{virtual_key}:{model}:{budget_config.budget_duration}"
|
||||
virtual_start_time_key = f"virtual_key_budget_start_time:{virtual_key}"
|
||||
await self._increment_spend_for_key(
|
||||
budget_config=budget_config,
|
||||
spend_key=virtual_spend_key,
|
||||
start_time_key=virtual_start_time_key,
|
||||
response_cost=response_cost,
|
||||
)
|
||||
verbose_proxy_logger.debug(
|
||||
"current state of in memory cache %s",
|
||||
json.dumps(
|
||||
self.dual_cache.in_memory_cache.cache_dict, indent=4, default=str
|
||||
),
|
||||
)
|
||||
@@ -0,0 +1,868 @@
|
||||
import asyncio
|
||||
import sys
|
||||
from datetime import datetime, timedelta
|
||||
from typing import TYPE_CHECKING, Any, List, Literal, Optional, Tuple, TypedDict, Union
|
||||
|
||||
from fastapi import HTTPException
|
||||
from pydantic import BaseModel
|
||||
|
||||
import litellm
|
||||
from litellm import DualCache, ModelResponse
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.integrations.custom_logger import CustomLogger
|
||||
from litellm.litellm_core_utils.core_helpers import _get_parent_otel_span_from_kwargs
|
||||
from litellm.proxy._types import CommonProxyErrors, CurrentItemRateLimit, UserAPIKeyAuth
|
||||
from litellm.proxy.auth.auth_utils import (
|
||||
get_key_model_rpm_limit,
|
||||
get_key_model_tpm_limit,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from opentelemetry.trace import Span as _Span
|
||||
|
||||
from litellm.proxy.utils import InternalUsageCache as _InternalUsageCache
|
||||
|
||||
Span = Union[_Span, Any]
|
||||
InternalUsageCache = _InternalUsageCache
|
||||
else:
|
||||
Span = Any
|
||||
InternalUsageCache = Any
|
||||
|
||||
|
||||
class CacheObject(TypedDict):
|
||||
current_global_requests: Optional[dict]
|
||||
request_count_api_key: Optional[dict]
|
||||
request_count_api_key_model: Optional[dict]
|
||||
request_count_user_id: Optional[dict]
|
||||
request_count_team_id: Optional[dict]
|
||||
request_count_end_user_id: Optional[dict]
|
||||
|
||||
|
||||
class _PROXY_MaxParallelRequestsHandler(CustomLogger):
|
||||
# Class variables or attributes
|
||||
def __init__(self, internal_usage_cache: InternalUsageCache):
|
||||
self.internal_usage_cache = internal_usage_cache
|
||||
|
||||
def print_verbose(self, print_statement):
|
||||
try:
|
||||
verbose_proxy_logger.debug(print_statement)
|
||||
if litellm.set_verbose:
|
||||
print(print_statement) # noqa
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
async def check_key_in_limits(
|
||||
self,
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
cache: DualCache,
|
||||
data: dict,
|
||||
call_type: str,
|
||||
max_parallel_requests: int,
|
||||
tpm_limit: int,
|
||||
rpm_limit: int,
|
||||
current: Optional[dict],
|
||||
request_count_api_key: str,
|
||||
rate_limit_type: Literal["key", "model_per_key", "user", "customer", "team"],
|
||||
values_to_update_in_cache: List[Tuple[Any, Any]],
|
||||
) -> dict:
|
||||
verbose_proxy_logger.info(
|
||||
f"Current Usage of {rate_limit_type} in this minute: {current}"
|
||||
)
|
||||
if current is None:
|
||||
if max_parallel_requests == 0 or tpm_limit == 0 or rpm_limit == 0:
|
||||
# base case
|
||||
raise self.raise_rate_limit_error(
|
||||
additional_details=f"{CommonProxyErrors.max_parallel_request_limit_reached.value}. Hit limit for {rate_limit_type}. Current limits: max_parallel_requests: {max_parallel_requests}, tpm_limit: {tpm_limit}, rpm_limit: {rpm_limit}"
|
||||
)
|
||||
new_val = {
|
||||
"current_requests": 1,
|
||||
"current_tpm": 0,
|
||||
"current_rpm": 1,
|
||||
}
|
||||
values_to_update_in_cache.append((request_count_api_key, new_val))
|
||||
elif (
|
||||
int(current["current_requests"]) < max_parallel_requests
|
||||
and current["current_tpm"] < tpm_limit
|
||||
and current["current_rpm"] < rpm_limit
|
||||
):
|
||||
# Increase count for this token
|
||||
new_val = {
|
||||
"current_requests": current["current_requests"] + 1,
|
||||
"current_tpm": current["current_tpm"],
|
||||
"current_rpm": current["current_rpm"] + 1,
|
||||
}
|
||||
values_to_update_in_cache.append((request_count_api_key, new_val))
|
||||
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=429,
|
||||
detail=f"LiteLLM Rate Limit Handler for rate limit type = {rate_limit_type}. {CommonProxyErrors.max_parallel_request_limit_reached.value}. current rpm: {current['current_rpm']}, rpm limit: {rpm_limit}, current tpm: {current['current_tpm']}, tpm limit: {tpm_limit}, current max_parallel_requests: {current['current_requests']}, max_parallel_requests: {max_parallel_requests}",
|
||||
headers={"retry-after": str(self.time_to_next_minute())},
|
||||
)
|
||||
|
||||
await self.internal_usage_cache.async_batch_set_cache(
|
||||
cache_list=values_to_update_in_cache,
|
||||
ttl=60,
|
||||
litellm_parent_otel_span=user_api_key_dict.parent_otel_span,
|
||||
local_only=True,
|
||||
)
|
||||
return new_val
|
||||
|
||||
def time_to_next_minute(self) -> float:
|
||||
# Get the current time
|
||||
now = datetime.now()
|
||||
|
||||
# Calculate the next minute
|
||||
next_minute = (now + timedelta(minutes=1)).replace(second=0, microsecond=0)
|
||||
|
||||
# Calculate the difference in seconds
|
||||
seconds_to_next_minute = (next_minute - now).total_seconds()
|
||||
|
||||
return seconds_to_next_minute
|
||||
|
||||
def raise_rate_limit_error(
|
||||
self, additional_details: Optional[str] = None
|
||||
) -> HTTPException:
|
||||
"""
|
||||
Raise an HTTPException with a 429 status code and a retry-after header
|
||||
"""
|
||||
error_message = "Max parallel request limit reached"
|
||||
if additional_details is not None:
|
||||
error_message = error_message + " " + additional_details
|
||||
raise HTTPException(
|
||||
status_code=429,
|
||||
detail=f"Max parallel request limit reached {additional_details}",
|
||||
headers={"retry-after": str(self.time_to_next_minute())},
|
||||
)
|
||||
|
||||
async def get_all_cache_objects(
|
||||
self,
|
||||
current_global_requests: Optional[str],
|
||||
request_count_api_key: Optional[str],
|
||||
request_count_api_key_model: Optional[str],
|
||||
request_count_user_id: Optional[str],
|
||||
request_count_team_id: Optional[str],
|
||||
request_count_end_user_id: Optional[str],
|
||||
parent_otel_span: Optional[Span] = None,
|
||||
) -> CacheObject:
|
||||
keys = [
|
||||
current_global_requests,
|
||||
request_count_api_key,
|
||||
request_count_api_key_model,
|
||||
request_count_user_id,
|
||||
request_count_team_id,
|
||||
request_count_end_user_id,
|
||||
]
|
||||
results = await self.internal_usage_cache.async_batch_get_cache(
|
||||
keys=keys,
|
||||
parent_otel_span=parent_otel_span,
|
||||
)
|
||||
|
||||
if results is None:
|
||||
return CacheObject(
|
||||
current_global_requests=None,
|
||||
request_count_api_key=None,
|
||||
request_count_api_key_model=None,
|
||||
request_count_user_id=None,
|
||||
request_count_team_id=None,
|
||||
request_count_end_user_id=None,
|
||||
)
|
||||
|
||||
return CacheObject(
|
||||
current_global_requests=results[0],
|
||||
request_count_api_key=results[1],
|
||||
request_count_api_key_model=results[2],
|
||||
request_count_user_id=results[3],
|
||||
request_count_team_id=results[4],
|
||||
request_count_end_user_id=results[5],
|
||||
)
|
||||
|
||||
async def async_pre_call_hook( # noqa: PLR0915
|
||||
self,
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
cache: DualCache,
|
||||
data: dict,
|
||||
call_type: str,
|
||||
):
|
||||
self.print_verbose("Inside Max Parallel Request Pre-Call Hook")
|
||||
api_key = user_api_key_dict.api_key
|
||||
max_parallel_requests = user_api_key_dict.max_parallel_requests
|
||||
if max_parallel_requests is None:
|
||||
max_parallel_requests = sys.maxsize
|
||||
if data is None:
|
||||
data = {}
|
||||
global_max_parallel_requests = data.get("metadata", {}).get(
|
||||
"global_max_parallel_requests", None
|
||||
)
|
||||
tpm_limit = getattr(user_api_key_dict, "tpm_limit", sys.maxsize)
|
||||
if tpm_limit is None:
|
||||
tpm_limit = sys.maxsize
|
||||
rpm_limit = getattr(user_api_key_dict, "rpm_limit", sys.maxsize)
|
||||
if rpm_limit is None:
|
||||
rpm_limit = sys.maxsize
|
||||
|
||||
values_to_update_in_cache: List[
|
||||
Tuple[Any, Any]
|
||||
] = (
|
||||
[]
|
||||
) # values that need to get updated in cache, will run a batch_set_cache after this function
|
||||
|
||||
# ------------
|
||||
# Setup values
|
||||
# ------------
|
||||
new_val: Optional[dict] = None
|
||||
|
||||
if global_max_parallel_requests is not None:
|
||||
# get value from cache
|
||||
_key = "global_max_parallel_requests"
|
||||
current_global_requests = await self.internal_usage_cache.async_get_cache(
|
||||
key=_key,
|
||||
local_only=True,
|
||||
litellm_parent_otel_span=user_api_key_dict.parent_otel_span,
|
||||
)
|
||||
# check if below limit
|
||||
if current_global_requests is None:
|
||||
current_global_requests = 1
|
||||
# if above -> raise error
|
||||
if current_global_requests >= global_max_parallel_requests:
|
||||
return self.raise_rate_limit_error(
|
||||
additional_details=f"Hit Global Limit: Limit={global_max_parallel_requests}, current: {current_global_requests}"
|
||||
)
|
||||
# if below -> increment
|
||||
else:
|
||||
await self.internal_usage_cache.async_increment_cache(
|
||||
key=_key,
|
||||
value=1,
|
||||
local_only=True,
|
||||
litellm_parent_otel_span=user_api_key_dict.parent_otel_span,
|
||||
)
|
||||
_model = data.get("model", None)
|
||||
|
||||
current_date = datetime.now().strftime("%Y-%m-%d")
|
||||
current_hour = datetime.now().strftime("%H")
|
||||
current_minute = datetime.now().strftime("%M")
|
||||
precise_minute = f"{current_date}-{current_hour}-{current_minute}"
|
||||
|
||||
cache_objects: CacheObject = await self.get_all_cache_objects(
|
||||
current_global_requests=(
|
||||
"global_max_parallel_requests"
|
||||
if global_max_parallel_requests is not None
|
||||
else None
|
||||
),
|
||||
request_count_api_key=(
|
||||
f"{api_key}::{precise_minute}::request_count"
|
||||
if api_key is not None
|
||||
else None
|
||||
),
|
||||
request_count_api_key_model=(
|
||||
f"{api_key}::{_model}::{precise_minute}::request_count"
|
||||
if api_key is not None and _model is not None
|
||||
else None
|
||||
),
|
||||
request_count_user_id=(
|
||||
f"{user_api_key_dict.user_id}::{precise_minute}::request_count"
|
||||
if user_api_key_dict.user_id is not None
|
||||
else None
|
||||
),
|
||||
request_count_team_id=(
|
||||
f"{user_api_key_dict.team_id}::{precise_minute}::request_count"
|
||||
if user_api_key_dict.team_id is not None
|
||||
else None
|
||||
),
|
||||
request_count_end_user_id=(
|
||||
f"{user_api_key_dict.end_user_id}::{precise_minute}::request_count"
|
||||
if user_api_key_dict.end_user_id is not None
|
||||
else None
|
||||
),
|
||||
parent_otel_span=user_api_key_dict.parent_otel_span,
|
||||
)
|
||||
if api_key is not None:
|
||||
request_count_api_key = f"{api_key}::{precise_minute}::request_count"
|
||||
# CHECK IF REQUEST ALLOWED for key
|
||||
await self.check_key_in_limits(
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
cache=cache,
|
||||
data=data,
|
||||
call_type=call_type,
|
||||
max_parallel_requests=max_parallel_requests,
|
||||
current=cache_objects["request_count_api_key"],
|
||||
request_count_api_key=request_count_api_key,
|
||||
tpm_limit=tpm_limit,
|
||||
rpm_limit=rpm_limit,
|
||||
rate_limit_type="key",
|
||||
values_to_update_in_cache=values_to_update_in_cache,
|
||||
)
|
||||
|
||||
# Check if request under RPM/TPM per model for a given API Key
|
||||
if (
|
||||
get_key_model_tpm_limit(user_api_key_dict) is not None
|
||||
or get_key_model_rpm_limit(user_api_key_dict) is not None
|
||||
):
|
||||
_model = data.get("model", None)
|
||||
request_count_api_key = (
|
||||
f"{api_key}::{_model}::{precise_minute}::request_count"
|
||||
)
|
||||
_tpm_limit_for_key_model = get_key_model_tpm_limit(user_api_key_dict)
|
||||
_rpm_limit_for_key_model = get_key_model_rpm_limit(user_api_key_dict)
|
||||
tpm_limit_for_model = None
|
||||
rpm_limit_for_model = None
|
||||
|
||||
if _model is not None:
|
||||
if _tpm_limit_for_key_model:
|
||||
tpm_limit_for_model = _tpm_limit_for_key_model.get(_model)
|
||||
|
||||
if _rpm_limit_for_key_model:
|
||||
rpm_limit_for_model = _rpm_limit_for_key_model.get(_model)
|
||||
|
||||
new_val = await self.check_key_in_limits(
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
cache=cache,
|
||||
data=data,
|
||||
call_type=call_type,
|
||||
max_parallel_requests=sys.maxsize, # TODO: Support max parallel requests for a model
|
||||
current=cache_objects["request_count_api_key_model"],
|
||||
request_count_api_key=request_count_api_key,
|
||||
tpm_limit=tpm_limit_for_model or sys.maxsize,
|
||||
rpm_limit=rpm_limit_for_model or sys.maxsize,
|
||||
rate_limit_type="model_per_key",
|
||||
values_to_update_in_cache=values_to_update_in_cache,
|
||||
)
|
||||
_remaining_tokens = None
|
||||
_remaining_requests = None
|
||||
# Add remaining tokens, requests to metadata
|
||||
if new_val:
|
||||
if tpm_limit_for_model is not None:
|
||||
_remaining_tokens = tpm_limit_for_model - new_val["current_tpm"]
|
||||
if rpm_limit_for_model is not None:
|
||||
_remaining_requests = rpm_limit_for_model - new_val["current_rpm"]
|
||||
|
||||
_remaining_limits_data = {
|
||||
f"litellm-key-remaining-tokens-{_model}": _remaining_tokens,
|
||||
f"litellm-key-remaining-requests-{_model}": _remaining_requests,
|
||||
}
|
||||
|
||||
if "metadata" not in data:
|
||||
data["metadata"] = {}
|
||||
data["metadata"].update(_remaining_limits_data)
|
||||
|
||||
# check if REQUEST ALLOWED for user_id
|
||||
user_id = user_api_key_dict.user_id
|
||||
if user_id is not None:
|
||||
user_tpm_limit = user_api_key_dict.user_tpm_limit
|
||||
user_rpm_limit = user_api_key_dict.user_rpm_limit
|
||||
if user_tpm_limit is None:
|
||||
user_tpm_limit = sys.maxsize
|
||||
if user_rpm_limit is None:
|
||||
user_rpm_limit = sys.maxsize
|
||||
|
||||
request_count_api_key = f"{user_id}::{precise_minute}::request_count"
|
||||
# print(f"Checking if {request_count_api_key} is allowed to make request for minute {precise_minute}")
|
||||
await self.check_key_in_limits(
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
cache=cache,
|
||||
data=data,
|
||||
call_type=call_type,
|
||||
max_parallel_requests=sys.maxsize, # TODO: Support max parallel requests for a user
|
||||
current=cache_objects["request_count_user_id"],
|
||||
request_count_api_key=request_count_api_key,
|
||||
tpm_limit=user_tpm_limit,
|
||||
rpm_limit=user_rpm_limit,
|
||||
rate_limit_type="user",
|
||||
values_to_update_in_cache=values_to_update_in_cache,
|
||||
)
|
||||
|
||||
# TEAM RATE LIMITS
|
||||
## get team tpm/rpm limits
|
||||
team_id = user_api_key_dict.team_id
|
||||
if team_id is not None:
|
||||
team_tpm_limit = user_api_key_dict.team_tpm_limit
|
||||
team_rpm_limit = user_api_key_dict.team_rpm_limit
|
||||
|
||||
if team_tpm_limit is None:
|
||||
team_tpm_limit = sys.maxsize
|
||||
if team_rpm_limit is None:
|
||||
team_rpm_limit = sys.maxsize
|
||||
|
||||
request_count_api_key = f"{team_id}::{precise_minute}::request_count"
|
||||
# print(f"Checking if {request_count_api_key} is allowed to make request for minute {precise_minute}")
|
||||
await self.check_key_in_limits(
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
cache=cache,
|
||||
data=data,
|
||||
call_type=call_type,
|
||||
max_parallel_requests=sys.maxsize, # TODO: Support max parallel requests for a team
|
||||
current=cache_objects["request_count_team_id"],
|
||||
request_count_api_key=request_count_api_key,
|
||||
tpm_limit=team_tpm_limit,
|
||||
rpm_limit=team_rpm_limit,
|
||||
rate_limit_type="team",
|
||||
values_to_update_in_cache=values_to_update_in_cache,
|
||||
)
|
||||
|
||||
# End-User Rate Limits
|
||||
# Only enforce if user passed `user` to /chat, /completions, /embeddings
|
||||
if user_api_key_dict.end_user_id:
|
||||
end_user_tpm_limit = getattr(
|
||||
user_api_key_dict, "end_user_tpm_limit", sys.maxsize
|
||||
)
|
||||
end_user_rpm_limit = getattr(
|
||||
user_api_key_dict, "end_user_rpm_limit", sys.maxsize
|
||||
)
|
||||
|
||||
if end_user_tpm_limit is None:
|
||||
end_user_tpm_limit = sys.maxsize
|
||||
if end_user_rpm_limit is None:
|
||||
end_user_rpm_limit = sys.maxsize
|
||||
|
||||
# now do the same tpm/rpm checks
|
||||
request_count_api_key = (
|
||||
f"{user_api_key_dict.end_user_id}::{precise_minute}::request_count"
|
||||
)
|
||||
|
||||
# print(f"Checking if {request_count_api_key} is allowed to make request for minute {precise_minute}")
|
||||
await self.check_key_in_limits(
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
cache=cache,
|
||||
data=data,
|
||||
call_type=call_type,
|
||||
max_parallel_requests=sys.maxsize, # TODO: Support max parallel requests for an End-User
|
||||
request_count_api_key=request_count_api_key,
|
||||
current=cache_objects["request_count_end_user_id"],
|
||||
tpm_limit=end_user_tpm_limit,
|
||||
rpm_limit=end_user_rpm_limit,
|
||||
rate_limit_type="customer",
|
||||
values_to_update_in_cache=values_to_update_in_cache,
|
||||
)
|
||||
|
||||
asyncio.create_task(
|
||||
self.internal_usage_cache.async_batch_set_cache(
|
||||
cache_list=values_to_update_in_cache,
|
||||
ttl=60,
|
||||
litellm_parent_otel_span=user_api_key_dict.parent_otel_span,
|
||||
) # don't block execution for cache updates
|
||||
)
|
||||
|
||||
return
|
||||
|
||||
async def async_log_success_event( # noqa: PLR0915
|
||||
self, kwargs, response_obj, start_time, end_time
|
||||
):
|
||||
from litellm.proxy.common_utils.callback_utils import (
|
||||
get_model_group_from_litellm_kwargs,
|
||||
)
|
||||
|
||||
litellm_parent_otel_span: Union[Span, None] = _get_parent_otel_span_from_kwargs(
|
||||
kwargs=kwargs
|
||||
)
|
||||
try:
|
||||
self.print_verbose("INSIDE parallel request limiter ASYNC SUCCESS LOGGING")
|
||||
|
||||
global_max_parallel_requests = kwargs["litellm_params"]["metadata"].get(
|
||||
"global_max_parallel_requests", None
|
||||
)
|
||||
user_api_key = kwargs["litellm_params"]["metadata"]["user_api_key"]
|
||||
user_api_key_user_id = kwargs["litellm_params"]["metadata"].get(
|
||||
"user_api_key_user_id", None
|
||||
)
|
||||
user_api_key_team_id = kwargs["litellm_params"]["metadata"].get(
|
||||
"user_api_key_team_id", None
|
||||
)
|
||||
user_api_key_model_max_budget = kwargs["litellm_params"]["metadata"].get(
|
||||
"user_api_key_model_max_budget", None
|
||||
)
|
||||
user_api_key_end_user_id = kwargs.get("user")
|
||||
|
||||
user_api_key_metadata = (
|
||||
kwargs["litellm_params"]["metadata"].get("user_api_key_metadata", {})
|
||||
or {}
|
||||
)
|
||||
|
||||
# ------------
|
||||
# Setup values
|
||||
# ------------
|
||||
|
||||
if global_max_parallel_requests is not None:
|
||||
# get value from cache
|
||||
_key = "global_max_parallel_requests"
|
||||
# decrement
|
||||
await self.internal_usage_cache.async_increment_cache(
|
||||
key=_key,
|
||||
value=-1,
|
||||
local_only=True,
|
||||
litellm_parent_otel_span=litellm_parent_otel_span,
|
||||
)
|
||||
|
||||
current_date = datetime.now().strftime("%Y-%m-%d")
|
||||
current_hour = datetime.now().strftime("%H")
|
||||
current_minute = datetime.now().strftime("%M")
|
||||
precise_minute = f"{current_date}-{current_hour}-{current_minute}"
|
||||
|
||||
total_tokens = 0
|
||||
|
||||
if isinstance(response_obj, ModelResponse):
|
||||
total_tokens = response_obj.usage.total_tokens # type: ignore
|
||||
|
||||
# ------------
|
||||
# Update usage - API Key
|
||||
# ------------
|
||||
|
||||
values_to_update_in_cache = []
|
||||
|
||||
if user_api_key is not None:
|
||||
request_count_api_key = (
|
||||
f"{user_api_key}::{precise_minute}::request_count"
|
||||
)
|
||||
|
||||
current = await self.internal_usage_cache.async_get_cache(
|
||||
key=request_count_api_key,
|
||||
litellm_parent_otel_span=litellm_parent_otel_span,
|
||||
) or {
|
||||
"current_requests": 1,
|
||||
"current_tpm": 0,
|
||||
"current_rpm": 0,
|
||||
}
|
||||
|
||||
new_val = {
|
||||
"current_requests": max(current["current_requests"] - 1, 0),
|
||||
"current_tpm": current["current_tpm"] + total_tokens,
|
||||
"current_rpm": current["current_rpm"],
|
||||
}
|
||||
|
||||
self.print_verbose(
|
||||
f"updated_value in success call: {new_val}, precise_minute: {precise_minute}"
|
||||
)
|
||||
values_to_update_in_cache.append((request_count_api_key, new_val))
|
||||
|
||||
# ------------
|
||||
# Update usage - model group + API Key
|
||||
# ------------
|
||||
model_group = get_model_group_from_litellm_kwargs(kwargs)
|
||||
if (
|
||||
user_api_key is not None
|
||||
and model_group is not None
|
||||
and (
|
||||
"model_rpm_limit" in user_api_key_metadata
|
||||
or "model_tpm_limit" in user_api_key_metadata
|
||||
or user_api_key_model_max_budget is not None
|
||||
)
|
||||
):
|
||||
request_count_api_key = (
|
||||
f"{user_api_key}::{model_group}::{precise_minute}::request_count"
|
||||
)
|
||||
|
||||
current = await self.internal_usage_cache.async_get_cache(
|
||||
key=request_count_api_key,
|
||||
litellm_parent_otel_span=litellm_parent_otel_span,
|
||||
) or {
|
||||
"current_requests": 1,
|
||||
"current_tpm": 0,
|
||||
"current_rpm": 0,
|
||||
}
|
||||
|
||||
new_val = {
|
||||
"current_requests": max(current["current_requests"] - 1, 0),
|
||||
"current_tpm": current["current_tpm"] + total_tokens,
|
||||
"current_rpm": current["current_rpm"],
|
||||
}
|
||||
|
||||
self.print_verbose(
|
||||
f"updated_value in success call: {new_val}, precise_minute: {precise_minute}"
|
||||
)
|
||||
values_to_update_in_cache.append((request_count_api_key, new_val))
|
||||
|
||||
# ------------
|
||||
# Update usage - User
|
||||
# ------------
|
||||
if user_api_key_user_id is not None:
|
||||
total_tokens = 0
|
||||
|
||||
if isinstance(response_obj, ModelResponse):
|
||||
total_tokens = response_obj.usage.total_tokens # type: ignore
|
||||
|
||||
request_count_api_key = (
|
||||
f"{user_api_key_user_id}::{precise_minute}::request_count"
|
||||
)
|
||||
|
||||
current = await self.internal_usage_cache.async_get_cache(
|
||||
key=request_count_api_key,
|
||||
litellm_parent_otel_span=litellm_parent_otel_span,
|
||||
) or {
|
||||
"current_requests": 1,
|
||||
"current_tpm": total_tokens,
|
||||
"current_rpm": 1,
|
||||
}
|
||||
|
||||
new_val = {
|
||||
"current_requests": max(current["current_requests"] - 1, 0),
|
||||
"current_tpm": current["current_tpm"] + total_tokens,
|
||||
"current_rpm": current["current_rpm"],
|
||||
}
|
||||
|
||||
self.print_verbose(
|
||||
f"updated_value in success call: {new_val}, precise_minute: {precise_minute}"
|
||||
)
|
||||
values_to_update_in_cache.append((request_count_api_key, new_val))
|
||||
|
||||
# ------------
|
||||
# Update usage - Team
|
||||
# ------------
|
||||
if user_api_key_team_id is not None:
|
||||
total_tokens = 0
|
||||
|
||||
if isinstance(response_obj, ModelResponse):
|
||||
total_tokens = response_obj.usage.total_tokens # type: ignore
|
||||
|
||||
request_count_api_key = (
|
||||
f"{user_api_key_team_id}::{precise_minute}::request_count"
|
||||
)
|
||||
|
||||
current = await self.internal_usage_cache.async_get_cache(
|
||||
key=request_count_api_key,
|
||||
litellm_parent_otel_span=litellm_parent_otel_span,
|
||||
) or {
|
||||
"current_requests": 1,
|
||||
"current_tpm": total_tokens,
|
||||
"current_rpm": 1,
|
||||
}
|
||||
|
||||
new_val = {
|
||||
"current_requests": max(current["current_requests"] - 1, 0),
|
||||
"current_tpm": current["current_tpm"] + total_tokens,
|
||||
"current_rpm": current["current_rpm"],
|
||||
}
|
||||
|
||||
self.print_verbose(
|
||||
f"updated_value in success call: {new_val}, precise_minute: {precise_minute}"
|
||||
)
|
||||
values_to_update_in_cache.append((request_count_api_key, new_val))
|
||||
|
||||
# ------------
|
||||
# Update usage - End User
|
||||
# ------------
|
||||
if user_api_key_end_user_id is not None:
|
||||
total_tokens = 0
|
||||
|
||||
if isinstance(response_obj, ModelResponse):
|
||||
total_tokens = response_obj.usage.total_tokens # type: ignore
|
||||
|
||||
request_count_api_key = (
|
||||
f"{user_api_key_end_user_id}::{precise_minute}::request_count"
|
||||
)
|
||||
|
||||
current = await self.internal_usage_cache.async_get_cache(
|
||||
key=request_count_api_key,
|
||||
litellm_parent_otel_span=litellm_parent_otel_span,
|
||||
) or {
|
||||
"current_requests": 1,
|
||||
"current_tpm": total_tokens,
|
||||
"current_rpm": 1,
|
||||
}
|
||||
|
||||
new_val = {
|
||||
"current_requests": max(current["current_requests"] - 1, 0),
|
||||
"current_tpm": current["current_tpm"] + total_tokens,
|
||||
"current_rpm": current["current_rpm"],
|
||||
}
|
||||
|
||||
self.print_verbose(
|
||||
f"updated_value in success call: {new_val}, precise_minute: {precise_minute}"
|
||||
)
|
||||
values_to_update_in_cache.append((request_count_api_key, new_val))
|
||||
|
||||
await self.internal_usage_cache.async_batch_set_cache(
|
||||
cache_list=values_to_update_in_cache,
|
||||
ttl=60,
|
||||
litellm_parent_otel_span=litellm_parent_otel_span,
|
||||
)
|
||||
except Exception as e:
|
||||
self.print_verbose(e) # noqa
|
||||
|
||||
async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time):
|
||||
try:
|
||||
self.print_verbose("Inside Max Parallel Request Failure Hook")
|
||||
litellm_parent_otel_span: Union[
|
||||
Span, None
|
||||
] = _get_parent_otel_span_from_kwargs(kwargs=kwargs)
|
||||
_metadata = kwargs["litellm_params"].get("metadata", {}) or {}
|
||||
global_max_parallel_requests = _metadata.get(
|
||||
"global_max_parallel_requests", None
|
||||
)
|
||||
user_api_key = _metadata.get("user_api_key", None)
|
||||
self.print_verbose(f"user_api_key: {user_api_key}")
|
||||
if user_api_key is None:
|
||||
return
|
||||
|
||||
## decrement call count if call failed
|
||||
if CommonProxyErrors.max_parallel_request_limit_reached.value in str(
|
||||
kwargs["exception"]
|
||||
):
|
||||
pass # ignore failed calls due to max limit being reached
|
||||
else:
|
||||
# ------------
|
||||
# Setup values
|
||||
# ------------
|
||||
|
||||
if global_max_parallel_requests is not None:
|
||||
# get value from cache
|
||||
_key = "global_max_parallel_requests"
|
||||
(
|
||||
await self.internal_usage_cache.async_get_cache(
|
||||
key=_key,
|
||||
local_only=True,
|
||||
litellm_parent_otel_span=litellm_parent_otel_span,
|
||||
)
|
||||
)
|
||||
# decrement
|
||||
await self.internal_usage_cache.async_increment_cache(
|
||||
key=_key,
|
||||
value=-1,
|
||||
local_only=True,
|
||||
litellm_parent_otel_span=litellm_parent_otel_span,
|
||||
)
|
||||
|
||||
current_date = datetime.now().strftime("%Y-%m-%d")
|
||||
current_hour = datetime.now().strftime("%H")
|
||||
current_minute = datetime.now().strftime("%M")
|
||||
precise_minute = f"{current_date}-{current_hour}-{current_minute}"
|
||||
|
||||
request_count_api_key = (
|
||||
f"{user_api_key}::{precise_minute}::request_count"
|
||||
)
|
||||
|
||||
# ------------
|
||||
# Update usage
|
||||
# ------------
|
||||
current = await self.internal_usage_cache.async_get_cache(
|
||||
key=request_count_api_key,
|
||||
litellm_parent_otel_span=litellm_parent_otel_span,
|
||||
) or {
|
||||
"current_requests": 1,
|
||||
"current_tpm": 0,
|
||||
"current_rpm": 0,
|
||||
}
|
||||
|
||||
new_val = {
|
||||
"current_requests": max(current["current_requests"] - 1, 0),
|
||||
"current_tpm": current["current_tpm"],
|
||||
"current_rpm": current["current_rpm"],
|
||||
}
|
||||
|
||||
self.print_verbose(f"updated_value in failure call: {new_val}")
|
||||
await self.internal_usage_cache.async_set_cache(
|
||||
request_count_api_key,
|
||||
new_val,
|
||||
ttl=60,
|
||||
litellm_parent_otel_span=litellm_parent_otel_span,
|
||||
) # save in cache for up to 1 min.
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.exception(
|
||||
"Inside Parallel Request Limiter: An exception occurred - {}".format(
|
||||
str(e)
|
||||
)
|
||||
)
|
||||
|
||||
async def get_internal_user_object(
|
||||
self,
|
||||
user_id: str,
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
) -> Optional[dict]:
|
||||
"""
|
||||
Helper to get the 'Internal User Object'
|
||||
|
||||
It uses the `get_user_object` function from `litellm.proxy.auth.auth_checks`
|
||||
|
||||
We need this because the UserApiKeyAuth object does not contain the rpm/tpm limits for a User AND there could be a perf impact by additionally reading the UserTable.
|
||||
"""
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.proxy.auth.auth_checks import get_user_object
|
||||
from litellm.proxy.proxy_server import prisma_client
|
||||
|
||||
try:
|
||||
_user_id_rate_limits = await get_user_object(
|
||||
user_id=user_id,
|
||||
prisma_client=prisma_client,
|
||||
user_api_key_cache=self.internal_usage_cache.dual_cache,
|
||||
user_id_upsert=False,
|
||||
parent_otel_span=user_api_key_dict.parent_otel_span,
|
||||
proxy_logging_obj=None,
|
||||
)
|
||||
|
||||
if _user_id_rate_limits is None:
|
||||
return None
|
||||
|
||||
return _user_id_rate_limits.model_dump()
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.debug(
|
||||
"Parallel Request Limiter: Error getting user object", str(e)
|
||||
)
|
||||
return None
|
||||
|
||||
async def async_post_call_success_hook(
|
||||
self, data: dict, user_api_key_dict: UserAPIKeyAuth, response
|
||||
):
|
||||
"""
|
||||
Retrieve the key's remaining rate limits.
|
||||
"""
|
||||
api_key = user_api_key_dict.api_key
|
||||
current_date = datetime.now().strftime("%Y-%m-%d")
|
||||
current_hour = datetime.now().strftime("%H")
|
||||
current_minute = datetime.now().strftime("%M")
|
||||
precise_minute = f"{current_date}-{current_hour}-{current_minute}"
|
||||
request_count_api_key = f"{api_key}::{precise_minute}::request_count"
|
||||
current: Optional[
|
||||
CurrentItemRateLimit
|
||||
] = await self.internal_usage_cache.async_get_cache(
|
||||
key=request_count_api_key,
|
||||
litellm_parent_otel_span=user_api_key_dict.parent_otel_span,
|
||||
)
|
||||
|
||||
key_remaining_rpm_limit: Optional[int] = None
|
||||
key_rpm_limit: Optional[int] = None
|
||||
key_remaining_tpm_limit: Optional[int] = None
|
||||
key_tpm_limit: Optional[int] = None
|
||||
if current is not None:
|
||||
if user_api_key_dict.rpm_limit is not None:
|
||||
key_remaining_rpm_limit = (
|
||||
user_api_key_dict.rpm_limit - current["current_rpm"]
|
||||
)
|
||||
key_rpm_limit = user_api_key_dict.rpm_limit
|
||||
if user_api_key_dict.tpm_limit is not None:
|
||||
key_remaining_tpm_limit = (
|
||||
user_api_key_dict.tpm_limit - current["current_tpm"]
|
||||
)
|
||||
key_tpm_limit = user_api_key_dict.tpm_limit
|
||||
|
||||
if hasattr(response, "_hidden_params"):
|
||||
_hidden_params = getattr(response, "_hidden_params")
|
||||
else:
|
||||
_hidden_params = None
|
||||
if _hidden_params is not None and (
|
||||
isinstance(_hidden_params, BaseModel) or isinstance(_hidden_params, dict)
|
||||
):
|
||||
if isinstance(_hidden_params, BaseModel):
|
||||
_hidden_params = _hidden_params.model_dump()
|
||||
|
||||
_additional_headers = _hidden_params.get("additional_headers", {}) or {}
|
||||
|
||||
if key_remaining_rpm_limit is not None:
|
||||
_additional_headers[
|
||||
"x-ratelimit-remaining-requests"
|
||||
] = key_remaining_rpm_limit
|
||||
if key_rpm_limit is not None:
|
||||
_additional_headers["x-ratelimit-limit-requests"] = key_rpm_limit
|
||||
if key_remaining_tpm_limit is not None:
|
||||
_additional_headers[
|
||||
"x-ratelimit-remaining-tokens"
|
||||
] = key_remaining_tpm_limit
|
||||
if key_tpm_limit is not None:
|
||||
_additional_headers["x-ratelimit-limit-tokens"] = key_tpm_limit
|
||||
|
||||
setattr(
|
||||
response,
|
||||
"_hidden_params",
|
||||
{**_hidden_params, "additional_headers": _additional_headers},
|
||||
)
|
||||
|
||||
return await super().async_post_call_success_hook(
|
||||
data, user_api_key_dict, response
|
||||
)
|
||||
@@ -0,0 +1,282 @@
|
||||
# +------------------------------------+
|
||||
#
|
||||
# Prompt Injection Detection
|
||||
#
|
||||
# +------------------------------------+
|
||||
# Thank you users! We ❤️ you! - Krrish & Ishaan
|
||||
## Reject a call if it contains a prompt injection attack.
|
||||
|
||||
|
||||
from difflib import SequenceMatcher
|
||||
from typing import List, Literal, Optional
|
||||
|
||||
from fastapi import HTTPException
|
||||
|
||||
import litellm
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.caching.caching import DualCache
|
||||
from litellm.constants import DEFAULT_PROMPT_INJECTION_SIMILARITY_THRESHOLD
|
||||
from litellm.integrations.custom_logger import CustomLogger
|
||||
from litellm.litellm_core_utils.prompt_templates.factory import (
|
||||
prompt_injection_detection_default_pt,
|
||||
)
|
||||
from litellm.proxy._types import LiteLLMPromptInjectionParams, UserAPIKeyAuth
|
||||
from litellm.router import Router
|
||||
from litellm.utils import get_formatted_prompt
|
||||
|
||||
|
||||
class _OPTIONAL_PromptInjectionDetection(CustomLogger):
|
||||
# Class variables or attributes
|
||||
def __init__(
|
||||
self,
|
||||
prompt_injection_params: Optional[LiteLLMPromptInjectionParams] = None,
|
||||
):
|
||||
self.prompt_injection_params = prompt_injection_params
|
||||
self.llm_router: Optional[Router] = None
|
||||
|
||||
self.verbs = [
|
||||
"Ignore",
|
||||
"Disregard",
|
||||
"Skip",
|
||||
"Forget",
|
||||
"Neglect",
|
||||
"Overlook",
|
||||
"Omit",
|
||||
"Bypass",
|
||||
"Pay no attention to",
|
||||
"Do not follow",
|
||||
"Do not obey",
|
||||
]
|
||||
self.adjectives = [
|
||||
"",
|
||||
"prior",
|
||||
"previous",
|
||||
"preceding",
|
||||
"above",
|
||||
"foregoing",
|
||||
"earlier",
|
||||
"initial",
|
||||
]
|
||||
self.prepositions = [
|
||||
"",
|
||||
"and start over",
|
||||
"and start anew",
|
||||
"and begin afresh",
|
||||
"and start from scratch",
|
||||
]
|
||||
|
||||
def print_verbose(self, print_statement, level: Literal["INFO", "DEBUG"] = "DEBUG"):
|
||||
if level == "INFO":
|
||||
verbose_proxy_logger.info(print_statement)
|
||||
elif level == "DEBUG":
|
||||
verbose_proxy_logger.debug(print_statement)
|
||||
|
||||
if litellm.set_verbose is True:
|
||||
print(print_statement) # noqa
|
||||
|
||||
def update_environment(self, router: Optional[Router] = None):
|
||||
self.llm_router = router
|
||||
|
||||
if (
|
||||
self.prompt_injection_params is not None
|
||||
and self.prompt_injection_params.llm_api_check is True
|
||||
):
|
||||
if self.llm_router is None:
|
||||
raise Exception(
|
||||
"PromptInjectionDetection: Model List not set. Required for Prompt Injection detection."
|
||||
)
|
||||
|
||||
self.print_verbose(
|
||||
f"model_names: {self.llm_router.model_names}; self.prompt_injection_params.llm_api_name: {self.prompt_injection_params.llm_api_name}"
|
||||
)
|
||||
if (
|
||||
self.prompt_injection_params.llm_api_name is None
|
||||
or self.prompt_injection_params.llm_api_name
|
||||
not in self.llm_router.model_names
|
||||
):
|
||||
raise Exception(
|
||||
"PromptInjectionDetection: Invalid LLM API Name. LLM API Name must be a 'model_name' in 'model_list'."
|
||||
)
|
||||
|
||||
def generate_injection_keywords(self) -> List[str]:
|
||||
combinations = []
|
||||
for verb in self.verbs:
|
||||
for adj in self.adjectives:
|
||||
for prep in self.prepositions:
|
||||
phrase = " ".join(filter(None, [verb, adj, prep])).strip()
|
||||
if (
|
||||
len(phrase.split()) > 2
|
||||
): # additional check to ensure more than 2 words
|
||||
combinations.append(phrase.lower())
|
||||
return combinations
|
||||
|
||||
def check_user_input_similarity(
|
||||
self,
|
||||
user_input: str,
|
||||
similarity_threshold: float = DEFAULT_PROMPT_INJECTION_SIMILARITY_THRESHOLD,
|
||||
) -> bool:
|
||||
user_input_lower = user_input.lower()
|
||||
keywords = self.generate_injection_keywords()
|
||||
|
||||
for keyword in keywords:
|
||||
# Calculate the length of the keyword to extract substrings of the same length from user input
|
||||
keyword_length = len(keyword)
|
||||
|
||||
for i in range(len(user_input_lower) - keyword_length + 1):
|
||||
# Extract a substring of the same length as the keyword
|
||||
substring = user_input_lower[i : i + keyword_length]
|
||||
|
||||
# Calculate similarity
|
||||
match_ratio = SequenceMatcher(None, substring, keyword).ratio()
|
||||
if match_ratio > similarity_threshold:
|
||||
self.print_verbose(
|
||||
print_statement=f"Rejected user input - {user_input}. {match_ratio} similar to {keyword}",
|
||||
level="INFO",
|
||||
)
|
||||
return True # Found a highly similar substring
|
||||
return False # No substring crossed the threshold
|
||||
|
||||
async def async_pre_call_hook(
|
||||
self,
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
cache: DualCache,
|
||||
data: dict,
|
||||
call_type: str, # "completion", "embeddings", "image_generation", "moderation"
|
||||
):
|
||||
try:
|
||||
"""
|
||||
- check if user id part of call
|
||||
- check if user id part of blocked list
|
||||
"""
|
||||
self.print_verbose("Inside Prompt Injection Detection Pre-Call Hook")
|
||||
try:
|
||||
assert call_type in [
|
||||
"completion",
|
||||
"text_completion",
|
||||
"embeddings",
|
||||
"image_generation",
|
||||
"moderation",
|
||||
"audio_transcription",
|
||||
]
|
||||
except Exception:
|
||||
self.print_verbose(
|
||||
f"Call Type - {call_type}, not in accepted list - ['completion','embeddings','image_generation','moderation','audio_transcription']"
|
||||
)
|
||||
return data
|
||||
formatted_prompt = get_formatted_prompt(data=data, call_type=call_type) # type: ignore
|
||||
|
||||
is_prompt_attack = False
|
||||
|
||||
if self.prompt_injection_params is not None:
|
||||
# 1. check if heuristics check turned on
|
||||
if self.prompt_injection_params.heuristics_check is True:
|
||||
is_prompt_attack = self.check_user_input_similarity(
|
||||
user_input=formatted_prompt
|
||||
)
|
||||
if is_prompt_attack is True:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={
|
||||
"error": "Rejected message. This is a prompt injection attack."
|
||||
},
|
||||
)
|
||||
# 2. check if vector db similarity check turned on [TODO] Not Implemented yet
|
||||
if self.prompt_injection_params.vector_db_check is True:
|
||||
pass
|
||||
else:
|
||||
is_prompt_attack = self.check_user_input_similarity(
|
||||
user_input=formatted_prompt
|
||||
)
|
||||
|
||||
if is_prompt_attack is True:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={
|
||||
"error": "Rejected message. This is a prompt injection attack."
|
||||
},
|
||||
)
|
||||
|
||||
return data
|
||||
|
||||
except HTTPException as e:
|
||||
if (
|
||||
e.status_code == 400
|
||||
and isinstance(e.detail, dict)
|
||||
and "error" in e.detail # type: ignore
|
||||
and self.prompt_injection_params is not None
|
||||
and self.prompt_injection_params.reject_as_response
|
||||
):
|
||||
return e.detail.get("error")
|
||||
raise e
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.exception(
|
||||
"litellm.proxy.hooks.prompt_injection_detection.py::async_pre_call_hook(): Exception occured - {}".format(
|
||||
str(e)
|
||||
)
|
||||
)
|
||||
|
||||
async def async_moderation_hook( # type: ignore
|
||||
self,
|
||||
data: dict,
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
call_type: Literal[
|
||||
"completion",
|
||||
"embeddings",
|
||||
"image_generation",
|
||||
"moderation",
|
||||
"audio_transcription",
|
||||
],
|
||||
) -> Optional[bool]:
|
||||
self.print_verbose(
|
||||
f"IN ASYNC MODERATION HOOK - self.prompt_injection_params = {self.prompt_injection_params}"
|
||||
)
|
||||
|
||||
if self.prompt_injection_params is None:
|
||||
return None
|
||||
|
||||
formatted_prompt = get_formatted_prompt(data=data, call_type=call_type) # type: ignore
|
||||
is_prompt_attack = False
|
||||
|
||||
prompt_injection_system_prompt = getattr(
|
||||
self.prompt_injection_params,
|
||||
"llm_api_system_prompt",
|
||||
prompt_injection_detection_default_pt(),
|
||||
)
|
||||
|
||||
# 3. check if llm api check turned on
|
||||
if (
|
||||
self.prompt_injection_params.llm_api_check is True
|
||||
and self.prompt_injection_params.llm_api_name is not None
|
||||
and self.llm_router is not None
|
||||
):
|
||||
# make a call to the llm api
|
||||
response = await self.llm_router.acompletion(
|
||||
model=self.prompt_injection_params.llm_api_name,
|
||||
messages=[
|
||||
{
|
||||
"role": "system",
|
||||
"content": prompt_injection_system_prompt,
|
||||
},
|
||||
{"role": "user", "content": formatted_prompt},
|
||||
],
|
||||
)
|
||||
|
||||
self.print_verbose(f"Received LLM Moderation response: {response}")
|
||||
self.print_verbose(
|
||||
f"llm_api_fail_call_string: {self.prompt_injection_params.llm_api_fail_call_string}"
|
||||
)
|
||||
if isinstance(response, litellm.ModelResponse) and isinstance(
|
||||
response.choices[0], litellm.Choices
|
||||
):
|
||||
if self.prompt_injection_params.llm_api_fail_call_string in response.choices[0].message.content: # type: ignore
|
||||
is_prompt_attack = True
|
||||
|
||||
if is_prompt_attack is True:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={
|
||||
"error": "Rejected message. This is a prompt injection attack."
|
||||
},
|
||||
)
|
||||
|
||||
return is_prompt_attack
|
||||
@@ -0,0 +1,257 @@
|
||||
import asyncio
|
||||
import traceback
|
||||
from datetime import datetime
|
||||
from typing import Any, Optional, Union, cast
|
||||
|
||||
import litellm
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.integrations.custom_logger import CustomLogger
|
||||
from litellm.litellm_core_utils.core_helpers import (
|
||||
_get_parent_otel_span_from_kwargs,
|
||||
get_litellm_metadata_from_kwargs,
|
||||
)
|
||||
from litellm.litellm_core_utils.litellm_logging import StandardLoggingPayloadSetup
|
||||
from litellm.proxy._types import UserAPIKeyAuth
|
||||
from litellm.proxy.auth.auth_checks import log_db_metrics
|
||||
from litellm.proxy.auth.route_checks import RouteChecks
|
||||
from litellm.proxy.utils import ProxyUpdateSpend
|
||||
from litellm.types.utils import (
|
||||
StandardLoggingPayload,
|
||||
StandardLoggingUserAPIKeyMetadata,
|
||||
)
|
||||
from litellm.utils import get_end_user_id_for_cost_tracking
|
||||
|
||||
|
||||
class _ProxyDBLogger(CustomLogger):
|
||||
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
|
||||
await self._PROXY_track_cost_callback(
|
||||
kwargs, response_obj, start_time, end_time
|
||||
)
|
||||
|
||||
async def async_post_call_failure_hook(
|
||||
self,
|
||||
request_data: dict,
|
||||
original_exception: Exception,
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
):
|
||||
request_route = user_api_key_dict.request_route
|
||||
if _ProxyDBLogger._should_track_errors_in_db() is False:
|
||||
return
|
||||
elif request_route is not None and not RouteChecks.is_llm_api_route(
|
||||
route=request_route
|
||||
):
|
||||
return
|
||||
|
||||
from litellm.proxy.proxy_server import proxy_logging_obj
|
||||
|
||||
_metadata = dict(
|
||||
StandardLoggingUserAPIKeyMetadata(
|
||||
user_api_key_hash=user_api_key_dict.api_key,
|
||||
user_api_key_alias=user_api_key_dict.key_alias,
|
||||
user_api_key_user_email=user_api_key_dict.user_email,
|
||||
user_api_key_user_id=user_api_key_dict.user_id,
|
||||
user_api_key_team_id=user_api_key_dict.team_id,
|
||||
user_api_key_org_id=user_api_key_dict.org_id,
|
||||
user_api_key_team_alias=user_api_key_dict.team_alias,
|
||||
user_api_key_end_user_id=user_api_key_dict.end_user_id,
|
||||
)
|
||||
)
|
||||
_metadata["user_api_key"] = user_api_key_dict.api_key
|
||||
_metadata["status"] = "failure"
|
||||
_metadata[
|
||||
"error_information"
|
||||
] = StandardLoggingPayloadSetup.get_error_information(
|
||||
original_exception=original_exception,
|
||||
)
|
||||
|
||||
existing_metadata: dict = request_data.get("metadata", None) or {}
|
||||
existing_metadata.update(_metadata)
|
||||
|
||||
if "litellm_params" not in request_data:
|
||||
request_data["litellm_params"] = {}
|
||||
request_data["litellm_params"]["proxy_server_request"] = (
|
||||
request_data.get("proxy_server_request") or {}
|
||||
)
|
||||
request_data["litellm_params"]["metadata"] = existing_metadata
|
||||
await proxy_logging_obj.db_spend_update_writer.update_database(
|
||||
token=user_api_key_dict.api_key,
|
||||
response_cost=0.0,
|
||||
user_id=user_api_key_dict.user_id,
|
||||
end_user_id=user_api_key_dict.end_user_id,
|
||||
team_id=user_api_key_dict.team_id,
|
||||
kwargs=request_data,
|
||||
completion_response=original_exception,
|
||||
start_time=datetime.now(),
|
||||
end_time=datetime.now(),
|
||||
org_id=user_api_key_dict.org_id,
|
||||
)
|
||||
|
||||
@log_db_metrics
|
||||
async def _PROXY_track_cost_callback(
|
||||
self,
|
||||
kwargs, # kwargs to completion
|
||||
completion_response: Optional[
|
||||
Union[litellm.ModelResponse, Any]
|
||||
], # response from completion
|
||||
start_time=None,
|
||||
end_time=None, # start/end time for completion
|
||||
):
|
||||
from litellm.proxy.proxy_server import (
|
||||
prisma_client,
|
||||
proxy_logging_obj,
|
||||
update_cache,
|
||||
)
|
||||
|
||||
verbose_proxy_logger.debug("INSIDE _PROXY_track_cost_callback")
|
||||
try:
|
||||
verbose_proxy_logger.debug(
|
||||
f"kwargs stream: {kwargs.get('stream', None)} + complete streaming response: {kwargs.get('complete_streaming_response', None)}"
|
||||
)
|
||||
parent_otel_span = _get_parent_otel_span_from_kwargs(kwargs=kwargs)
|
||||
litellm_params = kwargs.get("litellm_params", {}) or {}
|
||||
end_user_id = get_end_user_id_for_cost_tracking(litellm_params)
|
||||
metadata = get_litellm_metadata_from_kwargs(kwargs=kwargs)
|
||||
user_id = cast(Optional[str], metadata.get("user_api_key_user_id", None))
|
||||
team_id = cast(Optional[str], metadata.get("user_api_key_team_id", None))
|
||||
org_id = cast(Optional[str], metadata.get("user_api_key_org_id", None))
|
||||
key_alias = cast(Optional[str], metadata.get("user_api_key_alias", None))
|
||||
end_user_max_budget = metadata.get("user_api_end_user_max_budget", None)
|
||||
sl_object: Optional[StandardLoggingPayload] = kwargs.get(
|
||||
"standard_logging_object", None
|
||||
)
|
||||
response_cost = (
|
||||
sl_object.get("response_cost", None)
|
||||
if sl_object is not None
|
||||
else kwargs.get("response_cost", None)
|
||||
)
|
||||
|
||||
if response_cost is not None:
|
||||
user_api_key = metadata.get("user_api_key", None)
|
||||
if kwargs.get("cache_hit", False) is True:
|
||||
response_cost = 0.0
|
||||
verbose_proxy_logger.info(
|
||||
f"Cache Hit: response_cost {response_cost}, for user_id {user_id}"
|
||||
)
|
||||
|
||||
verbose_proxy_logger.debug(
|
||||
f"user_api_key {user_api_key}, prisma_client: {prisma_client}"
|
||||
)
|
||||
if _should_track_cost_callback(
|
||||
user_api_key=user_api_key,
|
||||
user_id=user_id,
|
||||
team_id=team_id,
|
||||
end_user_id=end_user_id,
|
||||
):
|
||||
## UPDATE DATABASE
|
||||
await proxy_logging_obj.db_spend_update_writer.update_database(
|
||||
token=user_api_key,
|
||||
response_cost=response_cost,
|
||||
user_id=user_id,
|
||||
end_user_id=end_user_id,
|
||||
team_id=team_id,
|
||||
kwargs=kwargs,
|
||||
completion_response=completion_response,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
org_id=org_id,
|
||||
)
|
||||
|
||||
# update cache
|
||||
asyncio.create_task(
|
||||
update_cache(
|
||||
token=user_api_key,
|
||||
user_id=user_id,
|
||||
end_user_id=end_user_id,
|
||||
response_cost=response_cost,
|
||||
team_id=team_id,
|
||||
parent_otel_span=parent_otel_span,
|
||||
)
|
||||
)
|
||||
|
||||
await proxy_logging_obj.slack_alerting_instance.customer_spend_alert(
|
||||
token=user_api_key,
|
||||
key_alias=key_alias,
|
||||
end_user_id=end_user_id,
|
||||
response_cost=response_cost,
|
||||
max_budget=end_user_max_budget,
|
||||
)
|
||||
else:
|
||||
raise Exception(
|
||||
"User API key and team id and user id missing from custom callback."
|
||||
)
|
||||
else:
|
||||
if kwargs["stream"] is not True or (
|
||||
kwargs["stream"] is True and "complete_streaming_response" in kwargs
|
||||
):
|
||||
if sl_object is not None:
|
||||
cost_tracking_failure_debug_info: Union[dict, str] = (
|
||||
sl_object["response_cost_failure_debug_info"] # type: ignore
|
||||
or "response_cost_failure_debug_info is None in standard_logging_object"
|
||||
)
|
||||
else:
|
||||
cost_tracking_failure_debug_info = (
|
||||
"standard_logging_object not found"
|
||||
)
|
||||
model = kwargs.get("model")
|
||||
raise Exception(
|
||||
f"Cost tracking failed for model={model}.\nDebug info - {cost_tracking_failure_debug_info}\nAdd custom pricing - https://docs.litellm.ai/docs/proxy/custom_pricing"
|
||||
)
|
||||
except Exception as e:
|
||||
error_msg = f"Error in tracking cost callback - {str(e)}\n Traceback:{traceback.format_exc()}"
|
||||
model = kwargs.get("model", "")
|
||||
metadata = get_litellm_metadata_from_kwargs(kwargs=kwargs)
|
||||
litellm_metadata = kwargs.get("litellm_params", {}).get(
|
||||
"litellm_metadata", {}
|
||||
)
|
||||
old_metadata = kwargs.get("litellm_params", {}).get("metadata", {})
|
||||
call_type = kwargs.get("call_type", "")
|
||||
error_msg += f"\n Args to _PROXY_track_cost_callback\n model: {model}\n chosen_metadata: {metadata}\n litellm_metadata: {litellm_metadata}\n old_metadata: {old_metadata}\n call_type: {call_type}\n"
|
||||
asyncio.create_task(
|
||||
proxy_logging_obj.failed_tracking_alert(
|
||||
error_message=error_msg,
|
||||
failing_model=model,
|
||||
)
|
||||
)
|
||||
|
||||
verbose_proxy_logger.exception(
|
||||
"Error in tracking cost callback - %s", str(e)
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _should_track_errors_in_db():
|
||||
"""
|
||||
Returns True if errors should be tracked in the database
|
||||
|
||||
By default, errors are tracked in the database
|
||||
|
||||
If users want to disable error tracking, they can set the disable_error_logs flag in the general_settings
|
||||
"""
|
||||
from litellm.proxy.proxy_server import general_settings
|
||||
|
||||
if general_settings.get("disable_error_logs") is True:
|
||||
return False
|
||||
return
|
||||
|
||||
|
||||
def _should_track_cost_callback(
|
||||
user_api_key: Optional[str],
|
||||
user_id: Optional[str],
|
||||
team_id: Optional[str],
|
||||
end_user_id: Optional[str],
|
||||
) -> bool:
|
||||
"""
|
||||
Determine if the cost callback should be tracked based on the kwargs
|
||||
"""
|
||||
|
||||
# don't run track cost callback if user opted into disabling spend
|
||||
if ProxyUpdateSpend.disable_spend_updates() is True:
|
||||
return False
|
||||
|
||||
if (
|
||||
user_api_key is not None
|
||||
or user_id is not None
|
||||
or team_id is not None
|
||||
or end_user_id is not None
|
||||
):
|
||||
return True
|
||||
return False
|
||||
Reference in New Issue
Block a user