structure saas with tools
This commit is contained in:
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,53 @@
|
||||
from typing import Any, Literal, List
|
||||
|
||||
|
||||
class CustomDB:
|
||||
"""
|
||||
Implements a base class that we expect any custom db implementation (e.g. DynamoDB) to follow
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
pass
|
||||
|
||||
def get_data(self, key: str, table_name: Literal["user", "key", "config"]):
|
||||
"""
|
||||
Check if key valid
|
||||
"""
|
||||
pass
|
||||
|
||||
def insert_data(self, value: Any, table_name: Literal["user", "key", "config"]):
|
||||
"""
|
||||
For new key / user logic
|
||||
"""
|
||||
pass
|
||||
|
||||
def update_data(
|
||||
self, key: str, value: Any, table_name: Literal["user", "key", "config"]
|
||||
):
|
||||
"""
|
||||
For cost tracking logic
|
||||
"""
|
||||
pass
|
||||
|
||||
def delete_data(
|
||||
self, keys: List[str], table_name: Literal["user", "key", "config"]
|
||||
):
|
||||
"""
|
||||
For /key/delete endpoint s
|
||||
"""
|
||||
|
||||
def connect(
|
||||
self,
|
||||
):
|
||||
"""
|
||||
For connecting to db and creating / updating any tables
|
||||
"""
|
||||
pass
|
||||
|
||||
def disconnect(
|
||||
self,
|
||||
):
|
||||
"""
|
||||
For closing connection on server shutdown
|
||||
"""
|
||||
pass
|
||||
@@ -0,0 +1,104 @@
|
||||
"""Module for checking differences between Prisma schema and database."""
|
||||
|
||||
import os
|
||||
import subprocess
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
from litellm._logging import verbose_logger
|
||||
|
||||
|
||||
def extract_sql_commands(diff_output: str) -> List[str]:
|
||||
"""
|
||||
Extract SQL commands from the Prisma migrate diff output.
|
||||
Args:
|
||||
diff_output (str): The full output from prisma migrate diff.
|
||||
Returns:
|
||||
List[str]: A list of SQL commands extracted from the diff output.
|
||||
"""
|
||||
# Split the output into lines and remove empty lines
|
||||
lines = [line.strip() for line in diff_output.split("\n") if line.strip()]
|
||||
|
||||
sql_commands = []
|
||||
current_command = ""
|
||||
in_sql_block = False
|
||||
|
||||
for line in lines:
|
||||
if line.startswith("-- "): # Comment line, likely a table operation description
|
||||
if in_sql_block and current_command:
|
||||
sql_commands.append(current_command.strip())
|
||||
current_command = ""
|
||||
in_sql_block = True
|
||||
elif in_sql_block:
|
||||
if line.endswith(";"):
|
||||
current_command += line
|
||||
sql_commands.append(current_command.strip())
|
||||
current_command = ""
|
||||
in_sql_block = False
|
||||
else:
|
||||
current_command += line + " "
|
||||
|
||||
# Add any remaining command
|
||||
if current_command:
|
||||
sql_commands.append(current_command.strip())
|
||||
|
||||
return sql_commands
|
||||
|
||||
|
||||
def check_prisma_schema_diff_helper(db_url: str) -> Tuple[bool, List[str]]:
|
||||
"""Checks for differences between current database and Prisma schema.
|
||||
Returns:
|
||||
A tuple containing:
|
||||
- A boolean indicating if differences were found (True) or not (False).
|
||||
- A string with the diff output or error message.
|
||||
Raises:
|
||||
subprocess.CalledProcessError: If the Prisma command fails.
|
||||
Exception: For any other errors during execution.
|
||||
"""
|
||||
verbose_logger.debug("Checking for Prisma schema diff...") # noqa: T201
|
||||
try:
|
||||
result = subprocess.run(
|
||||
[
|
||||
"prisma",
|
||||
"migrate",
|
||||
"diff",
|
||||
"--from-url",
|
||||
db_url,
|
||||
"--to-schema-datamodel",
|
||||
"./schema.prisma",
|
||||
"--script",
|
||||
],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
check=True,
|
||||
)
|
||||
|
||||
# return True, "Migration diff generated successfully."
|
||||
sql_commands = extract_sql_commands(result.stdout)
|
||||
|
||||
if sql_commands:
|
||||
print("Changes to DB Schema detected") # noqa: T201
|
||||
print("Required SQL commands:") # noqa: T201
|
||||
for command in sql_commands:
|
||||
print(command) # noqa: T201
|
||||
return True, sql_commands
|
||||
else:
|
||||
return False, []
|
||||
except subprocess.CalledProcessError as e:
|
||||
error_message = f"Failed to generate migration diff. Error: {e.stderr}"
|
||||
print(error_message) # noqa: T201
|
||||
return False, []
|
||||
|
||||
|
||||
def check_prisma_schema_diff(db_url: Optional[str] = None) -> None:
|
||||
"""Main function to run the Prisma schema diff check."""
|
||||
if db_url is None:
|
||||
db_url = os.getenv("DATABASE_URL")
|
||||
if db_url is None:
|
||||
raise Exception("DATABASE_URL not set")
|
||||
has_diff, message = check_prisma_schema_diff_helper(db_url)
|
||||
if has_diff:
|
||||
verbose_logger.exception(
|
||||
"🚨🚨🚨 prisma schema out of sync with db. Consider running these sql_commands to sync the two - {}".format(
|
||||
message
|
||||
)
|
||||
)
|
||||
@@ -0,0 +1,227 @@
|
||||
from typing import Any
|
||||
|
||||
from litellm import verbose_logger
|
||||
|
||||
_db = Any
|
||||
|
||||
|
||||
async def create_missing_views(db: _db): # noqa: PLR0915
|
||||
"""
|
||||
--------------------------------------------------
|
||||
NOTE: Copy of `litellm/db_scripts/create_views.py`.
|
||||
--------------------------------------------------
|
||||
Checks if the LiteLLM_VerificationTokenView and MonthlyGlobalSpend exists in the user's db.
|
||||
|
||||
LiteLLM_VerificationTokenView: This view is used for getting the token + team data in user_api_key_auth
|
||||
|
||||
MonthlyGlobalSpend: This view is used for the admin view to see global spend for this month
|
||||
|
||||
If the view doesn't exist, one will be created.
|
||||
"""
|
||||
try:
|
||||
# Try to select one row from the view
|
||||
await db.query_raw("""SELECT 1 FROM "LiteLLM_VerificationTokenView" LIMIT 1""")
|
||||
print("LiteLLM_VerificationTokenView Exists!") # noqa
|
||||
except Exception:
|
||||
# If an error occurs, the view does not exist, so create it
|
||||
await db.execute_raw(
|
||||
"""
|
||||
CREATE VIEW "LiteLLM_VerificationTokenView" AS
|
||||
SELECT
|
||||
v.*,
|
||||
t.spend AS team_spend,
|
||||
t.max_budget AS team_max_budget,
|
||||
t.tpm_limit AS team_tpm_limit,
|
||||
t.rpm_limit AS team_rpm_limit
|
||||
FROM "LiteLLM_VerificationToken" v
|
||||
LEFT JOIN "LiteLLM_TeamTable" t ON v.team_id = t.team_id;
|
||||
"""
|
||||
)
|
||||
|
||||
print("LiteLLM_VerificationTokenView Created!") # noqa
|
||||
|
||||
try:
|
||||
await db.query_raw("""SELECT 1 FROM "MonthlyGlobalSpend" LIMIT 1""")
|
||||
print("MonthlyGlobalSpend Exists!") # noqa
|
||||
except Exception:
|
||||
sql_query = """
|
||||
CREATE OR REPLACE VIEW "MonthlyGlobalSpend" AS
|
||||
SELECT
|
||||
DATE("startTime") AS date,
|
||||
SUM("spend") AS spend
|
||||
FROM
|
||||
"LiteLLM_SpendLogs"
|
||||
WHERE
|
||||
"startTime" >= (CURRENT_DATE - INTERVAL '30 days')
|
||||
GROUP BY
|
||||
DATE("startTime");
|
||||
"""
|
||||
await db.execute_raw(query=sql_query)
|
||||
|
||||
print("MonthlyGlobalSpend Created!") # noqa
|
||||
|
||||
try:
|
||||
await db.query_raw("""SELECT 1 FROM "Last30dKeysBySpend" LIMIT 1""")
|
||||
print("Last30dKeysBySpend Exists!") # noqa
|
||||
except Exception:
|
||||
sql_query = """
|
||||
CREATE OR REPLACE VIEW "Last30dKeysBySpend" AS
|
||||
SELECT
|
||||
L."api_key",
|
||||
V."key_alias",
|
||||
V."key_name",
|
||||
SUM(L."spend") AS total_spend
|
||||
FROM
|
||||
"LiteLLM_SpendLogs" L
|
||||
LEFT JOIN
|
||||
"LiteLLM_VerificationToken" V
|
||||
ON
|
||||
L."api_key" = V."token"
|
||||
WHERE
|
||||
L."startTime" >= (CURRENT_DATE - INTERVAL '30 days')
|
||||
GROUP BY
|
||||
L."api_key", V."key_alias", V."key_name"
|
||||
ORDER BY
|
||||
total_spend DESC;
|
||||
"""
|
||||
await db.execute_raw(query=sql_query)
|
||||
|
||||
print("Last30dKeysBySpend Created!") # noqa
|
||||
|
||||
try:
|
||||
await db.query_raw("""SELECT 1 FROM "Last30dModelsBySpend" LIMIT 1""")
|
||||
print("Last30dModelsBySpend Exists!") # noqa
|
||||
except Exception:
|
||||
sql_query = """
|
||||
CREATE OR REPLACE VIEW "Last30dModelsBySpend" AS
|
||||
SELECT
|
||||
"model",
|
||||
SUM("spend") AS total_spend
|
||||
FROM
|
||||
"LiteLLM_SpendLogs"
|
||||
WHERE
|
||||
"startTime" >= (CURRENT_DATE - INTERVAL '30 days')
|
||||
AND "model" != ''
|
||||
GROUP BY
|
||||
"model"
|
||||
ORDER BY
|
||||
total_spend DESC;
|
||||
"""
|
||||
await db.execute_raw(query=sql_query)
|
||||
|
||||
print("Last30dModelsBySpend Created!") # noqa
|
||||
try:
|
||||
await db.query_raw("""SELECT 1 FROM "MonthlyGlobalSpendPerKey" LIMIT 1""")
|
||||
print("MonthlyGlobalSpendPerKey Exists!") # noqa
|
||||
except Exception:
|
||||
sql_query = """
|
||||
CREATE OR REPLACE VIEW "MonthlyGlobalSpendPerKey" AS
|
||||
SELECT
|
||||
DATE("startTime") AS date,
|
||||
SUM("spend") AS spend,
|
||||
api_key as api_key
|
||||
FROM
|
||||
"LiteLLM_SpendLogs"
|
||||
WHERE
|
||||
"startTime" >= (CURRENT_DATE - INTERVAL '30 days')
|
||||
GROUP BY
|
||||
DATE("startTime"),
|
||||
api_key;
|
||||
"""
|
||||
await db.execute_raw(query=sql_query)
|
||||
|
||||
print("MonthlyGlobalSpendPerKey Created!") # noqa
|
||||
try:
|
||||
await db.query_raw(
|
||||
"""SELECT 1 FROM "MonthlyGlobalSpendPerUserPerKey" LIMIT 1"""
|
||||
)
|
||||
print("MonthlyGlobalSpendPerUserPerKey Exists!") # noqa
|
||||
except Exception:
|
||||
sql_query = """
|
||||
CREATE OR REPLACE VIEW "MonthlyGlobalSpendPerUserPerKey" AS
|
||||
SELECT
|
||||
DATE("startTime") AS date,
|
||||
SUM("spend") AS spend,
|
||||
api_key as api_key,
|
||||
"user" as "user"
|
||||
FROM
|
||||
"LiteLLM_SpendLogs"
|
||||
WHERE
|
||||
"startTime" >= (CURRENT_DATE - INTERVAL '30 days')
|
||||
GROUP BY
|
||||
DATE("startTime"),
|
||||
"user",
|
||||
api_key;
|
||||
"""
|
||||
await db.execute_raw(query=sql_query)
|
||||
|
||||
print("MonthlyGlobalSpendPerUserPerKey Created!") # noqa
|
||||
|
||||
try:
|
||||
await db.query_raw("""SELECT 1 FROM "DailyTagSpend" LIMIT 1""")
|
||||
print("DailyTagSpend Exists!") # noqa
|
||||
except Exception:
|
||||
sql_query = """
|
||||
CREATE OR REPLACE VIEW "DailyTagSpend" AS
|
||||
SELECT
|
||||
jsonb_array_elements_text(request_tags) AS individual_request_tag,
|
||||
DATE(s."startTime") AS spend_date,
|
||||
COUNT(*) AS log_count,
|
||||
SUM(spend) AS total_spend
|
||||
FROM "LiteLLM_SpendLogs" s
|
||||
GROUP BY individual_request_tag, DATE(s."startTime");
|
||||
"""
|
||||
await db.execute_raw(query=sql_query)
|
||||
|
||||
print("DailyTagSpend Created!") # noqa
|
||||
|
||||
try:
|
||||
await db.query_raw("""SELECT 1 FROM "Last30dTopEndUsersSpend" LIMIT 1""")
|
||||
print("Last30dTopEndUsersSpend Exists!") # noqa
|
||||
except Exception:
|
||||
sql_query = """
|
||||
CREATE VIEW "Last30dTopEndUsersSpend" AS
|
||||
SELECT end_user, COUNT(*) AS total_events, SUM(spend) AS total_spend
|
||||
FROM "LiteLLM_SpendLogs"
|
||||
WHERE end_user <> '' AND end_user <> user
|
||||
AND "startTime" >= CURRENT_DATE - INTERVAL '30 days'
|
||||
GROUP BY end_user
|
||||
ORDER BY total_spend DESC
|
||||
LIMIT 100;
|
||||
"""
|
||||
await db.execute_raw(query=sql_query)
|
||||
|
||||
print("Last30dTopEndUsersSpend Created!") # noqa
|
||||
|
||||
return
|
||||
|
||||
|
||||
async def should_create_missing_views(db: _db) -> bool:
|
||||
"""
|
||||
Run only on first time startup.
|
||||
|
||||
If SpendLogs table already has values, then don't create views on startup.
|
||||
"""
|
||||
|
||||
sql_query = """
|
||||
SELECT reltuples::BIGINT
|
||||
FROM pg_class
|
||||
WHERE oid = '"LiteLLM_SpendLogs"'::regclass;
|
||||
"""
|
||||
|
||||
result = await db.query_raw(query=sql_query)
|
||||
|
||||
verbose_logger.debug("Estimated Row count of LiteLLM_SpendLogs = {}".format(result))
|
||||
if (
|
||||
result
|
||||
and isinstance(result, list)
|
||||
and len(result) > 0
|
||||
and isinstance(result[0], dict)
|
||||
and "reltuples" in result[0]
|
||||
and result[0]["reltuples"]
|
||||
and (result[0]["reltuples"] == 0 or result[0]["reltuples"] == -1)
|
||||
):
|
||||
verbose_logger.debug("Should create views")
|
||||
return True
|
||||
|
||||
return False
|
||||
File diff suppressed because it is too large
Load Diff
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,49 @@
|
||||
"""
|
||||
Base class for in memory buffer for database transactions
|
||||
"""
|
||||
import asyncio
|
||||
from typing import Optional
|
||||
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm._service_logger import ServiceLogging
|
||||
|
||||
service_logger_obj = (
|
||||
ServiceLogging()
|
||||
) # used for tracking metrics for In memory buffer, redis buffer, pod lock manager
|
||||
from litellm.constants import MAX_IN_MEMORY_QUEUE_FLUSH_COUNT, MAX_SIZE_IN_MEMORY_QUEUE
|
||||
|
||||
|
||||
class BaseUpdateQueue:
|
||||
"""Base class for in memory buffer for database transactions"""
|
||||
|
||||
def __init__(self):
|
||||
self.update_queue = asyncio.Queue()
|
||||
self.MAX_SIZE_IN_MEMORY_QUEUE = MAX_SIZE_IN_MEMORY_QUEUE
|
||||
|
||||
async def add_update(self, update):
|
||||
"""Enqueue an update."""
|
||||
verbose_proxy_logger.debug("Adding update to queue: %s", update)
|
||||
await self.update_queue.put(update)
|
||||
await self._emit_new_item_added_to_queue_event(
|
||||
queue_size=self.update_queue.qsize()
|
||||
)
|
||||
|
||||
async def flush_all_updates_from_in_memory_queue(self):
|
||||
"""Get all updates from the queue."""
|
||||
updates = []
|
||||
while not self.update_queue.empty():
|
||||
# Circuit breaker to ensure we're not stuck dequeuing updates. Protect CPU utilization
|
||||
if len(updates) >= MAX_IN_MEMORY_QUEUE_FLUSH_COUNT:
|
||||
verbose_proxy_logger.warning(
|
||||
"Max in memory queue flush count reached, stopping flush"
|
||||
)
|
||||
break
|
||||
updates.append(await self.update_queue.get())
|
||||
return updates
|
||||
|
||||
async def _emit_new_item_added_to_queue_event(
|
||||
self,
|
||||
queue_size: Optional[int] = None,
|
||||
):
|
||||
"""placeholder, emit event when a new item is added to the queue"""
|
||||
pass
|
||||
@@ -0,0 +1,149 @@
|
||||
import asyncio
|
||||
from copy import deepcopy
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.proxy._types import BaseDailySpendTransaction
|
||||
from litellm.proxy.db.db_transaction_queue.base_update_queue import (
|
||||
BaseUpdateQueue,
|
||||
service_logger_obj,
|
||||
)
|
||||
from litellm.types.services import ServiceTypes
|
||||
|
||||
|
||||
class DailySpendUpdateQueue(BaseUpdateQueue):
|
||||
"""
|
||||
In memory buffer for daily spend updates that should be committed to the database
|
||||
|
||||
To add a new daily spend update transaction, use the following format:
|
||||
daily_spend_update_queue.add_update({
|
||||
"user1_date_api_key_model_custom_llm_provider": {
|
||||
"spend": 10,
|
||||
"prompt_tokens": 100,
|
||||
"completion_tokens": 100,
|
||||
}
|
||||
})
|
||||
|
||||
Queue contains a list of daily spend update transactions
|
||||
|
||||
eg
|
||||
queue = [
|
||||
{
|
||||
"user1_date_api_key_model_custom_llm_provider": {
|
||||
"spend": 10,
|
||||
"prompt_tokens": 100,
|
||||
"completion_tokens": 100,
|
||||
"api_requests": 100,
|
||||
"successful_requests": 100,
|
||||
"failed_requests": 100,
|
||||
}
|
||||
},
|
||||
{
|
||||
"user2_date_api_key_model_custom_llm_provider": {
|
||||
"spend": 10,
|
||||
"prompt_tokens": 100,
|
||||
"completion_tokens": 100,
|
||||
"api_requests": 100,
|
||||
"successful_requests": 100,
|
||||
"failed_requests": 100,
|
||||
}
|
||||
}
|
||||
]
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.update_queue: asyncio.Queue[Dict[str, BaseDailySpendTransaction]] = (
|
||||
asyncio.Queue()
|
||||
)
|
||||
|
||||
async def add_update(self, update: Dict[str, BaseDailySpendTransaction]):
|
||||
"""Enqueue an update."""
|
||||
verbose_proxy_logger.debug("Adding update to queue: %s", update)
|
||||
await self.update_queue.put(update)
|
||||
if self.update_queue.qsize() >= self.MAX_SIZE_IN_MEMORY_QUEUE:
|
||||
verbose_proxy_logger.warning(
|
||||
"Spend update queue is full. Aggregating all entries in queue to concatenate entries."
|
||||
)
|
||||
await self.aggregate_queue_updates()
|
||||
|
||||
async def aggregate_queue_updates(self):
|
||||
"""
|
||||
Combine all updates in the queue into a single update.
|
||||
This is used to reduce the size of the in-memory queue.
|
||||
"""
|
||||
updates: List[Dict[str, BaseDailySpendTransaction]] = (
|
||||
await self.flush_all_updates_from_in_memory_queue()
|
||||
)
|
||||
aggregated_updates = self.get_aggregated_daily_spend_update_transactions(
|
||||
updates
|
||||
)
|
||||
await self.update_queue.put(aggregated_updates)
|
||||
|
||||
async def flush_and_get_aggregated_daily_spend_update_transactions(
|
||||
self,
|
||||
) -> Dict[str, BaseDailySpendTransaction]:
|
||||
"""Get all updates from the queue and return all updates aggregated by daily_transaction_key. Works for both user and team spend updates."""
|
||||
updates = await self.flush_all_updates_from_in_memory_queue()
|
||||
aggregated_daily_spend_update_transactions = (
|
||||
DailySpendUpdateQueue.get_aggregated_daily_spend_update_transactions(
|
||||
updates
|
||||
)
|
||||
)
|
||||
verbose_proxy_logger.debug(
|
||||
"Aggregated daily spend update transactions: %s",
|
||||
aggregated_daily_spend_update_transactions,
|
||||
)
|
||||
return aggregated_daily_spend_update_transactions
|
||||
|
||||
@staticmethod
|
||||
def get_aggregated_daily_spend_update_transactions(
|
||||
updates: List[Dict[str, BaseDailySpendTransaction]],
|
||||
) -> Dict[str, BaseDailySpendTransaction]:
|
||||
"""Aggregate updates by daily_transaction_key."""
|
||||
aggregated_daily_spend_update_transactions: Dict[
|
||||
str, BaseDailySpendTransaction
|
||||
] = {}
|
||||
for _update in updates:
|
||||
for _key, payload in _update.items():
|
||||
if _key in aggregated_daily_spend_update_transactions:
|
||||
daily_transaction = aggregated_daily_spend_update_transactions[_key]
|
||||
daily_transaction["spend"] += payload["spend"]
|
||||
daily_transaction["prompt_tokens"] += payload["prompt_tokens"]
|
||||
daily_transaction["completion_tokens"] += payload[
|
||||
"completion_tokens"
|
||||
]
|
||||
daily_transaction["api_requests"] += payload["api_requests"]
|
||||
daily_transaction["successful_requests"] += payload[
|
||||
"successful_requests"
|
||||
]
|
||||
daily_transaction["failed_requests"] += payload["failed_requests"]
|
||||
|
||||
# Add optional metrics cache_read_input_tokens and cache_creation_input_tokens
|
||||
daily_transaction["cache_read_input_tokens"] = (
|
||||
payload.get("cache_read_input_tokens", 0) or 0
|
||||
) + daily_transaction.get("cache_read_input_tokens", 0)
|
||||
|
||||
daily_transaction["cache_creation_input_tokens"] = (
|
||||
payload.get("cache_creation_input_tokens", 0) or 0
|
||||
) + daily_transaction.get("cache_creation_input_tokens", 0)
|
||||
|
||||
else:
|
||||
aggregated_daily_spend_update_transactions[_key] = deepcopy(payload)
|
||||
return aggregated_daily_spend_update_transactions
|
||||
|
||||
async def _emit_new_item_added_to_queue_event(
|
||||
self,
|
||||
queue_size: Optional[int] = None,
|
||||
):
|
||||
asyncio.create_task(
|
||||
service_logger_obj.async_service_success_hook(
|
||||
service=ServiceTypes.IN_MEMORY_DAILY_SPEND_UPDATE_QUEUE,
|
||||
duration=0,
|
||||
call_type="_emit_new_item_added_to_queue_event",
|
||||
event_metadata={
|
||||
"gauge_labels": ServiceTypes.IN_MEMORY_DAILY_SPEND_UPDATE_QUEUE,
|
||||
"gauge_value": queue_size,
|
||||
},
|
||||
)
|
||||
)
|
||||
@@ -0,0 +1,173 @@
|
||||
import asyncio
|
||||
import uuid
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.caching.redis_cache import RedisCache
|
||||
from litellm.constants import DEFAULT_CRON_JOB_LOCK_TTL_SECONDS
|
||||
from litellm.proxy.db.db_transaction_queue.base_update_queue import service_logger_obj
|
||||
from litellm.types.services import ServiceTypes
|
||||
|
||||
if TYPE_CHECKING:
|
||||
ProxyLogging = Any
|
||||
else:
|
||||
ProxyLogging = Any
|
||||
|
||||
|
||||
class PodLockManager:
|
||||
"""
|
||||
Manager for acquiring and releasing locks for cron jobs using Redis.
|
||||
|
||||
Ensures that only one pod can run a cron job at a time.
|
||||
"""
|
||||
|
||||
def __init__(self, redis_cache: Optional[RedisCache] = None):
|
||||
self.pod_id = str(uuid.uuid4())
|
||||
self.redis_cache = redis_cache
|
||||
|
||||
@staticmethod
|
||||
def get_redis_lock_key(cronjob_id: str) -> str:
|
||||
return f"cronjob_lock:{cronjob_id}"
|
||||
|
||||
async def acquire_lock(
|
||||
self,
|
||||
cronjob_id: str,
|
||||
) -> Optional[bool]:
|
||||
"""
|
||||
Attempt to acquire the lock for a specific cron job using Redis.
|
||||
Uses the SET command with NX and EX options to ensure atomicity.
|
||||
"""
|
||||
if self.redis_cache is None:
|
||||
verbose_proxy_logger.debug("redis_cache is None, skipping acquire_lock")
|
||||
return None
|
||||
try:
|
||||
verbose_proxy_logger.debug(
|
||||
"Pod %s attempting to acquire Redis lock for cronjob_id=%s",
|
||||
self.pod_id,
|
||||
cronjob_id,
|
||||
)
|
||||
# Try to set the lock key with the pod_id as its value, only if it doesn't exist (NX)
|
||||
# and with an expiration (EX) to avoid deadlocks.
|
||||
lock_key = PodLockManager.get_redis_lock_key(cronjob_id)
|
||||
acquired = await self.redis_cache.async_set_cache(
|
||||
lock_key,
|
||||
self.pod_id,
|
||||
nx=True,
|
||||
ttl=DEFAULT_CRON_JOB_LOCK_TTL_SECONDS,
|
||||
)
|
||||
if acquired:
|
||||
verbose_proxy_logger.info(
|
||||
"Pod %s successfully acquired Redis lock for cronjob_id=%s",
|
||||
self.pod_id,
|
||||
cronjob_id,
|
||||
)
|
||||
|
||||
return True
|
||||
else:
|
||||
# Check if the current pod already holds the lock
|
||||
current_value = await self.redis_cache.async_get_cache(lock_key)
|
||||
if current_value is not None:
|
||||
if isinstance(current_value, bytes):
|
||||
current_value = current_value.decode("utf-8")
|
||||
if current_value == self.pod_id:
|
||||
verbose_proxy_logger.info(
|
||||
"Pod %s already holds the Redis lock for cronjob_id=%s",
|
||||
self.pod_id,
|
||||
cronjob_id,
|
||||
)
|
||||
self._emit_acquired_lock_event(cronjob_id, self.pod_id)
|
||||
return True
|
||||
return False
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.error(
|
||||
f"Error acquiring Redis lock for {cronjob_id}: {e}"
|
||||
)
|
||||
return False
|
||||
|
||||
async def release_lock(
|
||||
self,
|
||||
cronjob_id: str,
|
||||
):
|
||||
"""
|
||||
Release the lock if the current pod holds it.
|
||||
Uses get and delete commands to ensure that only the owner can release the lock.
|
||||
"""
|
||||
if self.redis_cache is None:
|
||||
verbose_proxy_logger.debug("redis_cache is None, skipping release_lock")
|
||||
return
|
||||
try:
|
||||
cronjob_id = cronjob_id
|
||||
verbose_proxy_logger.debug(
|
||||
"Pod %s attempting to release Redis lock for cronjob_id=%s",
|
||||
self.pod_id,
|
||||
cronjob_id,
|
||||
)
|
||||
lock_key = PodLockManager.get_redis_lock_key(cronjob_id)
|
||||
|
||||
current_value = await self.redis_cache.async_get_cache(lock_key)
|
||||
if current_value is not None:
|
||||
if isinstance(current_value, bytes):
|
||||
current_value = current_value.decode("utf-8")
|
||||
if current_value == self.pod_id:
|
||||
result = await self.redis_cache.async_delete_cache(lock_key)
|
||||
if result == 1:
|
||||
verbose_proxy_logger.info(
|
||||
"Pod %s successfully released Redis lock for cronjob_id=%s",
|
||||
self.pod_id,
|
||||
cronjob_id,
|
||||
)
|
||||
self._emit_released_lock_event(
|
||||
cronjob_id=cronjob_id,
|
||||
pod_id=self.pod_id,
|
||||
)
|
||||
else:
|
||||
verbose_proxy_logger.debug(
|
||||
"Pod %s failed to release Redis lock for cronjob_id=%s",
|
||||
self.pod_id,
|
||||
cronjob_id,
|
||||
)
|
||||
else:
|
||||
verbose_proxy_logger.debug(
|
||||
"Pod %s cannot release Redis lock for cronjob_id=%s because it is held by pod %s",
|
||||
self.pod_id,
|
||||
cronjob_id,
|
||||
current_value,
|
||||
)
|
||||
else:
|
||||
verbose_proxy_logger.debug(
|
||||
"Pod %s attempted to release Redis lock for cronjob_id=%s, but no lock was found",
|
||||
self.pod_id,
|
||||
cronjob_id,
|
||||
)
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.error(
|
||||
f"Error releasing Redis lock for {cronjob_id}: {e}"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _emit_acquired_lock_event(cronjob_id: str, pod_id: str):
|
||||
asyncio.create_task(
|
||||
service_logger_obj.async_service_success_hook(
|
||||
service=ServiceTypes.POD_LOCK_MANAGER,
|
||||
duration=DEFAULT_CRON_JOB_LOCK_TTL_SECONDS,
|
||||
call_type="_emit_acquired_lock_event",
|
||||
event_metadata={
|
||||
"gauge_labels": f"{cronjob_id}:{pod_id}",
|
||||
"gauge_value": 1,
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _emit_released_lock_event(cronjob_id: str, pod_id: str):
|
||||
asyncio.create_task(
|
||||
service_logger_obj.async_service_success_hook(
|
||||
service=ServiceTypes.POD_LOCK_MANAGER,
|
||||
duration=DEFAULT_CRON_JOB_LOCK_TTL_SECONDS,
|
||||
call_type="_emit_released_lock_event",
|
||||
event_metadata={
|
||||
"gauge_labels": f"{cronjob_id}:{pod_id}",
|
||||
"gauge_value": 0,
|
||||
},
|
||||
)
|
||||
)
|
||||
@@ -0,0 +1,405 @@
|
||||
"""
|
||||
Handles buffering database `UPDATE` transactions in Redis before committing them to the database
|
||||
|
||||
This is to prevent deadlocks and improve reliability
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union, cast
|
||||
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.caching import RedisCache
|
||||
from litellm.constants import (
|
||||
MAX_REDIS_BUFFER_DEQUEUE_COUNT,
|
||||
REDIS_DAILY_SPEND_UPDATE_BUFFER_KEY,
|
||||
REDIS_DAILY_TAG_SPEND_UPDATE_BUFFER_KEY,
|
||||
REDIS_DAILY_TEAM_SPEND_UPDATE_BUFFER_KEY,
|
||||
REDIS_UPDATE_BUFFER_KEY,
|
||||
)
|
||||
from litellm.litellm_core_utils.safe_json_dumps import safe_dumps
|
||||
from litellm.proxy._types import (
|
||||
DailyTeamSpendTransaction,
|
||||
DailyUserSpendTransaction,
|
||||
DBSpendUpdateTransactions,
|
||||
)
|
||||
from litellm.proxy.db.db_transaction_queue.base_update_queue import service_logger_obj
|
||||
from litellm.proxy.db.db_transaction_queue.daily_spend_update_queue import (
|
||||
DailySpendUpdateQueue,
|
||||
)
|
||||
from litellm.proxy.db.db_transaction_queue.spend_update_queue import SpendUpdateQueue
|
||||
from litellm.secret_managers.main import str_to_bool
|
||||
from litellm.types.services import ServiceTypes
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.proxy.utils import PrismaClient
|
||||
else:
|
||||
PrismaClient = Any
|
||||
|
||||
|
||||
class RedisUpdateBuffer:
|
||||
"""
|
||||
Handles buffering database `UPDATE` transactions in Redis before committing them to the database
|
||||
|
||||
This is to prevent deadlocks and improve reliability
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
redis_cache: Optional[RedisCache] = None,
|
||||
):
|
||||
self.redis_cache = redis_cache
|
||||
|
||||
@staticmethod
|
||||
def _should_commit_spend_updates_to_redis() -> bool:
|
||||
"""
|
||||
Checks if the Pod should commit spend updates to Redis
|
||||
|
||||
This setting enables buffering database transactions in Redis
|
||||
to improve reliability and reduce database contention
|
||||
"""
|
||||
from litellm.proxy.proxy_server import general_settings
|
||||
|
||||
_use_redis_transaction_buffer: Optional[
|
||||
Union[bool, str]
|
||||
] = general_settings.get("use_redis_transaction_buffer", False)
|
||||
if isinstance(_use_redis_transaction_buffer, str):
|
||||
_use_redis_transaction_buffer = str_to_bool(_use_redis_transaction_buffer)
|
||||
if _use_redis_transaction_buffer is None:
|
||||
return False
|
||||
return _use_redis_transaction_buffer
|
||||
|
||||
async def _store_transactions_in_redis(
|
||||
self,
|
||||
transactions: Any,
|
||||
redis_key: str,
|
||||
service_type: ServiceTypes,
|
||||
) -> None:
|
||||
"""
|
||||
Helper method to store transactions in Redis and emit an event
|
||||
|
||||
Args:
|
||||
transactions: The transactions to store
|
||||
redis_key: The Redis key to store under
|
||||
service_type: The service type for event emission
|
||||
"""
|
||||
if transactions is None or len(transactions) == 0:
|
||||
return
|
||||
|
||||
list_of_transactions = [safe_dumps(transactions)]
|
||||
if self.redis_cache is None:
|
||||
return
|
||||
current_redis_buffer_size = await self.redis_cache.async_rpush(
|
||||
key=redis_key,
|
||||
values=list_of_transactions,
|
||||
)
|
||||
await self._emit_new_item_added_to_redis_buffer_event(
|
||||
queue_size=current_redis_buffer_size,
|
||||
service=service_type,
|
||||
)
|
||||
|
||||
async def store_in_memory_spend_updates_in_redis(
|
||||
self,
|
||||
spend_update_queue: SpendUpdateQueue,
|
||||
daily_spend_update_queue: DailySpendUpdateQueue,
|
||||
daily_team_spend_update_queue: DailySpendUpdateQueue,
|
||||
daily_tag_spend_update_queue: DailySpendUpdateQueue,
|
||||
):
|
||||
"""
|
||||
Stores the in-memory spend updates to Redis
|
||||
|
||||
Stores the following in memory data structures in Redis:
|
||||
- SpendUpdateQueue - Key, User, Team, TeamMember, Org, EndUser Spend updates
|
||||
- DailySpendUpdateQueue - Daily Spend updates Aggregate view
|
||||
|
||||
For SpendUpdateQueue:
|
||||
Each transaction is a dict stored as following:
|
||||
- key is the entity id
|
||||
- value is the spend amount
|
||||
|
||||
```
|
||||
Redis List:
|
||||
key_list_transactions:
|
||||
[
|
||||
"0929880201": 1.2,
|
||||
"0929880202": 0.01,
|
||||
"0929880203": 0.001,
|
||||
]
|
||||
```
|
||||
|
||||
For DailySpendUpdateQueue:
|
||||
Each transaction is a Dict[str, DailyUserSpendTransaction] stored as following:
|
||||
- key is the daily_transaction_key
|
||||
- value is the DailyUserSpendTransaction
|
||||
|
||||
```
|
||||
Redis List:
|
||||
daily_spend_update_transactions:
|
||||
[
|
||||
{
|
||||
"user_keyhash_1_model_1": {
|
||||
"spend": 1.2,
|
||||
"prompt_tokens": 1000,
|
||||
"completion_tokens": 1000,
|
||||
"api_requests": 1000,
|
||||
"successful_requests": 1000,
|
||||
},
|
||||
|
||||
}
|
||||
]
|
||||
```
|
||||
"""
|
||||
if self.redis_cache is None:
|
||||
verbose_proxy_logger.debug(
|
||||
"redis_cache is None, skipping store_in_memory_spend_updates_in_redis"
|
||||
)
|
||||
return
|
||||
|
||||
# Get all transactions
|
||||
db_spend_update_transactions = (
|
||||
await spend_update_queue.flush_and_get_aggregated_db_spend_update_transactions()
|
||||
)
|
||||
daily_spend_update_transactions = (
|
||||
await daily_spend_update_queue.flush_and_get_aggregated_daily_spend_update_transactions()
|
||||
)
|
||||
daily_team_spend_update_transactions = (
|
||||
await daily_team_spend_update_queue.flush_and_get_aggregated_daily_spend_update_transactions()
|
||||
)
|
||||
daily_tag_spend_update_transactions = (
|
||||
await daily_tag_spend_update_queue.flush_and_get_aggregated_daily_spend_update_transactions()
|
||||
)
|
||||
|
||||
verbose_proxy_logger.debug(
|
||||
"ALL DB SPEND UPDATE TRANSACTIONS: %s", db_spend_update_transactions
|
||||
)
|
||||
verbose_proxy_logger.debug(
|
||||
"ALL DAILY SPEND UPDATE TRANSACTIONS: %s", daily_spend_update_transactions
|
||||
)
|
||||
|
||||
# only store in redis if there are any updates to commit
|
||||
if (
|
||||
self._number_of_transactions_to_store_in_redis(db_spend_update_transactions)
|
||||
== 0
|
||||
):
|
||||
return
|
||||
|
||||
# Store all transaction types using the helper method
|
||||
await self._store_transactions_in_redis(
|
||||
transactions=db_spend_update_transactions,
|
||||
redis_key=REDIS_UPDATE_BUFFER_KEY,
|
||||
service_type=ServiceTypes.REDIS_SPEND_UPDATE_QUEUE,
|
||||
)
|
||||
|
||||
await self._store_transactions_in_redis(
|
||||
transactions=daily_spend_update_transactions,
|
||||
redis_key=REDIS_DAILY_SPEND_UPDATE_BUFFER_KEY,
|
||||
service_type=ServiceTypes.REDIS_DAILY_SPEND_UPDATE_QUEUE,
|
||||
)
|
||||
|
||||
await self._store_transactions_in_redis(
|
||||
transactions=daily_team_spend_update_transactions,
|
||||
redis_key=REDIS_DAILY_TEAM_SPEND_UPDATE_BUFFER_KEY,
|
||||
service_type=ServiceTypes.REDIS_DAILY_TEAM_SPEND_UPDATE_QUEUE,
|
||||
)
|
||||
|
||||
await self._store_transactions_in_redis(
|
||||
transactions=daily_tag_spend_update_transactions,
|
||||
redis_key=REDIS_DAILY_TAG_SPEND_UPDATE_BUFFER_KEY,
|
||||
service_type=ServiceTypes.REDIS_DAILY_TAG_SPEND_UPDATE_QUEUE,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _number_of_transactions_to_store_in_redis(
|
||||
db_spend_update_transactions: DBSpendUpdateTransactions,
|
||||
) -> int:
|
||||
"""
|
||||
Gets the number of transactions to store in Redis
|
||||
"""
|
||||
num_transactions = 0
|
||||
for v in db_spend_update_transactions.values():
|
||||
if isinstance(v, dict):
|
||||
num_transactions += len(v)
|
||||
return num_transactions
|
||||
|
||||
@staticmethod
|
||||
def _remove_prefix_from_keys(data: Dict[str, Any], prefix: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Removes the specified prefix from the keys of a dictionary.
|
||||
"""
|
||||
return {key.replace(prefix, "", 1): value for key, value in data.items()}
|
||||
|
||||
async def get_all_update_transactions_from_redis_buffer(
|
||||
self,
|
||||
) -> Optional[DBSpendUpdateTransactions]:
|
||||
"""
|
||||
Gets all the update transactions from Redis
|
||||
|
||||
On Redis we store a list of transactions as a JSON string
|
||||
|
||||
eg.
|
||||
[
|
||||
DBSpendUpdateTransactions(
|
||||
user_list_transactions={
|
||||
"user_id_1": 1.2,
|
||||
"user_id_2": 0.01,
|
||||
},
|
||||
end_user_list_transactions={},
|
||||
key_list_transactions={
|
||||
"0929880201": 1.2,
|
||||
"0929880202": 0.01,
|
||||
},
|
||||
team_list_transactions={},
|
||||
team_member_list_transactions={},
|
||||
org_list_transactions={},
|
||||
),
|
||||
DBSpendUpdateTransactions(
|
||||
user_list_transactions={
|
||||
"user_id_3": 1.2,
|
||||
"user_id_4": 0.01,
|
||||
},
|
||||
end_user_list_transactions={},
|
||||
key_list_transactions={
|
||||
"key_id_1": 1.2,
|
||||
"key_id_2": 0.01,
|
||||
},
|
||||
team_list_transactions={},
|
||||
team_member_list_transactions={},
|
||||
org_list_transactions={},
|
||||
]
|
||||
"""
|
||||
if self.redis_cache is None:
|
||||
return None
|
||||
list_of_transactions = await self.redis_cache.async_lpop(
|
||||
key=REDIS_UPDATE_BUFFER_KEY,
|
||||
count=MAX_REDIS_BUFFER_DEQUEUE_COUNT,
|
||||
)
|
||||
if list_of_transactions is None:
|
||||
return None
|
||||
|
||||
# Parse the list of transactions from JSON strings
|
||||
parsed_transactions = self._parse_list_of_transactions(list_of_transactions)
|
||||
|
||||
# If there are no transactions, return None
|
||||
if len(parsed_transactions) == 0:
|
||||
return None
|
||||
|
||||
# Combine all transactions into a single transaction
|
||||
combined_transaction = self._combine_list_of_transactions(parsed_transactions)
|
||||
|
||||
return combined_transaction
|
||||
|
||||
async def get_all_daily_spend_update_transactions_from_redis_buffer(
|
||||
self,
|
||||
) -> Optional[Dict[str, DailyUserSpendTransaction]]:
|
||||
"""
|
||||
Gets all the daily spend update transactions from Redis
|
||||
"""
|
||||
if self.redis_cache is None:
|
||||
return None
|
||||
list_of_transactions = await self.redis_cache.async_lpop(
|
||||
key=REDIS_DAILY_SPEND_UPDATE_BUFFER_KEY,
|
||||
count=MAX_REDIS_BUFFER_DEQUEUE_COUNT,
|
||||
)
|
||||
if list_of_transactions is None:
|
||||
return None
|
||||
list_of_daily_spend_update_transactions = [
|
||||
json.loads(transaction) for transaction in list_of_transactions
|
||||
]
|
||||
return cast(
|
||||
Dict[str, DailyUserSpendTransaction],
|
||||
DailySpendUpdateQueue.get_aggregated_daily_spend_update_transactions(
|
||||
list_of_daily_spend_update_transactions
|
||||
),
|
||||
)
|
||||
|
||||
async def get_all_daily_team_spend_update_transactions_from_redis_buffer(
|
||||
self,
|
||||
) -> Optional[Dict[str, DailyTeamSpendTransaction]]:
|
||||
"""
|
||||
Gets all the daily team spend update transactions from Redis
|
||||
"""
|
||||
if self.redis_cache is None:
|
||||
return None
|
||||
list_of_transactions = await self.redis_cache.async_lpop(
|
||||
key=REDIS_DAILY_TEAM_SPEND_UPDATE_BUFFER_KEY,
|
||||
count=MAX_REDIS_BUFFER_DEQUEUE_COUNT,
|
||||
)
|
||||
if list_of_transactions is None:
|
||||
return None
|
||||
list_of_daily_spend_update_transactions = [
|
||||
json.loads(transaction) for transaction in list_of_transactions
|
||||
]
|
||||
return cast(
|
||||
Dict[str, DailyTeamSpendTransaction],
|
||||
DailySpendUpdateQueue.get_aggregated_daily_spend_update_transactions(
|
||||
list_of_daily_spend_update_transactions
|
||||
),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _parse_list_of_transactions(
|
||||
list_of_transactions: Union[Any, List[Any]],
|
||||
) -> List[DBSpendUpdateTransactions]:
|
||||
"""
|
||||
Parses the list of transactions from Redis
|
||||
"""
|
||||
if isinstance(list_of_transactions, list):
|
||||
return [json.loads(transaction) for transaction in list_of_transactions]
|
||||
else:
|
||||
return [json.loads(list_of_transactions)]
|
||||
|
||||
@staticmethod
|
||||
def _combine_list_of_transactions(
|
||||
list_of_transactions: List[DBSpendUpdateTransactions],
|
||||
) -> DBSpendUpdateTransactions:
|
||||
"""
|
||||
Combines the list of transactions into a single DBSpendUpdateTransactions object
|
||||
"""
|
||||
# Initialize a new combined transaction object with empty dictionaries
|
||||
combined_transaction = DBSpendUpdateTransactions(
|
||||
user_list_transactions={},
|
||||
end_user_list_transactions={},
|
||||
key_list_transactions={},
|
||||
team_list_transactions={},
|
||||
team_member_list_transactions={},
|
||||
org_list_transactions={},
|
||||
)
|
||||
|
||||
# Define the transaction fields to process
|
||||
transaction_fields = [
|
||||
"user_list_transactions",
|
||||
"end_user_list_transactions",
|
||||
"key_list_transactions",
|
||||
"team_list_transactions",
|
||||
"team_member_list_transactions",
|
||||
"org_list_transactions",
|
||||
]
|
||||
|
||||
# Loop through each transaction and combine the values
|
||||
for transaction in list_of_transactions:
|
||||
# Process each field type
|
||||
for field in transaction_fields:
|
||||
if transaction.get(field):
|
||||
for entity_id, amount in transaction[field].items(): # type: ignore
|
||||
combined_transaction[field][entity_id] = ( # type: ignore
|
||||
combined_transaction[field].get(entity_id, 0) + amount # type: ignore
|
||||
)
|
||||
|
||||
return combined_transaction
|
||||
|
||||
async def _emit_new_item_added_to_redis_buffer_event(
|
||||
self,
|
||||
service: ServiceTypes,
|
||||
queue_size: int,
|
||||
):
|
||||
asyncio.create_task(
|
||||
service_logger_obj.async_service_success_hook(
|
||||
service=service,
|
||||
duration=0,
|
||||
call_type="_emit_new_item_added_to_queue_event",
|
||||
event_metadata={
|
||||
"gauge_labels": service,
|
||||
"gauge_value": queue_size,
|
||||
},
|
||||
)
|
||||
)
|
||||
@@ -0,0 +1,225 @@
|
||||
import asyncio
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.proxy._types import (
|
||||
DBSpendUpdateTransactions,
|
||||
Litellm_EntityType,
|
||||
SpendUpdateQueueItem,
|
||||
)
|
||||
from litellm.proxy.db.db_transaction_queue.base_update_queue import (
|
||||
BaseUpdateQueue,
|
||||
service_logger_obj,
|
||||
)
|
||||
from litellm.types.services import ServiceTypes
|
||||
|
||||
|
||||
class SpendUpdateQueue(BaseUpdateQueue):
|
||||
"""
|
||||
In memory buffer for spend updates that should be committed to the database
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.update_queue: asyncio.Queue[SpendUpdateQueueItem] = asyncio.Queue()
|
||||
|
||||
async def flush_and_get_aggregated_db_spend_update_transactions(
|
||||
self,
|
||||
) -> DBSpendUpdateTransactions:
|
||||
"""Flush all updates from the queue and return all updates aggregated by entity type."""
|
||||
updates = await self.flush_all_updates_from_in_memory_queue()
|
||||
verbose_proxy_logger.debug("Aggregating updates by entity type: %s", updates)
|
||||
return self.get_aggregated_db_spend_update_transactions(updates)
|
||||
|
||||
async def add_update(self, update: SpendUpdateQueueItem):
|
||||
"""Enqueue an update to the spend update queue"""
|
||||
verbose_proxy_logger.debug("Adding update to queue: %s", update)
|
||||
await self.update_queue.put(update)
|
||||
|
||||
# if the queue is full, aggregate the updates
|
||||
if self.update_queue.qsize() >= self.MAX_SIZE_IN_MEMORY_QUEUE:
|
||||
verbose_proxy_logger.warning(
|
||||
"Spend update queue is full. Aggregating all entries in queue to concatenate entries."
|
||||
)
|
||||
await self.aggregate_queue_updates()
|
||||
|
||||
async def aggregate_queue_updates(self):
|
||||
"""Concatenate all updates in the queue to reduce the size of in-memory queue"""
|
||||
updates: List[
|
||||
SpendUpdateQueueItem
|
||||
] = await self.flush_all_updates_from_in_memory_queue()
|
||||
aggregated_updates = self._get_aggregated_spend_update_queue_item(updates)
|
||||
for update in aggregated_updates:
|
||||
await self.update_queue.put(update)
|
||||
return
|
||||
|
||||
def _get_aggregated_spend_update_queue_item(
|
||||
self, updates: List[SpendUpdateQueueItem]
|
||||
) -> List[SpendUpdateQueueItem]:
|
||||
"""
|
||||
This is used to reduce the size of the in-memory queue by aggregating updates by entity type + id
|
||||
|
||||
|
||||
Aggregate updates by entity type + id
|
||||
|
||||
eg.
|
||||
|
||||
```
|
||||
[
|
||||
{
|
||||
"entity_type": "user",
|
||||
"entity_id": "123",
|
||||
"response_cost": 100
|
||||
},
|
||||
{
|
||||
"entity_type": "user",
|
||||
"entity_id": "123",
|
||||
"response_cost": 200
|
||||
}
|
||||
]
|
||||
|
||||
```
|
||||
|
||||
becomes
|
||||
|
||||
```
|
||||
|
||||
[
|
||||
{
|
||||
"entity_type": "user",
|
||||
"entity_id": "123",
|
||||
"response_cost": 300
|
||||
}
|
||||
]
|
||||
|
||||
```
|
||||
"""
|
||||
verbose_proxy_logger.debug(
|
||||
"Aggregating spend updates, current queue size: %s",
|
||||
self.update_queue.qsize(),
|
||||
)
|
||||
aggregated_spend_updates: List[SpendUpdateQueueItem] = []
|
||||
|
||||
_in_memory_map: Dict[str, SpendUpdateQueueItem] = {}
|
||||
"""
|
||||
Used for combining several updates into a single update
|
||||
Key=entity_type:entity_id
|
||||
Value=SpendUpdateQueueItem
|
||||
"""
|
||||
for update in updates:
|
||||
_key = f"{update.get('entity_type')}:{update.get('entity_id')}"
|
||||
if _key not in _in_memory_map:
|
||||
_in_memory_map[_key] = update
|
||||
else:
|
||||
current_cost = _in_memory_map[_key].get("response_cost", 0) or 0
|
||||
update_cost = update.get("response_cost", 0) or 0
|
||||
_in_memory_map[_key]["response_cost"] = current_cost + update_cost
|
||||
|
||||
for _key, update in _in_memory_map.items():
|
||||
aggregated_spend_updates.append(update)
|
||||
|
||||
verbose_proxy_logger.debug(
|
||||
"Aggregated spend updates: %s", aggregated_spend_updates
|
||||
)
|
||||
return aggregated_spend_updates
|
||||
|
||||
def get_aggregated_db_spend_update_transactions(
|
||||
self, updates: List[SpendUpdateQueueItem]
|
||||
) -> DBSpendUpdateTransactions:
|
||||
"""Aggregate updates by entity type."""
|
||||
# Initialize all transaction lists as empty dicts
|
||||
db_spend_update_transactions = DBSpendUpdateTransactions(
|
||||
user_list_transactions={},
|
||||
end_user_list_transactions={},
|
||||
key_list_transactions={},
|
||||
team_list_transactions={},
|
||||
team_member_list_transactions={},
|
||||
org_list_transactions={},
|
||||
)
|
||||
|
||||
# Map entity types to their corresponding transaction dictionary keys
|
||||
entity_type_to_dict_key = {
|
||||
Litellm_EntityType.USER: "user_list_transactions",
|
||||
Litellm_EntityType.END_USER: "end_user_list_transactions",
|
||||
Litellm_EntityType.KEY: "key_list_transactions",
|
||||
Litellm_EntityType.TEAM: "team_list_transactions",
|
||||
Litellm_EntityType.TEAM_MEMBER: "team_member_list_transactions",
|
||||
Litellm_EntityType.ORGANIZATION: "org_list_transactions",
|
||||
}
|
||||
|
||||
for update in updates:
|
||||
entity_type = update.get("entity_type")
|
||||
entity_id = update.get("entity_id") or ""
|
||||
response_cost = update.get("response_cost") or 0
|
||||
|
||||
if entity_type is None:
|
||||
verbose_proxy_logger.debug(
|
||||
"Skipping update spend for update: %s, because entity_type is None",
|
||||
update,
|
||||
)
|
||||
continue
|
||||
|
||||
dict_key = entity_type_to_dict_key.get(entity_type)
|
||||
if dict_key is None:
|
||||
verbose_proxy_logger.debug(
|
||||
"Skipping update spend for update: %s, because entity_type is not in entity_type_to_dict_key",
|
||||
update,
|
||||
)
|
||||
continue # Skip unknown entity types
|
||||
|
||||
# Type-safe access using if/elif statements
|
||||
if dict_key == "user_list_transactions":
|
||||
transactions_dict = db_spend_update_transactions[
|
||||
"user_list_transactions"
|
||||
]
|
||||
elif dict_key == "end_user_list_transactions":
|
||||
transactions_dict = db_spend_update_transactions[
|
||||
"end_user_list_transactions"
|
||||
]
|
||||
elif dict_key == "key_list_transactions":
|
||||
transactions_dict = db_spend_update_transactions[
|
||||
"key_list_transactions"
|
||||
]
|
||||
elif dict_key == "team_list_transactions":
|
||||
transactions_dict = db_spend_update_transactions[
|
||||
"team_list_transactions"
|
||||
]
|
||||
elif dict_key == "team_member_list_transactions":
|
||||
transactions_dict = db_spend_update_transactions[
|
||||
"team_member_list_transactions"
|
||||
]
|
||||
elif dict_key == "org_list_transactions":
|
||||
transactions_dict = db_spend_update_transactions[
|
||||
"org_list_transactions"
|
||||
]
|
||||
else:
|
||||
continue
|
||||
|
||||
if transactions_dict is None:
|
||||
transactions_dict = {}
|
||||
|
||||
# type ignore: dict_key is guaranteed to be one of "one of ("user_list_transactions", "end_user_list_transactions", "key_list_transactions", "team_list_transactions", "team_member_list_transactions", "org_list_transactions")"
|
||||
db_spend_update_transactions[dict_key] = transactions_dict # type: ignore
|
||||
|
||||
if entity_id not in transactions_dict:
|
||||
transactions_dict[entity_id] = 0
|
||||
|
||||
transactions_dict[entity_id] += response_cost or 0
|
||||
|
||||
return db_spend_update_transactions
|
||||
|
||||
async def _emit_new_item_added_to_queue_event(
|
||||
self,
|
||||
queue_size: Optional[int] = None,
|
||||
):
|
||||
asyncio.create_task(
|
||||
service_logger_obj.async_service_success_hook(
|
||||
service=ServiceTypes.IN_MEMORY_SPEND_UPDATE_QUEUE,
|
||||
duration=0,
|
||||
call_type="_emit_new_item_added_to_queue_event",
|
||||
event_metadata={
|
||||
"gauge_labels": ServiceTypes.IN_MEMORY_SPEND_UPDATE_QUEUE,
|
||||
"gauge_value": queue_size,
|
||||
},
|
||||
)
|
||||
)
|
||||
@@ -0,0 +1,71 @@
|
||||
"""
|
||||
Deprecated. Only PostgresSQL is supported.
|
||||
"""
|
||||
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.proxy._types import DynamoDBArgs
|
||||
from litellm.proxy.db.base_client import CustomDB
|
||||
|
||||
|
||||
class DynamoDBWrapper(CustomDB):
|
||||
from aiodynamo.credentials import Credentials, StaticCredentials
|
||||
|
||||
credentials: Credentials
|
||||
|
||||
def __init__(self, database_arguments: DynamoDBArgs):
|
||||
from aiodynamo.models import PayPerRequest, Throughput
|
||||
|
||||
self.throughput_type = None
|
||||
if database_arguments.billing_mode == "PAY_PER_REQUEST":
|
||||
self.throughput_type = PayPerRequest()
|
||||
elif database_arguments.billing_mode == "PROVISIONED_THROUGHPUT":
|
||||
if (
|
||||
database_arguments.read_capacity_units is not None
|
||||
and isinstance(database_arguments.read_capacity_units, int)
|
||||
and database_arguments.write_capacity_units is not None
|
||||
and isinstance(database_arguments.write_capacity_units, int)
|
||||
):
|
||||
self.throughput_type = Throughput(read=database_arguments.read_capacity_units, write=database_arguments.write_capacity_units) # type: ignore
|
||||
else:
|
||||
raise Exception(
|
||||
f"Invalid args passed in. Need to set both read_capacity_units and write_capacity_units. Args passed in - {database_arguments}"
|
||||
)
|
||||
self.database_arguments = database_arguments
|
||||
self.region_name = database_arguments.region_name
|
||||
|
||||
def set_env_vars_based_on_arn(self):
|
||||
if self.database_arguments.aws_role_name is None:
|
||||
return
|
||||
verbose_proxy_logger.debug(
|
||||
f"DynamoDB: setting env vars based on arn={self.database_arguments.aws_role_name}"
|
||||
)
|
||||
import os
|
||||
|
||||
import boto3
|
||||
|
||||
sts_client = boto3.client("sts")
|
||||
|
||||
# call 1
|
||||
sts_client.assume_role_with_web_identity(
|
||||
RoleArn=self.database_arguments.aws_role_name,
|
||||
RoleSessionName=self.database_arguments.aws_session_name,
|
||||
WebIdentityToken=self.database_arguments.aws_web_identity_token,
|
||||
)
|
||||
|
||||
# call 2
|
||||
assumed_role = sts_client.assume_role(
|
||||
RoleArn=self.database_arguments.assume_role_aws_role_name,
|
||||
RoleSessionName=self.database_arguments.assume_role_aws_session_name,
|
||||
)
|
||||
|
||||
aws_access_key_id = assumed_role["Credentials"]["AccessKeyId"]
|
||||
aws_secret_access_key = assumed_role["Credentials"]["SecretAccessKey"]
|
||||
aws_session_token = assumed_role["Credentials"]["SessionToken"]
|
||||
|
||||
verbose_proxy_logger.debug(
|
||||
f"Got STS assumed Role, aws_access_key_id={aws_access_key_id}"
|
||||
)
|
||||
# set these in the env so aiodynamo can use them
|
||||
os.environ["AWS_ACCESS_KEY_ID"] = aws_access_key_id
|
||||
os.environ["AWS_SECRET_ACCESS_KEY"] = aws_secret_access_key
|
||||
os.environ["AWS_SESSION_TOKEN"] = aws_session_token
|
||||
@@ -0,0 +1,61 @@
|
||||
from typing import Union
|
||||
|
||||
from litellm.proxy._types import (
|
||||
DB_CONNECTION_ERROR_TYPES,
|
||||
ProxyErrorTypes,
|
||||
ProxyException,
|
||||
)
|
||||
from litellm.secret_managers.main import str_to_bool
|
||||
|
||||
|
||||
class PrismaDBExceptionHandler:
|
||||
"""
|
||||
Class to handle DB Exceptions or Connection Errors
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def should_allow_request_on_db_unavailable() -> bool:
|
||||
"""
|
||||
Returns True if the request should be allowed to proceed despite the DB connection error
|
||||
"""
|
||||
from litellm.proxy.proxy_server import general_settings
|
||||
|
||||
_allow_requests_on_db_unavailable: Union[bool, str] = general_settings.get(
|
||||
"allow_requests_on_db_unavailable", False
|
||||
)
|
||||
if isinstance(_allow_requests_on_db_unavailable, bool):
|
||||
return _allow_requests_on_db_unavailable
|
||||
if str_to_bool(_allow_requests_on_db_unavailable) is True:
|
||||
return True
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def is_database_connection_error(e: Exception) -> bool:
|
||||
"""
|
||||
Returns True if the exception is from a database outage / connection error
|
||||
"""
|
||||
import prisma
|
||||
|
||||
if isinstance(e, DB_CONNECTION_ERROR_TYPES):
|
||||
return True
|
||||
if isinstance(e, prisma.errors.PrismaError):
|
||||
return True
|
||||
if isinstance(e, ProxyException) and e.type == ProxyErrorTypes.no_db_connection:
|
||||
return True
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def handle_db_exception(e: Exception):
|
||||
"""
|
||||
Primary handler for `allow_requests_on_db_unavailable` flag. Decides whether to raise a DB Exception or not based on the flag.
|
||||
|
||||
- If exception is a DB Connection Error, and `allow_requests_on_db_unavailable` is True,
|
||||
- Do not raise an exception, return None
|
||||
- Else, raise the exception
|
||||
"""
|
||||
if (
|
||||
PrismaDBExceptionHandler.is_database_connection_error(e)
|
||||
and PrismaDBExceptionHandler.should_allow_request_on_db_unavailable()
|
||||
):
|
||||
return None
|
||||
raise e
|
||||
@@ -0,0 +1,142 @@
|
||||
"""
|
||||
Handles logging DB success/failure to ServiceLogger()
|
||||
|
||||
ServiceLogger() then sends DB logs to Prometheus, OTEL, Datadog etc
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from datetime import datetime
|
||||
from functools import wraps
|
||||
from typing import Callable, Dict, Tuple
|
||||
|
||||
from litellm._service_logger import ServiceTypes
|
||||
from litellm.litellm_core_utils.core_helpers import (
|
||||
_get_parent_otel_span_from_kwargs,
|
||||
get_litellm_metadata_from_kwargs,
|
||||
)
|
||||
|
||||
|
||||
def log_db_metrics(func):
|
||||
"""
|
||||
Decorator to log the duration of a DB related function to ServiceLogger()
|
||||
|
||||
Handles logging DB success/failure to ServiceLogger(), which logs to Prometheus, OTEL, Datadog
|
||||
|
||||
When logging Failure it checks if the Exception is a PrismaError, httpx.ConnectError or httpx.TimeoutException and then logs that as a DB Service Failure
|
||||
|
||||
Args:
|
||||
func: The function to be decorated
|
||||
|
||||
Returns:
|
||||
Result from the decorated function
|
||||
|
||||
Raises:
|
||||
Exception: If the decorated function raises an exception
|
||||
"""
|
||||
|
||||
@wraps(func)
|
||||
async def wrapper(*args, **kwargs):
|
||||
start_time: datetime = datetime.now()
|
||||
|
||||
try:
|
||||
result = await func(*args, **kwargs)
|
||||
end_time: datetime = datetime.now()
|
||||
from litellm.proxy.proxy_server import proxy_logging_obj
|
||||
|
||||
if "PROXY" not in func.__name__:
|
||||
asyncio.create_task(
|
||||
proxy_logging_obj.service_logging_obj.async_service_success_hook(
|
||||
service=ServiceTypes.DB,
|
||||
call_type=func.__name__,
|
||||
parent_otel_span=kwargs.get("parent_otel_span", None),
|
||||
duration=(end_time - start_time).total_seconds(),
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
event_metadata={
|
||||
"function_name": func.__name__,
|
||||
"function_kwargs": kwargs,
|
||||
"function_args": args,
|
||||
},
|
||||
)
|
||||
)
|
||||
elif (
|
||||
# in litellm custom callbacks kwargs is passed as arg[0]
|
||||
# https://docs.litellm.ai/docs/observability/custom_callback#callback-functions
|
||||
args is not None
|
||||
and len(args) > 1
|
||||
and isinstance(args[1], dict)
|
||||
):
|
||||
passed_kwargs = args[1]
|
||||
parent_otel_span = _get_parent_otel_span_from_kwargs(
|
||||
kwargs=passed_kwargs
|
||||
)
|
||||
if parent_otel_span is not None:
|
||||
metadata = get_litellm_metadata_from_kwargs(kwargs=passed_kwargs)
|
||||
|
||||
asyncio.create_task(
|
||||
proxy_logging_obj.service_logging_obj.async_service_success_hook(
|
||||
service=ServiceTypes.BATCH_WRITE_TO_DB,
|
||||
call_type=func.__name__,
|
||||
parent_otel_span=parent_otel_span,
|
||||
duration=0.0,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
event_metadata=metadata,
|
||||
)
|
||||
)
|
||||
# end of logging to otel
|
||||
return result
|
||||
except Exception as e:
|
||||
end_time: datetime = datetime.now()
|
||||
await _handle_logging_db_exception(
|
||||
e=e,
|
||||
func=func,
|
||||
kwargs=kwargs,
|
||||
args=args,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
)
|
||||
raise e
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def _is_exception_related_to_db(e: Exception) -> bool:
|
||||
"""
|
||||
Returns True if the exception is related to the DB
|
||||
"""
|
||||
|
||||
import httpx
|
||||
from prisma.errors import PrismaError
|
||||
|
||||
return isinstance(e, (PrismaError, httpx.ConnectError, httpx.TimeoutException))
|
||||
|
||||
|
||||
async def _handle_logging_db_exception(
|
||||
e: Exception,
|
||||
func: Callable,
|
||||
kwargs: Dict,
|
||||
args: Tuple,
|
||||
start_time: datetime,
|
||||
end_time: datetime,
|
||||
) -> None:
|
||||
from litellm.proxy.proxy_server import proxy_logging_obj
|
||||
|
||||
# don't log this as a DB Service Failure, if the DB did not raise an exception
|
||||
if _is_exception_related_to_db(e) is not True:
|
||||
return
|
||||
|
||||
await proxy_logging_obj.service_logging_obj.async_service_failure_hook(
|
||||
error=e,
|
||||
service=ServiceTypes.DB,
|
||||
call_type=func.__name__,
|
||||
parent_otel_span=kwargs.get("parent_otel_span"),
|
||||
duration=(end_time - start_time).total_seconds(),
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
event_metadata={
|
||||
"function_name": func.__name__,
|
||||
"function_kwargs": kwargs,
|
||||
"function_args": args,
|
||||
},
|
||||
)
|
||||
@@ -0,0 +1,209 @@
|
||||
"""
|
||||
This file contains the PrismaWrapper class, which is used to wrap the Prisma client and handle the RDS IAM token.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import random
|
||||
import subprocess
|
||||
import time
|
||||
import urllib
|
||||
import urllib.parse
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.secret_managers.main import str_to_bool
|
||||
|
||||
|
||||
class PrismaWrapper:
|
||||
def __init__(self, original_prisma: Any, iam_token_db_auth: bool):
|
||||
self._original_prisma = original_prisma
|
||||
self.iam_token_db_auth = iam_token_db_auth
|
||||
|
||||
def is_token_expired(self, token_url: Optional[str]) -> bool:
|
||||
if token_url is None:
|
||||
return True
|
||||
# Decode the token URL to handle URL-encoded characters
|
||||
decoded_url = urllib.parse.unquote(token_url)
|
||||
|
||||
# Parse the token URL
|
||||
parsed_url = urllib.parse.urlparse(decoded_url)
|
||||
|
||||
# Parse the query parameters from the path component (if they exist there)
|
||||
query_params = urllib.parse.parse_qs(parsed_url.query)
|
||||
|
||||
# Get expiration time from the query parameters
|
||||
expires = query_params.get("X-Amz-Expires", [None])[0]
|
||||
if expires is None:
|
||||
raise ValueError("X-Amz-Expires parameter is missing or invalid.")
|
||||
|
||||
expires_int = int(expires)
|
||||
|
||||
# Get the token's creation time from the X-Amz-Date parameter
|
||||
token_time_str = query_params.get("X-Amz-Date", [""])[0]
|
||||
if not token_time_str:
|
||||
raise ValueError("X-Amz-Date parameter is missing or invalid.")
|
||||
|
||||
# Ensure the token time string is parsed correctly
|
||||
try:
|
||||
token_time = datetime.strptime(token_time_str, "%Y%m%dT%H%M%SZ")
|
||||
except ValueError as e:
|
||||
raise ValueError(f"Invalid X-Amz-Date format: {e}")
|
||||
|
||||
# Calculate the expiration time
|
||||
expiration_time = token_time + timedelta(seconds=expires_int)
|
||||
|
||||
# Current time in UTC
|
||||
current_time = datetime.utcnow()
|
||||
|
||||
# Check if the token is expired
|
||||
return current_time > expiration_time
|
||||
|
||||
def get_rds_iam_token(self) -> Optional[str]:
|
||||
if self.iam_token_db_auth:
|
||||
from litellm.proxy.auth.rds_iam_token import generate_iam_auth_token
|
||||
|
||||
db_host = os.getenv("DATABASE_HOST")
|
||||
db_port = os.getenv("DATABASE_PORT")
|
||||
db_user = os.getenv("DATABASE_USER")
|
||||
db_name = os.getenv("DATABASE_NAME")
|
||||
db_schema = os.getenv("DATABASE_SCHEMA")
|
||||
|
||||
token = generate_iam_auth_token(
|
||||
db_host=db_host, db_port=db_port, db_user=db_user
|
||||
)
|
||||
|
||||
# print(f"token: {token}")
|
||||
_db_url = f"postgresql://{db_user}:{token}@{db_host}:{db_port}/{db_name}"
|
||||
if db_schema:
|
||||
_db_url += f"?schema={db_schema}"
|
||||
|
||||
os.environ["DATABASE_URL"] = _db_url
|
||||
return _db_url
|
||||
return None
|
||||
|
||||
async def recreate_prisma_client(
|
||||
self, new_db_url: str, http_client: Optional[Any] = None
|
||||
):
|
||||
from prisma import Prisma # type: ignore
|
||||
|
||||
if http_client is not None:
|
||||
self._original_prisma = Prisma(http=http_client)
|
||||
else:
|
||||
self._original_prisma = Prisma()
|
||||
|
||||
await self._original_prisma.connect()
|
||||
|
||||
def __getattr__(self, name: str):
|
||||
original_attr = getattr(self._original_prisma, name)
|
||||
if self.iam_token_db_auth:
|
||||
db_url = os.getenv("DATABASE_URL")
|
||||
if self.is_token_expired(db_url):
|
||||
db_url = self.get_rds_iam_token()
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
if db_url:
|
||||
if loop.is_running():
|
||||
asyncio.run_coroutine_threadsafe(
|
||||
self.recreate_prisma_client(db_url), loop
|
||||
)
|
||||
else:
|
||||
asyncio.run(self.recreate_prisma_client(db_url))
|
||||
else:
|
||||
raise ValueError("Failed to get RDS IAM token")
|
||||
|
||||
return original_attr
|
||||
|
||||
|
||||
class PrismaManager:
|
||||
@staticmethod
|
||||
def _get_prisma_dir() -> str:
|
||||
"""Get the path to the migrations directory"""
|
||||
abspath = os.path.abspath(__file__)
|
||||
dname = os.path.dirname(os.path.dirname(abspath))
|
||||
return dname
|
||||
|
||||
@staticmethod
|
||||
def setup_database(use_migrate: bool = False) -> bool:
|
||||
"""
|
||||
Set up the database using either prisma migrate or prisma db push
|
||||
|
||||
Returns:
|
||||
bool: True if setup was successful, False otherwise
|
||||
"""
|
||||
|
||||
use_migrate = str_to_bool(os.getenv("USE_PRISMA_MIGRATE")) or use_migrate
|
||||
for attempt in range(4):
|
||||
original_dir = os.getcwd()
|
||||
prisma_dir = PrismaManager._get_prisma_dir()
|
||||
schema_path = prisma_dir + "/schema.prisma"
|
||||
os.chdir(prisma_dir)
|
||||
try:
|
||||
if use_migrate:
|
||||
try:
|
||||
from litellm_proxy_extras.utils import ProxyExtrasDBManager
|
||||
except ImportError as e:
|
||||
verbose_proxy_logger.error(
|
||||
f"\033[1;31mLiteLLM: Failed to import proxy extras. Got {e}\033[0m"
|
||||
)
|
||||
return False
|
||||
|
||||
prisma_dir = PrismaManager._get_prisma_dir()
|
||||
schema_path = prisma_dir + "/schema.prisma"
|
||||
|
||||
return ProxyExtrasDBManager.setup_database(
|
||||
schema_path=schema_path, use_migrate=use_migrate
|
||||
)
|
||||
else:
|
||||
# Use prisma db push with increased timeout
|
||||
subprocess.run(
|
||||
["prisma", "db", "push", "--accept-data-loss"],
|
||||
timeout=60,
|
||||
check=True,
|
||||
)
|
||||
return True
|
||||
except subprocess.TimeoutExpired:
|
||||
verbose_proxy_logger.warning(f"Attempt {attempt + 1} timed out")
|
||||
time.sleep(random.randrange(5, 15))
|
||||
except subprocess.CalledProcessError as e:
|
||||
attempts_left = 3 - attempt
|
||||
retry_msg = (
|
||||
f" Retrying... ({attempts_left} attempts left)"
|
||||
if attempts_left > 0
|
||||
else ""
|
||||
)
|
||||
verbose_proxy_logger.warning(
|
||||
f"The process failed to execute. Details: {e}.{retry_msg}"
|
||||
)
|
||||
time.sleep(random.randrange(5, 15))
|
||||
finally:
|
||||
os.chdir(original_dir)
|
||||
return False
|
||||
|
||||
|
||||
def should_update_prisma_schema(
|
||||
disable_updates: Optional[Union[bool, str]] = None
|
||||
) -> bool:
|
||||
"""
|
||||
Determines if Prisma Schema updates should be applied during startup.
|
||||
|
||||
Args:
|
||||
disable_updates: Controls whether schema updates are disabled.
|
||||
Accepts boolean or string ('true'/'false'). Defaults to checking DISABLE_SCHEMA_UPDATE env var.
|
||||
|
||||
Returns:
|
||||
bool: True if schema updates should be applied, False if updates are disabled.
|
||||
|
||||
Examples:
|
||||
>>> should_update_prisma_schema() # Checks DISABLE_SCHEMA_UPDATE env var
|
||||
>>> should_update_prisma_schema(True) # Explicitly disable updates
|
||||
>>> should_update_prisma_schema("false") # Enable updates using string
|
||||
"""
|
||||
if disable_updates is None:
|
||||
disable_updates = os.getenv("DISABLE_SCHEMA_UPDATE", "false")
|
||||
|
||||
if isinstance(disable_updates, str):
|
||||
disable_updates = str_to_bool(disable_updates)
|
||||
|
||||
return not bool(disable_updates)
|
||||
Reference in New Issue
Block a user