structure saas with tools

This commit is contained in:
Davidson Gomes
2025-04-25 15:30:54 -03:00
commit 1aef473937
16434 changed files with 6584257 additions and 0 deletions

View File

@@ -0,0 +1,36 @@
from typing import Literal, Union
from . import *
from .cache_control_check import _PROXY_CacheControlCheck
from .managed_files import _PROXY_LiteLLMManagedFiles
from .max_budget_limiter import _PROXY_MaxBudgetLimiter
from .parallel_request_limiter import _PROXY_MaxParallelRequestsHandler
# List of all available hooks that can be enabled
PROXY_HOOKS = {
"max_budget_limiter": _PROXY_MaxBudgetLimiter,
"managed_files": _PROXY_LiteLLMManagedFiles,
"parallel_request_limiter": _PROXY_MaxParallelRequestsHandler,
"cache_control_check": _PROXY_CacheControlCheck,
}
def get_proxy_hook(
hook_name: Union[
Literal[
"max_budget_limiter",
"managed_files",
"parallel_request_limiter",
"cache_control_check",
],
str,
]
):
"""
Factory method to get a proxy hook instance by name
"""
if hook_name not in PROXY_HOOKS:
raise ValueError(
f"Unknown hook: {hook_name}. Available hooks: {list(PROXY_HOOKS.keys())}"
)
return PROXY_HOOKS[hook_name]

View File

@@ -0,0 +1,156 @@
import traceback
from typing import Optional
from fastapi import HTTPException
import litellm
from litellm._logging import verbose_proxy_logger
from litellm.caching.caching import DualCache
from litellm.integrations.custom_logger import CustomLogger
from litellm.proxy._types import UserAPIKeyAuth
class _PROXY_AzureContentSafety(
CustomLogger
): # https://docs.litellm.ai/docs/observability/custom_callback#callback-class
# Class variables or attributes
def __init__(self, endpoint, api_key, thresholds=None):
try:
from azure.ai.contentsafety.aio import ContentSafetyClient
from azure.ai.contentsafety.models import (
AnalyzeTextOptions,
AnalyzeTextOutputType,
TextCategory,
)
from azure.core.credentials import AzureKeyCredential
from azure.core.exceptions import HttpResponseError
except Exception as e:
raise Exception(
f"\033[91mAzure Content-Safety not installed, try running 'pip install azure-ai-contentsafety' to fix this error: {e}\n{traceback.format_exc()}\033[0m"
)
self.endpoint = endpoint
self.api_key = api_key
self.text_category = TextCategory
self.analyze_text_options = AnalyzeTextOptions
self.analyze_text_output_type = AnalyzeTextOutputType
self.azure_http_error = HttpResponseError
self.thresholds = self._configure_thresholds(thresholds)
self.client = ContentSafetyClient(
self.endpoint, AzureKeyCredential(self.api_key)
)
def _configure_thresholds(self, thresholds=None):
default_thresholds = {
self.text_category.HATE: 4,
self.text_category.SELF_HARM: 4,
self.text_category.SEXUAL: 4,
self.text_category.VIOLENCE: 4,
}
if thresholds is None:
return default_thresholds
for key, default in default_thresholds.items():
if key not in thresholds:
thresholds[key] = default
return thresholds
def _compute_result(self, response):
result = {}
category_severity = {
item.category: item.severity for item in response.categories_analysis
}
for category in self.text_category:
severity = category_severity.get(category)
if severity is not None:
result[category] = {
"filtered": severity >= self.thresholds[category],
"severity": severity,
}
return result
async def test_violation(self, content: str, source: Optional[str] = None):
verbose_proxy_logger.debug("Testing Azure Content-Safety for: %s", content)
# Construct a request
request = self.analyze_text_options(
text=content,
output_type=self.analyze_text_output_type.EIGHT_SEVERITY_LEVELS,
)
# Analyze text
try:
response = await self.client.analyze_text(request)
except self.azure_http_error:
verbose_proxy_logger.debug(
"Error in Azure Content-Safety: %s", traceback.format_exc()
)
verbose_proxy_logger.debug(traceback.format_exc())
raise
result = self._compute_result(response)
verbose_proxy_logger.debug("Azure Content-Safety Result: %s", result)
for key, value in result.items():
if value["filtered"]:
raise HTTPException(
status_code=400,
detail={
"error": "Violated content safety policy",
"source": source,
"category": key,
"severity": value["severity"],
},
)
async def async_pre_call_hook(
self,
user_api_key_dict: UserAPIKeyAuth,
cache: DualCache,
data: dict,
call_type: str, # "completion", "embeddings", "image_generation", "moderation"
):
verbose_proxy_logger.debug("Inside Azure Content-Safety Pre-Call Hook")
try:
if call_type == "completion" and "messages" in data:
for m in data["messages"]:
if "content" in m and isinstance(m["content"], str):
await self.test_violation(content=m["content"], source="input")
except HTTPException as e:
raise e
except Exception as e:
verbose_proxy_logger.error(
"litellm.proxy.hooks.azure_content_safety.py::async_pre_call_hook(): Exception occured - {}".format(
str(e)
)
)
verbose_proxy_logger.debug(traceback.format_exc())
async def async_post_call_success_hook(
self,
data: dict,
user_api_key_dict: UserAPIKeyAuth,
response,
):
verbose_proxy_logger.debug("Inside Azure Content-Safety Post-Call Hook")
if isinstance(response, litellm.ModelResponse) and isinstance(
response.choices[0], litellm.utils.Choices
):
await self.test_violation(
content=response.choices[0].message.content or "", source="output"
)
# async def async_post_call_streaming_hook(
# self,
# user_api_key_dict: UserAPIKeyAuth,
# response: str,
# ):
# verbose_proxy_logger.debug("Inside Azure Content-Safety Call-Stream Hook")
# await self.test_violation(content=response, source="output")

View File

@@ -0,0 +1,149 @@
# What this does?
## Gets a key's redis cache, and store it in memory for 1 minute.
## This reduces the number of REDIS GET requests made during high-traffic by the proxy.
### [BETA] this is in Beta. And might change.
import traceback
from typing import Literal, Optional
from fastapi import HTTPException
import litellm
from litellm._logging import verbose_proxy_logger
from litellm.caching.caching import DualCache, InMemoryCache, RedisCache
from litellm.integrations.custom_logger import CustomLogger
from litellm.proxy._types import UserAPIKeyAuth
class _PROXY_BatchRedisRequests(CustomLogger):
# Class variables or attributes
in_memory_cache: Optional[InMemoryCache] = None
def __init__(self):
if litellm.cache is not None:
litellm.cache.async_get_cache = (
self.async_get_cache
) # map the litellm 'get_cache' function to our custom function
def print_verbose(
self, print_statement, debug_level: Literal["INFO", "DEBUG"] = "DEBUG"
):
if debug_level == "DEBUG":
verbose_proxy_logger.debug(print_statement)
elif debug_level == "INFO":
verbose_proxy_logger.debug(print_statement)
if litellm.set_verbose is True:
print(print_statement) # noqa
async def async_pre_call_hook(
self,
user_api_key_dict: UserAPIKeyAuth,
cache: DualCache,
data: dict,
call_type: str,
):
try:
"""
Get the user key
Check if a key starting with `litellm:<api_key>:<call_type:` exists in-memory
If no, then get relevant cache from redis
"""
api_key = user_api_key_dict.api_key
cache_key_name = f"litellm:{api_key}:{call_type}"
self.in_memory_cache = cache.in_memory_cache
key_value_dict = {}
in_memory_cache_exists = False
for key in cache.in_memory_cache.cache_dict.keys():
if isinstance(key, str) and key.startswith(cache_key_name):
in_memory_cache_exists = True
if in_memory_cache_exists is False and litellm.cache is not None:
"""
- Check if `litellm.Cache` is redis
- Get the relevant values
"""
if litellm.cache.type is not None and isinstance(
litellm.cache.cache, RedisCache
):
# Initialize an empty list to store the keys
keys = []
self.print_verbose(f"cache_key_name: {cache_key_name}")
# Use the SCAN iterator to fetch keys matching the pattern
keys = await litellm.cache.cache.async_scan_iter(
pattern=cache_key_name, count=100
)
# If you need the truly "last" based on time or another criteria,
# ensure your key naming or storage strategy allows this determination
# Here you would sort or filter the keys as needed based on your strategy
self.print_verbose(f"redis keys: {keys}")
if len(keys) > 0:
key_value_dict = (
await litellm.cache.cache.async_batch_get_cache(
key_list=keys
)
)
## Add to cache
if len(key_value_dict.items()) > 0:
await cache.in_memory_cache.async_set_cache_pipeline(
cache_list=list(key_value_dict.items()), ttl=60
)
## Set cache namespace if it's a miss
data["metadata"]["redis_namespace"] = cache_key_name
except HTTPException as e:
raise e
except Exception as e:
verbose_proxy_logger.error(
"litellm.proxy.hooks.batch_redis_get.py::async_pre_call_hook(): Exception occured - {}".format(
str(e)
)
)
verbose_proxy_logger.debug(traceback.format_exc())
async def async_get_cache(self, *args, **kwargs):
"""
- Check if the cache key is in-memory
- Else:
- add missing cache key from REDIS
- update in-memory cache
- return redis cache request
"""
try: # never block execution
cache_key: Optional[str] = None
if "cache_key" in kwargs:
cache_key = kwargs["cache_key"]
elif litellm.cache is not None:
cache_key = litellm.cache.get_cache_key(
*args, **kwargs
) # returns "<cache_key_name>:<hash>" - we pass redis_namespace in async_pre_call_hook. Done to avoid rewriting the async_set_cache logic
if (
cache_key is not None
and self.in_memory_cache is not None
and litellm.cache is not None
):
cache_control_args = kwargs.get("cache", {})
max_age = cache_control_args.get(
"s-max-age", cache_control_args.get("s-maxage", float("inf"))
)
cached_result = self.in_memory_cache.get_cache(
cache_key, *args, **kwargs
)
if cached_result is None:
cached_result = await litellm.cache.cache.async_get_cache(
cache_key, *args, **kwargs
)
if cached_result is not None:
await self.in_memory_cache.async_set_cache(
cache_key, cached_result, ttl=60
)
return litellm.cache._get_cache_logic(
cached_result=cached_result, max_age=max_age
)
except Exception:
return None

View File

@@ -0,0 +1,58 @@
# What this does?
## Checks if key is allowed to use the cache controls passed in to the completion() call
from fastapi import HTTPException
from litellm import verbose_logger
from litellm._logging import verbose_proxy_logger
from litellm.caching.caching import DualCache
from litellm.integrations.custom_logger import CustomLogger
from litellm.proxy._types import UserAPIKeyAuth
class _PROXY_CacheControlCheck(CustomLogger):
# Class variables or attributes
def __init__(self):
pass
async def async_pre_call_hook(
self,
user_api_key_dict: UserAPIKeyAuth,
cache: DualCache,
data: dict,
call_type: str,
):
try:
verbose_proxy_logger.debug("Inside Cache Control Check Pre-Call Hook")
allowed_cache_controls = user_api_key_dict.allowed_cache_controls
if data.get("cache", None) is None:
return
cache_args = data.get("cache", None)
if isinstance(cache_args, dict):
for k, v in cache_args.items():
if (
(allowed_cache_controls is not None)
and (isinstance(allowed_cache_controls, list))
and (
len(allowed_cache_controls) > 0
) # assume empty list to be nullable - https://github.com/prisma/prisma/issues/847#issuecomment-546895663
and k not in allowed_cache_controls
):
raise HTTPException(
status_code=403,
detail=f"Not allowed to set {k} as a cache control. Contact admin to change permissions.",
)
else: # invalid cache
return
except HTTPException as e:
raise e
except Exception as e:
verbose_logger.exception(
"litellm.proxy.hooks.cache_control_check.py::async_pre_call_hook(): Exception occured - {}".format(
str(e)
)
)

View File

@@ -0,0 +1,306 @@
# What is this?
## Allocates dynamic tpm/rpm quota for a project based on current traffic
## Tracks num active projects per minute
import asyncio
import os
from typing import List, Literal, Optional, Tuple, Union
from fastapi import HTTPException
import litellm
from litellm import ModelResponse, Router
from litellm._logging import verbose_proxy_logger
from litellm.caching.caching import DualCache
from litellm.integrations.custom_logger import CustomLogger
from litellm.proxy._types import UserAPIKeyAuth
from litellm.types.router import ModelGroupInfo
from litellm.utils import get_utc_datetime
class DynamicRateLimiterCache:
"""
Thin wrapper on DualCache for this file.
Track number of active projects calling a model.
"""
def __init__(self, cache: DualCache) -> None:
self.cache = cache
self.ttl = 60 # 1 min ttl
async def async_get_cache(self, model: str) -> Optional[int]:
dt = get_utc_datetime()
current_minute = dt.strftime("%H-%M")
key_name = "{}:{}".format(current_minute, model)
_response = await self.cache.async_get_cache(key=key_name)
response: Optional[int] = None
if _response is not None:
response = len(_response)
return response
async def async_set_cache_sadd(self, model: str, value: List):
"""
Add value to set.
Parameters:
- model: str, the name of the model group
- value: str, the team id
Returns:
- None
Raises:
- Exception, if unable to connect to cache client (if redis caching enabled)
"""
try:
dt = get_utc_datetime()
current_minute = dt.strftime("%H-%M")
key_name = "{}:{}".format(current_minute, model)
await self.cache.async_set_cache_sadd(
key=key_name, value=value, ttl=self.ttl
)
except Exception as e:
verbose_proxy_logger.exception(
"litellm.proxy.hooks.dynamic_rate_limiter.py::async_set_cache_sadd(): Exception occured - {}".format(
str(e)
)
)
raise e
class _PROXY_DynamicRateLimitHandler(CustomLogger):
# Class variables or attributes
def __init__(self, internal_usage_cache: DualCache):
self.internal_usage_cache = DynamicRateLimiterCache(cache=internal_usage_cache)
def update_variables(self, llm_router: Router):
self.llm_router = llm_router
async def check_available_usage(
self, model: str, priority: Optional[str] = None
) -> Tuple[
Optional[int], Optional[int], Optional[int], Optional[int], Optional[int]
]:
"""
For a given model, get its available tpm
Params:
- model: str, the name of the model in the router model_list
- priority: Optional[str], the priority for the request.
Returns
- Tuple[available_tpm, available_tpm, model_tpm, model_rpm, active_projects]
- available_tpm: int or null - always 0 or positive.
- available_tpm: int or null - always 0 or positive.
- remaining_model_tpm: int or null. If available tpm is int, then this will be too.
- remaining_model_rpm: int or null. If available rpm is int, then this will be too.
- active_projects: int or null
"""
try:
weight: float = 1
if (
litellm.priority_reservation is None
or priority not in litellm.priority_reservation
):
verbose_proxy_logger.error(
"Priority Reservation not set. priority={}, but litellm.priority_reservation is {}.".format(
priority, litellm.priority_reservation
)
)
elif priority is not None and litellm.priority_reservation is not None:
if os.getenv("LITELLM_LICENSE", None) is None:
verbose_proxy_logger.error(
"PREMIUM FEATURE: Reserving tpm/rpm by priority is a premium feature. Please add a 'LITELLM_LICENSE' to your .env to enable this.\nGet a license: https://docs.litellm.ai/docs/proxy/enterprise."
)
else:
weight = litellm.priority_reservation[priority]
active_projects = await self.internal_usage_cache.async_get_cache(
model=model
)
(
current_model_tpm,
current_model_rpm,
) = await self.llm_router.get_model_group_usage(model_group=model)
model_group_info: Optional[
ModelGroupInfo
] = self.llm_router.get_model_group_info(model_group=model)
total_model_tpm: Optional[int] = None
total_model_rpm: Optional[int] = None
if model_group_info is not None:
if model_group_info.tpm is not None:
total_model_tpm = model_group_info.tpm
if model_group_info.rpm is not None:
total_model_rpm = model_group_info.rpm
remaining_model_tpm: Optional[int] = None
if total_model_tpm is not None and current_model_tpm is not None:
remaining_model_tpm = total_model_tpm - current_model_tpm
elif total_model_tpm is not None:
remaining_model_tpm = total_model_tpm
remaining_model_rpm: Optional[int] = None
if total_model_rpm is not None and current_model_rpm is not None:
remaining_model_rpm = total_model_rpm - current_model_rpm
elif total_model_rpm is not None:
remaining_model_rpm = total_model_rpm
available_tpm: Optional[int] = None
if remaining_model_tpm is not None:
if active_projects is not None:
available_tpm = int(remaining_model_tpm * weight / active_projects)
else:
available_tpm = int(remaining_model_tpm * weight)
if available_tpm is not None and available_tpm < 0:
available_tpm = 0
available_rpm: Optional[int] = None
if remaining_model_rpm is not None:
if active_projects is not None:
available_rpm = int(remaining_model_rpm * weight / active_projects)
else:
available_rpm = int(remaining_model_rpm * weight)
if available_rpm is not None and available_rpm < 0:
available_rpm = 0
return (
available_tpm,
available_rpm,
remaining_model_tpm,
remaining_model_rpm,
active_projects,
)
except Exception as e:
verbose_proxy_logger.exception(
"litellm.proxy.hooks.dynamic_rate_limiter.py::check_available_usage: Exception occurred - {}".format(
str(e)
)
)
return None, None, None, None, None
async def async_pre_call_hook(
self,
user_api_key_dict: UserAPIKeyAuth,
cache: DualCache,
data: dict,
call_type: Literal[
"completion",
"text_completion",
"embeddings",
"image_generation",
"moderation",
"audio_transcription",
"pass_through_endpoint",
"rerank",
],
) -> Optional[
Union[Exception, str, dict]
]: # raise exception if invalid, return a str for the user to receive - if rejected, or return a modified dictionary for passing into litellm
"""
- For a model group
- Check if tpm/rpm available
- Raise RateLimitError if no tpm/rpm available
"""
if "model" in data:
key_priority: Optional[str] = user_api_key_dict.metadata.get(
"priority", None
)
(
available_tpm,
available_rpm,
model_tpm,
model_rpm,
active_projects,
) = await self.check_available_usage(
model=data["model"], priority=key_priority
)
### CHECK TPM ###
if available_tpm is not None and available_tpm == 0:
raise HTTPException(
status_code=429,
detail={
"error": "Key={} over available TPM={}. Model TPM={}, Active keys={}".format(
user_api_key_dict.api_key,
available_tpm,
model_tpm,
active_projects,
)
},
)
### CHECK RPM ###
elif available_rpm is not None and available_rpm == 0:
raise HTTPException(
status_code=429,
detail={
"error": "Key={} over available RPM={}. Model RPM={}, Active keys={}".format(
user_api_key_dict.api_key,
available_rpm,
model_rpm,
active_projects,
)
},
)
elif available_rpm is not None or available_tpm is not None:
## UPDATE CACHE WITH ACTIVE PROJECT
asyncio.create_task(
self.internal_usage_cache.async_set_cache_sadd( # this is a set
model=data["model"], # type: ignore
value=[user_api_key_dict.token or "default_key"],
)
)
return None
async def async_post_call_success_hook(
self, data: dict, user_api_key_dict: UserAPIKeyAuth, response
):
try:
if isinstance(response, ModelResponse):
model_info = self.llm_router.get_model_info(
id=response._hidden_params["model_id"]
)
assert (
model_info is not None
), "Model info for model with id={} is None".format(
response._hidden_params["model_id"]
)
key_priority: Optional[str] = user_api_key_dict.metadata.get(
"priority", None
)
(
available_tpm,
available_rpm,
model_tpm,
model_rpm,
active_projects,
) = await self.check_available_usage(
model=model_info["model_name"], priority=key_priority
)
response._hidden_params[
"additional_headers"
] = { # Add additional response headers - easier debugging
"x-litellm-model_group": model_info["model_name"],
"x-ratelimit-remaining-litellm-project-tokens": available_tpm,
"x-ratelimit-remaining-litellm-project-requests": available_rpm,
"x-ratelimit-remaining-model-tokens": model_tpm,
"x-ratelimit-remaining-model-requests": model_rpm,
"x-ratelimit-current-active-projects": active_projects,
}
return response
return await super().async_post_call_success_hook(
data=data,
user_api_key_dict=user_api_key_dict,
response=response,
)
except Exception as e:
verbose_proxy_logger.exception(
"litellm.proxy.hooks.dynamic_rate_limiter.py::async_post_call_success_hook(): Exception occured - {}".format(
str(e)
)
)
return response

View File

@@ -0,0 +1,28 @@
[
{
"name": "Zip code Recognizer",
"supported_language": "en",
"patterns": [
{
"name": "zip code (weak)",
"regex": "(\\b\\d{5}(?:\\-\\d{4})?\\b)",
"score": 0.01
}
],
"context": ["zip", "code"],
"supported_entity": "ZIP"
},
{
"name": "Swiss AHV Number Recognizer",
"supported_language": "en",
"patterns": [
{
"name": "AHV number (strong)",
"regex": "(756\\.\\d{4}\\.\\d{4}\\.\\d{2})|(756\\d{10})",
"score": 0.95
}
],
"context": ["AHV", "social security", "Swiss"],
"supported_entity": "AHV_NUMBER"
}
]

View File

@@ -0,0 +1,323 @@
import asyncio
import json
import uuid
from datetime import datetime, timezone
from typing import Any, List, Optional
from fastapi import status
import litellm
from litellm._logging import verbose_proxy_logger
from litellm.proxy._types import (
GenerateKeyRequest,
GenerateKeyResponse,
KeyRequest,
LiteLLM_AuditLogs,
LiteLLM_VerificationToken,
LitellmTableNames,
ProxyErrorTypes,
ProxyException,
RegenerateKeyRequest,
UpdateKeyRequest,
UserAPIKeyAuth,
WebhookEvent,
)
# NOTE: This is the prefix for all virtual keys stored in AWS Secrets Manager
LITELLM_PREFIX_STORED_VIRTUAL_KEYS = "litellm/"
class KeyManagementEventHooks:
@staticmethod
async def async_key_generated_hook(
data: GenerateKeyRequest,
response: GenerateKeyResponse,
user_api_key_dict: UserAPIKeyAuth,
litellm_changed_by: Optional[str] = None,
):
"""
Hook that runs after a successful /key/generate request
Handles the following:
- Sending Email with Key Details
- Storing Audit Logs for key generation
- Storing Generated Key in DB
"""
from litellm.proxy.management_helpers.audit_logs import (
create_audit_log_for_update,
)
from litellm.proxy.proxy_server import litellm_proxy_admin_name
if data.send_invite_email is True:
await KeyManagementEventHooks._send_key_created_email(
response.model_dump(exclude_none=True)
)
# Enterprise Feature - Audit Logging. Enable with litellm.store_audit_logs = True
if litellm.store_audit_logs is True:
_updated_values = response.model_dump_json(exclude_none=True)
asyncio.create_task(
create_audit_log_for_update(
request_data=LiteLLM_AuditLogs(
id=str(uuid.uuid4()),
updated_at=datetime.now(timezone.utc),
changed_by=litellm_changed_by
or user_api_key_dict.user_id
or litellm_proxy_admin_name,
changed_by_api_key=user_api_key_dict.api_key,
table_name=LitellmTableNames.KEY_TABLE_NAME,
object_id=response.token_id or "",
action="created",
updated_values=_updated_values,
before_value=None,
)
)
)
# store the generated key in the secret manager
await KeyManagementEventHooks._store_virtual_key_in_secret_manager(
secret_name=data.key_alias or f"virtual-key-{response.token_id}",
secret_token=response.key,
)
@staticmethod
async def async_key_updated_hook(
data: UpdateKeyRequest,
existing_key_row: Any,
response: Any,
user_api_key_dict: UserAPIKeyAuth,
litellm_changed_by: Optional[str] = None,
):
"""
Post /key/update processing hook
Handles the following:
- Storing Audit Logs for key update
"""
from litellm.proxy.management_helpers.audit_logs import (
create_audit_log_for_update,
)
from litellm.proxy.proxy_server import litellm_proxy_admin_name
# Enterprise Feature - Audit Logging. Enable with litellm.store_audit_logs = True
if litellm.store_audit_logs is True:
_updated_values = json.dumps(data.json(exclude_none=True), default=str)
_before_value = existing_key_row.json(exclude_none=True)
_before_value = json.dumps(_before_value, default=str)
asyncio.create_task(
create_audit_log_for_update(
request_data=LiteLLM_AuditLogs(
id=str(uuid.uuid4()),
updated_at=datetime.now(timezone.utc),
changed_by=litellm_changed_by
or user_api_key_dict.user_id
or litellm_proxy_admin_name,
changed_by_api_key=user_api_key_dict.api_key,
table_name=LitellmTableNames.KEY_TABLE_NAME,
object_id=data.key,
action="updated",
updated_values=_updated_values,
before_value=_before_value,
)
)
)
@staticmethod
async def async_key_rotated_hook(
data: Optional[RegenerateKeyRequest],
existing_key_row: Any,
response: GenerateKeyResponse,
user_api_key_dict: UserAPIKeyAuth,
litellm_changed_by: Optional[str] = None,
):
# store the generated key in the secret manager
if data is not None and response.token_id is not None:
initial_secret_name = (
existing_key_row.key_alias or f"virtual-key-{existing_key_row.token}"
)
await KeyManagementEventHooks._rotate_virtual_key_in_secret_manager(
current_secret_name=initial_secret_name,
new_secret_name=data.key_alias or f"virtual-key-{response.token_id}",
new_secret_value=response.key,
)
@staticmethod
async def async_key_deleted_hook(
data: KeyRequest,
keys_being_deleted: List[LiteLLM_VerificationToken],
response: dict,
user_api_key_dict: UserAPIKeyAuth,
litellm_changed_by: Optional[str] = None,
):
"""
Post /key/delete processing hook
Handles the following:
- Storing Audit Logs for key deletion
"""
from litellm.proxy.management_helpers.audit_logs import (
create_audit_log_for_update,
)
from litellm.proxy.proxy_server import litellm_proxy_admin_name, prisma_client
# Enterprise Feature - Audit Logging. Enable with litellm.store_audit_logs = True
# we do this after the first for loop, since first for loop is for validation. we only want this inserted after validation passes
if litellm.store_audit_logs is True and data.keys is not None:
# make an audit log for each team deleted
for key in data.keys:
key_row = await prisma_client.get_data( # type: ignore
token=key, table_name="key", query_type="find_unique"
)
if key_row is None:
raise ProxyException(
message=f"Key {key} not found",
type=ProxyErrorTypes.bad_request_error,
param="key",
code=status.HTTP_404_NOT_FOUND,
)
key_row = key_row.json(exclude_none=True)
_key_row = json.dumps(key_row, default=str)
asyncio.create_task(
create_audit_log_for_update(
request_data=LiteLLM_AuditLogs(
id=str(uuid.uuid4()),
updated_at=datetime.now(timezone.utc),
changed_by=litellm_changed_by
or user_api_key_dict.user_id
or litellm_proxy_admin_name,
changed_by_api_key=user_api_key_dict.api_key,
table_name=LitellmTableNames.KEY_TABLE_NAME,
object_id=key,
action="deleted",
updated_values="{}",
before_value=_key_row,
)
)
)
# delete the keys from the secret manager
await KeyManagementEventHooks._delete_virtual_keys_from_secret_manager(
keys_being_deleted=keys_being_deleted
)
pass
@staticmethod
async def _store_virtual_key_in_secret_manager(secret_name: str, secret_token: str):
"""
Store a virtual key in the secret manager
Args:
secret_name: Name of the virtual key
secret_token: Value of the virtual key (example: sk-1234)
"""
if litellm._key_management_settings is not None:
if litellm._key_management_settings.store_virtual_keys is True:
from litellm.secret_managers.base_secret_manager import (
BaseSecretManager,
)
# store the key in the secret manager
if isinstance(litellm.secret_manager_client, BaseSecretManager):
await litellm.secret_manager_client.async_write_secret(
secret_name=KeyManagementEventHooks._get_secret_name(
secret_name
),
secret_value=secret_token,
)
@staticmethod
async def _rotate_virtual_key_in_secret_manager(
current_secret_name: str, new_secret_name: str, new_secret_value: str
):
"""
Update a virtual key in the secret manager
Args:
secret_name: Name of the virtual key
secret_token: Value of the virtual key (example: sk-1234)
"""
if litellm._key_management_settings is not None:
if litellm._key_management_settings.store_virtual_keys is True:
from litellm.secret_managers.base_secret_manager import (
BaseSecretManager,
)
# store the key in the secret manager
if isinstance(litellm.secret_manager_client, BaseSecretManager):
await litellm.secret_manager_client.async_rotate_secret(
current_secret_name=KeyManagementEventHooks._get_secret_name(
current_secret_name
),
new_secret_name=KeyManagementEventHooks._get_secret_name(
new_secret_name
),
new_secret_value=new_secret_value,
)
@staticmethod
def _get_secret_name(secret_name: str) -> str:
if litellm._key_management_settings.prefix_for_stored_virtual_keys.endswith(
"/"
):
return f"{litellm._key_management_settings.prefix_for_stored_virtual_keys}{secret_name}"
else:
return f"{litellm._key_management_settings.prefix_for_stored_virtual_keys}/{secret_name}"
@staticmethod
async def _delete_virtual_keys_from_secret_manager(
keys_being_deleted: List[LiteLLM_VerificationToken],
):
"""
Deletes virtual keys from the secret manager
Args:
keys_being_deleted: List of keys being deleted, this is passed down from the /key/delete operation
"""
if litellm._key_management_settings is not None:
if litellm._key_management_settings.store_virtual_keys is True:
from litellm.secret_managers.base_secret_manager import (
BaseSecretManager,
)
if isinstance(litellm.secret_manager_client, BaseSecretManager):
for key in keys_being_deleted:
if key.key_alias is not None:
await litellm.secret_manager_client.async_delete_secret(
secret_name=KeyManagementEventHooks._get_secret_name(
key.key_alias
)
)
else:
verbose_proxy_logger.warning(
f"KeyManagementEventHooks._delete_virtual_key_from_secret_manager: Key alias not found for key {key.token}. Skipping deletion from secret manager."
)
@staticmethod
async def _send_key_created_email(response: dict):
from litellm.proxy.proxy_server import general_settings, proxy_logging_obj
if "email" not in general_settings.get("alerting", []):
raise ValueError(
"Email alerting not setup on config.yaml. Please set `alerting=['email']. \nDocs: https://docs.litellm.ai/docs/proxy/email`"
)
event = WebhookEvent(
event="key_created",
event_group="key",
event_message="API Key Created",
token=response.get("token", ""),
spend=response.get("spend", 0.0),
max_budget=response.get("max_budget", 0.0),
user_id=response.get("user_id", None),
team_id=response.get("team_id", "Default Team"),
key_alias=response.get("key_alias", None),
)
# If user configured email alerting - send an Email letting their end-user know the key was created
asyncio.create_task(
proxy_logging_obj.slack_alerting_instance.send_key_created_or_user_invited_email(
webhook_event=event,
)
)

View File

@@ -0,0 +1,429 @@
# What is this?
## This hook is used to check for LiteLLM managed files in the request body, and replace them with model-specific file id
import base64
import json
import uuid
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Union, cast
from litellm import Router, verbose_logger
from litellm.caching.caching import DualCache
from litellm.integrations.custom_logger import CustomLogger
from litellm.litellm_core_utils.prompt_templates.common_utils import extract_file_data
from litellm.proxy._types import CallTypes, LiteLLM_ManagedFileTable, UserAPIKeyAuth
from litellm.types.llms.openai import (
AllMessageValues,
ChatCompletionFileObject,
CreateFileRequest,
OpenAIFileObject,
OpenAIFilesPurpose,
)
from litellm.types.utils import SpecialEnums
if TYPE_CHECKING:
from opentelemetry.trace import Span as _Span
from litellm.proxy.utils import InternalUsageCache as _InternalUsageCache
from litellm.proxy.utils import PrismaClient as _PrismaClient
Span = Union[_Span, Any]
InternalUsageCache = _InternalUsageCache
PrismaClient = _PrismaClient
else:
Span = Any
InternalUsageCache = Any
PrismaClient = Any
class BaseFileEndpoints(ABC):
@abstractmethod
async def afile_retrieve(
self,
file_id: str,
litellm_parent_otel_span: Optional[Span],
) -> OpenAIFileObject:
pass
@abstractmethod
async def afile_list(
self, custom_llm_provider: str, **data: dict
) -> List[OpenAIFileObject]:
pass
@abstractmethod
async def afile_delete(
self, custom_llm_provider: str, file_id: str, **data: dict
) -> OpenAIFileObject:
pass
class _PROXY_LiteLLMManagedFiles(CustomLogger):
# Class variables or attributes
def __init__(
self, internal_usage_cache: InternalUsageCache, prisma_client: PrismaClient
):
self.internal_usage_cache = internal_usage_cache
self.prisma_client = prisma_client
async def store_unified_file_id(
self,
file_id: str,
file_object: OpenAIFileObject,
litellm_parent_otel_span: Optional[Span],
model_mappings: Dict[str, str],
) -> None:
verbose_logger.info(
f"Storing LiteLLM Managed File object with id={file_id} in cache"
)
litellm_managed_file_object = LiteLLM_ManagedFileTable(
unified_file_id=file_id,
file_object=file_object,
model_mappings=model_mappings,
)
await self.internal_usage_cache.async_set_cache(
key=file_id,
value=litellm_managed_file_object.model_dump(),
litellm_parent_otel_span=litellm_parent_otel_span,
)
await self.prisma_client.db.litellm_managedfiletable.create(
data={
"unified_file_id": file_id,
"file_object": file_object.model_dump_json(),
"model_mappings": json.dumps(model_mappings),
}
)
async def get_unified_file_id(
self, file_id: str, litellm_parent_otel_span: Optional[Span] = None
) -> Optional[LiteLLM_ManagedFileTable]:
## CHECK CACHE
result = cast(
Optional[dict],
await self.internal_usage_cache.async_get_cache(
key=file_id,
litellm_parent_otel_span=litellm_parent_otel_span,
),
)
if result:
return LiteLLM_ManagedFileTable(**result)
## CHECK DB
db_object = await self.prisma_client.db.litellm_managedfiletable.find_first(
where={"unified_file_id": file_id}
)
if db_object:
return LiteLLM_ManagedFileTable(**db_object.model_dump())
return None
async def delete_unified_file_id(
self, file_id: str, litellm_parent_otel_span: Optional[Span] = None
) -> OpenAIFileObject:
## get old value
initial_value = await self.prisma_client.db.litellm_managedfiletable.find_first(
where={"unified_file_id": file_id}
)
if initial_value is None:
raise Exception(f"LiteLLM Managed File object with id={file_id} not found")
## delete old value
await self.internal_usage_cache.async_set_cache(
key=file_id,
value=None,
litellm_parent_otel_span=litellm_parent_otel_span,
)
await self.prisma_client.db.litellm_managedfiletable.delete(
where={"unified_file_id": file_id}
)
return initial_value.file_object
async def async_pre_call_hook(
self,
user_api_key_dict: UserAPIKeyAuth,
cache: DualCache,
data: Dict,
call_type: Literal[
"completion",
"text_completion",
"embeddings",
"image_generation",
"moderation",
"audio_transcription",
"pass_through_endpoint",
"rerank",
],
) -> Union[Exception, str, Dict, None]:
"""
- Detect litellm_proxy/ file_id
- add dictionary of mappings of litellm_proxy/ file_id -> provider_file_id => {litellm_proxy/file_id: {"model_id": id, "file_id": provider_file_id}}
"""
if call_type == CallTypes.completion.value:
messages = data.get("messages")
if messages:
file_ids = self.get_file_ids_from_messages(messages)
if file_ids:
model_file_id_mapping = await self.get_model_file_id_mapping(
file_ids, user_api_key_dict.parent_otel_span
)
data["model_file_id_mapping"] = model_file_id_mapping
return data
def get_file_ids_from_messages(self, messages: List[AllMessageValues]) -> List[str]:
"""
Gets file ids from messages
"""
file_ids = []
for message in messages:
if message.get("role") == "user":
content = message.get("content")
if content:
if isinstance(content, str):
continue
for c in content:
if c["type"] == "file":
file_object = cast(ChatCompletionFileObject, c)
file_object_file_field = file_object["file"]
file_id = file_object_file_field.get("file_id")
if file_id:
file_ids.append(file_id)
return file_ids
@staticmethod
def _convert_b64_uid_to_unified_uid(b64_uid: str) -> str:
is_base64_unified_file_id = (
_PROXY_LiteLLMManagedFiles._is_base64_encoded_unified_file_id(b64_uid)
)
if is_base64_unified_file_id:
return is_base64_unified_file_id
else:
return b64_uid
@staticmethod
def _is_base64_encoded_unified_file_id(b64_uid: str) -> Union[str, Literal[False]]:
# Add padding back if needed
padded = b64_uid + "=" * (-len(b64_uid) % 4)
# Decode from base64
try:
decoded = base64.urlsafe_b64decode(padded).decode()
if decoded.startswith(SpecialEnums.LITELM_MANAGED_FILE_ID_PREFIX.value):
return decoded
else:
return False
except Exception:
return False
def convert_b64_uid_to_unified_uid(self, b64_uid: str) -> str:
is_base64_unified_file_id = self._is_base64_encoded_unified_file_id(b64_uid)
if is_base64_unified_file_id:
return is_base64_unified_file_id
else:
return b64_uid
async def get_model_file_id_mapping(
self, file_ids: List[str], litellm_parent_otel_span: Span
) -> dict:
"""
Get model-specific file IDs for a list of proxy file IDs.
Returns a dictionary mapping litellm_proxy/ file_id -> model_id -> model_file_id
1. Get all the litellm_proxy/ file_ids from the messages
2. For each file_id, search for cache keys matching the pattern file_id:*
3. Return a dictionary of mappings of litellm_proxy/ file_id -> model_id -> model_file_id
Example:
{
"litellm_proxy/file_id": {
"model_id": "model_file_id"
}
}
"""
file_id_mapping: Dict[str, Dict[str, str]] = {}
litellm_managed_file_ids = []
for file_id in file_ids:
## CHECK IF FILE ID IS MANAGED BY LITELM
is_base64_unified_file_id = self._is_base64_encoded_unified_file_id(file_id)
if is_base64_unified_file_id:
litellm_managed_file_ids.append(file_id)
if litellm_managed_file_ids:
# Get all cache keys matching the pattern file_id:*
for file_id in litellm_managed_file_ids:
# Search for any cache key starting with this file_id
unified_file_object = await self.get_unified_file_id(
file_id, litellm_parent_otel_span
)
if unified_file_object:
file_id_mapping[file_id] = unified_file_object.model_mappings
return file_id_mapping
async def create_file_for_each_model(
self,
llm_router: Optional[Router],
_create_file_request: CreateFileRequest,
target_model_names_list: List[str],
litellm_parent_otel_span: Span,
) -> List[OpenAIFileObject]:
if llm_router is None:
raise Exception("LLM Router not initialized. Ensure models added to proxy.")
responses = []
for model in target_model_names_list:
individual_response = await llm_router.acreate_file(
model=model, **_create_file_request
)
responses.append(individual_response)
return responses
async def acreate_file(
self,
create_file_request: CreateFileRequest,
llm_router: Router,
target_model_names_list: List[str],
litellm_parent_otel_span: Span,
) -> OpenAIFileObject:
responses = await self.create_file_for_each_model(
llm_router=llm_router,
_create_file_request=create_file_request,
target_model_names_list=target_model_names_list,
litellm_parent_otel_span=litellm_parent_otel_span,
)
response = await _PROXY_LiteLLMManagedFiles.return_unified_file_id(
file_objects=responses,
create_file_request=create_file_request,
internal_usage_cache=self.internal_usage_cache,
litellm_parent_otel_span=litellm_parent_otel_span,
)
## STORE MODEL MAPPINGS IN DB
model_mappings: Dict[str, str] = {}
for file_object in responses:
model_id = file_object._hidden_params.get("model_id")
if model_id is None:
verbose_logger.warning(
f"Skipping file_object: {file_object} because model_id in hidden_params={file_object._hidden_params} is None"
)
continue
file_id = file_object.id
model_mappings[model_id] = file_id
await self.store_unified_file_id(
file_id=response.id,
file_object=response,
litellm_parent_otel_span=litellm_parent_otel_span,
model_mappings=model_mappings,
)
return response
@staticmethod
async def return_unified_file_id(
file_objects: List[OpenAIFileObject],
create_file_request: CreateFileRequest,
internal_usage_cache: InternalUsageCache,
litellm_parent_otel_span: Span,
) -> OpenAIFileObject:
## GET THE FILE TYPE FROM THE CREATE FILE REQUEST
file_data = extract_file_data(create_file_request["file"])
file_type = file_data["content_type"]
unified_file_id = SpecialEnums.LITELLM_MANAGED_FILE_COMPLETE_STR.value.format(
file_type, str(uuid.uuid4())
)
# Convert to URL-safe base64 and strip padding
base64_unified_file_id = (
base64.urlsafe_b64encode(unified_file_id.encode()).decode().rstrip("=")
)
## CREATE RESPONSE OBJECT
response = OpenAIFileObject(
id=base64_unified_file_id,
object="file",
purpose=create_file_request["purpose"],
created_at=file_objects[0].created_at,
bytes=file_objects[0].bytes,
filename=file_objects[0].filename,
status="uploaded",
)
return response
async def afile_retrieve(
self, file_id: str, litellm_parent_otel_span: Optional[Span]
) -> OpenAIFileObject:
stored_file_object = await self.get_unified_file_id(
file_id, litellm_parent_otel_span
)
if stored_file_object:
return stored_file_object.file_object
else:
raise Exception(f"LiteLLM Managed File object with id={file_id} not found")
async def afile_list(
self,
purpose: Optional[OpenAIFilesPurpose],
litellm_parent_otel_span: Optional[Span],
**data: Dict,
) -> List[OpenAIFileObject]:
return []
async def afile_delete(
self,
file_id: str,
litellm_parent_otel_span: Optional[Span],
llm_router: Router,
**data: Dict,
) -> OpenAIFileObject:
file_id = self.convert_b64_uid_to_unified_uid(file_id)
model_file_id_mapping = await self.get_model_file_id_mapping(
[file_id], litellm_parent_otel_span
)
specific_model_file_id_mapping = model_file_id_mapping.get(file_id)
if specific_model_file_id_mapping:
for model_id, file_id in specific_model_file_id_mapping.items():
await llm_router.afile_delete(model=model_id, file_id=file_id, **data) # type: ignore
stored_file_object = await self.delete_unified_file_id(
file_id, litellm_parent_otel_span
)
if stored_file_object:
return stored_file_object
else:
raise Exception(f"LiteLLM Managed File object with id={file_id} not found")
async def afile_content(
self,
file_id: str,
litellm_parent_otel_span: Optional[Span],
llm_router: Router,
**data: Dict,
) -> str:
"""
Get the content of a file from first model that has it
"""
model_file_id_mapping = await self.get_model_file_id_mapping(
[file_id], litellm_parent_otel_span
)
specific_model_file_id_mapping = model_file_id_mapping.get(file_id)
if specific_model_file_id_mapping:
exception_dict = {}
for model_id, file_id in specific_model_file_id_mapping.items():
try:
return await llm_router.afile_content(model=model_id, file_id=file_id, **data) # type: ignore
except Exception as e:
exception_dict[model_id] = str(e)
raise Exception(
f"LiteLLM Managed File object with id={file_id} not found. Checked model id's: {specific_model_file_id_mapping.keys()}. Errors: {exception_dict}"
)
else:
raise Exception(f"LiteLLM Managed File object with id={file_id} not found")

View File

@@ -0,0 +1,49 @@
from fastapi import HTTPException
from litellm import verbose_logger
from litellm._logging import verbose_proxy_logger
from litellm.caching.caching import DualCache
from litellm.integrations.custom_logger import CustomLogger
from litellm.proxy._types import UserAPIKeyAuth
class _PROXY_MaxBudgetLimiter(CustomLogger):
# Class variables or attributes
def __init__(self):
pass
async def async_pre_call_hook(
self,
user_api_key_dict: UserAPIKeyAuth,
cache: DualCache,
data: dict,
call_type: str,
):
try:
verbose_proxy_logger.debug("Inside Max Budget Limiter Pre-Call Hook")
cache_key = f"{user_api_key_dict.user_id}_user_api_key_user_id"
user_row = await cache.async_get_cache(
cache_key, parent_otel_span=user_api_key_dict.parent_otel_span
)
if user_row is None: # value not yet cached
return
max_budget = user_row["max_budget"]
curr_spend = user_row["spend"]
if max_budget is None:
return
if curr_spend is None:
return
# CHECK IF REQUEST ALLOWED
if curr_spend >= max_budget:
raise HTTPException(status_code=429, detail="Max budget limit reached.")
except HTTPException as e:
raise e
except Exception as e:
verbose_logger.exception(
"litellm.proxy.hooks.max_budget_limiter.py::async_pre_call_hook(): Exception occured - {}".format(
str(e)
)
)

View File

@@ -0,0 +1,192 @@
import json
from typing import List, Optional
import litellm
from litellm._logging import verbose_proxy_logger
from litellm.caching.caching import DualCache
from litellm.integrations.custom_logger import Span
from litellm.proxy._types import UserAPIKeyAuth
from litellm.router_strategy.budget_limiter import RouterBudgetLimiting
from litellm.types.llms.openai import AllMessageValues
from litellm.types.utils import (
BudgetConfig,
GenericBudgetConfigType,
StandardLoggingPayload,
)
VIRTUAL_KEY_SPEND_CACHE_KEY_PREFIX = "virtual_key_spend"
class _PROXY_VirtualKeyModelMaxBudgetLimiter(RouterBudgetLimiting):
"""
Handles budgets for model + virtual key
Example: key=sk-1234567890, model=gpt-4o, max_budget=100, time_period=1d
"""
def __init__(self, dual_cache: DualCache):
self.dual_cache = dual_cache
self.redis_increment_operation_queue = []
async def is_key_within_model_budget(
self,
user_api_key_dict: UserAPIKeyAuth,
model: str,
) -> bool:
"""
Check if the user_api_key_dict is within the model budget
Raises:
BudgetExceededError: If the user_api_key_dict has exceeded the model budget
"""
_model_max_budget = user_api_key_dict.model_max_budget
internal_model_max_budget: GenericBudgetConfigType = {}
for _model, _budget_info in _model_max_budget.items():
internal_model_max_budget[_model] = BudgetConfig(**_budget_info)
verbose_proxy_logger.debug(
"internal_model_max_budget %s",
json.dumps(internal_model_max_budget, indent=4, default=str),
)
# check if current model is in internal_model_max_budget
_current_model_budget_info = self._get_request_model_budget_config(
model=model, internal_model_max_budget=internal_model_max_budget
)
if _current_model_budget_info is None:
verbose_proxy_logger.debug(
f"Model {model} not found in internal_model_max_budget"
)
return True
# check if current model is within budget
if (
_current_model_budget_info.max_budget
and _current_model_budget_info.max_budget > 0
):
_current_spend = await self._get_virtual_key_spend_for_model(
user_api_key_hash=user_api_key_dict.token,
model=model,
key_budget_config=_current_model_budget_info,
)
if (
_current_spend is not None
and _current_model_budget_info.max_budget is not None
and _current_spend > _current_model_budget_info.max_budget
):
raise litellm.BudgetExceededError(
message=f"LiteLLM Virtual Key: {user_api_key_dict.token}, key_alias: {user_api_key_dict.key_alias}, exceeded budget for model={model}",
current_cost=_current_spend,
max_budget=_current_model_budget_info.max_budget,
)
return True
async def _get_virtual_key_spend_for_model(
self,
user_api_key_hash: Optional[str],
model: str,
key_budget_config: BudgetConfig,
) -> Optional[float]:
"""
Get the current spend for a virtual key for a model
Lookup model in this order:
1. model: directly look up `model`
2. If 1, does not exist, check if passed as {custom_llm_provider}/model
"""
# 1. model: directly look up `model`
virtual_key_model_spend_cache_key = f"{VIRTUAL_KEY_SPEND_CACHE_KEY_PREFIX}:{user_api_key_hash}:{model}:{key_budget_config.budget_duration}"
_current_spend = await self.dual_cache.async_get_cache(
key=virtual_key_model_spend_cache_key,
)
if _current_spend is None:
# 2. If 1, does not exist, check if passed as {custom_llm_provider}/model
# if "/" in model, remove first part before "/" - eg. openai/o1-preview -> o1-preview
virtual_key_model_spend_cache_key = f"{VIRTUAL_KEY_SPEND_CACHE_KEY_PREFIX}:{user_api_key_hash}:{self._get_model_without_custom_llm_provider(model)}:{key_budget_config.budget_duration}"
_current_spend = await self.dual_cache.async_get_cache(
key=virtual_key_model_spend_cache_key,
)
return _current_spend
def _get_request_model_budget_config(
self, model: str, internal_model_max_budget: GenericBudgetConfigType
) -> Optional[BudgetConfig]:
"""
Get the budget config for the request model
1. Check if `model` is in `internal_model_max_budget`
2. If not, check if `model` without custom llm provider is in `internal_model_max_budget`
"""
return internal_model_max_budget.get(
model, None
) or internal_model_max_budget.get(
self._get_model_without_custom_llm_provider(model), None
)
def _get_model_without_custom_llm_provider(self, model: str) -> str:
if "/" in model:
return model.split("/")[-1]
return model
async def async_filter_deployments(
self,
model: str,
healthy_deployments: List,
messages: Optional[List[AllMessageValues]],
request_kwargs: Optional[dict] = None,
parent_otel_span: Optional[Span] = None, # type: ignore
) -> List[dict]:
return healthy_deployments
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
"""
Track spend for virtual key + model in DualCache
Example: key=sk-1234567890, model=gpt-4o, max_budget=100, time_period=1d
"""
verbose_proxy_logger.debug("in RouterBudgetLimiting.async_log_success_event")
standard_logging_payload: Optional[StandardLoggingPayload] = kwargs.get(
"standard_logging_object", None
)
if standard_logging_payload is None:
raise ValueError("standard_logging_payload is required")
_litellm_params: dict = kwargs.get("litellm_params", {}) or {}
_metadata: dict = _litellm_params.get("metadata", {}) or {}
user_api_key_model_max_budget: Optional[dict] = _metadata.get(
"user_api_key_model_max_budget", None
)
if (
user_api_key_model_max_budget is None
or len(user_api_key_model_max_budget) == 0
):
verbose_proxy_logger.debug(
"Not running _PROXY_VirtualKeyModelMaxBudgetLimiter.async_log_success_event because user_api_key_model_max_budget is None or empty. `user_api_key_model_max_budget`=%s",
user_api_key_model_max_budget,
)
return
response_cost: float = standard_logging_payload.get("response_cost", 0)
model = standard_logging_payload.get("model")
virtual_key = standard_logging_payload.get("metadata").get("user_api_key_hash")
model = standard_logging_payload.get("model")
if virtual_key is not None:
budget_config = BudgetConfig(time_period="1d", budget_limit=0.1)
virtual_spend_key = f"{VIRTUAL_KEY_SPEND_CACHE_KEY_PREFIX}:{virtual_key}:{model}:{budget_config.budget_duration}"
virtual_start_time_key = f"virtual_key_budget_start_time:{virtual_key}"
await self._increment_spend_for_key(
budget_config=budget_config,
spend_key=virtual_spend_key,
start_time_key=virtual_start_time_key,
response_cost=response_cost,
)
verbose_proxy_logger.debug(
"current state of in memory cache %s",
json.dumps(
self.dual_cache.in_memory_cache.cache_dict, indent=4, default=str
),
)

View File

@@ -0,0 +1,868 @@
import asyncio
import sys
from datetime import datetime, timedelta
from typing import TYPE_CHECKING, Any, List, Literal, Optional, Tuple, TypedDict, Union
from fastapi import HTTPException
from pydantic import BaseModel
import litellm
from litellm import DualCache, ModelResponse
from litellm._logging import verbose_proxy_logger
from litellm.integrations.custom_logger import CustomLogger
from litellm.litellm_core_utils.core_helpers import _get_parent_otel_span_from_kwargs
from litellm.proxy._types import CommonProxyErrors, CurrentItemRateLimit, UserAPIKeyAuth
from litellm.proxy.auth.auth_utils import (
get_key_model_rpm_limit,
get_key_model_tpm_limit,
)
if TYPE_CHECKING:
from opentelemetry.trace import Span as _Span
from litellm.proxy.utils import InternalUsageCache as _InternalUsageCache
Span = Union[_Span, Any]
InternalUsageCache = _InternalUsageCache
else:
Span = Any
InternalUsageCache = Any
class CacheObject(TypedDict):
current_global_requests: Optional[dict]
request_count_api_key: Optional[dict]
request_count_api_key_model: Optional[dict]
request_count_user_id: Optional[dict]
request_count_team_id: Optional[dict]
request_count_end_user_id: Optional[dict]
class _PROXY_MaxParallelRequestsHandler(CustomLogger):
# Class variables or attributes
def __init__(self, internal_usage_cache: InternalUsageCache):
self.internal_usage_cache = internal_usage_cache
def print_verbose(self, print_statement):
try:
verbose_proxy_logger.debug(print_statement)
if litellm.set_verbose:
print(print_statement) # noqa
except Exception:
pass
async def check_key_in_limits(
self,
user_api_key_dict: UserAPIKeyAuth,
cache: DualCache,
data: dict,
call_type: str,
max_parallel_requests: int,
tpm_limit: int,
rpm_limit: int,
current: Optional[dict],
request_count_api_key: str,
rate_limit_type: Literal["key", "model_per_key", "user", "customer", "team"],
values_to_update_in_cache: List[Tuple[Any, Any]],
) -> dict:
verbose_proxy_logger.info(
f"Current Usage of {rate_limit_type} in this minute: {current}"
)
if current is None:
if max_parallel_requests == 0 or tpm_limit == 0 or rpm_limit == 0:
# base case
raise self.raise_rate_limit_error(
additional_details=f"{CommonProxyErrors.max_parallel_request_limit_reached.value}. Hit limit for {rate_limit_type}. Current limits: max_parallel_requests: {max_parallel_requests}, tpm_limit: {tpm_limit}, rpm_limit: {rpm_limit}"
)
new_val = {
"current_requests": 1,
"current_tpm": 0,
"current_rpm": 1,
}
values_to_update_in_cache.append((request_count_api_key, new_val))
elif (
int(current["current_requests"]) < max_parallel_requests
and current["current_tpm"] < tpm_limit
and current["current_rpm"] < rpm_limit
):
# Increase count for this token
new_val = {
"current_requests": current["current_requests"] + 1,
"current_tpm": current["current_tpm"],
"current_rpm": current["current_rpm"] + 1,
}
values_to_update_in_cache.append((request_count_api_key, new_val))
else:
raise HTTPException(
status_code=429,
detail=f"LiteLLM Rate Limit Handler for rate limit type = {rate_limit_type}. {CommonProxyErrors.max_parallel_request_limit_reached.value}. current rpm: {current['current_rpm']}, rpm limit: {rpm_limit}, current tpm: {current['current_tpm']}, tpm limit: {tpm_limit}, current max_parallel_requests: {current['current_requests']}, max_parallel_requests: {max_parallel_requests}",
headers={"retry-after": str(self.time_to_next_minute())},
)
await self.internal_usage_cache.async_batch_set_cache(
cache_list=values_to_update_in_cache,
ttl=60,
litellm_parent_otel_span=user_api_key_dict.parent_otel_span,
local_only=True,
)
return new_val
def time_to_next_minute(self) -> float:
# Get the current time
now = datetime.now()
# Calculate the next minute
next_minute = (now + timedelta(minutes=1)).replace(second=0, microsecond=0)
# Calculate the difference in seconds
seconds_to_next_minute = (next_minute - now).total_seconds()
return seconds_to_next_minute
def raise_rate_limit_error(
self, additional_details: Optional[str] = None
) -> HTTPException:
"""
Raise an HTTPException with a 429 status code and a retry-after header
"""
error_message = "Max parallel request limit reached"
if additional_details is not None:
error_message = error_message + " " + additional_details
raise HTTPException(
status_code=429,
detail=f"Max parallel request limit reached {additional_details}",
headers={"retry-after": str(self.time_to_next_minute())},
)
async def get_all_cache_objects(
self,
current_global_requests: Optional[str],
request_count_api_key: Optional[str],
request_count_api_key_model: Optional[str],
request_count_user_id: Optional[str],
request_count_team_id: Optional[str],
request_count_end_user_id: Optional[str],
parent_otel_span: Optional[Span] = None,
) -> CacheObject:
keys = [
current_global_requests,
request_count_api_key,
request_count_api_key_model,
request_count_user_id,
request_count_team_id,
request_count_end_user_id,
]
results = await self.internal_usage_cache.async_batch_get_cache(
keys=keys,
parent_otel_span=parent_otel_span,
)
if results is None:
return CacheObject(
current_global_requests=None,
request_count_api_key=None,
request_count_api_key_model=None,
request_count_user_id=None,
request_count_team_id=None,
request_count_end_user_id=None,
)
return CacheObject(
current_global_requests=results[0],
request_count_api_key=results[1],
request_count_api_key_model=results[2],
request_count_user_id=results[3],
request_count_team_id=results[4],
request_count_end_user_id=results[5],
)
async def async_pre_call_hook( # noqa: PLR0915
self,
user_api_key_dict: UserAPIKeyAuth,
cache: DualCache,
data: dict,
call_type: str,
):
self.print_verbose("Inside Max Parallel Request Pre-Call Hook")
api_key = user_api_key_dict.api_key
max_parallel_requests = user_api_key_dict.max_parallel_requests
if max_parallel_requests is None:
max_parallel_requests = sys.maxsize
if data is None:
data = {}
global_max_parallel_requests = data.get("metadata", {}).get(
"global_max_parallel_requests", None
)
tpm_limit = getattr(user_api_key_dict, "tpm_limit", sys.maxsize)
if tpm_limit is None:
tpm_limit = sys.maxsize
rpm_limit = getattr(user_api_key_dict, "rpm_limit", sys.maxsize)
if rpm_limit is None:
rpm_limit = sys.maxsize
values_to_update_in_cache: List[
Tuple[Any, Any]
] = (
[]
) # values that need to get updated in cache, will run a batch_set_cache after this function
# ------------
# Setup values
# ------------
new_val: Optional[dict] = None
if global_max_parallel_requests is not None:
# get value from cache
_key = "global_max_parallel_requests"
current_global_requests = await self.internal_usage_cache.async_get_cache(
key=_key,
local_only=True,
litellm_parent_otel_span=user_api_key_dict.parent_otel_span,
)
# check if below limit
if current_global_requests is None:
current_global_requests = 1
# if above -> raise error
if current_global_requests >= global_max_parallel_requests:
return self.raise_rate_limit_error(
additional_details=f"Hit Global Limit: Limit={global_max_parallel_requests}, current: {current_global_requests}"
)
# if below -> increment
else:
await self.internal_usage_cache.async_increment_cache(
key=_key,
value=1,
local_only=True,
litellm_parent_otel_span=user_api_key_dict.parent_otel_span,
)
_model = data.get("model", None)
current_date = datetime.now().strftime("%Y-%m-%d")
current_hour = datetime.now().strftime("%H")
current_minute = datetime.now().strftime("%M")
precise_minute = f"{current_date}-{current_hour}-{current_minute}"
cache_objects: CacheObject = await self.get_all_cache_objects(
current_global_requests=(
"global_max_parallel_requests"
if global_max_parallel_requests is not None
else None
),
request_count_api_key=(
f"{api_key}::{precise_minute}::request_count"
if api_key is not None
else None
),
request_count_api_key_model=(
f"{api_key}::{_model}::{precise_minute}::request_count"
if api_key is not None and _model is not None
else None
),
request_count_user_id=(
f"{user_api_key_dict.user_id}::{precise_minute}::request_count"
if user_api_key_dict.user_id is not None
else None
),
request_count_team_id=(
f"{user_api_key_dict.team_id}::{precise_minute}::request_count"
if user_api_key_dict.team_id is not None
else None
),
request_count_end_user_id=(
f"{user_api_key_dict.end_user_id}::{precise_minute}::request_count"
if user_api_key_dict.end_user_id is not None
else None
),
parent_otel_span=user_api_key_dict.parent_otel_span,
)
if api_key is not None:
request_count_api_key = f"{api_key}::{precise_minute}::request_count"
# CHECK IF REQUEST ALLOWED for key
await self.check_key_in_limits(
user_api_key_dict=user_api_key_dict,
cache=cache,
data=data,
call_type=call_type,
max_parallel_requests=max_parallel_requests,
current=cache_objects["request_count_api_key"],
request_count_api_key=request_count_api_key,
tpm_limit=tpm_limit,
rpm_limit=rpm_limit,
rate_limit_type="key",
values_to_update_in_cache=values_to_update_in_cache,
)
# Check if request under RPM/TPM per model for a given API Key
if (
get_key_model_tpm_limit(user_api_key_dict) is not None
or get_key_model_rpm_limit(user_api_key_dict) is not None
):
_model = data.get("model", None)
request_count_api_key = (
f"{api_key}::{_model}::{precise_minute}::request_count"
)
_tpm_limit_for_key_model = get_key_model_tpm_limit(user_api_key_dict)
_rpm_limit_for_key_model = get_key_model_rpm_limit(user_api_key_dict)
tpm_limit_for_model = None
rpm_limit_for_model = None
if _model is not None:
if _tpm_limit_for_key_model:
tpm_limit_for_model = _tpm_limit_for_key_model.get(_model)
if _rpm_limit_for_key_model:
rpm_limit_for_model = _rpm_limit_for_key_model.get(_model)
new_val = await self.check_key_in_limits(
user_api_key_dict=user_api_key_dict,
cache=cache,
data=data,
call_type=call_type,
max_parallel_requests=sys.maxsize, # TODO: Support max parallel requests for a model
current=cache_objects["request_count_api_key_model"],
request_count_api_key=request_count_api_key,
tpm_limit=tpm_limit_for_model or sys.maxsize,
rpm_limit=rpm_limit_for_model or sys.maxsize,
rate_limit_type="model_per_key",
values_to_update_in_cache=values_to_update_in_cache,
)
_remaining_tokens = None
_remaining_requests = None
# Add remaining tokens, requests to metadata
if new_val:
if tpm_limit_for_model is not None:
_remaining_tokens = tpm_limit_for_model - new_val["current_tpm"]
if rpm_limit_for_model is not None:
_remaining_requests = rpm_limit_for_model - new_val["current_rpm"]
_remaining_limits_data = {
f"litellm-key-remaining-tokens-{_model}": _remaining_tokens,
f"litellm-key-remaining-requests-{_model}": _remaining_requests,
}
if "metadata" not in data:
data["metadata"] = {}
data["metadata"].update(_remaining_limits_data)
# check if REQUEST ALLOWED for user_id
user_id = user_api_key_dict.user_id
if user_id is not None:
user_tpm_limit = user_api_key_dict.user_tpm_limit
user_rpm_limit = user_api_key_dict.user_rpm_limit
if user_tpm_limit is None:
user_tpm_limit = sys.maxsize
if user_rpm_limit is None:
user_rpm_limit = sys.maxsize
request_count_api_key = f"{user_id}::{precise_minute}::request_count"
# print(f"Checking if {request_count_api_key} is allowed to make request for minute {precise_minute}")
await self.check_key_in_limits(
user_api_key_dict=user_api_key_dict,
cache=cache,
data=data,
call_type=call_type,
max_parallel_requests=sys.maxsize, # TODO: Support max parallel requests for a user
current=cache_objects["request_count_user_id"],
request_count_api_key=request_count_api_key,
tpm_limit=user_tpm_limit,
rpm_limit=user_rpm_limit,
rate_limit_type="user",
values_to_update_in_cache=values_to_update_in_cache,
)
# TEAM RATE LIMITS
## get team tpm/rpm limits
team_id = user_api_key_dict.team_id
if team_id is not None:
team_tpm_limit = user_api_key_dict.team_tpm_limit
team_rpm_limit = user_api_key_dict.team_rpm_limit
if team_tpm_limit is None:
team_tpm_limit = sys.maxsize
if team_rpm_limit is None:
team_rpm_limit = sys.maxsize
request_count_api_key = f"{team_id}::{precise_minute}::request_count"
# print(f"Checking if {request_count_api_key} is allowed to make request for minute {precise_minute}")
await self.check_key_in_limits(
user_api_key_dict=user_api_key_dict,
cache=cache,
data=data,
call_type=call_type,
max_parallel_requests=sys.maxsize, # TODO: Support max parallel requests for a team
current=cache_objects["request_count_team_id"],
request_count_api_key=request_count_api_key,
tpm_limit=team_tpm_limit,
rpm_limit=team_rpm_limit,
rate_limit_type="team",
values_to_update_in_cache=values_to_update_in_cache,
)
# End-User Rate Limits
# Only enforce if user passed `user` to /chat, /completions, /embeddings
if user_api_key_dict.end_user_id:
end_user_tpm_limit = getattr(
user_api_key_dict, "end_user_tpm_limit", sys.maxsize
)
end_user_rpm_limit = getattr(
user_api_key_dict, "end_user_rpm_limit", sys.maxsize
)
if end_user_tpm_limit is None:
end_user_tpm_limit = sys.maxsize
if end_user_rpm_limit is None:
end_user_rpm_limit = sys.maxsize
# now do the same tpm/rpm checks
request_count_api_key = (
f"{user_api_key_dict.end_user_id}::{precise_minute}::request_count"
)
# print(f"Checking if {request_count_api_key} is allowed to make request for minute {precise_minute}")
await self.check_key_in_limits(
user_api_key_dict=user_api_key_dict,
cache=cache,
data=data,
call_type=call_type,
max_parallel_requests=sys.maxsize, # TODO: Support max parallel requests for an End-User
request_count_api_key=request_count_api_key,
current=cache_objects["request_count_end_user_id"],
tpm_limit=end_user_tpm_limit,
rpm_limit=end_user_rpm_limit,
rate_limit_type="customer",
values_to_update_in_cache=values_to_update_in_cache,
)
asyncio.create_task(
self.internal_usage_cache.async_batch_set_cache(
cache_list=values_to_update_in_cache,
ttl=60,
litellm_parent_otel_span=user_api_key_dict.parent_otel_span,
) # don't block execution for cache updates
)
return
async def async_log_success_event( # noqa: PLR0915
self, kwargs, response_obj, start_time, end_time
):
from litellm.proxy.common_utils.callback_utils import (
get_model_group_from_litellm_kwargs,
)
litellm_parent_otel_span: Union[Span, None] = _get_parent_otel_span_from_kwargs(
kwargs=kwargs
)
try:
self.print_verbose("INSIDE parallel request limiter ASYNC SUCCESS LOGGING")
global_max_parallel_requests = kwargs["litellm_params"]["metadata"].get(
"global_max_parallel_requests", None
)
user_api_key = kwargs["litellm_params"]["metadata"]["user_api_key"]
user_api_key_user_id = kwargs["litellm_params"]["metadata"].get(
"user_api_key_user_id", None
)
user_api_key_team_id = kwargs["litellm_params"]["metadata"].get(
"user_api_key_team_id", None
)
user_api_key_model_max_budget = kwargs["litellm_params"]["metadata"].get(
"user_api_key_model_max_budget", None
)
user_api_key_end_user_id = kwargs.get("user")
user_api_key_metadata = (
kwargs["litellm_params"]["metadata"].get("user_api_key_metadata", {})
or {}
)
# ------------
# Setup values
# ------------
if global_max_parallel_requests is not None:
# get value from cache
_key = "global_max_parallel_requests"
# decrement
await self.internal_usage_cache.async_increment_cache(
key=_key,
value=-1,
local_only=True,
litellm_parent_otel_span=litellm_parent_otel_span,
)
current_date = datetime.now().strftime("%Y-%m-%d")
current_hour = datetime.now().strftime("%H")
current_minute = datetime.now().strftime("%M")
precise_minute = f"{current_date}-{current_hour}-{current_minute}"
total_tokens = 0
if isinstance(response_obj, ModelResponse):
total_tokens = response_obj.usage.total_tokens # type: ignore
# ------------
# Update usage - API Key
# ------------
values_to_update_in_cache = []
if user_api_key is not None:
request_count_api_key = (
f"{user_api_key}::{precise_minute}::request_count"
)
current = await self.internal_usage_cache.async_get_cache(
key=request_count_api_key,
litellm_parent_otel_span=litellm_parent_otel_span,
) or {
"current_requests": 1,
"current_tpm": 0,
"current_rpm": 0,
}
new_val = {
"current_requests": max(current["current_requests"] - 1, 0),
"current_tpm": current["current_tpm"] + total_tokens,
"current_rpm": current["current_rpm"],
}
self.print_verbose(
f"updated_value in success call: {new_val}, precise_minute: {precise_minute}"
)
values_to_update_in_cache.append((request_count_api_key, new_val))
# ------------
# Update usage - model group + API Key
# ------------
model_group = get_model_group_from_litellm_kwargs(kwargs)
if (
user_api_key is not None
and model_group is not None
and (
"model_rpm_limit" in user_api_key_metadata
or "model_tpm_limit" in user_api_key_metadata
or user_api_key_model_max_budget is not None
)
):
request_count_api_key = (
f"{user_api_key}::{model_group}::{precise_minute}::request_count"
)
current = await self.internal_usage_cache.async_get_cache(
key=request_count_api_key,
litellm_parent_otel_span=litellm_parent_otel_span,
) or {
"current_requests": 1,
"current_tpm": 0,
"current_rpm": 0,
}
new_val = {
"current_requests": max(current["current_requests"] - 1, 0),
"current_tpm": current["current_tpm"] + total_tokens,
"current_rpm": current["current_rpm"],
}
self.print_verbose(
f"updated_value in success call: {new_val}, precise_minute: {precise_minute}"
)
values_to_update_in_cache.append((request_count_api_key, new_val))
# ------------
# Update usage - User
# ------------
if user_api_key_user_id is not None:
total_tokens = 0
if isinstance(response_obj, ModelResponse):
total_tokens = response_obj.usage.total_tokens # type: ignore
request_count_api_key = (
f"{user_api_key_user_id}::{precise_minute}::request_count"
)
current = await self.internal_usage_cache.async_get_cache(
key=request_count_api_key,
litellm_parent_otel_span=litellm_parent_otel_span,
) or {
"current_requests": 1,
"current_tpm": total_tokens,
"current_rpm": 1,
}
new_val = {
"current_requests": max(current["current_requests"] - 1, 0),
"current_tpm": current["current_tpm"] + total_tokens,
"current_rpm": current["current_rpm"],
}
self.print_verbose(
f"updated_value in success call: {new_val}, precise_minute: {precise_minute}"
)
values_to_update_in_cache.append((request_count_api_key, new_val))
# ------------
# Update usage - Team
# ------------
if user_api_key_team_id is not None:
total_tokens = 0
if isinstance(response_obj, ModelResponse):
total_tokens = response_obj.usage.total_tokens # type: ignore
request_count_api_key = (
f"{user_api_key_team_id}::{precise_minute}::request_count"
)
current = await self.internal_usage_cache.async_get_cache(
key=request_count_api_key,
litellm_parent_otel_span=litellm_parent_otel_span,
) or {
"current_requests": 1,
"current_tpm": total_tokens,
"current_rpm": 1,
}
new_val = {
"current_requests": max(current["current_requests"] - 1, 0),
"current_tpm": current["current_tpm"] + total_tokens,
"current_rpm": current["current_rpm"],
}
self.print_verbose(
f"updated_value in success call: {new_val}, precise_minute: {precise_minute}"
)
values_to_update_in_cache.append((request_count_api_key, new_val))
# ------------
# Update usage - End User
# ------------
if user_api_key_end_user_id is not None:
total_tokens = 0
if isinstance(response_obj, ModelResponse):
total_tokens = response_obj.usage.total_tokens # type: ignore
request_count_api_key = (
f"{user_api_key_end_user_id}::{precise_minute}::request_count"
)
current = await self.internal_usage_cache.async_get_cache(
key=request_count_api_key,
litellm_parent_otel_span=litellm_parent_otel_span,
) or {
"current_requests": 1,
"current_tpm": total_tokens,
"current_rpm": 1,
}
new_val = {
"current_requests": max(current["current_requests"] - 1, 0),
"current_tpm": current["current_tpm"] + total_tokens,
"current_rpm": current["current_rpm"],
}
self.print_verbose(
f"updated_value in success call: {new_val}, precise_minute: {precise_minute}"
)
values_to_update_in_cache.append((request_count_api_key, new_val))
await self.internal_usage_cache.async_batch_set_cache(
cache_list=values_to_update_in_cache,
ttl=60,
litellm_parent_otel_span=litellm_parent_otel_span,
)
except Exception as e:
self.print_verbose(e) # noqa
async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time):
try:
self.print_verbose("Inside Max Parallel Request Failure Hook")
litellm_parent_otel_span: Union[
Span, None
] = _get_parent_otel_span_from_kwargs(kwargs=kwargs)
_metadata = kwargs["litellm_params"].get("metadata", {}) or {}
global_max_parallel_requests = _metadata.get(
"global_max_parallel_requests", None
)
user_api_key = _metadata.get("user_api_key", None)
self.print_verbose(f"user_api_key: {user_api_key}")
if user_api_key is None:
return
## decrement call count if call failed
if CommonProxyErrors.max_parallel_request_limit_reached.value in str(
kwargs["exception"]
):
pass # ignore failed calls due to max limit being reached
else:
# ------------
# Setup values
# ------------
if global_max_parallel_requests is not None:
# get value from cache
_key = "global_max_parallel_requests"
(
await self.internal_usage_cache.async_get_cache(
key=_key,
local_only=True,
litellm_parent_otel_span=litellm_parent_otel_span,
)
)
# decrement
await self.internal_usage_cache.async_increment_cache(
key=_key,
value=-1,
local_only=True,
litellm_parent_otel_span=litellm_parent_otel_span,
)
current_date = datetime.now().strftime("%Y-%m-%d")
current_hour = datetime.now().strftime("%H")
current_minute = datetime.now().strftime("%M")
precise_minute = f"{current_date}-{current_hour}-{current_minute}"
request_count_api_key = (
f"{user_api_key}::{precise_minute}::request_count"
)
# ------------
# Update usage
# ------------
current = await self.internal_usage_cache.async_get_cache(
key=request_count_api_key,
litellm_parent_otel_span=litellm_parent_otel_span,
) or {
"current_requests": 1,
"current_tpm": 0,
"current_rpm": 0,
}
new_val = {
"current_requests": max(current["current_requests"] - 1, 0),
"current_tpm": current["current_tpm"],
"current_rpm": current["current_rpm"],
}
self.print_verbose(f"updated_value in failure call: {new_val}")
await self.internal_usage_cache.async_set_cache(
request_count_api_key,
new_val,
ttl=60,
litellm_parent_otel_span=litellm_parent_otel_span,
) # save in cache for up to 1 min.
except Exception as e:
verbose_proxy_logger.exception(
"Inside Parallel Request Limiter: An exception occurred - {}".format(
str(e)
)
)
async def get_internal_user_object(
self,
user_id: str,
user_api_key_dict: UserAPIKeyAuth,
) -> Optional[dict]:
"""
Helper to get the 'Internal User Object'
It uses the `get_user_object` function from `litellm.proxy.auth.auth_checks`
We need this because the UserApiKeyAuth object does not contain the rpm/tpm limits for a User AND there could be a perf impact by additionally reading the UserTable.
"""
from litellm._logging import verbose_proxy_logger
from litellm.proxy.auth.auth_checks import get_user_object
from litellm.proxy.proxy_server import prisma_client
try:
_user_id_rate_limits = await get_user_object(
user_id=user_id,
prisma_client=prisma_client,
user_api_key_cache=self.internal_usage_cache.dual_cache,
user_id_upsert=False,
parent_otel_span=user_api_key_dict.parent_otel_span,
proxy_logging_obj=None,
)
if _user_id_rate_limits is None:
return None
return _user_id_rate_limits.model_dump()
except Exception as e:
verbose_proxy_logger.debug(
"Parallel Request Limiter: Error getting user object", str(e)
)
return None
async def async_post_call_success_hook(
self, data: dict, user_api_key_dict: UserAPIKeyAuth, response
):
"""
Retrieve the key's remaining rate limits.
"""
api_key = user_api_key_dict.api_key
current_date = datetime.now().strftime("%Y-%m-%d")
current_hour = datetime.now().strftime("%H")
current_minute = datetime.now().strftime("%M")
precise_minute = f"{current_date}-{current_hour}-{current_minute}"
request_count_api_key = f"{api_key}::{precise_minute}::request_count"
current: Optional[
CurrentItemRateLimit
] = await self.internal_usage_cache.async_get_cache(
key=request_count_api_key,
litellm_parent_otel_span=user_api_key_dict.parent_otel_span,
)
key_remaining_rpm_limit: Optional[int] = None
key_rpm_limit: Optional[int] = None
key_remaining_tpm_limit: Optional[int] = None
key_tpm_limit: Optional[int] = None
if current is not None:
if user_api_key_dict.rpm_limit is not None:
key_remaining_rpm_limit = (
user_api_key_dict.rpm_limit - current["current_rpm"]
)
key_rpm_limit = user_api_key_dict.rpm_limit
if user_api_key_dict.tpm_limit is not None:
key_remaining_tpm_limit = (
user_api_key_dict.tpm_limit - current["current_tpm"]
)
key_tpm_limit = user_api_key_dict.tpm_limit
if hasattr(response, "_hidden_params"):
_hidden_params = getattr(response, "_hidden_params")
else:
_hidden_params = None
if _hidden_params is not None and (
isinstance(_hidden_params, BaseModel) or isinstance(_hidden_params, dict)
):
if isinstance(_hidden_params, BaseModel):
_hidden_params = _hidden_params.model_dump()
_additional_headers = _hidden_params.get("additional_headers", {}) or {}
if key_remaining_rpm_limit is not None:
_additional_headers[
"x-ratelimit-remaining-requests"
] = key_remaining_rpm_limit
if key_rpm_limit is not None:
_additional_headers["x-ratelimit-limit-requests"] = key_rpm_limit
if key_remaining_tpm_limit is not None:
_additional_headers[
"x-ratelimit-remaining-tokens"
] = key_remaining_tpm_limit
if key_tpm_limit is not None:
_additional_headers["x-ratelimit-limit-tokens"] = key_tpm_limit
setattr(
response,
"_hidden_params",
{**_hidden_params, "additional_headers": _additional_headers},
)
return await super().async_post_call_success_hook(
data, user_api_key_dict, response
)

View File

@@ -0,0 +1,282 @@
# +------------------------------------+
#
# Prompt Injection Detection
#
# +------------------------------------+
# Thank you users! We ❤️ you! - Krrish & Ishaan
## Reject a call if it contains a prompt injection attack.
from difflib import SequenceMatcher
from typing import List, Literal, Optional
from fastapi import HTTPException
import litellm
from litellm._logging import verbose_proxy_logger
from litellm.caching.caching import DualCache
from litellm.constants import DEFAULT_PROMPT_INJECTION_SIMILARITY_THRESHOLD
from litellm.integrations.custom_logger import CustomLogger
from litellm.litellm_core_utils.prompt_templates.factory import (
prompt_injection_detection_default_pt,
)
from litellm.proxy._types import LiteLLMPromptInjectionParams, UserAPIKeyAuth
from litellm.router import Router
from litellm.utils import get_formatted_prompt
class _OPTIONAL_PromptInjectionDetection(CustomLogger):
# Class variables or attributes
def __init__(
self,
prompt_injection_params: Optional[LiteLLMPromptInjectionParams] = None,
):
self.prompt_injection_params = prompt_injection_params
self.llm_router: Optional[Router] = None
self.verbs = [
"Ignore",
"Disregard",
"Skip",
"Forget",
"Neglect",
"Overlook",
"Omit",
"Bypass",
"Pay no attention to",
"Do not follow",
"Do not obey",
]
self.adjectives = [
"",
"prior",
"previous",
"preceding",
"above",
"foregoing",
"earlier",
"initial",
]
self.prepositions = [
"",
"and start over",
"and start anew",
"and begin afresh",
"and start from scratch",
]
def print_verbose(self, print_statement, level: Literal["INFO", "DEBUG"] = "DEBUG"):
if level == "INFO":
verbose_proxy_logger.info(print_statement)
elif level == "DEBUG":
verbose_proxy_logger.debug(print_statement)
if litellm.set_verbose is True:
print(print_statement) # noqa
def update_environment(self, router: Optional[Router] = None):
self.llm_router = router
if (
self.prompt_injection_params is not None
and self.prompt_injection_params.llm_api_check is True
):
if self.llm_router is None:
raise Exception(
"PromptInjectionDetection: Model List not set. Required for Prompt Injection detection."
)
self.print_verbose(
f"model_names: {self.llm_router.model_names}; self.prompt_injection_params.llm_api_name: {self.prompt_injection_params.llm_api_name}"
)
if (
self.prompt_injection_params.llm_api_name is None
or self.prompt_injection_params.llm_api_name
not in self.llm_router.model_names
):
raise Exception(
"PromptInjectionDetection: Invalid LLM API Name. LLM API Name must be a 'model_name' in 'model_list'."
)
def generate_injection_keywords(self) -> List[str]:
combinations = []
for verb in self.verbs:
for adj in self.adjectives:
for prep in self.prepositions:
phrase = " ".join(filter(None, [verb, adj, prep])).strip()
if (
len(phrase.split()) > 2
): # additional check to ensure more than 2 words
combinations.append(phrase.lower())
return combinations
def check_user_input_similarity(
self,
user_input: str,
similarity_threshold: float = DEFAULT_PROMPT_INJECTION_SIMILARITY_THRESHOLD,
) -> bool:
user_input_lower = user_input.lower()
keywords = self.generate_injection_keywords()
for keyword in keywords:
# Calculate the length of the keyword to extract substrings of the same length from user input
keyword_length = len(keyword)
for i in range(len(user_input_lower) - keyword_length + 1):
# Extract a substring of the same length as the keyword
substring = user_input_lower[i : i + keyword_length]
# Calculate similarity
match_ratio = SequenceMatcher(None, substring, keyword).ratio()
if match_ratio > similarity_threshold:
self.print_verbose(
print_statement=f"Rejected user input - {user_input}. {match_ratio} similar to {keyword}",
level="INFO",
)
return True # Found a highly similar substring
return False # No substring crossed the threshold
async def async_pre_call_hook(
self,
user_api_key_dict: UserAPIKeyAuth,
cache: DualCache,
data: dict,
call_type: str, # "completion", "embeddings", "image_generation", "moderation"
):
try:
"""
- check if user id part of call
- check if user id part of blocked list
"""
self.print_verbose("Inside Prompt Injection Detection Pre-Call Hook")
try:
assert call_type in [
"completion",
"text_completion",
"embeddings",
"image_generation",
"moderation",
"audio_transcription",
]
except Exception:
self.print_verbose(
f"Call Type - {call_type}, not in accepted list - ['completion','embeddings','image_generation','moderation','audio_transcription']"
)
return data
formatted_prompt = get_formatted_prompt(data=data, call_type=call_type) # type: ignore
is_prompt_attack = False
if self.prompt_injection_params is not None:
# 1. check if heuristics check turned on
if self.prompt_injection_params.heuristics_check is True:
is_prompt_attack = self.check_user_input_similarity(
user_input=formatted_prompt
)
if is_prompt_attack is True:
raise HTTPException(
status_code=400,
detail={
"error": "Rejected message. This is a prompt injection attack."
},
)
# 2. check if vector db similarity check turned on [TODO] Not Implemented yet
if self.prompt_injection_params.vector_db_check is True:
pass
else:
is_prompt_attack = self.check_user_input_similarity(
user_input=formatted_prompt
)
if is_prompt_attack is True:
raise HTTPException(
status_code=400,
detail={
"error": "Rejected message. This is a prompt injection attack."
},
)
return data
except HTTPException as e:
if (
e.status_code == 400
and isinstance(e.detail, dict)
and "error" in e.detail # type: ignore
and self.prompt_injection_params is not None
and self.prompt_injection_params.reject_as_response
):
return e.detail.get("error")
raise e
except Exception as e:
verbose_proxy_logger.exception(
"litellm.proxy.hooks.prompt_injection_detection.py::async_pre_call_hook(): Exception occured - {}".format(
str(e)
)
)
async def async_moderation_hook( # type: ignore
self,
data: dict,
user_api_key_dict: UserAPIKeyAuth,
call_type: Literal[
"completion",
"embeddings",
"image_generation",
"moderation",
"audio_transcription",
],
) -> Optional[bool]:
self.print_verbose(
f"IN ASYNC MODERATION HOOK - self.prompt_injection_params = {self.prompt_injection_params}"
)
if self.prompt_injection_params is None:
return None
formatted_prompt = get_formatted_prompt(data=data, call_type=call_type) # type: ignore
is_prompt_attack = False
prompt_injection_system_prompt = getattr(
self.prompt_injection_params,
"llm_api_system_prompt",
prompt_injection_detection_default_pt(),
)
# 3. check if llm api check turned on
if (
self.prompt_injection_params.llm_api_check is True
and self.prompt_injection_params.llm_api_name is not None
and self.llm_router is not None
):
# make a call to the llm api
response = await self.llm_router.acompletion(
model=self.prompt_injection_params.llm_api_name,
messages=[
{
"role": "system",
"content": prompt_injection_system_prompt,
},
{"role": "user", "content": formatted_prompt},
],
)
self.print_verbose(f"Received LLM Moderation response: {response}")
self.print_verbose(
f"llm_api_fail_call_string: {self.prompt_injection_params.llm_api_fail_call_string}"
)
if isinstance(response, litellm.ModelResponse) and isinstance(
response.choices[0], litellm.Choices
):
if self.prompt_injection_params.llm_api_fail_call_string in response.choices[0].message.content: # type: ignore
is_prompt_attack = True
if is_prompt_attack is True:
raise HTTPException(
status_code=400,
detail={
"error": "Rejected message. This is a prompt injection attack."
},
)
return is_prompt_attack

View File

@@ -0,0 +1,257 @@
import asyncio
import traceback
from datetime import datetime
from typing import Any, Optional, Union, cast
import litellm
from litellm._logging import verbose_proxy_logger
from litellm.integrations.custom_logger import CustomLogger
from litellm.litellm_core_utils.core_helpers import (
_get_parent_otel_span_from_kwargs,
get_litellm_metadata_from_kwargs,
)
from litellm.litellm_core_utils.litellm_logging import StandardLoggingPayloadSetup
from litellm.proxy._types import UserAPIKeyAuth
from litellm.proxy.auth.auth_checks import log_db_metrics
from litellm.proxy.auth.route_checks import RouteChecks
from litellm.proxy.utils import ProxyUpdateSpend
from litellm.types.utils import (
StandardLoggingPayload,
StandardLoggingUserAPIKeyMetadata,
)
from litellm.utils import get_end_user_id_for_cost_tracking
class _ProxyDBLogger(CustomLogger):
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
await self._PROXY_track_cost_callback(
kwargs, response_obj, start_time, end_time
)
async def async_post_call_failure_hook(
self,
request_data: dict,
original_exception: Exception,
user_api_key_dict: UserAPIKeyAuth,
):
request_route = user_api_key_dict.request_route
if _ProxyDBLogger._should_track_errors_in_db() is False:
return
elif request_route is not None and not RouteChecks.is_llm_api_route(
route=request_route
):
return
from litellm.proxy.proxy_server import proxy_logging_obj
_metadata = dict(
StandardLoggingUserAPIKeyMetadata(
user_api_key_hash=user_api_key_dict.api_key,
user_api_key_alias=user_api_key_dict.key_alias,
user_api_key_user_email=user_api_key_dict.user_email,
user_api_key_user_id=user_api_key_dict.user_id,
user_api_key_team_id=user_api_key_dict.team_id,
user_api_key_org_id=user_api_key_dict.org_id,
user_api_key_team_alias=user_api_key_dict.team_alias,
user_api_key_end_user_id=user_api_key_dict.end_user_id,
)
)
_metadata["user_api_key"] = user_api_key_dict.api_key
_metadata["status"] = "failure"
_metadata[
"error_information"
] = StandardLoggingPayloadSetup.get_error_information(
original_exception=original_exception,
)
existing_metadata: dict = request_data.get("metadata", None) or {}
existing_metadata.update(_metadata)
if "litellm_params" not in request_data:
request_data["litellm_params"] = {}
request_data["litellm_params"]["proxy_server_request"] = (
request_data.get("proxy_server_request") or {}
)
request_data["litellm_params"]["metadata"] = existing_metadata
await proxy_logging_obj.db_spend_update_writer.update_database(
token=user_api_key_dict.api_key,
response_cost=0.0,
user_id=user_api_key_dict.user_id,
end_user_id=user_api_key_dict.end_user_id,
team_id=user_api_key_dict.team_id,
kwargs=request_data,
completion_response=original_exception,
start_time=datetime.now(),
end_time=datetime.now(),
org_id=user_api_key_dict.org_id,
)
@log_db_metrics
async def _PROXY_track_cost_callback(
self,
kwargs, # kwargs to completion
completion_response: Optional[
Union[litellm.ModelResponse, Any]
], # response from completion
start_time=None,
end_time=None, # start/end time for completion
):
from litellm.proxy.proxy_server import (
prisma_client,
proxy_logging_obj,
update_cache,
)
verbose_proxy_logger.debug("INSIDE _PROXY_track_cost_callback")
try:
verbose_proxy_logger.debug(
f"kwargs stream: {kwargs.get('stream', None)} + complete streaming response: {kwargs.get('complete_streaming_response', None)}"
)
parent_otel_span = _get_parent_otel_span_from_kwargs(kwargs=kwargs)
litellm_params = kwargs.get("litellm_params", {}) or {}
end_user_id = get_end_user_id_for_cost_tracking(litellm_params)
metadata = get_litellm_metadata_from_kwargs(kwargs=kwargs)
user_id = cast(Optional[str], metadata.get("user_api_key_user_id", None))
team_id = cast(Optional[str], metadata.get("user_api_key_team_id", None))
org_id = cast(Optional[str], metadata.get("user_api_key_org_id", None))
key_alias = cast(Optional[str], metadata.get("user_api_key_alias", None))
end_user_max_budget = metadata.get("user_api_end_user_max_budget", None)
sl_object: Optional[StandardLoggingPayload] = kwargs.get(
"standard_logging_object", None
)
response_cost = (
sl_object.get("response_cost", None)
if sl_object is not None
else kwargs.get("response_cost", None)
)
if response_cost is not None:
user_api_key = metadata.get("user_api_key", None)
if kwargs.get("cache_hit", False) is True:
response_cost = 0.0
verbose_proxy_logger.info(
f"Cache Hit: response_cost {response_cost}, for user_id {user_id}"
)
verbose_proxy_logger.debug(
f"user_api_key {user_api_key}, prisma_client: {prisma_client}"
)
if _should_track_cost_callback(
user_api_key=user_api_key,
user_id=user_id,
team_id=team_id,
end_user_id=end_user_id,
):
## UPDATE DATABASE
await proxy_logging_obj.db_spend_update_writer.update_database(
token=user_api_key,
response_cost=response_cost,
user_id=user_id,
end_user_id=end_user_id,
team_id=team_id,
kwargs=kwargs,
completion_response=completion_response,
start_time=start_time,
end_time=end_time,
org_id=org_id,
)
# update cache
asyncio.create_task(
update_cache(
token=user_api_key,
user_id=user_id,
end_user_id=end_user_id,
response_cost=response_cost,
team_id=team_id,
parent_otel_span=parent_otel_span,
)
)
await proxy_logging_obj.slack_alerting_instance.customer_spend_alert(
token=user_api_key,
key_alias=key_alias,
end_user_id=end_user_id,
response_cost=response_cost,
max_budget=end_user_max_budget,
)
else:
raise Exception(
"User API key and team id and user id missing from custom callback."
)
else:
if kwargs["stream"] is not True or (
kwargs["stream"] is True and "complete_streaming_response" in kwargs
):
if sl_object is not None:
cost_tracking_failure_debug_info: Union[dict, str] = (
sl_object["response_cost_failure_debug_info"] # type: ignore
or "response_cost_failure_debug_info is None in standard_logging_object"
)
else:
cost_tracking_failure_debug_info = (
"standard_logging_object not found"
)
model = kwargs.get("model")
raise Exception(
f"Cost tracking failed for model={model}.\nDebug info - {cost_tracking_failure_debug_info}\nAdd custom pricing - https://docs.litellm.ai/docs/proxy/custom_pricing"
)
except Exception as e:
error_msg = f"Error in tracking cost callback - {str(e)}\n Traceback:{traceback.format_exc()}"
model = kwargs.get("model", "")
metadata = get_litellm_metadata_from_kwargs(kwargs=kwargs)
litellm_metadata = kwargs.get("litellm_params", {}).get(
"litellm_metadata", {}
)
old_metadata = kwargs.get("litellm_params", {}).get("metadata", {})
call_type = kwargs.get("call_type", "")
error_msg += f"\n Args to _PROXY_track_cost_callback\n model: {model}\n chosen_metadata: {metadata}\n litellm_metadata: {litellm_metadata}\n old_metadata: {old_metadata}\n call_type: {call_type}\n"
asyncio.create_task(
proxy_logging_obj.failed_tracking_alert(
error_message=error_msg,
failing_model=model,
)
)
verbose_proxy_logger.exception(
"Error in tracking cost callback - %s", str(e)
)
@staticmethod
def _should_track_errors_in_db():
"""
Returns True if errors should be tracked in the database
By default, errors are tracked in the database
If users want to disable error tracking, they can set the disable_error_logs flag in the general_settings
"""
from litellm.proxy.proxy_server import general_settings
if general_settings.get("disable_error_logs") is True:
return False
return
def _should_track_cost_callback(
user_api_key: Optional[str],
user_id: Optional[str],
team_id: Optional[str],
end_user_id: Optional[str],
) -> bool:
"""
Determine if the cost callback should be tracked based on the kwargs
"""
# don't run track cost callback if user opted into disabling spend
if ProxyUpdateSpend.disable_spend_updates() is True:
return False
if (
user_api_key is not None
or user_id is not None
or team_id is not None
or end_user_id is not None
):
return True
return False