structure saas with tools

This commit is contained in:
Davidson Gomes
2025-04-25 15:30:54 -03:00
commit 1aef473937
16434 changed files with 6584257 additions and 0 deletions

View File

@@ -0,0 +1,53 @@
from typing import Any, Literal, List
class CustomDB:
"""
Implements a base class that we expect any custom db implementation (e.g. DynamoDB) to follow
"""
def __init__(self) -> None:
pass
def get_data(self, key: str, table_name: Literal["user", "key", "config"]):
"""
Check if key valid
"""
pass
def insert_data(self, value: Any, table_name: Literal["user", "key", "config"]):
"""
For new key / user logic
"""
pass
def update_data(
self, key: str, value: Any, table_name: Literal["user", "key", "config"]
):
"""
For cost tracking logic
"""
pass
def delete_data(
self, keys: List[str], table_name: Literal["user", "key", "config"]
):
"""
For /key/delete endpoint s
"""
def connect(
self,
):
"""
For connecting to db and creating / updating any tables
"""
pass
def disconnect(
self,
):
"""
For closing connection on server shutdown
"""
pass

View File

@@ -0,0 +1,104 @@
"""Module for checking differences between Prisma schema and database."""
import os
import subprocess
from typing import List, Optional, Tuple
from litellm._logging import verbose_logger
def extract_sql_commands(diff_output: str) -> List[str]:
"""
Extract SQL commands from the Prisma migrate diff output.
Args:
diff_output (str): The full output from prisma migrate diff.
Returns:
List[str]: A list of SQL commands extracted from the diff output.
"""
# Split the output into lines and remove empty lines
lines = [line.strip() for line in diff_output.split("\n") if line.strip()]
sql_commands = []
current_command = ""
in_sql_block = False
for line in lines:
if line.startswith("-- "): # Comment line, likely a table operation description
if in_sql_block and current_command:
sql_commands.append(current_command.strip())
current_command = ""
in_sql_block = True
elif in_sql_block:
if line.endswith(";"):
current_command += line
sql_commands.append(current_command.strip())
current_command = ""
in_sql_block = False
else:
current_command += line + " "
# Add any remaining command
if current_command:
sql_commands.append(current_command.strip())
return sql_commands
def check_prisma_schema_diff_helper(db_url: str) -> Tuple[bool, List[str]]:
"""Checks for differences between current database and Prisma schema.
Returns:
A tuple containing:
- A boolean indicating if differences were found (True) or not (False).
- A string with the diff output or error message.
Raises:
subprocess.CalledProcessError: If the Prisma command fails.
Exception: For any other errors during execution.
"""
verbose_logger.debug("Checking for Prisma schema diff...") # noqa: T201
try:
result = subprocess.run(
[
"prisma",
"migrate",
"diff",
"--from-url",
db_url,
"--to-schema-datamodel",
"./schema.prisma",
"--script",
],
capture_output=True,
text=True,
check=True,
)
# return True, "Migration diff generated successfully."
sql_commands = extract_sql_commands(result.stdout)
if sql_commands:
print("Changes to DB Schema detected") # noqa: T201
print("Required SQL commands:") # noqa: T201
for command in sql_commands:
print(command) # noqa: T201
return True, sql_commands
else:
return False, []
except subprocess.CalledProcessError as e:
error_message = f"Failed to generate migration diff. Error: {e.stderr}"
print(error_message) # noqa: T201
return False, []
def check_prisma_schema_diff(db_url: Optional[str] = None) -> None:
"""Main function to run the Prisma schema diff check."""
if db_url is None:
db_url = os.getenv("DATABASE_URL")
if db_url is None:
raise Exception("DATABASE_URL not set")
has_diff, message = check_prisma_schema_diff_helper(db_url)
if has_diff:
verbose_logger.exception(
"🚨🚨🚨 prisma schema out of sync with db. Consider running these sql_commands to sync the two - {}".format(
message
)
)

View File

@@ -0,0 +1,227 @@
from typing import Any
from litellm import verbose_logger
_db = Any
async def create_missing_views(db: _db): # noqa: PLR0915
"""
--------------------------------------------------
NOTE: Copy of `litellm/db_scripts/create_views.py`.
--------------------------------------------------
Checks if the LiteLLM_VerificationTokenView and MonthlyGlobalSpend exists in the user's db.
LiteLLM_VerificationTokenView: This view is used for getting the token + team data in user_api_key_auth
MonthlyGlobalSpend: This view is used for the admin view to see global spend for this month
If the view doesn't exist, one will be created.
"""
try:
# Try to select one row from the view
await db.query_raw("""SELECT 1 FROM "LiteLLM_VerificationTokenView" LIMIT 1""")
print("LiteLLM_VerificationTokenView Exists!") # noqa
except Exception:
# If an error occurs, the view does not exist, so create it
await db.execute_raw(
"""
CREATE VIEW "LiteLLM_VerificationTokenView" AS
SELECT
v.*,
t.spend AS team_spend,
t.max_budget AS team_max_budget,
t.tpm_limit AS team_tpm_limit,
t.rpm_limit AS team_rpm_limit
FROM "LiteLLM_VerificationToken" v
LEFT JOIN "LiteLLM_TeamTable" t ON v.team_id = t.team_id;
"""
)
print("LiteLLM_VerificationTokenView Created!") # noqa
try:
await db.query_raw("""SELECT 1 FROM "MonthlyGlobalSpend" LIMIT 1""")
print("MonthlyGlobalSpend Exists!") # noqa
except Exception:
sql_query = """
CREATE OR REPLACE VIEW "MonthlyGlobalSpend" AS
SELECT
DATE("startTime") AS date,
SUM("spend") AS spend
FROM
"LiteLLM_SpendLogs"
WHERE
"startTime" >= (CURRENT_DATE - INTERVAL '30 days')
GROUP BY
DATE("startTime");
"""
await db.execute_raw(query=sql_query)
print("MonthlyGlobalSpend Created!") # noqa
try:
await db.query_raw("""SELECT 1 FROM "Last30dKeysBySpend" LIMIT 1""")
print("Last30dKeysBySpend Exists!") # noqa
except Exception:
sql_query = """
CREATE OR REPLACE VIEW "Last30dKeysBySpend" AS
SELECT
L."api_key",
V."key_alias",
V."key_name",
SUM(L."spend") AS total_spend
FROM
"LiteLLM_SpendLogs" L
LEFT JOIN
"LiteLLM_VerificationToken" V
ON
L."api_key" = V."token"
WHERE
L."startTime" >= (CURRENT_DATE - INTERVAL '30 days')
GROUP BY
L."api_key", V."key_alias", V."key_name"
ORDER BY
total_spend DESC;
"""
await db.execute_raw(query=sql_query)
print("Last30dKeysBySpend Created!") # noqa
try:
await db.query_raw("""SELECT 1 FROM "Last30dModelsBySpend" LIMIT 1""")
print("Last30dModelsBySpend Exists!") # noqa
except Exception:
sql_query = """
CREATE OR REPLACE VIEW "Last30dModelsBySpend" AS
SELECT
"model",
SUM("spend") AS total_spend
FROM
"LiteLLM_SpendLogs"
WHERE
"startTime" >= (CURRENT_DATE - INTERVAL '30 days')
AND "model" != ''
GROUP BY
"model"
ORDER BY
total_spend DESC;
"""
await db.execute_raw(query=sql_query)
print("Last30dModelsBySpend Created!") # noqa
try:
await db.query_raw("""SELECT 1 FROM "MonthlyGlobalSpendPerKey" LIMIT 1""")
print("MonthlyGlobalSpendPerKey Exists!") # noqa
except Exception:
sql_query = """
CREATE OR REPLACE VIEW "MonthlyGlobalSpendPerKey" AS
SELECT
DATE("startTime") AS date,
SUM("spend") AS spend,
api_key as api_key
FROM
"LiteLLM_SpendLogs"
WHERE
"startTime" >= (CURRENT_DATE - INTERVAL '30 days')
GROUP BY
DATE("startTime"),
api_key;
"""
await db.execute_raw(query=sql_query)
print("MonthlyGlobalSpendPerKey Created!") # noqa
try:
await db.query_raw(
"""SELECT 1 FROM "MonthlyGlobalSpendPerUserPerKey" LIMIT 1"""
)
print("MonthlyGlobalSpendPerUserPerKey Exists!") # noqa
except Exception:
sql_query = """
CREATE OR REPLACE VIEW "MonthlyGlobalSpendPerUserPerKey" AS
SELECT
DATE("startTime") AS date,
SUM("spend") AS spend,
api_key as api_key,
"user" as "user"
FROM
"LiteLLM_SpendLogs"
WHERE
"startTime" >= (CURRENT_DATE - INTERVAL '30 days')
GROUP BY
DATE("startTime"),
"user",
api_key;
"""
await db.execute_raw(query=sql_query)
print("MonthlyGlobalSpendPerUserPerKey Created!") # noqa
try:
await db.query_raw("""SELECT 1 FROM "DailyTagSpend" LIMIT 1""")
print("DailyTagSpend Exists!") # noqa
except Exception:
sql_query = """
CREATE OR REPLACE VIEW "DailyTagSpend" AS
SELECT
jsonb_array_elements_text(request_tags) AS individual_request_tag,
DATE(s."startTime") AS spend_date,
COUNT(*) AS log_count,
SUM(spend) AS total_spend
FROM "LiteLLM_SpendLogs" s
GROUP BY individual_request_tag, DATE(s."startTime");
"""
await db.execute_raw(query=sql_query)
print("DailyTagSpend Created!") # noqa
try:
await db.query_raw("""SELECT 1 FROM "Last30dTopEndUsersSpend" LIMIT 1""")
print("Last30dTopEndUsersSpend Exists!") # noqa
except Exception:
sql_query = """
CREATE VIEW "Last30dTopEndUsersSpend" AS
SELECT end_user, COUNT(*) AS total_events, SUM(spend) AS total_spend
FROM "LiteLLM_SpendLogs"
WHERE end_user <> '' AND end_user <> user
AND "startTime" >= CURRENT_DATE - INTERVAL '30 days'
GROUP BY end_user
ORDER BY total_spend DESC
LIMIT 100;
"""
await db.execute_raw(query=sql_query)
print("Last30dTopEndUsersSpend Created!") # noqa
return
async def should_create_missing_views(db: _db) -> bool:
"""
Run only on first time startup.
If SpendLogs table already has values, then don't create views on startup.
"""
sql_query = """
SELECT reltuples::BIGINT
FROM pg_class
WHERE oid = '"LiteLLM_SpendLogs"'::regclass;
"""
result = await db.query_raw(query=sql_query)
verbose_logger.debug("Estimated Row count of LiteLLM_SpendLogs = {}".format(result))
if (
result
and isinstance(result, list)
and len(result) > 0
and isinstance(result[0], dict)
and "reltuples" in result[0]
and result[0]["reltuples"]
and (result[0]["reltuples"] == 0 or result[0]["reltuples"] == -1)
):
verbose_logger.debug("Should create views")
return True
return False

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,49 @@
"""
Base class for in memory buffer for database transactions
"""
import asyncio
from typing import Optional
from litellm._logging import verbose_proxy_logger
from litellm._service_logger import ServiceLogging
service_logger_obj = (
ServiceLogging()
) # used for tracking metrics for In memory buffer, redis buffer, pod lock manager
from litellm.constants import MAX_IN_MEMORY_QUEUE_FLUSH_COUNT, MAX_SIZE_IN_MEMORY_QUEUE
class BaseUpdateQueue:
"""Base class for in memory buffer for database transactions"""
def __init__(self):
self.update_queue = asyncio.Queue()
self.MAX_SIZE_IN_MEMORY_QUEUE = MAX_SIZE_IN_MEMORY_QUEUE
async def add_update(self, update):
"""Enqueue an update."""
verbose_proxy_logger.debug("Adding update to queue: %s", update)
await self.update_queue.put(update)
await self._emit_new_item_added_to_queue_event(
queue_size=self.update_queue.qsize()
)
async def flush_all_updates_from_in_memory_queue(self):
"""Get all updates from the queue."""
updates = []
while not self.update_queue.empty():
# Circuit breaker to ensure we're not stuck dequeuing updates. Protect CPU utilization
if len(updates) >= MAX_IN_MEMORY_QUEUE_FLUSH_COUNT:
verbose_proxy_logger.warning(
"Max in memory queue flush count reached, stopping flush"
)
break
updates.append(await self.update_queue.get())
return updates
async def _emit_new_item_added_to_queue_event(
self,
queue_size: Optional[int] = None,
):
"""placeholder, emit event when a new item is added to the queue"""
pass

View File

@@ -0,0 +1,149 @@
import asyncio
from copy import deepcopy
from typing import Dict, List, Optional
from litellm._logging import verbose_proxy_logger
from litellm.proxy._types import BaseDailySpendTransaction
from litellm.proxy.db.db_transaction_queue.base_update_queue import (
BaseUpdateQueue,
service_logger_obj,
)
from litellm.types.services import ServiceTypes
class DailySpendUpdateQueue(BaseUpdateQueue):
"""
In memory buffer for daily spend updates that should be committed to the database
To add a new daily spend update transaction, use the following format:
daily_spend_update_queue.add_update({
"user1_date_api_key_model_custom_llm_provider": {
"spend": 10,
"prompt_tokens": 100,
"completion_tokens": 100,
}
})
Queue contains a list of daily spend update transactions
eg
queue = [
{
"user1_date_api_key_model_custom_llm_provider": {
"spend": 10,
"prompt_tokens": 100,
"completion_tokens": 100,
"api_requests": 100,
"successful_requests": 100,
"failed_requests": 100,
}
},
{
"user2_date_api_key_model_custom_llm_provider": {
"spend": 10,
"prompt_tokens": 100,
"completion_tokens": 100,
"api_requests": 100,
"successful_requests": 100,
"failed_requests": 100,
}
}
]
"""
def __init__(self):
super().__init__()
self.update_queue: asyncio.Queue[Dict[str, BaseDailySpendTransaction]] = (
asyncio.Queue()
)
async def add_update(self, update: Dict[str, BaseDailySpendTransaction]):
"""Enqueue an update."""
verbose_proxy_logger.debug("Adding update to queue: %s", update)
await self.update_queue.put(update)
if self.update_queue.qsize() >= self.MAX_SIZE_IN_MEMORY_QUEUE:
verbose_proxy_logger.warning(
"Spend update queue is full. Aggregating all entries in queue to concatenate entries."
)
await self.aggregate_queue_updates()
async def aggregate_queue_updates(self):
"""
Combine all updates in the queue into a single update.
This is used to reduce the size of the in-memory queue.
"""
updates: List[Dict[str, BaseDailySpendTransaction]] = (
await self.flush_all_updates_from_in_memory_queue()
)
aggregated_updates = self.get_aggregated_daily_spend_update_transactions(
updates
)
await self.update_queue.put(aggregated_updates)
async def flush_and_get_aggregated_daily_spend_update_transactions(
self,
) -> Dict[str, BaseDailySpendTransaction]:
"""Get all updates from the queue and return all updates aggregated by daily_transaction_key. Works for both user and team spend updates."""
updates = await self.flush_all_updates_from_in_memory_queue()
aggregated_daily_spend_update_transactions = (
DailySpendUpdateQueue.get_aggregated_daily_spend_update_transactions(
updates
)
)
verbose_proxy_logger.debug(
"Aggregated daily spend update transactions: %s",
aggregated_daily_spend_update_transactions,
)
return aggregated_daily_spend_update_transactions
@staticmethod
def get_aggregated_daily_spend_update_transactions(
updates: List[Dict[str, BaseDailySpendTransaction]],
) -> Dict[str, BaseDailySpendTransaction]:
"""Aggregate updates by daily_transaction_key."""
aggregated_daily_spend_update_transactions: Dict[
str, BaseDailySpendTransaction
] = {}
for _update in updates:
for _key, payload in _update.items():
if _key in aggregated_daily_spend_update_transactions:
daily_transaction = aggregated_daily_spend_update_transactions[_key]
daily_transaction["spend"] += payload["spend"]
daily_transaction["prompt_tokens"] += payload["prompt_tokens"]
daily_transaction["completion_tokens"] += payload[
"completion_tokens"
]
daily_transaction["api_requests"] += payload["api_requests"]
daily_transaction["successful_requests"] += payload[
"successful_requests"
]
daily_transaction["failed_requests"] += payload["failed_requests"]
# Add optional metrics cache_read_input_tokens and cache_creation_input_tokens
daily_transaction["cache_read_input_tokens"] = (
payload.get("cache_read_input_tokens", 0) or 0
) + daily_transaction.get("cache_read_input_tokens", 0)
daily_transaction["cache_creation_input_tokens"] = (
payload.get("cache_creation_input_tokens", 0) or 0
) + daily_transaction.get("cache_creation_input_tokens", 0)
else:
aggregated_daily_spend_update_transactions[_key] = deepcopy(payload)
return aggregated_daily_spend_update_transactions
async def _emit_new_item_added_to_queue_event(
self,
queue_size: Optional[int] = None,
):
asyncio.create_task(
service_logger_obj.async_service_success_hook(
service=ServiceTypes.IN_MEMORY_DAILY_SPEND_UPDATE_QUEUE,
duration=0,
call_type="_emit_new_item_added_to_queue_event",
event_metadata={
"gauge_labels": ServiceTypes.IN_MEMORY_DAILY_SPEND_UPDATE_QUEUE,
"gauge_value": queue_size,
},
)
)

View File

@@ -0,0 +1,173 @@
import asyncio
import uuid
from typing import TYPE_CHECKING, Any, Optional
from litellm._logging import verbose_proxy_logger
from litellm.caching.redis_cache import RedisCache
from litellm.constants import DEFAULT_CRON_JOB_LOCK_TTL_SECONDS
from litellm.proxy.db.db_transaction_queue.base_update_queue import service_logger_obj
from litellm.types.services import ServiceTypes
if TYPE_CHECKING:
ProxyLogging = Any
else:
ProxyLogging = Any
class PodLockManager:
"""
Manager for acquiring and releasing locks for cron jobs using Redis.
Ensures that only one pod can run a cron job at a time.
"""
def __init__(self, redis_cache: Optional[RedisCache] = None):
self.pod_id = str(uuid.uuid4())
self.redis_cache = redis_cache
@staticmethod
def get_redis_lock_key(cronjob_id: str) -> str:
return f"cronjob_lock:{cronjob_id}"
async def acquire_lock(
self,
cronjob_id: str,
) -> Optional[bool]:
"""
Attempt to acquire the lock for a specific cron job using Redis.
Uses the SET command with NX and EX options to ensure atomicity.
"""
if self.redis_cache is None:
verbose_proxy_logger.debug("redis_cache is None, skipping acquire_lock")
return None
try:
verbose_proxy_logger.debug(
"Pod %s attempting to acquire Redis lock for cronjob_id=%s",
self.pod_id,
cronjob_id,
)
# Try to set the lock key with the pod_id as its value, only if it doesn't exist (NX)
# and with an expiration (EX) to avoid deadlocks.
lock_key = PodLockManager.get_redis_lock_key(cronjob_id)
acquired = await self.redis_cache.async_set_cache(
lock_key,
self.pod_id,
nx=True,
ttl=DEFAULT_CRON_JOB_LOCK_TTL_SECONDS,
)
if acquired:
verbose_proxy_logger.info(
"Pod %s successfully acquired Redis lock for cronjob_id=%s",
self.pod_id,
cronjob_id,
)
return True
else:
# Check if the current pod already holds the lock
current_value = await self.redis_cache.async_get_cache(lock_key)
if current_value is not None:
if isinstance(current_value, bytes):
current_value = current_value.decode("utf-8")
if current_value == self.pod_id:
verbose_proxy_logger.info(
"Pod %s already holds the Redis lock for cronjob_id=%s",
self.pod_id,
cronjob_id,
)
self._emit_acquired_lock_event(cronjob_id, self.pod_id)
return True
return False
except Exception as e:
verbose_proxy_logger.error(
f"Error acquiring Redis lock for {cronjob_id}: {e}"
)
return False
async def release_lock(
self,
cronjob_id: str,
):
"""
Release the lock if the current pod holds it.
Uses get and delete commands to ensure that only the owner can release the lock.
"""
if self.redis_cache is None:
verbose_proxy_logger.debug("redis_cache is None, skipping release_lock")
return
try:
cronjob_id = cronjob_id
verbose_proxy_logger.debug(
"Pod %s attempting to release Redis lock for cronjob_id=%s",
self.pod_id,
cronjob_id,
)
lock_key = PodLockManager.get_redis_lock_key(cronjob_id)
current_value = await self.redis_cache.async_get_cache(lock_key)
if current_value is not None:
if isinstance(current_value, bytes):
current_value = current_value.decode("utf-8")
if current_value == self.pod_id:
result = await self.redis_cache.async_delete_cache(lock_key)
if result == 1:
verbose_proxy_logger.info(
"Pod %s successfully released Redis lock for cronjob_id=%s",
self.pod_id,
cronjob_id,
)
self._emit_released_lock_event(
cronjob_id=cronjob_id,
pod_id=self.pod_id,
)
else:
verbose_proxy_logger.debug(
"Pod %s failed to release Redis lock for cronjob_id=%s",
self.pod_id,
cronjob_id,
)
else:
verbose_proxy_logger.debug(
"Pod %s cannot release Redis lock for cronjob_id=%s because it is held by pod %s",
self.pod_id,
cronjob_id,
current_value,
)
else:
verbose_proxy_logger.debug(
"Pod %s attempted to release Redis lock for cronjob_id=%s, but no lock was found",
self.pod_id,
cronjob_id,
)
except Exception as e:
verbose_proxy_logger.error(
f"Error releasing Redis lock for {cronjob_id}: {e}"
)
@staticmethod
def _emit_acquired_lock_event(cronjob_id: str, pod_id: str):
asyncio.create_task(
service_logger_obj.async_service_success_hook(
service=ServiceTypes.POD_LOCK_MANAGER,
duration=DEFAULT_CRON_JOB_LOCK_TTL_SECONDS,
call_type="_emit_acquired_lock_event",
event_metadata={
"gauge_labels": f"{cronjob_id}:{pod_id}",
"gauge_value": 1,
},
)
)
@staticmethod
def _emit_released_lock_event(cronjob_id: str, pod_id: str):
asyncio.create_task(
service_logger_obj.async_service_success_hook(
service=ServiceTypes.POD_LOCK_MANAGER,
duration=DEFAULT_CRON_JOB_LOCK_TTL_SECONDS,
call_type="_emit_released_lock_event",
event_metadata={
"gauge_labels": f"{cronjob_id}:{pod_id}",
"gauge_value": 0,
},
)
)

View File

@@ -0,0 +1,405 @@
"""
Handles buffering database `UPDATE` transactions in Redis before committing them to the database
This is to prevent deadlocks and improve reliability
"""
import asyncio
import json
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union, cast
from litellm._logging import verbose_proxy_logger
from litellm.caching import RedisCache
from litellm.constants import (
MAX_REDIS_BUFFER_DEQUEUE_COUNT,
REDIS_DAILY_SPEND_UPDATE_BUFFER_KEY,
REDIS_DAILY_TAG_SPEND_UPDATE_BUFFER_KEY,
REDIS_DAILY_TEAM_SPEND_UPDATE_BUFFER_KEY,
REDIS_UPDATE_BUFFER_KEY,
)
from litellm.litellm_core_utils.safe_json_dumps import safe_dumps
from litellm.proxy._types import (
DailyTeamSpendTransaction,
DailyUserSpendTransaction,
DBSpendUpdateTransactions,
)
from litellm.proxy.db.db_transaction_queue.base_update_queue import service_logger_obj
from litellm.proxy.db.db_transaction_queue.daily_spend_update_queue import (
DailySpendUpdateQueue,
)
from litellm.proxy.db.db_transaction_queue.spend_update_queue import SpendUpdateQueue
from litellm.secret_managers.main import str_to_bool
from litellm.types.services import ServiceTypes
if TYPE_CHECKING:
from litellm.proxy.utils import PrismaClient
else:
PrismaClient = Any
class RedisUpdateBuffer:
"""
Handles buffering database `UPDATE` transactions in Redis before committing them to the database
This is to prevent deadlocks and improve reliability
"""
def __init__(
self,
redis_cache: Optional[RedisCache] = None,
):
self.redis_cache = redis_cache
@staticmethod
def _should_commit_spend_updates_to_redis() -> bool:
"""
Checks if the Pod should commit spend updates to Redis
This setting enables buffering database transactions in Redis
to improve reliability and reduce database contention
"""
from litellm.proxy.proxy_server import general_settings
_use_redis_transaction_buffer: Optional[
Union[bool, str]
] = general_settings.get("use_redis_transaction_buffer", False)
if isinstance(_use_redis_transaction_buffer, str):
_use_redis_transaction_buffer = str_to_bool(_use_redis_transaction_buffer)
if _use_redis_transaction_buffer is None:
return False
return _use_redis_transaction_buffer
async def _store_transactions_in_redis(
self,
transactions: Any,
redis_key: str,
service_type: ServiceTypes,
) -> None:
"""
Helper method to store transactions in Redis and emit an event
Args:
transactions: The transactions to store
redis_key: The Redis key to store under
service_type: The service type for event emission
"""
if transactions is None or len(transactions) == 0:
return
list_of_transactions = [safe_dumps(transactions)]
if self.redis_cache is None:
return
current_redis_buffer_size = await self.redis_cache.async_rpush(
key=redis_key,
values=list_of_transactions,
)
await self._emit_new_item_added_to_redis_buffer_event(
queue_size=current_redis_buffer_size,
service=service_type,
)
async def store_in_memory_spend_updates_in_redis(
self,
spend_update_queue: SpendUpdateQueue,
daily_spend_update_queue: DailySpendUpdateQueue,
daily_team_spend_update_queue: DailySpendUpdateQueue,
daily_tag_spend_update_queue: DailySpendUpdateQueue,
):
"""
Stores the in-memory spend updates to Redis
Stores the following in memory data structures in Redis:
- SpendUpdateQueue - Key, User, Team, TeamMember, Org, EndUser Spend updates
- DailySpendUpdateQueue - Daily Spend updates Aggregate view
For SpendUpdateQueue:
Each transaction is a dict stored as following:
- key is the entity id
- value is the spend amount
```
Redis List:
key_list_transactions:
[
"0929880201": 1.2,
"0929880202": 0.01,
"0929880203": 0.001,
]
```
For DailySpendUpdateQueue:
Each transaction is a Dict[str, DailyUserSpendTransaction] stored as following:
- key is the daily_transaction_key
- value is the DailyUserSpendTransaction
```
Redis List:
daily_spend_update_transactions:
[
{
"user_keyhash_1_model_1": {
"spend": 1.2,
"prompt_tokens": 1000,
"completion_tokens": 1000,
"api_requests": 1000,
"successful_requests": 1000,
},
}
]
```
"""
if self.redis_cache is None:
verbose_proxy_logger.debug(
"redis_cache is None, skipping store_in_memory_spend_updates_in_redis"
)
return
# Get all transactions
db_spend_update_transactions = (
await spend_update_queue.flush_and_get_aggregated_db_spend_update_transactions()
)
daily_spend_update_transactions = (
await daily_spend_update_queue.flush_and_get_aggregated_daily_spend_update_transactions()
)
daily_team_spend_update_transactions = (
await daily_team_spend_update_queue.flush_and_get_aggregated_daily_spend_update_transactions()
)
daily_tag_spend_update_transactions = (
await daily_tag_spend_update_queue.flush_and_get_aggregated_daily_spend_update_transactions()
)
verbose_proxy_logger.debug(
"ALL DB SPEND UPDATE TRANSACTIONS: %s", db_spend_update_transactions
)
verbose_proxy_logger.debug(
"ALL DAILY SPEND UPDATE TRANSACTIONS: %s", daily_spend_update_transactions
)
# only store in redis if there are any updates to commit
if (
self._number_of_transactions_to_store_in_redis(db_spend_update_transactions)
== 0
):
return
# Store all transaction types using the helper method
await self._store_transactions_in_redis(
transactions=db_spend_update_transactions,
redis_key=REDIS_UPDATE_BUFFER_KEY,
service_type=ServiceTypes.REDIS_SPEND_UPDATE_QUEUE,
)
await self._store_transactions_in_redis(
transactions=daily_spend_update_transactions,
redis_key=REDIS_DAILY_SPEND_UPDATE_BUFFER_KEY,
service_type=ServiceTypes.REDIS_DAILY_SPEND_UPDATE_QUEUE,
)
await self._store_transactions_in_redis(
transactions=daily_team_spend_update_transactions,
redis_key=REDIS_DAILY_TEAM_SPEND_UPDATE_BUFFER_KEY,
service_type=ServiceTypes.REDIS_DAILY_TEAM_SPEND_UPDATE_QUEUE,
)
await self._store_transactions_in_redis(
transactions=daily_tag_spend_update_transactions,
redis_key=REDIS_DAILY_TAG_SPEND_UPDATE_BUFFER_KEY,
service_type=ServiceTypes.REDIS_DAILY_TAG_SPEND_UPDATE_QUEUE,
)
@staticmethod
def _number_of_transactions_to_store_in_redis(
db_spend_update_transactions: DBSpendUpdateTransactions,
) -> int:
"""
Gets the number of transactions to store in Redis
"""
num_transactions = 0
for v in db_spend_update_transactions.values():
if isinstance(v, dict):
num_transactions += len(v)
return num_transactions
@staticmethod
def _remove_prefix_from_keys(data: Dict[str, Any], prefix: str) -> Dict[str, Any]:
"""
Removes the specified prefix from the keys of a dictionary.
"""
return {key.replace(prefix, "", 1): value for key, value in data.items()}
async def get_all_update_transactions_from_redis_buffer(
self,
) -> Optional[DBSpendUpdateTransactions]:
"""
Gets all the update transactions from Redis
On Redis we store a list of transactions as a JSON string
eg.
[
DBSpendUpdateTransactions(
user_list_transactions={
"user_id_1": 1.2,
"user_id_2": 0.01,
},
end_user_list_transactions={},
key_list_transactions={
"0929880201": 1.2,
"0929880202": 0.01,
},
team_list_transactions={},
team_member_list_transactions={},
org_list_transactions={},
),
DBSpendUpdateTransactions(
user_list_transactions={
"user_id_3": 1.2,
"user_id_4": 0.01,
},
end_user_list_transactions={},
key_list_transactions={
"key_id_1": 1.2,
"key_id_2": 0.01,
},
team_list_transactions={},
team_member_list_transactions={},
org_list_transactions={},
]
"""
if self.redis_cache is None:
return None
list_of_transactions = await self.redis_cache.async_lpop(
key=REDIS_UPDATE_BUFFER_KEY,
count=MAX_REDIS_BUFFER_DEQUEUE_COUNT,
)
if list_of_transactions is None:
return None
# Parse the list of transactions from JSON strings
parsed_transactions = self._parse_list_of_transactions(list_of_transactions)
# If there are no transactions, return None
if len(parsed_transactions) == 0:
return None
# Combine all transactions into a single transaction
combined_transaction = self._combine_list_of_transactions(parsed_transactions)
return combined_transaction
async def get_all_daily_spend_update_transactions_from_redis_buffer(
self,
) -> Optional[Dict[str, DailyUserSpendTransaction]]:
"""
Gets all the daily spend update transactions from Redis
"""
if self.redis_cache is None:
return None
list_of_transactions = await self.redis_cache.async_lpop(
key=REDIS_DAILY_SPEND_UPDATE_BUFFER_KEY,
count=MAX_REDIS_BUFFER_DEQUEUE_COUNT,
)
if list_of_transactions is None:
return None
list_of_daily_spend_update_transactions = [
json.loads(transaction) for transaction in list_of_transactions
]
return cast(
Dict[str, DailyUserSpendTransaction],
DailySpendUpdateQueue.get_aggregated_daily_spend_update_transactions(
list_of_daily_spend_update_transactions
),
)
async def get_all_daily_team_spend_update_transactions_from_redis_buffer(
self,
) -> Optional[Dict[str, DailyTeamSpendTransaction]]:
"""
Gets all the daily team spend update transactions from Redis
"""
if self.redis_cache is None:
return None
list_of_transactions = await self.redis_cache.async_lpop(
key=REDIS_DAILY_TEAM_SPEND_UPDATE_BUFFER_KEY,
count=MAX_REDIS_BUFFER_DEQUEUE_COUNT,
)
if list_of_transactions is None:
return None
list_of_daily_spend_update_transactions = [
json.loads(transaction) for transaction in list_of_transactions
]
return cast(
Dict[str, DailyTeamSpendTransaction],
DailySpendUpdateQueue.get_aggregated_daily_spend_update_transactions(
list_of_daily_spend_update_transactions
),
)
@staticmethod
def _parse_list_of_transactions(
list_of_transactions: Union[Any, List[Any]],
) -> List[DBSpendUpdateTransactions]:
"""
Parses the list of transactions from Redis
"""
if isinstance(list_of_transactions, list):
return [json.loads(transaction) for transaction in list_of_transactions]
else:
return [json.loads(list_of_transactions)]
@staticmethod
def _combine_list_of_transactions(
list_of_transactions: List[DBSpendUpdateTransactions],
) -> DBSpendUpdateTransactions:
"""
Combines the list of transactions into a single DBSpendUpdateTransactions object
"""
# Initialize a new combined transaction object with empty dictionaries
combined_transaction = DBSpendUpdateTransactions(
user_list_transactions={},
end_user_list_transactions={},
key_list_transactions={},
team_list_transactions={},
team_member_list_transactions={},
org_list_transactions={},
)
# Define the transaction fields to process
transaction_fields = [
"user_list_transactions",
"end_user_list_transactions",
"key_list_transactions",
"team_list_transactions",
"team_member_list_transactions",
"org_list_transactions",
]
# Loop through each transaction and combine the values
for transaction in list_of_transactions:
# Process each field type
for field in transaction_fields:
if transaction.get(field):
for entity_id, amount in transaction[field].items(): # type: ignore
combined_transaction[field][entity_id] = ( # type: ignore
combined_transaction[field].get(entity_id, 0) + amount # type: ignore
)
return combined_transaction
async def _emit_new_item_added_to_redis_buffer_event(
self,
service: ServiceTypes,
queue_size: int,
):
asyncio.create_task(
service_logger_obj.async_service_success_hook(
service=service,
duration=0,
call_type="_emit_new_item_added_to_queue_event",
event_metadata={
"gauge_labels": service,
"gauge_value": queue_size,
},
)
)

View File

@@ -0,0 +1,225 @@
import asyncio
from typing import Dict, List, Optional
from litellm._logging import verbose_proxy_logger
from litellm.proxy._types import (
DBSpendUpdateTransactions,
Litellm_EntityType,
SpendUpdateQueueItem,
)
from litellm.proxy.db.db_transaction_queue.base_update_queue import (
BaseUpdateQueue,
service_logger_obj,
)
from litellm.types.services import ServiceTypes
class SpendUpdateQueue(BaseUpdateQueue):
"""
In memory buffer for spend updates that should be committed to the database
"""
def __init__(self):
super().__init__()
self.update_queue: asyncio.Queue[SpendUpdateQueueItem] = asyncio.Queue()
async def flush_and_get_aggregated_db_spend_update_transactions(
self,
) -> DBSpendUpdateTransactions:
"""Flush all updates from the queue and return all updates aggregated by entity type."""
updates = await self.flush_all_updates_from_in_memory_queue()
verbose_proxy_logger.debug("Aggregating updates by entity type: %s", updates)
return self.get_aggregated_db_spend_update_transactions(updates)
async def add_update(self, update: SpendUpdateQueueItem):
"""Enqueue an update to the spend update queue"""
verbose_proxy_logger.debug("Adding update to queue: %s", update)
await self.update_queue.put(update)
# if the queue is full, aggregate the updates
if self.update_queue.qsize() >= self.MAX_SIZE_IN_MEMORY_QUEUE:
verbose_proxy_logger.warning(
"Spend update queue is full. Aggregating all entries in queue to concatenate entries."
)
await self.aggregate_queue_updates()
async def aggregate_queue_updates(self):
"""Concatenate all updates in the queue to reduce the size of in-memory queue"""
updates: List[
SpendUpdateQueueItem
] = await self.flush_all_updates_from_in_memory_queue()
aggregated_updates = self._get_aggregated_spend_update_queue_item(updates)
for update in aggregated_updates:
await self.update_queue.put(update)
return
def _get_aggregated_spend_update_queue_item(
self, updates: List[SpendUpdateQueueItem]
) -> List[SpendUpdateQueueItem]:
"""
This is used to reduce the size of the in-memory queue by aggregating updates by entity type + id
Aggregate updates by entity type + id
eg.
```
[
{
"entity_type": "user",
"entity_id": "123",
"response_cost": 100
},
{
"entity_type": "user",
"entity_id": "123",
"response_cost": 200
}
]
```
becomes
```
[
{
"entity_type": "user",
"entity_id": "123",
"response_cost": 300
}
]
```
"""
verbose_proxy_logger.debug(
"Aggregating spend updates, current queue size: %s",
self.update_queue.qsize(),
)
aggregated_spend_updates: List[SpendUpdateQueueItem] = []
_in_memory_map: Dict[str, SpendUpdateQueueItem] = {}
"""
Used for combining several updates into a single update
Key=entity_type:entity_id
Value=SpendUpdateQueueItem
"""
for update in updates:
_key = f"{update.get('entity_type')}:{update.get('entity_id')}"
if _key not in _in_memory_map:
_in_memory_map[_key] = update
else:
current_cost = _in_memory_map[_key].get("response_cost", 0) or 0
update_cost = update.get("response_cost", 0) or 0
_in_memory_map[_key]["response_cost"] = current_cost + update_cost
for _key, update in _in_memory_map.items():
aggregated_spend_updates.append(update)
verbose_proxy_logger.debug(
"Aggregated spend updates: %s", aggregated_spend_updates
)
return aggregated_spend_updates
def get_aggregated_db_spend_update_transactions(
self, updates: List[SpendUpdateQueueItem]
) -> DBSpendUpdateTransactions:
"""Aggregate updates by entity type."""
# Initialize all transaction lists as empty dicts
db_spend_update_transactions = DBSpendUpdateTransactions(
user_list_transactions={},
end_user_list_transactions={},
key_list_transactions={},
team_list_transactions={},
team_member_list_transactions={},
org_list_transactions={},
)
# Map entity types to their corresponding transaction dictionary keys
entity_type_to_dict_key = {
Litellm_EntityType.USER: "user_list_transactions",
Litellm_EntityType.END_USER: "end_user_list_transactions",
Litellm_EntityType.KEY: "key_list_transactions",
Litellm_EntityType.TEAM: "team_list_transactions",
Litellm_EntityType.TEAM_MEMBER: "team_member_list_transactions",
Litellm_EntityType.ORGANIZATION: "org_list_transactions",
}
for update in updates:
entity_type = update.get("entity_type")
entity_id = update.get("entity_id") or ""
response_cost = update.get("response_cost") or 0
if entity_type is None:
verbose_proxy_logger.debug(
"Skipping update spend for update: %s, because entity_type is None",
update,
)
continue
dict_key = entity_type_to_dict_key.get(entity_type)
if dict_key is None:
verbose_proxy_logger.debug(
"Skipping update spend for update: %s, because entity_type is not in entity_type_to_dict_key",
update,
)
continue # Skip unknown entity types
# Type-safe access using if/elif statements
if dict_key == "user_list_transactions":
transactions_dict = db_spend_update_transactions[
"user_list_transactions"
]
elif dict_key == "end_user_list_transactions":
transactions_dict = db_spend_update_transactions[
"end_user_list_transactions"
]
elif dict_key == "key_list_transactions":
transactions_dict = db_spend_update_transactions[
"key_list_transactions"
]
elif dict_key == "team_list_transactions":
transactions_dict = db_spend_update_transactions[
"team_list_transactions"
]
elif dict_key == "team_member_list_transactions":
transactions_dict = db_spend_update_transactions[
"team_member_list_transactions"
]
elif dict_key == "org_list_transactions":
transactions_dict = db_spend_update_transactions[
"org_list_transactions"
]
else:
continue
if transactions_dict is None:
transactions_dict = {}
# type ignore: dict_key is guaranteed to be one of "one of ("user_list_transactions", "end_user_list_transactions", "key_list_transactions", "team_list_transactions", "team_member_list_transactions", "org_list_transactions")"
db_spend_update_transactions[dict_key] = transactions_dict # type: ignore
if entity_id not in transactions_dict:
transactions_dict[entity_id] = 0
transactions_dict[entity_id] += response_cost or 0
return db_spend_update_transactions
async def _emit_new_item_added_to_queue_event(
self,
queue_size: Optional[int] = None,
):
asyncio.create_task(
service_logger_obj.async_service_success_hook(
service=ServiceTypes.IN_MEMORY_SPEND_UPDATE_QUEUE,
duration=0,
call_type="_emit_new_item_added_to_queue_event",
event_metadata={
"gauge_labels": ServiceTypes.IN_MEMORY_SPEND_UPDATE_QUEUE,
"gauge_value": queue_size,
},
)
)

View File

@@ -0,0 +1,71 @@
"""
Deprecated. Only PostgresSQL is supported.
"""
from litellm._logging import verbose_proxy_logger
from litellm.proxy._types import DynamoDBArgs
from litellm.proxy.db.base_client import CustomDB
class DynamoDBWrapper(CustomDB):
from aiodynamo.credentials import Credentials, StaticCredentials
credentials: Credentials
def __init__(self, database_arguments: DynamoDBArgs):
from aiodynamo.models import PayPerRequest, Throughput
self.throughput_type = None
if database_arguments.billing_mode == "PAY_PER_REQUEST":
self.throughput_type = PayPerRequest()
elif database_arguments.billing_mode == "PROVISIONED_THROUGHPUT":
if (
database_arguments.read_capacity_units is not None
and isinstance(database_arguments.read_capacity_units, int)
and database_arguments.write_capacity_units is not None
and isinstance(database_arguments.write_capacity_units, int)
):
self.throughput_type = Throughput(read=database_arguments.read_capacity_units, write=database_arguments.write_capacity_units) # type: ignore
else:
raise Exception(
f"Invalid args passed in. Need to set both read_capacity_units and write_capacity_units. Args passed in - {database_arguments}"
)
self.database_arguments = database_arguments
self.region_name = database_arguments.region_name
def set_env_vars_based_on_arn(self):
if self.database_arguments.aws_role_name is None:
return
verbose_proxy_logger.debug(
f"DynamoDB: setting env vars based on arn={self.database_arguments.aws_role_name}"
)
import os
import boto3
sts_client = boto3.client("sts")
# call 1
sts_client.assume_role_with_web_identity(
RoleArn=self.database_arguments.aws_role_name,
RoleSessionName=self.database_arguments.aws_session_name,
WebIdentityToken=self.database_arguments.aws_web_identity_token,
)
# call 2
assumed_role = sts_client.assume_role(
RoleArn=self.database_arguments.assume_role_aws_role_name,
RoleSessionName=self.database_arguments.assume_role_aws_session_name,
)
aws_access_key_id = assumed_role["Credentials"]["AccessKeyId"]
aws_secret_access_key = assumed_role["Credentials"]["SecretAccessKey"]
aws_session_token = assumed_role["Credentials"]["SessionToken"]
verbose_proxy_logger.debug(
f"Got STS assumed Role, aws_access_key_id={aws_access_key_id}"
)
# set these in the env so aiodynamo can use them
os.environ["AWS_ACCESS_KEY_ID"] = aws_access_key_id
os.environ["AWS_SECRET_ACCESS_KEY"] = aws_secret_access_key
os.environ["AWS_SESSION_TOKEN"] = aws_session_token

View File

@@ -0,0 +1,61 @@
from typing import Union
from litellm.proxy._types import (
DB_CONNECTION_ERROR_TYPES,
ProxyErrorTypes,
ProxyException,
)
from litellm.secret_managers.main import str_to_bool
class PrismaDBExceptionHandler:
"""
Class to handle DB Exceptions or Connection Errors
"""
@staticmethod
def should_allow_request_on_db_unavailable() -> bool:
"""
Returns True if the request should be allowed to proceed despite the DB connection error
"""
from litellm.proxy.proxy_server import general_settings
_allow_requests_on_db_unavailable: Union[bool, str] = general_settings.get(
"allow_requests_on_db_unavailable", False
)
if isinstance(_allow_requests_on_db_unavailable, bool):
return _allow_requests_on_db_unavailable
if str_to_bool(_allow_requests_on_db_unavailable) is True:
return True
return False
@staticmethod
def is_database_connection_error(e: Exception) -> bool:
"""
Returns True if the exception is from a database outage / connection error
"""
import prisma
if isinstance(e, DB_CONNECTION_ERROR_TYPES):
return True
if isinstance(e, prisma.errors.PrismaError):
return True
if isinstance(e, ProxyException) and e.type == ProxyErrorTypes.no_db_connection:
return True
return False
@staticmethod
def handle_db_exception(e: Exception):
"""
Primary handler for `allow_requests_on_db_unavailable` flag. Decides whether to raise a DB Exception or not based on the flag.
- If exception is a DB Connection Error, and `allow_requests_on_db_unavailable` is True,
- Do not raise an exception, return None
- Else, raise the exception
"""
if (
PrismaDBExceptionHandler.is_database_connection_error(e)
and PrismaDBExceptionHandler.should_allow_request_on_db_unavailable()
):
return None
raise e

View File

@@ -0,0 +1,142 @@
"""
Handles logging DB success/failure to ServiceLogger()
ServiceLogger() then sends DB logs to Prometheus, OTEL, Datadog etc
"""
import asyncio
from datetime import datetime
from functools import wraps
from typing import Callable, Dict, Tuple
from litellm._service_logger import ServiceTypes
from litellm.litellm_core_utils.core_helpers import (
_get_parent_otel_span_from_kwargs,
get_litellm_metadata_from_kwargs,
)
def log_db_metrics(func):
"""
Decorator to log the duration of a DB related function to ServiceLogger()
Handles logging DB success/failure to ServiceLogger(), which logs to Prometheus, OTEL, Datadog
When logging Failure it checks if the Exception is a PrismaError, httpx.ConnectError or httpx.TimeoutException and then logs that as a DB Service Failure
Args:
func: The function to be decorated
Returns:
Result from the decorated function
Raises:
Exception: If the decorated function raises an exception
"""
@wraps(func)
async def wrapper(*args, **kwargs):
start_time: datetime = datetime.now()
try:
result = await func(*args, **kwargs)
end_time: datetime = datetime.now()
from litellm.proxy.proxy_server import proxy_logging_obj
if "PROXY" not in func.__name__:
asyncio.create_task(
proxy_logging_obj.service_logging_obj.async_service_success_hook(
service=ServiceTypes.DB,
call_type=func.__name__,
parent_otel_span=kwargs.get("parent_otel_span", None),
duration=(end_time - start_time).total_seconds(),
start_time=start_time,
end_time=end_time,
event_metadata={
"function_name": func.__name__,
"function_kwargs": kwargs,
"function_args": args,
},
)
)
elif (
# in litellm custom callbacks kwargs is passed as arg[0]
# https://docs.litellm.ai/docs/observability/custom_callback#callback-functions
args is not None
and len(args) > 1
and isinstance(args[1], dict)
):
passed_kwargs = args[1]
parent_otel_span = _get_parent_otel_span_from_kwargs(
kwargs=passed_kwargs
)
if parent_otel_span is not None:
metadata = get_litellm_metadata_from_kwargs(kwargs=passed_kwargs)
asyncio.create_task(
proxy_logging_obj.service_logging_obj.async_service_success_hook(
service=ServiceTypes.BATCH_WRITE_TO_DB,
call_type=func.__name__,
parent_otel_span=parent_otel_span,
duration=0.0,
start_time=start_time,
end_time=end_time,
event_metadata=metadata,
)
)
# end of logging to otel
return result
except Exception as e:
end_time: datetime = datetime.now()
await _handle_logging_db_exception(
e=e,
func=func,
kwargs=kwargs,
args=args,
start_time=start_time,
end_time=end_time,
)
raise e
return wrapper
def _is_exception_related_to_db(e: Exception) -> bool:
"""
Returns True if the exception is related to the DB
"""
import httpx
from prisma.errors import PrismaError
return isinstance(e, (PrismaError, httpx.ConnectError, httpx.TimeoutException))
async def _handle_logging_db_exception(
e: Exception,
func: Callable,
kwargs: Dict,
args: Tuple,
start_time: datetime,
end_time: datetime,
) -> None:
from litellm.proxy.proxy_server import proxy_logging_obj
# don't log this as a DB Service Failure, if the DB did not raise an exception
if _is_exception_related_to_db(e) is not True:
return
await proxy_logging_obj.service_logging_obj.async_service_failure_hook(
error=e,
service=ServiceTypes.DB,
call_type=func.__name__,
parent_otel_span=kwargs.get("parent_otel_span"),
duration=(end_time - start_time).total_seconds(),
start_time=start_time,
end_time=end_time,
event_metadata={
"function_name": func.__name__,
"function_kwargs": kwargs,
"function_args": args,
},
)

View File

@@ -0,0 +1,209 @@
"""
This file contains the PrismaWrapper class, which is used to wrap the Prisma client and handle the RDS IAM token.
"""
import asyncio
import os
import random
import subprocess
import time
import urllib
import urllib.parse
from datetime import datetime, timedelta
from typing import Any, Optional, Union
from litellm._logging import verbose_proxy_logger
from litellm.secret_managers.main import str_to_bool
class PrismaWrapper:
def __init__(self, original_prisma: Any, iam_token_db_auth: bool):
self._original_prisma = original_prisma
self.iam_token_db_auth = iam_token_db_auth
def is_token_expired(self, token_url: Optional[str]) -> bool:
if token_url is None:
return True
# Decode the token URL to handle URL-encoded characters
decoded_url = urllib.parse.unquote(token_url)
# Parse the token URL
parsed_url = urllib.parse.urlparse(decoded_url)
# Parse the query parameters from the path component (if they exist there)
query_params = urllib.parse.parse_qs(parsed_url.query)
# Get expiration time from the query parameters
expires = query_params.get("X-Amz-Expires", [None])[0]
if expires is None:
raise ValueError("X-Amz-Expires parameter is missing or invalid.")
expires_int = int(expires)
# Get the token's creation time from the X-Amz-Date parameter
token_time_str = query_params.get("X-Amz-Date", [""])[0]
if not token_time_str:
raise ValueError("X-Amz-Date parameter is missing or invalid.")
# Ensure the token time string is parsed correctly
try:
token_time = datetime.strptime(token_time_str, "%Y%m%dT%H%M%SZ")
except ValueError as e:
raise ValueError(f"Invalid X-Amz-Date format: {e}")
# Calculate the expiration time
expiration_time = token_time + timedelta(seconds=expires_int)
# Current time in UTC
current_time = datetime.utcnow()
# Check if the token is expired
return current_time > expiration_time
def get_rds_iam_token(self) -> Optional[str]:
if self.iam_token_db_auth:
from litellm.proxy.auth.rds_iam_token import generate_iam_auth_token
db_host = os.getenv("DATABASE_HOST")
db_port = os.getenv("DATABASE_PORT")
db_user = os.getenv("DATABASE_USER")
db_name = os.getenv("DATABASE_NAME")
db_schema = os.getenv("DATABASE_SCHEMA")
token = generate_iam_auth_token(
db_host=db_host, db_port=db_port, db_user=db_user
)
# print(f"token: {token}")
_db_url = f"postgresql://{db_user}:{token}@{db_host}:{db_port}/{db_name}"
if db_schema:
_db_url += f"?schema={db_schema}"
os.environ["DATABASE_URL"] = _db_url
return _db_url
return None
async def recreate_prisma_client(
self, new_db_url: str, http_client: Optional[Any] = None
):
from prisma import Prisma # type: ignore
if http_client is not None:
self._original_prisma = Prisma(http=http_client)
else:
self._original_prisma = Prisma()
await self._original_prisma.connect()
def __getattr__(self, name: str):
original_attr = getattr(self._original_prisma, name)
if self.iam_token_db_auth:
db_url = os.getenv("DATABASE_URL")
if self.is_token_expired(db_url):
db_url = self.get_rds_iam_token()
loop = asyncio.get_event_loop()
if db_url:
if loop.is_running():
asyncio.run_coroutine_threadsafe(
self.recreate_prisma_client(db_url), loop
)
else:
asyncio.run(self.recreate_prisma_client(db_url))
else:
raise ValueError("Failed to get RDS IAM token")
return original_attr
class PrismaManager:
@staticmethod
def _get_prisma_dir() -> str:
"""Get the path to the migrations directory"""
abspath = os.path.abspath(__file__)
dname = os.path.dirname(os.path.dirname(abspath))
return dname
@staticmethod
def setup_database(use_migrate: bool = False) -> bool:
"""
Set up the database using either prisma migrate or prisma db push
Returns:
bool: True if setup was successful, False otherwise
"""
use_migrate = str_to_bool(os.getenv("USE_PRISMA_MIGRATE")) or use_migrate
for attempt in range(4):
original_dir = os.getcwd()
prisma_dir = PrismaManager._get_prisma_dir()
schema_path = prisma_dir + "/schema.prisma"
os.chdir(prisma_dir)
try:
if use_migrate:
try:
from litellm_proxy_extras.utils import ProxyExtrasDBManager
except ImportError as e:
verbose_proxy_logger.error(
f"\033[1;31mLiteLLM: Failed to import proxy extras. Got {e}\033[0m"
)
return False
prisma_dir = PrismaManager._get_prisma_dir()
schema_path = prisma_dir + "/schema.prisma"
return ProxyExtrasDBManager.setup_database(
schema_path=schema_path, use_migrate=use_migrate
)
else:
# Use prisma db push with increased timeout
subprocess.run(
["prisma", "db", "push", "--accept-data-loss"],
timeout=60,
check=True,
)
return True
except subprocess.TimeoutExpired:
verbose_proxy_logger.warning(f"Attempt {attempt + 1} timed out")
time.sleep(random.randrange(5, 15))
except subprocess.CalledProcessError as e:
attempts_left = 3 - attempt
retry_msg = (
f" Retrying... ({attempts_left} attempts left)"
if attempts_left > 0
else ""
)
verbose_proxy_logger.warning(
f"The process failed to execute. Details: {e}.{retry_msg}"
)
time.sleep(random.randrange(5, 15))
finally:
os.chdir(original_dir)
return False
def should_update_prisma_schema(
disable_updates: Optional[Union[bool, str]] = None
) -> bool:
"""
Determines if Prisma Schema updates should be applied during startup.
Args:
disable_updates: Controls whether schema updates are disabled.
Accepts boolean or string ('true'/'false'). Defaults to checking DISABLE_SCHEMA_UPDATE env var.
Returns:
bool: True if schema updates should be applied, False if updates are disabled.
Examples:
>>> should_update_prisma_schema() # Checks DISABLE_SCHEMA_UPDATE env var
>>> should_update_prisma_schema(True) # Explicitly disable updates
>>> should_update_prisma_schema("false") # Enable updates using string
"""
if disable_updates is None:
disable_updates = os.getenv("DISABLE_SCHEMA_UPDATE", "false")
if isinstance(disable_updates, str):
disable_updates = str_to_bool(disable_updates)
return not bool(disable_updates)