structure saas with tools

This commit is contained in:
Davidson Gomes
2025-04-25 15:30:54 -03:00
commit 1aef473937
16434 changed files with 6584257 additions and 0 deletions

View File

@@ -0,0 +1,169 @@
def show_missing_vars_in_env():
from fastapi.responses import HTMLResponse
from litellm.proxy.proxy_server import master_key, prisma_client
if prisma_client is None and master_key is None:
return HTMLResponse(
content=missing_keys_form(
missing_key_names="DATABASE_URL, LITELLM_MASTER_KEY"
),
status_code=200,
)
if prisma_client is None:
return HTMLResponse(
content=missing_keys_form(missing_key_names="DATABASE_URL"), status_code=200
)
if master_key is None:
return HTMLResponse(
content=missing_keys_form(missing_key_names="LITELLM_MASTER_KEY"),
status_code=200,
)
return None
def missing_keys_form(missing_key_names: str):
missing_keys_html_form = """
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<style>
body {{
font-family: Arial, sans-serif;
background-color: #f4f4f9;
color: #333;
margin: 20px;
line-height: 1.6;
}}
.container {{
max-width: 800px;
margin: auto;
padding: 20px;
background: #fff;
border: 1px solid #ddd;
border-radius: 5px;
box-shadow: 0 0 10px rgba(0, 0, 0, 0.1);
}}
h1 {{
font-size: 24px;
margin-bottom: 20px;
}}
pre {{
background: #f8f8f8;
padding: 1px;
border: 1px solid #ccc;
border-radius: 4px;
overflow-x: auto;
font-size: 14px;
}}
.env-var {{
font-weight: normal;
}}
.comment {{
font-weight: normal;
color: #777;
}}
</style>
<title>Environment Setup Instructions</title>
</head>
<body>
<div class="container">
<h1>Environment Setup Instructions</h1>
<p>Please add the following variables to your environment variables:</p>
<pre>
<span class="env-var">LITELLM_MASTER_KEY="sk-1234"</span> <span class="comment"># Your master key for the proxy server. Can use this to send /chat/completion requests etc</span>
<span class="env-var">LITELLM_SALT_KEY="sk-XXXXXXXX"</span> <span class="comment"># Can NOT CHANGE THIS ONCE SET - It is used to encrypt/decrypt credentials stored in DB. If value of 'LITELLM_SALT_KEY' changes your models cannot be retrieved from DB</span>
<span class="env-var">DATABASE_URL="postgres://..."</span> <span class="comment"># Need a postgres database? (Check out Supabase, Neon, etc)</span>
<span class="comment">## OPTIONAL ##</span>
<span class="env-var">PORT=4000</span> <span class="comment"># DO THIS FOR RENDER/RAILWAY</span>
<span class="env-var">STORE_MODEL_IN_DB="True"</span> <span class="comment"># Allow storing models in db</span>
</pre>
<h1>Missing Environment Variables</h1>
<p>{missing_keys}</p>
</div>
<div class="container">
<h1>Need Help? Support</h1>
<p>Discord: <a href="https://discord.com/invite/wuPM9dRgDw" target="_blank">https://discord.com/invite/wuPM9dRgDw</a></p>
<p>Docs: <a href="https://docs.litellm.ai/docs/" target="_blank">https://docs.litellm.ai/docs/</a></p>
</div>
</body>
</html>
"""
return missing_keys_html_form.format(missing_keys=missing_key_names)
def admin_ui_disabled():
from fastapi.responses import HTMLResponse
ui_disabled_html = """
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<style>
body {{
font-family: Arial, sans-serif;
background-color: #f4f4f9;
color: #333;
margin: 20px;
line-height: 1.6;
}}
.container {{
max-width: 800px;
margin: auto;
padding: 20px;
background: #fff;
border: 1px solid #ddd;
border-radius: 5px;
box-shadow: 0 0 10px rgba(0, 0, 0, 0.1);
}}
h1 {{
font-size: 24px;
margin-bottom: 20px;
}}
pre {{
background: #f8f8f8;
padding: 1px;
border: 1px solid #ccc;
border-radius: 4px;
overflow-x: auto;
font-size: 14px;
}}
.env-var {{
font-weight: normal;
}}
.comment {{
font-weight: normal;
color: #777;
}}
</style>
<title>Admin UI Disabled</title>
</head>
<body>
<div class="container">
<h1>Admin UI is Disabled</h1>
<p>The Admin UI has been disabled by the administrator. To re-enable it, please update the following environment variable:</p>
<pre>
<span class="env-var">DISABLE_ADMIN_UI="False"</span> <span class="comment"># Set this to "False" to enable the Admin UI.</span>
</pre>
<p>After making this change, restart the application for it to take effect.</p>
</div>
<div class="container">
<h1>Need Help? Support</h1>
<p>Discord: <a href="https://discord.com/invite/wuPM9dRgDw" target="_blank">https://discord.com/invite/wuPM9dRgDw</a></p>
<p>Docs: <a href="https://docs.litellm.ai/docs/" target="_blank">https://docs.litellm.ai/docs/</a></p>
</div>
</body>
</html>
"""
return HTMLResponse(
content=ui_disabled_html,
status_code=200,
)

View File

@@ -0,0 +1,310 @@
from typing import Any, Dict, List, Optional
import litellm
from litellm import get_secret
from litellm._logging import verbose_proxy_logger
from litellm.proxy._types import CommonProxyErrors, LiteLLMPromptInjectionParams
from litellm.proxy.types_utils.utils import get_instance_fn
blue_color_code = "\033[94m"
reset_color_code = "\033[0m"
def initialize_callbacks_on_proxy( # noqa: PLR0915
value: Any,
premium_user: bool,
config_file_path: str,
litellm_settings: dict,
callback_specific_params: dict = {},
):
from litellm.proxy.proxy_server import prisma_client
verbose_proxy_logger.debug(
f"{blue_color_code}initializing callbacks={value} on proxy{reset_color_code}"
)
if isinstance(value, list):
imported_list: List[Any] = []
for callback in value: # ["presidio", <my-custom-callback>]
if (
isinstance(callback, str)
and callback in litellm._known_custom_logger_compatible_callbacks
):
imported_list.append(callback)
elif isinstance(callback, str) and callback == "presidio":
from litellm.proxy.guardrails.guardrail_hooks.presidio import (
_OPTIONAL_PresidioPIIMasking,
)
presidio_logging_only: Optional[bool] = litellm_settings.get(
"presidio_logging_only", None
)
if presidio_logging_only is not None:
presidio_logging_only = bool(
presidio_logging_only
) # validate boolean given
_presidio_params = {}
if "presidio" in callback_specific_params and isinstance(
callback_specific_params["presidio"], dict
):
_presidio_params = callback_specific_params["presidio"]
params: Dict[str, Any] = {
"logging_only": presidio_logging_only,
**_presidio_params,
}
pii_masking_object = _OPTIONAL_PresidioPIIMasking(**params)
imported_list.append(pii_masking_object)
elif isinstance(callback, str) and callback == "llamaguard_moderations":
from enterprise.enterprise_hooks.llama_guard import (
_ENTERPRISE_LlamaGuard,
)
if premium_user is not True:
raise Exception(
"Trying to use Llama Guard"
+ CommonProxyErrors.not_premium_user.value
)
llama_guard_object = _ENTERPRISE_LlamaGuard()
imported_list.append(llama_guard_object)
elif isinstance(callback, str) and callback == "hide_secrets":
from enterprise.enterprise_hooks.secret_detection import (
_ENTERPRISE_SecretDetection,
)
if premium_user is not True:
raise Exception(
"Trying to use secret hiding"
+ CommonProxyErrors.not_premium_user.value
)
_secret_detection_object = _ENTERPRISE_SecretDetection()
imported_list.append(_secret_detection_object)
elif isinstance(callback, str) and callback == "openai_moderations":
from enterprise.enterprise_hooks.openai_moderation import (
_ENTERPRISE_OpenAI_Moderation,
)
if premium_user is not True:
raise Exception(
"Trying to use OpenAI Moderations Check"
+ CommonProxyErrors.not_premium_user.value
)
openai_moderations_object = _ENTERPRISE_OpenAI_Moderation()
imported_list.append(openai_moderations_object)
elif isinstance(callback, str) and callback == "lakera_prompt_injection":
from litellm.proxy.guardrails.guardrail_hooks.lakera_ai import (
lakeraAI_Moderation,
)
init_params = {}
if "lakera_prompt_injection" in callback_specific_params:
init_params = callback_specific_params["lakera_prompt_injection"]
lakera_moderations_object = lakeraAI_Moderation(**init_params)
imported_list.append(lakera_moderations_object)
elif isinstance(callback, str) and callback == "aporia_prompt_injection":
from litellm.proxy.guardrails.guardrail_hooks.aporia_ai import (
AporiaGuardrail,
)
aporia_guardrail_object = AporiaGuardrail()
imported_list.append(aporia_guardrail_object)
elif isinstance(callback, str) and callback == "google_text_moderation":
from enterprise.enterprise_hooks.google_text_moderation import (
_ENTERPRISE_GoogleTextModeration,
)
if premium_user is not True:
raise Exception(
"Trying to use Google Text Moderation"
+ CommonProxyErrors.not_premium_user.value
)
google_text_moderation_obj = _ENTERPRISE_GoogleTextModeration()
imported_list.append(google_text_moderation_obj)
elif isinstance(callback, str) and callback == "llmguard_moderations":
from enterprise.enterprise_hooks.llm_guard import _ENTERPRISE_LLMGuard
if premium_user is not True:
raise Exception(
"Trying to use Llm Guard"
+ CommonProxyErrors.not_premium_user.value
)
llm_guard_moderation_obj = _ENTERPRISE_LLMGuard()
imported_list.append(llm_guard_moderation_obj)
elif isinstance(callback, str) and callback == "blocked_user_check":
from enterprise.enterprise_hooks.blocked_user_list import (
_ENTERPRISE_BlockedUserList,
)
if premium_user is not True:
raise Exception(
"Trying to use ENTERPRISE BlockedUser"
+ CommonProxyErrors.not_premium_user.value
)
blocked_user_list = _ENTERPRISE_BlockedUserList(
prisma_client=prisma_client
)
imported_list.append(blocked_user_list)
elif isinstance(callback, str) and callback == "banned_keywords":
from enterprise.enterprise_hooks.banned_keywords import (
_ENTERPRISE_BannedKeywords,
)
if premium_user is not True:
raise Exception(
"Trying to use ENTERPRISE BannedKeyword"
+ CommonProxyErrors.not_premium_user.value
)
banned_keywords_obj = _ENTERPRISE_BannedKeywords()
imported_list.append(banned_keywords_obj)
elif isinstance(callback, str) and callback == "detect_prompt_injection":
from litellm.proxy.hooks.prompt_injection_detection import (
_OPTIONAL_PromptInjectionDetection,
)
prompt_injection_params = None
if "prompt_injection_params" in litellm_settings:
prompt_injection_params_in_config = litellm_settings[
"prompt_injection_params"
]
prompt_injection_params = LiteLLMPromptInjectionParams(
**prompt_injection_params_in_config
)
prompt_injection_detection_obj = _OPTIONAL_PromptInjectionDetection(
prompt_injection_params=prompt_injection_params,
)
imported_list.append(prompt_injection_detection_obj)
elif isinstance(callback, str) and callback == "batch_redis_requests":
from litellm.proxy.hooks.batch_redis_get import (
_PROXY_BatchRedisRequests,
)
batch_redis_obj = _PROXY_BatchRedisRequests()
imported_list.append(batch_redis_obj)
elif isinstance(callback, str) and callback == "azure_content_safety":
from litellm.proxy.hooks.azure_content_safety import (
_PROXY_AzureContentSafety,
)
azure_content_safety_params = litellm_settings[
"azure_content_safety_params"
]
for k, v in azure_content_safety_params.items():
if (
v is not None
and isinstance(v, str)
and v.startswith("os.environ/")
):
azure_content_safety_params[k] = get_secret(v)
azure_content_safety_obj = _PROXY_AzureContentSafety(
**azure_content_safety_params,
)
imported_list.append(azure_content_safety_obj)
else:
verbose_proxy_logger.debug(
f"{blue_color_code} attempting to import custom calback={callback} {reset_color_code}"
)
imported_list.append(
get_instance_fn(
value=callback,
config_file_path=config_file_path,
)
)
if isinstance(litellm.callbacks, list):
litellm.callbacks.extend(imported_list)
else:
litellm.callbacks = imported_list # type: ignore
if "prometheus" in value:
from litellm.integrations.prometheus import PrometheusLogger
PrometheusLogger._mount_metrics_endpoint(premium_user)
else:
litellm.callbacks = [
get_instance_fn(
value=value,
config_file_path=config_file_path,
)
]
verbose_proxy_logger.debug(
f"{blue_color_code} Initialized Callbacks - {litellm.callbacks} {reset_color_code}"
)
def get_model_group_from_litellm_kwargs(kwargs: dict) -> Optional[str]:
_litellm_params = kwargs.get("litellm_params", None) or {}
_metadata = _litellm_params.get("metadata", None) or {}
_model_group = _metadata.get("model_group", None)
if _model_group is not None:
return _model_group
return None
def get_model_group_from_request_data(data: dict) -> Optional[str]:
_metadata = data.get("metadata", None) or {}
_model_group = _metadata.get("model_group", None)
if _model_group is not None:
return _model_group
return None
def get_remaining_tokens_and_requests_from_request_data(data: Dict) -> Dict[str, str]:
"""
Helper function to return x-litellm-key-remaining-tokens-{model_group} and x-litellm-key-remaining-requests-{model_group}
Returns {} when api_key + model rpm/tpm limit is not set
"""
headers = {}
_metadata = data.get("metadata", None) or {}
model_group = get_model_group_from_request_data(data)
# Remaining Requests
remaining_requests_variable_name = f"litellm-key-remaining-requests-{model_group}"
remaining_requests = _metadata.get(remaining_requests_variable_name, None)
if remaining_requests:
headers[f"x-litellm-key-remaining-requests-{model_group}"] = remaining_requests
# Remaining Tokens
remaining_tokens_variable_name = f"litellm-key-remaining-tokens-{model_group}"
remaining_tokens = _metadata.get(remaining_tokens_variable_name, None)
if remaining_tokens:
headers[f"x-litellm-key-remaining-tokens-{model_group}"] = remaining_tokens
return headers
def get_logging_caching_headers(request_data: Dict) -> Optional[Dict]:
_metadata = request_data.get("metadata", None) or {}
headers = {}
if "applied_guardrails" in _metadata:
headers["x-litellm-applied-guardrails"] = ",".join(
_metadata["applied_guardrails"]
)
if "semantic-similarity" in _metadata:
headers["x-litellm-semantic-similarity"] = str(_metadata["semantic-similarity"])
return headers
def add_guardrail_to_applied_guardrails_header(
request_data: Dict, guardrail_name: Optional[str]
):
if guardrail_name is None:
return
_metadata = request_data.get("metadata", None) or {}
if "applied_guardrails" in _metadata:
_metadata["applied_guardrails"].append(guardrail_name)
else:
_metadata["applied_guardrails"] = [guardrail_name]

View File

@@ -0,0 +1,242 @@
# Start tracing memory allocations
import json
import os
import tracemalloc
from fastapi import APIRouter
from litellm import get_secret_str
from litellm._logging import verbose_proxy_logger
router = APIRouter()
if os.environ.get("LITELLM_PROFILE", "false").lower() == "true":
try:
import objgraph # type: ignore
print("growth of objects") # noqa
objgraph.show_growth()
print("\n\nMost common types") # noqa
objgraph.show_most_common_types()
roots = objgraph.get_leaking_objects()
print("\n\nLeaking objects") # noqa
objgraph.show_most_common_types(objects=roots)
except ImportError:
raise ImportError(
"objgraph not found. Please install objgraph to use this feature."
)
tracemalloc.start(10)
@router.get("/memory-usage", include_in_schema=False)
async def memory_usage():
# Take a snapshot of the current memory usage
snapshot = tracemalloc.take_snapshot()
top_stats = snapshot.statistics("lineno")
verbose_proxy_logger.debug("TOP STATS: %s", top_stats)
# Get the top 50 memory usage lines
top_50 = top_stats[:50]
result = []
for stat in top_50:
result.append(f"{stat.traceback.format(limit=10)}: {stat.size / 1024} KiB")
return {"top_50_memory_usage": result}
@router.get("/memory-usage-in-mem-cache", include_in_schema=False)
async def memory_usage_in_mem_cache():
# returns the size of all in-memory caches on the proxy server
"""
1. user_api_key_cache
2. router_cache
3. proxy_logging_cache
4. internal_usage_cache
"""
from litellm.proxy.proxy_server import (
llm_router,
proxy_logging_obj,
user_api_key_cache,
)
if llm_router is None:
num_items_in_llm_router_cache = 0
else:
num_items_in_llm_router_cache = len(
llm_router.cache.in_memory_cache.cache_dict
) + len(llm_router.cache.in_memory_cache.ttl_dict)
num_items_in_user_api_key_cache = len(
user_api_key_cache.in_memory_cache.cache_dict
) + len(user_api_key_cache.in_memory_cache.ttl_dict)
num_items_in_proxy_logging_obj_cache = len(
proxy_logging_obj.internal_usage_cache.dual_cache.in_memory_cache.cache_dict
) + len(proxy_logging_obj.internal_usage_cache.dual_cache.in_memory_cache.ttl_dict)
return {
"num_items_in_user_api_key_cache": num_items_in_user_api_key_cache,
"num_items_in_llm_router_cache": num_items_in_llm_router_cache,
"num_items_in_proxy_logging_obj_cache": num_items_in_proxy_logging_obj_cache,
}
@router.get("/memory-usage-in-mem-cache-items", include_in_schema=False)
async def memory_usage_in_mem_cache_items():
# returns the size of all in-memory caches on the proxy server
"""
1. user_api_key_cache
2. router_cache
3. proxy_logging_cache
4. internal_usage_cache
"""
from litellm.proxy.proxy_server import (
llm_router,
proxy_logging_obj,
user_api_key_cache,
)
if llm_router is None:
llm_router_in_memory_cache_dict = {}
llm_router_in_memory_ttl_dict = {}
else:
llm_router_in_memory_cache_dict = llm_router.cache.in_memory_cache.cache_dict
llm_router_in_memory_ttl_dict = llm_router.cache.in_memory_cache.ttl_dict
return {
"user_api_key_cache": user_api_key_cache.in_memory_cache.cache_dict,
"user_api_key_ttl": user_api_key_cache.in_memory_cache.ttl_dict,
"llm_router_cache": llm_router_in_memory_cache_dict,
"llm_router_ttl": llm_router_in_memory_ttl_dict,
"proxy_logging_obj_cache": proxy_logging_obj.internal_usage_cache.dual_cache.in_memory_cache.cache_dict,
"proxy_logging_obj_ttl": proxy_logging_obj.internal_usage_cache.dual_cache.in_memory_cache.ttl_dict,
}
@router.get("/otel-spans", include_in_schema=False)
async def get_otel_spans():
from litellm.proxy.proxy_server import open_telemetry_logger
if open_telemetry_logger is None:
return {
"otel_spans": [],
"spans_grouped_by_parent": {},
"most_recent_parent": None,
}
otel_exporter = open_telemetry_logger.OTEL_EXPORTER
if hasattr(otel_exporter, "get_finished_spans"):
recorded_spans = otel_exporter.get_finished_spans() # type: ignore
else:
recorded_spans = []
print("Spans: ", recorded_spans) # noqa
most_recent_parent = None
most_recent_start_time = 1000000
spans_grouped_by_parent = {}
for span in recorded_spans:
if span.parent is not None:
parent_trace_id = span.parent.trace_id
if parent_trace_id not in spans_grouped_by_parent:
spans_grouped_by_parent[parent_trace_id] = []
spans_grouped_by_parent[parent_trace_id].append(span.name)
# check time of span
if span.start_time > most_recent_start_time:
most_recent_parent = parent_trace_id
most_recent_start_time = span.start_time
# these are otel spans - get the span name
span_names = [span.name for span in recorded_spans]
return {
"otel_spans": span_names,
"spans_grouped_by_parent": spans_grouped_by_parent,
"most_recent_parent": most_recent_parent,
}
# Helper functions for debugging
def init_verbose_loggers():
try:
worker_config = get_secret_str("WORKER_CONFIG")
# if not, assume it's a json string
if worker_config is None:
return
if os.path.isfile(worker_config):
return
_settings = json.loads(worker_config)
if not isinstance(_settings, dict):
return
debug = _settings.get("debug", None)
detailed_debug = _settings.get("detailed_debug", None)
if debug is True: # this needs to be first, so users can see Router init debugg
import logging
from litellm._logging import (
verbose_logger,
verbose_proxy_logger,
verbose_router_logger,
)
# this must ALWAYS remain logging.INFO, DO NOT MODIFY THIS
verbose_logger.setLevel(level=logging.INFO) # sets package logs to info
verbose_router_logger.setLevel(
level=logging.INFO
) # set router logs to info
verbose_proxy_logger.setLevel(level=logging.INFO) # set proxy logs to info
if detailed_debug is True:
import logging
from litellm._logging import (
verbose_logger,
verbose_proxy_logger,
verbose_router_logger,
)
verbose_logger.setLevel(level=logging.DEBUG) # set package log to debug
verbose_router_logger.setLevel(
level=logging.DEBUG
) # set router logs to debug
verbose_proxy_logger.setLevel(
level=logging.DEBUG
) # set proxy logs to debug
elif debug is False and detailed_debug is False:
# users can control proxy debugging using env variable = 'LITELLM_LOG'
litellm_log_setting = os.environ.get("LITELLM_LOG", "")
if litellm_log_setting is not None:
if litellm_log_setting.upper() == "INFO":
import logging
from litellm._logging import (
verbose_proxy_logger,
verbose_router_logger,
)
# this must ALWAYS remain logging.INFO, DO NOT MODIFY THIS
verbose_router_logger.setLevel(
level=logging.INFO
) # set router logs to info
verbose_proxy_logger.setLevel(
level=logging.INFO
) # set proxy logs to info
elif litellm_log_setting.upper() == "DEBUG":
import logging
from litellm._logging import (
verbose_proxy_logger,
verbose_router_logger,
)
verbose_router_logger.setLevel(
level=logging.DEBUG
) # set router logs to info
verbose_proxy_logger.setLevel(
level=logging.DEBUG
) # set proxy logs to debug
except Exception as e:
import logging
logging.warning(f"Failed to init verbose loggers: {str(e)}")

View File

@@ -0,0 +1,99 @@
import base64
import os
from typing import Optional
from litellm._logging import verbose_proxy_logger
def _get_salt_key():
from litellm.proxy.proxy_server import master_key
salt_key = os.getenv("LITELLM_SALT_KEY", None)
if salt_key is None:
verbose_proxy_logger.debug(
"LITELLM_SALT_KEY is None using master_key to encrypt/decrypt secrets stored in DB"
)
salt_key = master_key
return salt_key
def encrypt_value_helper(value: str, new_encryption_key: Optional[str] = None):
signing_key = new_encryption_key or _get_salt_key()
try:
if isinstance(value, str):
encrypted_value = encrypt_value(value=value, signing_key=signing_key) # type: ignore
encrypted_value = base64.b64encode(encrypted_value).decode("utf-8")
return encrypted_value
verbose_proxy_logger.debug(
f"Invalid value type passed to encrypt_value: {type(value)} for Value: {value}\n Value must be a string"
)
# if it's not a string - do not encrypt it and return the value
return value
except Exception as e:
raise e
def decrypt_value_helper(value: str):
signing_key = _get_salt_key()
try:
if isinstance(value, str):
decoded_b64 = base64.b64decode(value)
value = decrypt_value(value=decoded_b64, signing_key=signing_key) # type: ignore
return value
# if it's not str - do not decrypt it, return the value
return value
except Exception as e:
verbose_proxy_logger.error(
f"Error decrypting value, Did your master_key/salt key change recently? \nError: {str(e)}\nSet permanent salt key - https://docs.litellm.ai/docs/proxy/prod#5-set-litellm-salt-key"
)
# [Non-Blocking Exception. - this should not block decrypting other values]
pass
def encrypt_value(value: str, signing_key: str):
import hashlib
import nacl.secret
import nacl.utils
# get 32 byte master key #
hash_object = hashlib.sha256(signing_key.encode())
hash_bytes = hash_object.digest()
# initialize secret box #
box = nacl.secret.SecretBox(hash_bytes)
# encode message #
value_bytes = value.encode("utf-8")
encrypted = box.encrypt(value_bytes)
return encrypted
def decrypt_value(value: bytes, signing_key: str) -> str:
import hashlib
import nacl.secret
import nacl.utils
# get 32 byte master key #
hash_object = hashlib.sha256(signing_key.encode())
hash_bytes = hash_object.digest()
# initialize secret box #
box = nacl.secret.SecretBox(hash_bytes)
# Convert the bytes object to a string
plaintext = box.decrypt(value)
plaintext = plaintext.decode("utf-8") # type: ignore
return plaintext # type: ignore

View File

@@ -0,0 +1,284 @@
# JWT display template for SSO debug callback
jwt_display_template = """
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<title>LiteLLM SSO Debug - JWT Information</title>
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<style>
body {
font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, Oxygen, Ubuntu, Cantarell, sans-serif;
background-color: #f8fafc;
margin: 0;
padding: 20px;
display: flex;
justify-content: center;
align-items: center;
min-height: 100vh;
color: #333;
}
.container {
background-color: #fff;
padding: 40px;
border-radius: 8px;
box-shadow: 0 4px 12px rgba(0, 0, 0, 0.1);
width: 800px;
max-width: 100%;
}
.logo-container {
text-align: center;
margin-bottom: 30px;
}
.logo {
font-size: 24px;
font-weight: 600;
color: #1e293b;
}
h2 {
margin: 0 0 10px;
color: #1e293b;
font-size: 28px;
font-weight: 600;
text-align: center;
}
.subtitle {
color: #64748b;
margin: 0 0 20px;
font-size: 16px;
text-align: center;
}
.info-box {
background-color: #f1f5f9;
border-radius: 6px;
padding: 20px;
margin-bottom: 30px;
border-left: 4px solid #2563eb;
}
.success-box {
background-color: #f0fdf4;
border-radius: 6px;
padding: 20px;
margin-bottom: 30px;
border-left: 4px solid #16a34a;
}
.info-header {
display: flex;
align-items: center;
margin-bottom: 12px;
color: #1e40af;
font-weight: 600;
font-size: 16px;
}
.success-header {
display: flex;
align-items: center;
margin-bottom: 12px;
color: #166534;
font-weight: 600;
font-size: 16px;
}
.info-header svg, .success-header svg {
margin-right: 8px;
}
.data-container {
margin-top: 20px;
}
.data-row {
display: flex;
border-bottom: 1px solid #e2e8f0;
padding: 12px 0;
}
.data-row:last-child {
border-bottom: none;
}
.data-label {
font-weight: 500;
color: #334155;
width: 180px;
flex-shrink: 0;
}
.data-value {
color: #475569;
word-break: break-all;
}
.jwt-container {
background-color: #f8fafc;
border-radius: 6px;
padding: 15px;
margin-top: 20px;
overflow-x: auto;
border: 1px solid #e2e8f0;
}
.jwt-text {
font-family: monospace;
white-space: pre-wrap;
word-break: break-all;
margin: 0;
color: #334155;
}
.back-button {
display: inline-block;
background-color: #6466E9;
color: #fff;
text-decoration: none;
padding: 10px 16px;
border-radius: 6px;
font-weight: 500;
margin-top: 20px;
text-align: center;
}
.back-button:hover {
background-color: #4138C2;
text-decoration: none;
}
.buttons {
display: flex;
gap: 10px;
margin-top: 20px;
}
.copy-button {
background-color: #e2e8f0;
color: #334155;
border: none;
padding: 8px 12px;
border-radius: 4px;
cursor: pointer;
font-size: 14px;
display: flex;
align-items: center;
}
.copy-button:hover {
background-color: #cbd5e1;
}
.copy-button svg {
margin-right: 6px;
}
</style>
</head>
<body>
<div class="container">
<div class="logo-container">
<div class="logo">
🚅 LiteLLM
</div>
</div>
<h2>SSO Debug Information</h2>
<p class="subtitle">Results from the SSO authentication process.</p>
<div class="success-box">
<div class="success-header">
<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round">
<path d="M22 11.08V12a10 10 0 1 1-5.93-9.14"></path>
<polyline points="22 4 12 14.01 9 11.01"></polyline>
</svg>
Authentication Successful
</div>
<p>The SSO authentication completed successfully. Below is the information returned by the provider.</p>
</div>
<div class="data-container" id="userData">
<!-- Data will be inserted here by JavaScript -->
</div>
<div class="info-box">
<div class="info-header">
<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round">
<circle cx="12" cy="12" r="10"></circle>
<line x1="12" y1="16" x2="12" y2="12"></line>
<line x1="12" y1="8" x2="12.01" y2="8"></line>
</svg>
JSON Representation
</div>
<div class="jwt-container">
<pre class="jwt-text" id="jsonData">Loading...</pre>
</div>
<div class="buttons">
<button class="copy-button" onclick="copyToClipboard('jsonData')">
<svg xmlns="http://www.w3.org/2000/svg" width="14" height="14" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round">
<rect x="9" y="9" width="13" height="13" rx="2" ry="2"></rect>
<path d="M5 15H4a2 2 0 0 1-2-2V4a2 2 0 0 1 2-2h9a2 2 0 0 1 2 2v1"></path>
</svg>
Copy to Clipboard
</button>
</div>
</div>
<a href="/sso/debug/login" class="back-button">
Try Another SSO Login
</a>
</div>
<script>
// This will be populated with the actual data from the server
const userData = SSO_DATA;
function renderUserData() {
const container = document.getElementById('userData');
const jsonDisplay = document.getElementById('jsonData');
// Format JSON with indentation for display
jsonDisplay.textContent = JSON.stringify(userData, null, 2);
// Clear container
container.innerHTML = '';
// Add each key-value pair to the UI
for (const [key, value] of Object.entries(userData)) {
if (typeof value !== 'object' || value === null) {
const row = document.createElement('div');
row.className = 'data-row';
const label = document.createElement('div');
label.className = 'data-label';
label.textContent = key;
const dataValue = document.createElement('div');
dataValue.className = 'data-value';
dataValue.textContent = value !== null ? value : 'null';
row.appendChild(label);
row.appendChild(dataValue);
container.appendChild(row);
}
}
}
function copyToClipboard(elementId) {
const text = document.getElementById(elementId).textContent;
navigator.clipboard.writeText(text).then(() => {
alert('Copied to clipboard!');
}).catch(err => {
console.error('Could not copy text: ', err);
});
}
// Render the data when the page loads
document.addEventListener('DOMContentLoaded', renderUserData);
</script>
</body>
</html>
"""

View File

@@ -0,0 +1,217 @@
import os
url_to_redirect_to = os.getenv("PROXY_BASE_URL", "")
url_to_redirect_to += "/login"
html_form = f"""
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<title>LiteLLM Login</title>
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<style>
body {{
font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, Oxygen, Ubuntu, Cantarell, sans-serif;
background-color: #f8fafc;
margin: 0;
padding: 20px;
display: flex;
justify-content: center;
align-items: center;
min-height: 100vh;
color: #333;
}}
form {{
background-color: #fff;
padding: 40px;
border-radius: 8px;
box-shadow: 0 4px 12px rgba(0, 0, 0, 0.1);
width: 450px;
max-width: 100%;
}}
.logo-container {{
text-align: center;
margin-bottom: 30px;
}}
.logo {{
font-size: 24px;
font-weight: 600;
color: #1e293b;
}}
h2 {{
margin: 0 0 10px;
color: #1e293b;
font-size: 28px;
font-weight: 600;
text-align: center;
}}
.subtitle {{
color: #64748b;
margin: 0 0 20px;
font-size: 16px;
text-align: center;
}}
.info-box {{
background-color: #f1f5f9;
border-radius: 6px;
padding: 20px;
margin-bottom: 30px;
border-left: 4px solid #2563eb;
}}
.info-header {{
display: flex;
align-items: center;
margin-bottom: 12px;
color: #1e40af;
font-weight: 600;
font-size: 16px;
}}
.info-header svg {{
margin-right: 8px;
}}
.info-box p {{
color: #475569;
margin: 8px 0;
line-height: 1.5;
font-size: 14px;
}}
label {{
display: block;
margin-bottom: 8px;
font-weight: 500;
color: #334155;
font-size: 14px;
}}
.required {{
color: #dc2626;
margin-left: 2px;
}}
input[type="text"],
input[type="password"] {{
width: 100%;
padding: 10px 14px;
margin-bottom: 20px;
box-sizing: border-box;
border: 1px solid #e2e8f0;
border-radius: 6px;
font-size: 15px;
color: #1e293b;
background-color: #fff;
transition: border-color 0.2s, box-shadow 0.2s;
}}
input[type="text"]:focus,
input[type="password"]:focus {{
outline: none;
border-color: #3b82f6;
box-shadow: 0 0 0 2px rgba(59, 130, 246, 0.2);
}}
.toggle-password {{
display: flex;
align-items: center;
margin-top: -15px;
margin-bottom: 20px;
}}
.toggle-password input {{
margin-right: 6px;
}}
input[type="submit"] {{
background-color: #6466E9;
color: #fff;
cursor: pointer;
font-weight: 500;
border: none;
padding: 10px 16px;
transition: background-color 0.2s;
border-radius: 6px;
margin-top: 10px;
font-size: 14px;
width: 100%;
}}
input[type="submit"]:hover {{
background-color: #4138C2;
}}
a {{
color: #3b82f6;
text-decoration: none;
}}
a:hover {{
text-decoration: underline;
}}
code {{
background-color: #f1f5f9;
padding: 2px 4px;
border-radius: 4px;
font-family: monospace;
font-size: 13px;
color: #334155;
}}
.help-text {{
color: #64748b;
font-size: 14px;
margin-top: -12px;
margin-bottom: 20px;
}}
</style>
</head>
<body>
<form action="{url_to_redirect_to}" method="post">
<div class="logo-container">
<div class="logo">
🚅 LiteLLM
</div>
</div>
<h2>Login</h2>
<p class="subtitle">Access your LiteLLM Admin UI.</p>
<div class="info-box">
<div class="info-header">
<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round">
<circle cx="12" cy="12" r="10"></circle>
<line x1="12" y1="16" x2="12" y2="12"></line>
<line x1="12" y1="8" x2="12.01" y2="8"></line>
</svg>
Default Credentials
</div>
<p>By default, Username is <code>admin</code> and Password is your set LiteLLM Proxy <code>MASTER_KEY</code>.</p>
<p>Need to set UI credentials or SSO? <a href="https://docs.litellm.ai/docs/proxy/ui" target="_blank">Check the documentation</a>.</p>
</div>
<label for="username">Username<span class="required">*</span></label>
<input type="text" id="username" name="username" required placeholder="Enter your username" autocomplete="username">
<label for="password">Password<span class="required">*</span></label>
<input type="password" id="password" name="password" required placeholder="Enter your password" autocomplete="current-password">
<div class="toggle-password">
<input type="checkbox" id="show-password" onclick="togglePasswordVisibility()">
<label for="show-password">Show password</label>
</div>
<input type="submit" value="Login">
</form>
<script>
function togglePasswordVisibility() {{
var passwordField = document.getElementById("password");
passwordField.type = passwordField.type === "password" ? "text" : "password";
}}
</script>
</body>
</html>
"""

View File

@@ -0,0 +1,187 @@
import json
from typing import Dict, List, Optional
import orjson
from fastapi import Request, UploadFile, status
from litellm._logging import verbose_proxy_logger
from litellm.types.router import Deployment
async def _read_request_body(request: Optional[Request]) -> Dict:
"""
Safely read the request body and parse it as JSON.
Parameters:
- request: The request object to read the body from
Returns:
- dict: Parsed request data as a dictionary or an empty dictionary if parsing fails
"""
try:
if request is None:
return {}
# Check if we already read and parsed the body
_cached_request_body: Optional[dict] = _safe_get_request_parsed_body(
request=request
)
if _cached_request_body is not None:
return _cached_request_body
_request_headers: dict = _safe_get_request_headers(request=request)
content_type = _request_headers.get("content-type", "")
if "form" in content_type:
parsed_body = dict(await request.form())
else:
# Read the request body
body = await request.body()
# Return empty dict if body is empty or None
if not body:
parsed_body = {}
else:
try:
parsed_body = orjson.loads(body)
except orjson.JSONDecodeError:
# Fall back to the standard json module which is more forgiving
# First decode bytes to string if needed
body_str = body.decode("utf-8") if isinstance(body, bytes) else body
# Replace invalid surrogate pairs
import re
# This regex finds incomplete surrogate pairs
body_str = re.sub(
r"[\uD800-\uDBFF](?![\uDC00-\uDFFF])", "", body_str
)
# This regex finds low surrogates without high surrogates
body_str = re.sub(
r"(?<![\uD800-\uDBFF])[\uDC00-\uDFFF]", "", body_str
)
parsed_body = json.loads(body_str)
# Cache the parsed result
_safe_set_request_parsed_body(request=request, parsed_body=parsed_body)
return parsed_body
except (json.JSONDecodeError, orjson.JSONDecodeError):
verbose_proxy_logger.exception("Invalid JSON payload received.")
return {}
except Exception as e:
# Catch unexpected errors to avoid crashes
verbose_proxy_logger.exception(
"Unexpected error reading request body - {}".format(e)
)
return {}
def _safe_get_request_parsed_body(request: Optional[Request]) -> Optional[dict]:
if request is None:
return None
if (
hasattr(request, "scope")
and "parsed_body" in request.scope
and isinstance(request.scope["parsed_body"], tuple)
):
accepted_keys, parsed_body = request.scope["parsed_body"]
return {key: parsed_body[key] for key in accepted_keys}
return None
def _safe_set_request_parsed_body(
request: Optional[Request],
parsed_body: dict,
) -> None:
try:
if request is None:
return
request.scope["parsed_body"] = (tuple(parsed_body.keys()), parsed_body)
except Exception as e:
verbose_proxy_logger.debug(
"Unexpected error setting request parsed body - {}".format(e)
)
def _safe_get_request_headers(request: Optional[Request]) -> dict:
"""
[Non-Blocking] Safely get the request headers
"""
try:
if request is None:
return {}
return dict(request.headers)
except Exception as e:
verbose_proxy_logger.debug(
"Unexpected error reading request headers - {}".format(e)
)
return {}
def check_file_size_under_limit(
request_data: dict,
file: UploadFile,
router_model_names: List[str],
) -> bool:
"""
Check if any files passed in request are under max_file_size_mb
Returns True -> when file size is under max_file_size_mb limit
Raises ProxyException -> when file size is over max_file_size_mb limit or not a premium_user
"""
from litellm.proxy.proxy_server import (
CommonProxyErrors,
ProxyException,
llm_router,
premium_user,
)
file_contents_size = file.size or 0
file_content_size_in_mb = file_contents_size / (1024 * 1024)
if "metadata" not in request_data:
request_data["metadata"] = {}
request_data["metadata"]["file_size_in_mb"] = file_content_size_in_mb
max_file_size_mb = None
if llm_router is not None and request_data["model"] in router_model_names:
try:
deployment: Optional[
Deployment
] = llm_router.get_deployment_by_model_group_name(
model_group_name=request_data["model"]
)
if (
deployment
and deployment.litellm_params is not None
and deployment.litellm_params.max_file_size_mb is not None
):
max_file_size_mb = deployment.litellm_params.max_file_size_mb
except Exception as e:
verbose_proxy_logger.error(
"Got error when checking file size: %s", (str(e))
)
if max_file_size_mb is not None:
verbose_proxy_logger.debug(
"Checking file size, file content size=%s, max_file_size_mb=%s",
file_content_size_in_mb,
max_file_size_mb,
)
if not premium_user:
raise ProxyException(
message=f"Tried setting max_file_size_mb for /audio/transcriptions. {CommonProxyErrors.not_premium_user.value}",
code=status.HTTP_400_BAD_REQUEST,
type="bad_request",
param="file",
)
if file_content_size_in_mb > max_file_size_mb:
raise ProxyException(
message=f"File size is too large. Please check your file size. Passed file size: {file_content_size_in_mb} MB. Max file size: {max_file_size_mb} MB",
code=status.HTTP_400_BAD_REQUEST,
type="bad_request",
param="file",
)
return True

View File

@@ -0,0 +1,76 @@
import yaml
from litellm._logging import verbose_proxy_logger
def get_file_contents_from_s3(bucket_name, object_key):
try:
# v0 rely on boto3 for authentication - allowing boto3 to handle IAM credentials etc
import tempfile
import boto3
from botocore.credentials import Credentials
from litellm.main import bedrock_converse_chat_completion
credentials: Credentials = bedrock_converse_chat_completion.get_credentials()
s3_client = boto3.client(
"s3",
aws_access_key_id=credentials.access_key,
aws_secret_access_key=credentials.secret_key,
aws_session_token=credentials.token, # Optional, if using temporary credentials
)
verbose_proxy_logger.debug(
f"Retrieving {object_key} from S3 bucket: {bucket_name}"
)
response = s3_client.get_object(Bucket=bucket_name, Key=object_key)
verbose_proxy_logger.debug(f"Response: {response}")
# Read the file contents
file_contents = response["Body"].read().decode("utf-8")
verbose_proxy_logger.debug("File contents retrieved from S3")
# Create a temporary file with YAML extension
with tempfile.NamedTemporaryFile(delete=False, suffix=".yaml") as temp_file:
temp_file.write(file_contents.encode("utf-8"))
temp_file_path = temp_file.name
verbose_proxy_logger.debug(f"File stored temporarily at: {temp_file_path}")
# Load the YAML file content
with open(temp_file_path, "r") as yaml_file:
config = yaml.safe_load(yaml_file)
return config
except ImportError as e:
# this is most likely if a user is not using the litellm docker container
verbose_proxy_logger.error(f"ImportError: {str(e)}")
pass
except Exception as e:
verbose_proxy_logger.error(f"Error retrieving file contents: {str(e)}")
return None
async def get_config_file_contents_from_gcs(bucket_name, object_key):
try:
from litellm.integrations.gcs_bucket.gcs_bucket import GCSBucketLogger
gcs_bucket = GCSBucketLogger(
bucket_name=bucket_name,
)
file_contents = await gcs_bucket.download_gcs_object(object_key)
if file_contents is None:
raise Exception(f"File contents are None for {object_key}")
# file_contentis is a bytes object, so we need to convert it to yaml
file_contents = file_contents.decode("utf-8")
# convert to yaml
config = yaml.safe_load(file_contents)
return config
except Exception as e:
verbose_proxy_logger.error(f"Error retrieving file contents: {str(e)}")
return None
# # Example usage
# bucket_name = 'litellm-proxy'
# object_key = 'litellm_proxy_config.yaml'

View File

@@ -0,0 +1,39 @@
"""
Contains utils used by OpenAI compatible endpoints
"""
from typing import Optional
from fastapi import Request
from litellm.proxy.common_utils.http_parsing_utils import _read_request_body
def remove_sensitive_info_from_deployment(deployment_dict: dict) -> dict:
"""
Removes sensitive information from a deployment dictionary.
Args:
deployment_dict (dict): The deployment dictionary to remove sensitive information from.
Returns:
dict: The modified deployment dictionary with sensitive information removed.
"""
deployment_dict["litellm_params"].pop("api_key", None)
deployment_dict["litellm_params"].pop("vertex_credentials", None)
deployment_dict["litellm_params"].pop("aws_access_key_id", None)
deployment_dict["litellm_params"].pop("aws_secret_access_key", None)
return deployment_dict
async def get_custom_llm_provider_from_request_body(request: Request) -> Optional[str]:
"""
Get the `custom_llm_provider` from the request body
Safely reads the request body
"""
request_body: dict = await _read_request_body(request=request) or {}
if "custom_llm_provider" in request_body:
return request_body["custom_llm_provider"]
return None

View File

@@ -0,0 +1,36 @@
"""
This file is used to store the state variables of the proxy server.
Example: `spend_logs_row_count` is used to store the number of rows in the `LiteLLM_SpendLogs` table.
"""
from typing import Any, Literal
from litellm.proxy._types import ProxyStateVariables
class ProxyState:
"""
Proxy state class has get/set methods for Proxy state variables.
"""
# Note: mypy does not recognize when we fetch ProxyStateVariables.annotations.keys(), so we also need to add the valid keys here
valid_keys_literal = Literal["spend_logs_row_count"]
def __init__(self) -> None:
self.proxy_state_variables: ProxyStateVariables = ProxyStateVariables(
spend_logs_row_count=0,
)
def get_proxy_state_variable(
self,
variable_name: valid_keys_literal,
) -> Any:
return self.proxy_state_variables.get(variable_name, None)
def set_proxy_state_variable(
self,
variable_name: valid_keys_literal,
value: Any,
) -> None:
self.proxy_state_variables[variable_name] = value

View File

@@ -0,0 +1,365 @@
import asyncio
import json
import time
from datetime import datetime, timedelta
from typing import List, Literal, Optional, Union
from litellm._logging import verbose_proxy_logger
from litellm.litellm_core_utils.duration_parser import duration_in_seconds
from litellm.proxy._types import (
LiteLLM_TeamTable,
LiteLLM_UserTable,
LiteLLM_VerificationToken,
)
from litellm.proxy.utils import PrismaClient, ProxyLogging
from litellm.types.services import ServiceTypes
class ResetBudgetJob:
"""
Resets the budget for all the keys, users, and teams that need it
"""
def __init__(self, proxy_logging_obj: ProxyLogging, prisma_client: PrismaClient):
self.proxy_logging_obj: ProxyLogging = proxy_logging_obj
self.prisma_client: PrismaClient = prisma_client
async def reset_budget(
self,
):
"""
Gets all the non-expired keys for a db, which need spend to be reset
Resets their spend
Updates db
"""
if self.prisma_client is not None:
### RESET KEY BUDGET ###
await self.reset_budget_for_litellm_keys()
### RESET USER BUDGET ###
await self.reset_budget_for_litellm_users()
## Reset Team Budget
await self.reset_budget_for_litellm_teams()
async def reset_budget_for_litellm_keys(self):
"""
Resets the budget for all the litellm keys
Catches Exceptions and logs them
"""
now = datetime.utcnow()
start_time = time.time()
keys_to_reset: Optional[List[LiteLLM_VerificationToken]] = None
try:
keys_to_reset = await self.prisma_client.get_data(
table_name="key", query_type="find_all", expires=now, reset_at=now
)
verbose_proxy_logger.debug(
"Keys to reset %s", json.dumps(keys_to_reset, indent=4, default=str)
)
updated_keys: List[LiteLLM_VerificationToken] = []
failed_keys = []
if keys_to_reset is not None and len(keys_to_reset) > 0:
for key in keys_to_reset:
try:
updated_key = await ResetBudgetJob._reset_budget_for_key(
key=key, current_time=now
)
if updated_key is not None:
updated_keys.append(updated_key)
else:
failed_keys.append(
{"key": key, "error": "Returned None without exception"}
)
except Exception as e:
failed_keys.append({"key": key, "error": str(e)})
verbose_proxy_logger.exception(
"Failed to reset budget for key: %s", key
)
verbose_proxy_logger.debug(
"Updated keys %s", json.dumps(updated_keys, indent=4, default=str)
)
if updated_keys:
await self.prisma_client.update_data(
query_type="update_many",
data_list=updated_keys,
table_name="key",
)
end_time = time.time()
if len(failed_keys) > 0: # If any keys failed to reset
raise Exception(
f"Failed to reset {len(failed_keys)} keys: {json.dumps(failed_keys, default=str)}"
)
asyncio.create_task(
self.proxy_logging_obj.service_logging_obj.async_service_success_hook(
service=ServiceTypes.RESET_BUDGET_JOB,
duration=end_time - start_time,
call_type="reset_budget_keys",
start_time=start_time,
end_time=end_time,
event_metadata={
"num_keys_found": len(keys_to_reset) if keys_to_reset else 0,
"keys_found": json.dumps(keys_to_reset, indent=4, default=str),
"num_keys_updated": len(updated_keys),
"keys_updated": json.dumps(updated_keys, indent=4, default=str),
"num_keys_failed": len(failed_keys),
"keys_failed": json.dumps(failed_keys, indent=4, default=str),
},
)
)
except Exception as e:
end_time = time.time()
asyncio.create_task(
self.proxy_logging_obj.service_logging_obj.async_service_failure_hook(
service=ServiceTypes.RESET_BUDGET_JOB,
duration=end_time - start_time,
error=e,
call_type="reset_budget_keys",
start_time=start_time,
end_time=end_time,
event_metadata={
"num_keys_found": len(keys_to_reset) if keys_to_reset else 0,
"keys_found": json.dumps(keys_to_reset, indent=4, default=str),
},
)
)
verbose_proxy_logger.exception("Failed to reset budget for keys: %s", e)
async def reset_budget_for_litellm_users(self):
"""
Resets the budget for all LiteLLM Internal Users if their budget has expired
"""
now = datetime.utcnow()
start_time = time.time()
users_to_reset: Optional[List[LiteLLM_UserTable]] = None
try:
users_to_reset = await self.prisma_client.get_data(
table_name="user", query_type="find_all", reset_at=now
)
updated_users: List[LiteLLM_UserTable] = []
failed_users = []
if users_to_reset is not None and len(users_to_reset) > 0:
for user in users_to_reset:
try:
updated_user = await ResetBudgetJob._reset_budget_for_user(
user=user, current_time=now
)
if updated_user is not None:
updated_users.append(updated_user)
else:
failed_users.append(
{
"user": user,
"error": "Returned None without exception",
}
)
except Exception as e:
failed_users.append({"user": user, "error": str(e)})
verbose_proxy_logger.exception(
"Failed to reset budget for user: %s", user
)
verbose_proxy_logger.debug(
"Updated users %s", json.dumps(updated_users, indent=4, default=str)
)
if updated_users:
await self.prisma_client.update_data(
query_type="update_many",
data_list=updated_users,
table_name="user",
)
end_time = time.time()
if len(failed_users) > 0: # If any users failed to reset
raise Exception(
f"Failed to reset {len(failed_users)} users: {json.dumps(failed_users, default=str)}"
)
asyncio.create_task(
self.proxy_logging_obj.service_logging_obj.async_service_success_hook(
service=ServiceTypes.RESET_BUDGET_JOB,
duration=end_time - start_time,
call_type="reset_budget_users",
start_time=start_time,
end_time=end_time,
event_metadata={
"num_users_found": len(users_to_reset) if users_to_reset else 0,
"users_found": json.dumps(
users_to_reset, indent=4, default=str
),
"num_users_updated": len(updated_users),
"users_updated": json.dumps(
updated_users, indent=4, default=str
),
"num_users_failed": len(failed_users),
"users_failed": json.dumps(failed_users, indent=4, default=str),
},
)
)
except Exception as e:
end_time = time.time()
asyncio.create_task(
self.proxy_logging_obj.service_logging_obj.async_service_failure_hook(
service=ServiceTypes.RESET_BUDGET_JOB,
duration=end_time - start_time,
error=e,
call_type="reset_budget_users",
start_time=start_time,
end_time=end_time,
event_metadata={
"num_users_found": len(users_to_reset) if users_to_reset else 0,
"users_found": json.dumps(
users_to_reset, indent=4, default=str
),
},
)
)
verbose_proxy_logger.exception("Failed to reset budget for users: %s", e)
async def reset_budget_for_litellm_teams(self):
"""
Resets the budget for all LiteLLM Internal Teams if their budget has expired
"""
now = datetime.utcnow()
start_time = time.time()
teams_to_reset: Optional[List[LiteLLM_TeamTable]] = None
try:
teams_to_reset = await self.prisma_client.get_data(
table_name="team", query_type="find_all", reset_at=now
)
updated_teams: List[LiteLLM_TeamTable] = []
failed_teams = []
if teams_to_reset is not None and len(teams_to_reset) > 0:
for team in teams_to_reset:
try:
updated_team = await ResetBudgetJob._reset_budget_for_team(
team=team, current_time=now
)
if updated_team is not None:
updated_teams.append(updated_team)
else:
failed_teams.append(
{
"team": team,
"error": "Returned None without exception",
}
)
except Exception as e:
failed_teams.append({"team": team, "error": str(e)})
verbose_proxy_logger.exception(
"Failed to reset budget for team: %s", team
)
verbose_proxy_logger.debug(
"Updated teams %s", json.dumps(updated_teams, indent=4, default=str)
)
if updated_teams:
await self.prisma_client.update_data(
query_type="update_many",
data_list=updated_teams,
table_name="team",
)
end_time = time.time()
if len(failed_teams) > 0: # If any teams failed to reset
raise Exception(
f"Failed to reset {len(failed_teams)} teams: {json.dumps(failed_teams, default=str)}"
)
asyncio.create_task(
self.proxy_logging_obj.service_logging_obj.async_service_success_hook(
service=ServiceTypes.RESET_BUDGET_JOB,
duration=end_time - start_time,
call_type="reset_budget_teams",
start_time=start_time,
end_time=end_time,
event_metadata={
"num_teams_found": len(teams_to_reset) if teams_to_reset else 0,
"teams_found": json.dumps(
teams_to_reset, indent=4, default=str
),
"num_teams_updated": len(updated_teams),
"teams_updated": json.dumps(
updated_teams, indent=4, default=str
),
"num_teams_failed": len(failed_teams),
"teams_failed": json.dumps(failed_teams, indent=4, default=str),
},
)
)
except Exception as e:
end_time = time.time()
asyncio.create_task(
self.proxy_logging_obj.service_logging_obj.async_service_failure_hook(
service=ServiceTypes.RESET_BUDGET_JOB,
duration=end_time - start_time,
error=e,
call_type="reset_budget_teams",
start_time=start_time,
end_time=end_time,
event_metadata={
"num_teams_found": len(teams_to_reset) if teams_to_reset else 0,
"teams_found": json.dumps(
teams_to_reset, indent=4, default=str
),
},
)
)
verbose_proxy_logger.exception("Failed to reset budget for teams: %s", e)
@staticmethod
async def _reset_budget_common(
item: Union[LiteLLM_TeamTable, LiteLLM_UserTable, LiteLLM_VerificationToken],
current_time: datetime,
item_type: Literal["key", "team", "user"],
):
"""
In-place, updates spend=0, and sets budget_reset_at to current_time + budget_duration
Common logic for resetting budget for a team, user, or key
"""
try:
item.spend = 0.0
if hasattr(item, "budget_duration") and item.budget_duration is not None:
duration_s = duration_in_seconds(duration=item.budget_duration)
item.budget_reset_at = current_time + timedelta(seconds=duration_s)
return item
except Exception as e:
verbose_proxy_logger.exception(
"Error resetting budget for %s: %s. Item: %s", item_type, e, item
)
raise e
@staticmethod
async def _reset_budget_for_team(
team: LiteLLM_TeamTable, current_time: datetime
) -> Optional[LiteLLM_TeamTable]:
await ResetBudgetJob._reset_budget_common(
item=team, current_time=current_time, item_type="team"
)
return team
@staticmethod
async def _reset_budget_for_user(
user: LiteLLM_UserTable, current_time: datetime
) -> Optional[LiteLLM_UserTable]:
await ResetBudgetJob._reset_budget_common(
item=user, current_time=current_time, item_type="user"
)
return user
@staticmethod
async def _reset_budget_for_key(
key: LiteLLM_VerificationToken, current_time: datetime
) -> Optional[LiteLLM_VerificationToken]:
await ResetBudgetJob._reset_budget_common(
item=key, current_time=current_time, item_type="key"
)
return key

View File

@@ -0,0 +1,48 @@
from typing import Any, Dict
from pydantic import BaseModel, Field
from litellm.exceptions import LITELLM_EXCEPTION_TYPES
class ErrorResponse(BaseModel):
detail: Dict[str, Any] = Field(
...,
example={ # type: ignore
"error": {
"message": "Error message",
"type": "error_type",
"param": "error_param",
"code": "error_code",
}
},
)
# Define a function to get the status code
def get_status_code(exception):
if hasattr(exception, "status_code"):
return exception.status_code
# Default status codes for exceptions without a status_code attribute
if exception.__name__ == "Timeout":
return 408 # Request Timeout
if exception.__name__ == "APIConnectionError":
return 503 # Service Unavailable
return 500 # Internal Server Error as default
# Create error responses
ERROR_RESPONSES = {
get_status_code(exception): {
"model": ErrorResponse,
"description": exception.__doc__ or exception.__name__,
}
for exception in LITELLM_EXCEPTION_TYPES
}
# Ensure we have a 500 error response
if 500 not in ERROR_RESPONSES:
ERROR_RESPONSES[500] = {
"model": ErrorResponse,
"description": "Internal Server Error",
}