structure saas with tools
This commit is contained in:
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,286 @@
|
||||
"""
|
||||
BUDGET MANAGEMENT
|
||||
|
||||
All /budget management endpoints
|
||||
|
||||
/budget/new
|
||||
/budget/info
|
||||
/budget/update
|
||||
/budget/delete
|
||||
/budget/settings
|
||||
/budget/list
|
||||
"""
|
||||
|
||||
#### BUDGET TABLE MANAGEMENT ####
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
|
||||
from litellm.proxy._types import *
|
||||
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
|
||||
from litellm.proxy.utils import jsonify_object
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.post(
|
||||
"/budget/new",
|
||||
tags=["budget management"],
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
)
|
||||
async def new_budget(
|
||||
budget_obj: BudgetNewRequest,
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
):
|
||||
"""
|
||||
Create a new budget object. Can apply this to teams, orgs, end-users, keys.
|
||||
|
||||
Parameters:
|
||||
- budget_duration: Optional[str] - Budget reset period ("30d", "1h", etc.)
|
||||
- budget_id: Optional[str] - The id of the budget. If not provided, a new id will be generated.
|
||||
- max_budget: Optional[float] - The max budget for the budget.
|
||||
- soft_budget: Optional[float] - The soft budget for the budget.
|
||||
- max_parallel_requests: Optional[int] - The max number of parallel requests for the budget.
|
||||
- tpm_limit: Optional[int] - The tokens per minute limit for the budget.
|
||||
- rpm_limit: Optional[int] - The requests per minute limit for the budget.
|
||||
- model_max_budget: Optional[dict] - Specify max budget for a given model. Example: {"openai/gpt-4o-mini": {"max_budget": 100.0, "budget_duration": "1d", "tpm_limit": 100000, "rpm_limit": 100000}}
|
||||
"""
|
||||
from litellm.proxy.proxy_server import litellm_proxy_admin_name, prisma_client
|
||||
|
||||
if prisma_client is None:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail={"error": CommonProxyErrors.db_not_connected_error.value},
|
||||
)
|
||||
|
||||
budget_obj_json = budget_obj.model_dump(exclude_none=True)
|
||||
budget_obj_jsonified = jsonify_object(budget_obj_json) # json dump any dictionaries
|
||||
response = await prisma_client.db.litellm_budgettable.create(
|
||||
data={
|
||||
**budget_obj_jsonified, # type: ignore
|
||||
"created_by": user_api_key_dict.user_id or litellm_proxy_admin_name,
|
||||
"updated_by": user_api_key_dict.user_id or litellm_proxy_admin_name,
|
||||
} # type: ignore
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
|
||||
@router.post(
|
||||
"/budget/update",
|
||||
tags=["budget management"],
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
)
|
||||
async def update_budget(
|
||||
budget_obj: BudgetNewRequest,
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
):
|
||||
"""
|
||||
Update an existing budget object.
|
||||
|
||||
Parameters:
|
||||
- budget_duration: Optional[str] - Budget reset period ("30d", "1h", etc.)
|
||||
- budget_id: Optional[str] - The id of the budget. If not provided, a new id will be generated.
|
||||
- max_budget: Optional[float] - The max budget for the budget.
|
||||
- soft_budget: Optional[float] - The soft budget for the budget.
|
||||
- max_parallel_requests: Optional[int] - The max number of parallel requests for the budget.
|
||||
- tpm_limit: Optional[int] - The tokens per minute limit for the budget.
|
||||
- rpm_limit: Optional[int] - The requests per minute limit for the budget.
|
||||
- model_max_budget: Optional[dict] - Specify max budget for a given model. Example: {"openai/gpt-4o-mini": {"max_budget": 100.0, "budget_duration": "1d", "tpm_limit": 100000, "rpm_limit": 100000}}
|
||||
"""
|
||||
from litellm.proxy.proxy_server import litellm_proxy_admin_name, prisma_client
|
||||
|
||||
if prisma_client is None:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail={"error": CommonProxyErrors.db_not_connected_error.value},
|
||||
)
|
||||
if budget_obj.budget_id is None:
|
||||
raise HTTPException(status_code=400, detail={"error": "budget_id is required"})
|
||||
|
||||
response = await prisma_client.db.litellm_budgettable.update(
|
||||
where={"budget_id": budget_obj.budget_id},
|
||||
data={
|
||||
**budget_obj.model_dump(exclude_none=True), # type: ignore
|
||||
"updated_by": user_api_key_dict.user_id or litellm_proxy_admin_name,
|
||||
}, # type: ignore
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
|
||||
@router.post(
|
||||
"/budget/info",
|
||||
tags=["budget management"],
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
)
|
||||
async def info_budget(data: BudgetRequest):
|
||||
"""
|
||||
Get the budget id specific information
|
||||
|
||||
Parameters:
|
||||
- budgets: List[str] - The list of budget ids to get information for
|
||||
"""
|
||||
from litellm.proxy.proxy_server import prisma_client
|
||||
|
||||
if prisma_client is None:
|
||||
raise HTTPException(status_code=500, detail={"error": "No db connected"})
|
||||
|
||||
if len(data.budgets) == 0:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={
|
||||
"error": f"Specify list of budget id's to query. Passed in={data.budgets}"
|
||||
},
|
||||
)
|
||||
response = await prisma_client.db.litellm_budgettable.find_many(
|
||||
where={"budget_id": {"in": data.budgets}},
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
|
||||
@router.get(
|
||||
"/budget/settings",
|
||||
tags=["budget management"],
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
)
|
||||
async def budget_settings(
|
||||
budget_id: str,
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
):
|
||||
"""
|
||||
Get list of configurable params + current value for a budget item + description of each field
|
||||
|
||||
Used on Admin UI.
|
||||
|
||||
Query Parameters:
|
||||
- budget_id: str - The budget id to get information for
|
||||
"""
|
||||
from litellm.proxy.proxy_server import prisma_client
|
||||
|
||||
if prisma_client is None:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={"error": CommonProxyErrors.db_not_connected_error.value},
|
||||
)
|
||||
|
||||
if user_api_key_dict.user_role != LitellmUserRoles.PROXY_ADMIN:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={
|
||||
"error": "{}, your role={}".format(
|
||||
CommonProxyErrors.not_allowed_access.value,
|
||||
user_api_key_dict.user_role,
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
## get budget item from db
|
||||
db_budget_row = await prisma_client.db.litellm_budgettable.find_first(
|
||||
where={"budget_id": budget_id}
|
||||
)
|
||||
|
||||
if db_budget_row is not None:
|
||||
db_budget_row_dict = db_budget_row.model_dump(exclude_none=True)
|
||||
else:
|
||||
db_budget_row_dict = {}
|
||||
|
||||
allowed_args = {
|
||||
"max_parallel_requests": {"type": "Integer"},
|
||||
"tpm_limit": {"type": "Integer"},
|
||||
"rpm_limit": {"type": "Integer"},
|
||||
"budget_duration": {"type": "String"},
|
||||
"max_budget": {"type": "Float"},
|
||||
"soft_budget": {"type": "Float"},
|
||||
}
|
||||
|
||||
return_val = []
|
||||
|
||||
for field_name, field_info in BudgetNewRequest.model_fields.items():
|
||||
if field_name in allowed_args:
|
||||
_stored_in_db = True
|
||||
|
||||
_response_obj = ConfigList(
|
||||
field_name=field_name,
|
||||
field_type=allowed_args[field_name]["type"],
|
||||
field_description=field_info.description or "",
|
||||
field_value=db_budget_row_dict.get(field_name, None),
|
||||
stored_in_db=_stored_in_db,
|
||||
field_default_value=field_info.default,
|
||||
)
|
||||
return_val.append(_response_obj)
|
||||
|
||||
return return_val
|
||||
|
||||
|
||||
@router.get(
|
||||
"/budget/list",
|
||||
tags=["budget management"],
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
)
|
||||
async def list_budget(
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
):
|
||||
"""List all the created budgets in proxy db. Used on Admin UI."""
|
||||
from litellm.proxy.proxy_server import prisma_client
|
||||
|
||||
if prisma_client is None:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={"error": CommonProxyErrors.db_not_connected_error.value},
|
||||
)
|
||||
|
||||
if user_api_key_dict.user_role != LitellmUserRoles.PROXY_ADMIN:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={
|
||||
"error": "{}, your role={}".format(
|
||||
CommonProxyErrors.not_allowed_access.value,
|
||||
user_api_key_dict.user_role,
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
response = await prisma_client.db.litellm_budgettable.find_many()
|
||||
|
||||
return response
|
||||
|
||||
|
||||
@router.post(
|
||||
"/budget/delete",
|
||||
tags=["budget management"],
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
)
|
||||
async def delete_budget(
|
||||
data: BudgetDeleteRequest,
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
):
|
||||
"""
|
||||
Delete budget
|
||||
|
||||
Parameters:
|
||||
- id: str - The budget id to delete
|
||||
"""
|
||||
from litellm.proxy.proxy_server import prisma_client
|
||||
|
||||
if prisma_client is None:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail={"error": CommonProxyErrors.db_not_connected_error.value},
|
||||
)
|
||||
|
||||
if user_api_key_dict.user_role != LitellmUserRoles.PROXY_ADMIN:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={
|
||||
"error": "{}, your role={}".format(
|
||||
CommonProxyErrors.not_allowed_access.value,
|
||||
user_api_key_dict.user_role,
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
response = await prisma_client.db.litellm_budgettable.delete(
|
||||
where={"budget_id": data.id}
|
||||
)
|
||||
|
||||
return response
|
||||
@@ -0,0 +1,274 @@
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Optional, Set, Union
|
||||
|
||||
from fastapi import HTTPException, status
|
||||
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.proxy._types import CommonProxyErrors
|
||||
from litellm.proxy.utils import PrismaClient
|
||||
from litellm.types.proxy.management_endpoints.common_daily_activity import (
|
||||
BreakdownMetrics,
|
||||
DailySpendData,
|
||||
DailySpendMetadata,
|
||||
KeyMetadata,
|
||||
KeyMetricWithMetadata,
|
||||
MetricWithMetadata,
|
||||
SpendAnalyticsPaginatedResponse,
|
||||
SpendMetrics,
|
||||
)
|
||||
|
||||
|
||||
def update_metrics(existing_metrics: SpendMetrics, record: Any) -> SpendMetrics:
|
||||
"""Update metrics with new record data."""
|
||||
existing_metrics.spend += record.spend
|
||||
existing_metrics.prompt_tokens += record.prompt_tokens
|
||||
existing_metrics.completion_tokens += record.completion_tokens
|
||||
existing_metrics.total_tokens += record.prompt_tokens + record.completion_tokens
|
||||
existing_metrics.cache_read_input_tokens += record.cache_read_input_tokens
|
||||
existing_metrics.cache_creation_input_tokens += record.cache_creation_input_tokens
|
||||
existing_metrics.api_requests += record.api_requests
|
||||
existing_metrics.successful_requests += record.successful_requests
|
||||
existing_metrics.failed_requests += record.failed_requests
|
||||
return existing_metrics
|
||||
|
||||
|
||||
def update_breakdown_metrics(
|
||||
breakdown: BreakdownMetrics,
|
||||
record: Any,
|
||||
model_metadata: Dict[str, Dict[str, Any]],
|
||||
provider_metadata: Dict[str, Dict[str, Any]],
|
||||
api_key_metadata: Dict[str, Dict[str, Any]],
|
||||
entity_id_field: Optional[str] = None,
|
||||
entity_metadata_field: Optional[Dict[str, dict]] = None,
|
||||
) -> BreakdownMetrics:
|
||||
"""Updates breakdown metrics for a single record using the existing update_metrics function"""
|
||||
|
||||
# Update model breakdown
|
||||
if record.model not in breakdown.models:
|
||||
breakdown.models[record.model] = MetricWithMetadata(
|
||||
metrics=SpendMetrics(),
|
||||
metadata=model_metadata.get(
|
||||
record.model, {}
|
||||
), # Add any model-specific metadata here
|
||||
)
|
||||
breakdown.models[record.model].metrics = update_metrics(
|
||||
breakdown.models[record.model].metrics, record
|
||||
)
|
||||
|
||||
# Update provider breakdown
|
||||
provider = record.custom_llm_provider or "unknown"
|
||||
if provider not in breakdown.providers:
|
||||
breakdown.providers[provider] = MetricWithMetadata(
|
||||
metrics=SpendMetrics(),
|
||||
metadata=provider_metadata.get(
|
||||
provider, {}
|
||||
), # Add any provider-specific metadata here
|
||||
)
|
||||
breakdown.providers[provider].metrics = update_metrics(
|
||||
breakdown.providers[provider].metrics, record
|
||||
)
|
||||
|
||||
# Update api key breakdown
|
||||
if record.api_key not in breakdown.api_keys:
|
||||
breakdown.api_keys[record.api_key] = KeyMetricWithMetadata(
|
||||
metrics=SpendMetrics(),
|
||||
metadata=KeyMetadata(
|
||||
key_alias=api_key_metadata.get(record.api_key, {}).get(
|
||||
"key_alias", None
|
||||
),
|
||||
team_id=api_key_metadata.get(record.api_key, {}).get("team_id", None),
|
||||
), # Add any api_key-specific metadata here
|
||||
)
|
||||
breakdown.api_keys[record.api_key].metrics = update_metrics(
|
||||
breakdown.api_keys[record.api_key].metrics, record
|
||||
)
|
||||
|
||||
# Update entity-specific metrics if entity_id_field is provided
|
||||
if entity_id_field:
|
||||
entity_value = getattr(record, entity_id_field, None)
|
||||
if entity_value:
|
||||
if entity_value not in breakdown.entities:
|
||||
breakdown.entities[entity_value] = MetricWithMetadata(
|
||||
metrics=SpendMetrics(),
|
||||
metadata=entity_metadata_field.get(entity_value, {})
|
||||
if entity_metadata_field
|
||||
else {},
|
||||
)
|
||||
breakdown.entities[entity_value].metrics = update_metrics(
|
||||
breakdown.entities[entity_value].metrics, record
|
||||
)
|
||||
|
||||
return breakdown
|
||||
|
||||
|
||||
async def get_api_key_metadata(
|
||||
prisma_client: PrismaClient,
|
||||
api_keys: Set[str],
|
||||
) -> Dict[str, Dict[str, Any]]:
|
||||
"""Update api key metadata for a single record."""
|
||||
key_records = await prisma_client.db.litellm_verificationtoken.find_many(
|
||||
where={"token": {"in": list(api_keys)}}
|
||||
)
|
||||
return {
|
||||
k.token: {"key_alias": k.key_alias, "team_id": k.team_id} for k in key_records
|
||||
}
|
||||
|
||||
|
||||
async def get_daily_activity(
|
||||
prisma_client: Optional[PrismaClient],
|
||||
table_name: str,
|
||||
entity_id_field: str,
|
||||
entity_id: Optional[Union[str, List[str]]],
|
||||
entity_metadata_field: Optional[Dict[str, dict]],
|
||||
start_date: Optional[str],
|
||||
end_date: Optional[str],
|
||||
model: Optional[str],
|
||||
api_key: Optional[str],
|
||||
page: int,
|
||||
page_size: int,
|
||||
exclude_entity_ids: Optional[List[str]] = None,
|
||||
) -> SpendAnalyticsPaginatedResponse:
|
||||
"""Common function to get daily activity for any entity type."""
|
||||
if prisma_client is None:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail={"error": CommonProxyErrors.db_not_connected_error.value},
|
||||
)
|
||||
|
||||
if start_date is None or end_date is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail={"error": "Please provide start_date and end_date"},
|
||||
)
|
||||
|
||||
try:
|
||||
# Build filter conditions
|
||||
where_conditions: Dict[str, Any] = {
|
||||
"date": {
|
||||
"gte": start_date,
|
||||
"lte": end_date,
|
||||
}
|
||||
}
|
||||
|
||||
if model:
|
||||
where_conditions["model"] = model
|
||||
if api_key:
|
||||
where_conditions["api_key"] = api_key
|
||||
if entity_id is not None:
|
||||
if isinstance(entity_id, list):
|
||||
where_conditions[entity_id_field] = {"in": entity_id}
|
||||
else:
|
||||
where_conditions[entity_id_field] = entity_id
|
||||
if exclude_entity_ids:
|
||||
where_conditions.setdefault(entity_id_field, {})["not"] = {
|
||||
"in": exclude_entity_ids
|
||||
}
|
||||
|
||||
# Get total count for pagination
|
||||
total_count = await getattr(prisma_client.db, table_name).count(
|
||||
where=where_conditions
|
||||
)
|
||||
|
||||
# Fetch paginated results
|
||||
daily_spend_data = await getattr(prisma_client.db, table_name).find_many(
|
||||
where=where_conditions,
|
||||
order=[
|
||||
{"date": "desc"},
|
||||
],
|
||||
skip=(page - 1) * page_size,
|
||||
take=page_size,
|
||||
)
|
||||
|
||||
# Get all unique API keys from the spend data
|
||||
api_keys = set()
|
||||
for record in daily_spend_data:
|
||||
if record.api_key:
|
||||
api_keys.add(record.api_key)
|
||||
|
||||
# Fetch key aliases in bulk
|
||||
api_key_metadata: Dict[str, Dict[str, Any]] = {}
|
||||
model_metadata: Dict[str, Dict[str, Any]] = {}
|
||||
provider_metadata: Dict[str, Dict[str, Any]] = {}
|
||||
if api_keys:
|
||||
api_key_metadata = await get_api_key_metadata(prisma_client, api_keys)
|
||||
|
||||
# Process results
|
||||
results = []
|
||||
total_metrics = SpendMetrics()
|
||||
grouped_data: Dict[str, Dict[str, Any]] = {}
|
||||
|
||||
for record in daily_spend_data:
|
||||
date_str = record.date
|
||||
if date_str not in grouped_data:
|
||||
grouped_data[date_str] = {
|
||||
"metrics": SpendMetrics(),
|
||||
"breakdown": BreakdownMetrics(),
|
||||
}
|
||||
|
||||
# Update metrics
|
||||
grouped_data[date_str]["metrics"] = update_metrics(
|
||||
grouped_data[date_str]["metrics"], record
|
||||
)
|
||||
# Update breakdowns
|
||||
grouped_data[date_str]["breakdown"] = update_breakdown_metrics(
|
||||
grouped_data[date_str]["breakdown"],
|
||||
record,
|
||||
model_metadata,
|
||||
provider_metadata,
|
||||
api_key_metadata,
|
||||
entity_id_field=entity_id_field,
|
||||
entity_metadata_field=entity_metadata_field,
|
||||
)
|
||||
|
||||
# Update total metrics
|
||||
total_metrics.spend += record.spend
|
||||
total_metrics.prompt_tokens += record.prompt_tokens
|
||||
total_metrics.completion_tokens += record.completion_tokens
|
||||
total_metrics.total_tokens += (
|
||||
record.prompt_tokens + record.completion_tokens
|
||||
)
|
||||
total_metrics.cache_read_input_tokens += record.cache_read_input_tokens
|
||||
total_metrics.cache_creation_input_tokens += (
|
||||
record.cache_creation_input_tokens
|
||||
)
|
||||
total_metrics.api_requests += record.api_requests
|
||||
total_metrics.successful_requests += record.successful_requests
|
||||
total_metrics.failed_requests += record.failed_requests
|
||||
|
||||
# Convert grouped data to response format
|
||||
for date_str, data in grouped_data.items():
|
||||
results.append(
|
||||
DailySpendData(
|
||||
date=datetime.strptime(date_str, "%Y-%m-%d").date(),
|
||||
metrics=data["metrics"],
|
||||
breakdown=data["breakdown"],
|
||||
)
|
||||
)
|
||||
|
||||
# Sort results by date
|
||||
results.sort(key=lambda x: x.date, reverse=True)
|
||||
|
||||
return SpendAnalyticsPaginatedResponse(
|
||||
results=results,
|
||||
metadata=DailySpendMetadata(
|
||||
total_spend=total_metrics.spend,
|
||||
total_prompt_tokens=total_metrics.prompt_tokens,
|
||||
total_completion_tokens=total_metrics.completion_tokens,
|
||||
total_tokens=total_metrics.total_tokens,
|
||||
total_api_requests=total_metrics.api_requests,
|
||||
total_successful_requests=total_metrics.successful_requests,
|
||||
total_failed_requests=total_metrics.failed_requests,
|
||||
total_cache_read_input_tokens=total_metrics.cache_read_input_tokens,
|
||||
total_cache_creation_input_tokens=total_metrics.cache_creation_input_tokens,
|
||||
page=page,
|
||||
total_pages=-(-total_count // page_size), # Ceiling division
|
||||
has_more=(page * page_size) < total_count,
|
||||
),
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.exception(f"Error fetching daily activity: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail={"error": f"Failed to fetch analytics: {str(e)}"},
|
||||
)
|
||||
@@ -0,0 +1,48 @@
|
||||
from typing import Any, Union
|
||||
|
||||
from litellm.proxy._types import (
|
||||
GenerateKeyRequest,
|
||||
LiteLLM_ManagementEndpoint_MetadataFields_Premium,
|
||||
LiteLLM_TeamTable,
|
||||
LitellmUserRoles,
|
||||
UserAPIKeyAuth,
|
||||
)
|
||||
from litellm.proxy.utils import _premium_user_check
|
||||
|
||||
|
||||
def _user_has_admin_view(user_api_key_dict: UserAPIKeyAuth) -> bool:
|
||||
return (
|
||||
user_api_key_dict.user_role == LitellmUserRoles.PROXY_ADMIN
|
||||
or user_api_key_dict.user_role == LitellmUserRoles.PROXY_ADMIN_VIEW_ONLY
|
||||
)
|
||||
|
||||
|
||||
def _is_user_team_admin(
|
||||
user_api_key_dict: UserAPIKeyAuth, team_obj: LiteLLM_TeamTable
|
||||
) -> bool:
|
||||
for member in team_obj.members_with_roles:
|
||||
if (
|
||||
member.user_id is not None and member.user_id == user_api_key_dict.user_id
|
||||
) and member.role == "admin":
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def _set_object_metadata_field(
|
||||
object_data: Union[LiteLLM_TeamTable, GenerateKeyRequest],
|
||||
field_name: str,
|
||||
value: Any,
|
||||
) -> None:
|
||||
"""
|
||||
Helper function to set metadata fields that require premium user checks
|
||||
|
||||
Args:
|
||||
object_data: The team data object to modify
|
||||
field_name: Name of the metadata field to set
|
||||
value: Value to set for the field
|
||||
"""
|
||||
if field_name in LiteLLM_ManagementEndpoint_MetadataFields_Premium:
|
||||
_premium_user_check()
|
||||
object_data.metadata = object_data.metadata or {}
|
||||
object_data.metadata[field_name] = value
|
||||
@@ -0,0 +1,619 @@
|
||||
"""
|
||||
CUSTOMER MANAGEMENT
|
||||
|
||||
All /customer management endpoints
|
||||
|
||||
/customer/new
|
||||
/customer/info
|
||||
/customer/update
|
||||
/customer/delete
|
||||
"""
|
||||
|
||||
#### END-USER/CUSTOMER MANAGEMENT ####
|
||||
import traceback
|
||||
from typing import List, Optional
|
||||
|
||||
import fastapi
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request, status
|
||||
|
||||
import litellm
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.proxy._types import *
|
||||
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.post(
|
||||
"/end_user/block",
|
||||
tags=["Customer Management"],
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
include_in_schema=False,
|
||||
)
|
||||
@router.post(
|
||||
"/customer/block",
|
||||
tags=["Customer Management"],
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
)
|
||||
async def block_user(data: BlockUsers):
|
||||
"""
|
||||
[BETA] Reject calls with this end-user id
|
||||
|
||||
Parameters:
|
||||
- user_ids (List[str], required): The unique `user_id`s for the users to block
|
||||
|
||||
(any /chat/completion call with this user={end-user-id} param, will be rejected.)
|
||||
|
||||
```
|
||||
curl -X POST "http://0.0.0.0:8000/user/block"
|
||||
-H "Authorization: Bearer sk-1234"
|
||||
-d '{
|
||||
"user_ids": [<user_id>, ...]
|
||||
}'
|
||||
```
|
||||
"""
|
||||
from litellm.proxy.proxy_server import prisma_client
|
||||
|
||||
try:
|
||||
records = []
|
||||
if prisma_client is not None:
|
||||
for id in data.user_ids:
|
||||
record = await prisma_client.db.litellm_endusertable.upsert(
|
||||
where={"user_id": id}, # type: ignore
|
||||
data={
|
||||
"create": {"user_id": id, "blocked": True}, # type: ignore
|
||||
"update": {"blocked": True},
|
||||
},
|
||||
)
|
||||
records.append(record)
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail={"error": "Postgres DB Not connected"},
|
||||
)
|
||||
|
||||
return {"blocked_users": records}
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.error(f"An error occurred - {str(e)}")
|
||||
raise HTTPException(status_code=500, detail={"error": str(e)})
|
||||
|
||||
|
||||
@router.post(
|
||||
"/end_user/unblock",
|
||||
tags=["Customer Management"],
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
include_in_schema=False,
|
||||
)
|
||||
@router.post(
|
||||
"/customer/unblock",
|
||||
tags=["Customer Management"],
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
)
|
||||
async def unblock_user(data: BlockUsers):
|
||||
"""
|
||||
[BETA] Unblock calls with this user id
|
||||
|
||||
Example
|
||||
```
|
||||
curl -X POST "http://0.0.0.0:8000/user/unblock"
|
||||
-H "Authorization: Bearer sk-1234"
|
||||
-d '{
|
||||
"user_ids": [<user_id>, ...]
|
||||
}'
|
||||
```
|
||||
"""
|
||||
from enterprise.enterprise_hooks.blocked_user_list import (
|
||||
_ENTERPRISE_BlockedUserList,
|
||||
)
|
||||
|
||||
if (
|
||||
not any(isinstance(x, _ENTERPRISE_BlockedUserList) for x in litellm.callbacks)
|
||||
or litellm.blocked_user_list is None
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={
|
||||
"error": "Blocked user check was never set. This call has no effect."
|
||||
},
|
||||
)
|
||||
|
||||
if isinstance(litellm.blocked_user_list, list):
|
||||
for id in data.user_ids:
|
||||
litellm.blocked_user_list.remove(id)
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail={
|
||||
"error": "`blocked_user_list` must be set as a list. Filepaths can't be updated."
|
||||
},
|
||||
)
|
||||
|
||||
return {"blocked_users": litellm.blocked_user_list}
|
||||
|
||||
|
||||
def new_budget_request(data: NewCustomerRequest) -> Optional[BudgetNewRequest]:
|
||||
"""
|
||||
Return a new budget object if new budget params are passed.
|
||||
"""
|
||||
budget_params = BudgetNewRequest.model_fields.keys()
|
||||
budget_kv_pairs = {}
|
||||
|
||||
# Get the actual values from the data object using getattr
|
||||
for field_name in budget_params:
|
||||
if field_name == "budget_id":
|
||||
continue
|
||||
value = getattr(data, field_name, None)
|
||||
if value is not None:
|
||||
budget_kv_pairs[field_name] = value
|
||||
|
||||
if budget_kv_pairs:
|
||||
return BudgetNewRequest(**budget_kv_pairs)
|
||||
return None
|
||||
|
||||
|
||||
@router.post(
|
||||
"/end_user/new",
|
||||
tags=["Customer Management"],
|
||||
include_in_schema=False,
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
)
|
||||
@router.post(
|
||||
"/customer/new",
|
||||
tags=["Customer Management"],
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
)
|
||||
async def new_end_user(
|
||||
data: NewCustomerRequest,
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
):
|
||||
"""
|
||||
Allow creating a new Customer
|
||||
|
||||
|
||||
Parameters:
|
||||
- user_id: str - The unique identifier for the user.
|
||||
- alias: Optional[str] - A human-friendly alias for the user.
|
||||
- blocked: bool - Flag to allow or disallow requests for this end-user. Default is False.
|
||||
- max_budget: Optional[float] - The maximum budget allocated to the user. Either 'max_budget' or 'budget_id' should be provided, not both.
|
||||
- budget_id: Optional[str] - The identifier for an existing budget allocated to the user. Either 'max_budget' or 'budget_id' should be provided, not both.
|
||||
- allowed_model_region: Optional[Union[Literal["eu"], Literal["us"]]] - Require all user requests to use models in this specific region.
|
||||
- default_model: Optional[str] - If no equivalent model in the allowed region, default all requests to this model.
|
||||
- metadata: Optional[dict] = Metadata for customer, store information for customer. Example metadata = {"data_training_opt_out": True}
|
||||
- budget_duration: Optional[str] - Budget is reset at the end of specified duration. If not set, budget is never reset. You can set duration as seconds ("30s"), minutes ("30m"), hours ("30h"), days ("30d").
|
||||
- tpm_limit: Optional[int] - [Not Implemented Yet] Specify tpm limit for a given customer (Tokens per minute)
|
||||
- rpm_limit: Optional[int] - [Not Implemented Yet] Specify rpm limit for a given customer (Requests per minute)
|
||||
- model_max_budget: Optional[dict] - [Not Implemented Yet] Specify max budget for a given model. Example: {"openai/gpt-4o-mini": {"max_budget": 100.0, "budget_duration": "1d"}}
|
||||
- max_parallel_requests: Optional[int] - [Not Implemented Yet] Specify max parallel requests for a given customer.
|
||||
- soft_budget: Optional[float] - [Not Implemented Yet] Get alerts when customer crosses given budget, doesn't block requests.
|
||||
|
||||
|
||||
- Allow specifying allowed regions
|
||||
- Allow specifying default model
|
||||
|
||||
Example curl:
|
||||
```
|
||||
curl --location 'http://0.0.0.0:4000/customer/new' \
|
||||
--header 'Authorization: Bearer sk-1234' \
|
||||
--header 'Content-Type: application/json' \
|
||||
--data '{
|
||||
"user_id" : "ishaan-jaff-3",
|
||||
"allowed_region": "eu",
|
||||
"budget_id": "free_tier",
|
||||
"default_model": "azure/gpt-3.5-turbo-eu" <- all calls from this user, use this model?
|
||||
}'
|
||||
|
||||
# return end-user object
|
||||
```
|
||||
|
||||
NOTE: This used to be called `/end_user/new`, we will still be maintaining compatibility for /end_user/XXX for these endpoints
|
||||
"""
|
||||
"""
|
||||
Validation:
|
||||
- check if default model exists
|
||||
- create budget object if not already created
|
||||
|
||||
- Add user to end user table
|
||||
|
||||
Return
|
||||
- end-user object
|
||||
- currently allowed models
|
||||
"""
|
||||
from litellm.proxy.proxy_server import (
|
||||
litellm_proxy_admin_name,
|
||||
llm_router,
|
||||
prisma_client,
|
||||
)
|
||||
|
||||
if prisma_client is None:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail={"error": CommonProxyErrors.db_not_connected_error.value},
|
||||
)
|
||||
try:
|
||||
## VALIDATION ##
|
||||
if data.default_model is not None:
|
||||
if llm_router is None:
|
||||
raise HTTPException(
|
||||
status_code=422,
|
||||
detail={"error": CommonProxyErrors.no_llm_router.value},
|
||||
)
|
||||
elif data.default_model not in llm_router.get_model_names():
|
||||
raise HTTPException(
|
||||
status_code=422,
|
||||
detail={
|
||||
"error": "Default Model not on proxy. Configure via `/model/new` or config.yaml. Default_model={}, proxy_model_names={}".format(
|
||||
data.default_model, set(llm_router.get_model_names())
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
new_end_user_obj: Dict = {}
|
||||
|
||||
## CREATE BUDGET ## if set
|
||||
_new_budget = new_budget_request(data)
|
||||
if _new_budget is not None:
|
||||
try:
|
||||
budget_record = await prisma_client.db.litellm_budgettable.create(
|
||||
data={
|
||||
**_new_budget.model_dump(exclude_unset=True),
|
||||
"created_by": user_api_key_dict.user_id or litellm_proxy_admin_name, # type: ignore
|
||||
"updated_by": user_api_key_dict.user_id
|
||||
or litellm_proxy_admin_name,
|
||||
}
|
||||
)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=422, detail={"error": str(e)})
|
||||
|
||||
new_end_user_obj["budget_id"] = budget_record.budget_id
|
||||
elif data.budget_id is not None:
|
||||
new_end_user_obj["budget_id"] = data.budget_id
|
||||
|
||||
_user_data = data.dict(exclude_none=True)
|
||||
|
||||
for k, v in _user_data.items():
|
||||
if k not in BudgetNewRequest.model_fields.keys():
|
||||
new_end_user_obj[k] = v
|
||||
|
||||
## WRITE TO DB ##
|
||||
end_user_record = await prisma_client.db.litellm_endusertable.create(
|
||||
data=new_end_user_obj, # type: ignore
|
||||
include={"litellm_budget_table": True},
|
||||
)
|
||||
|
||||
return end_user_record
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.exception(
|
||||
"litellm.proxy.management_endpoints.customer_endpoints.new_end_user(): Exception occured - {}".format(
|
||||
str(e)
|
||||
)
|
||||
)
|
||||
if "Unique constraint failed on the fields: (`user_id`)" in str(e):
|
||||
raise ProxyException(
|
||||
message=f"Customer already exists, passed user_id={data.user_id}. Please pass a new user_id.",
|
||||
type="bad_request",
|
||||
code=400,
|
||||
param="user_id",
|
||||
)
|
||||
|
||||
if isinstance(e, HTTPException):
|
||||
raise ProxyException(
|
||||
message=getattr(e, "detail", f"Internal Server Error({str(e)})"),
|
||||
type="internal_error",
|
||||
param=getattr(e, "param", "None"),
|
||||
code=getattr(e, "status_code", status.HTTP_500_INTERNAL_SERVER_ERROR),
|
||||
)
|
||||
elif isinstance(e, ProxyException):
|
||||
raise e
|
||||
raise ProxyException(
|
||||
message="Internal Server Error, " + str(e),
|
||||
type="internal_error",
|
||||
param=getattr(e, "param", "None"),
|
||||
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/customer/info",
|
||||
tags=["Customer Management"],
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
response_model=LiteLLM_EndUserTable,
|
||||
)
|
||||
@router.get(
|
||||
"/end_user/info",
|
||||
tags=["Customer Management"],
|
||||
include_in_schema=False,
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
)
|
||||
async def end_user_info(
|
||||
end_user_id: str = fastapi.Query(
|
||||
description="End User ID in the request parameters"
|
||||
),
|
||||
):
|
||||
"""
|
||||
Get information about an end-user. An `end_user` is a customer (external user) of the proxy.
|
||||
|
||||
Parameters:
|
||||
- end_user_id (str, required): The unique identifier for the end-user
|
||||
|
||||
Example curl:
|
||||
```
|
||||
curl -X GET 'http://localhost:4000/customer/info?end_user_id=test-litellm-user-4' \
|
||||
-H 'Authorization: Bearer sk-1234'
|
||||
```
|
||||
"""
|
||||
from litellm.proxy.proxy_server import prisma_client
|
||||
|
||||
if prisma_client is None:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail={"error": CommonProxyErrors.db_not_connected_error.value},
|
||||
)
|
||||
|
||||
user_info = await prisma_client.db.litellm_endusertable.find_first(
|
||||
where={"user_id": end_user_id}, include={"litellm_budget_table": True}
|
||||
)
|
||||
|
||||
if user_info is None:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={"error": "End User Id={} does not exist in db".format(end_user_id)},
|
||||
)
|
||||
return user_info.model_dump(exclude_none=True)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/customer/update",
|
||||
tags=["Customer Management"],
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
)
|
||||
@router.post(
|
||||
"/end_user/update",
|
||||
tags=["Customer Management"],
|
||||
include_in_schema=False,
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
)
|
||||
async def update_end_user(
|
||||
data: UpdateCustomerRequest,
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
):
|
||||
"""
|
||||
Example curl
|
||||
|
||||
Parameters:
|
||||
- user_id: str
|
||||
- alias: Optional[str] = None # human-friendly alias
|
||||
- blocked: bool = False # allow/disallow requests for this end-user
|
||||
- max_budget: Optional[float] = None
|
||||
- budget_id: Optional[str] = None # give either a budget_id or max_budget
|
||||
- allowed_model_region: Optional[AllowedModelRegion] = (
|
||||
None # require all user requests to use models in this specific region
|
||||
)
|
||||
- default_model: Optional[str] = (
|
||||
None # if no equivalent model in allowed region - default all requests to this model
|
||||
)
|
||||
|
||||
Example curl:
|
||||
```
|
||||
curl --location 'http://0.0.0.0:4000/customer/update' \
|
||||
--header 'Authorization: Bearer sk-1234' \
|
||||
--header 'Content-Type: application/json' \
|
||||
--data '{
|
||||
"user_id": "test-litellm-user-4",
|
||||
"budget_id": "paid_tier"
|
||||
}'
|
||||
|
||||
See below for all params
|
||||
```
|
||||
"""
|
||||
|
||||
from litellm.proxy.proxy_server import prisma_client
|
||||
|
||||
try:
|
||||
data_json: dict = data.json()
|
||||
# get the row from db
|
||||
if prisma_client is None:
|
||||
raise Exception("Not connected to DB!")
|
||||
|
||||
# get non default values for key
|
||||
non_default_values = {}
|
||||
for k, v in data_json.items():
|
||||
if v is not None and v not in (
|
||||
[],
|
||||
{},
|
||||
0,
|
||||
): # models default to [], spend defaults to 0, we should not reset these values
|
||||
non_default_values[k] = v
|
||||
|
||||
## ADD USER, IF NEW ##
|
||||
verbose_proxy_logger.debug("/customer/update: Received data = %s", data)
|
||||
if data.user_id is not None and len(data.user_id) > 0:
|
||||
non_default_values["user_id"] = data.user_id # type: ignore
|
||||
verbose_proxy_logger.debug("In update customer, user_id condition block.")
|
||||
response = await prisma_client.db.litellm_endusertable.update(
|
||||
where={"user_id": data.user_id}, data=non_default_values # type: ignore
|
||||
)
|
||||
if response is None:
|
||||
raise ValueError(
|
||||
f"Failed updating customer data. User ID does not exist passed user_id={data.user_id}"
|
||||
)
|
||||
verbose_proxy_logger.debug(
|
||||
f"received response from updating prisma client. response={response}"
|
||||
)
|
||||
return response
|
||||
else:
|
||||
raise ValueError(f"user_id is required, passed user_id = {data.user_id}")
|
||||
|
||||
# update based on remaining passed in values
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.error(
|
||||
"litellm.proxy.proxy_server.update_end_user(): Exception occured - {}".format(
|
||||
str(e)
|
||||
)
|
||||
)
|
||||
verbose_proxy_logger.debug(traceback.format_exc())
|
||||
if isinstance(e, HTTPException):
|
||||
raise ProxyException(
|
||||
message=getattr(e, "detail", f"Internal Server Error({str(e)})"),
|
||||
type="internal_error",
|
||||
param=getattr(e, "param", "None"),
|
||||
code=getattr(e, "status_code", status.HTTP_500_INTERNAL_SERVER_ERROR),
|
||||
)
|
||||
elif isinstance(e, ProxyException):
|
||||
raise e
|
||||
raise ProxyException(
|
||||
message="Internal Server Error, " + str(e),
|
||||
type="internal_error",
|
||||
param=getattr(e, "param", "None"),
|
||||
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
)
|
||||
pass
|
||||
|
||||
|
||||
@router.post(
|
||||
"/customer/delete",
|
||||
tags=["Customer Management"],
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
)
|
||||
@router.post(
|
||||
"/end_user/delete",
|
||||
tags=["Customer Management"],
|
||||
include_in_schema=False,
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
)
|
||||
async def delete_end_user(
|
||||
data: DeleteCustomerRequest,
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
):
|
||||
"""
|
||||
Delete multiple end-users.
|
||||
|
||||
Parameters:
|
||||
- user_ids (List[str], required): The unique `user_id`s for the users to delete
|
||||
|
||||
Example curl:
|
||||
```
|
||||
curl --location 'http://0.0.0.0:4000/customer/delete' \
|
||||
--header 'Authorization: Bearer sk-1234' \
|
||||
--header 'Content-Type: application/json' \
|
||||
--data '{
|
||||
"user_ids" :["ishaan-jaff-5"]
|
||||
}'
|
||||
|
||||
See below for all params
|
||||
```
|
||||
"""
|
||||
from litellm.proxy.proxy_server import prisma_client
|
||||
|
||||
try:
|
||||
if prisma_client is None:
|
||||
raise Exception("Not connected to DB!")
|
||||
|
||||
verbose_proxy_logger.debug("/customer/delete: Received data = %s", data)
|
||||
if (
|
||||
data.user_ids is not None
|
||||
and isinstance(data.user_ids, list)
|
||||
and len(data.user_ids) > 0
|
||||
):
|
||||
response = await prisma_client.db.litellm_endusertable.delete_many(
|
||||
where={"user_id": {"in": data.user_ids}}
|
||||
)
|
||||
if response is None:
|
||||
raise ValueError(
|
||||
f"Failed deleting customer data. User ID does not exist passed user_id={data.user_ids}"
|
||||
)
|
||||
if response != len(data.user_ids):
|
||||
raise ValueError(
|
||||
f"Failed deleting all customer data. User ID does not exist passed user_id={data.user_ids}. Deleted {response} customers, passed {len(data.user_ids)} customers"
|
||||
)
|
||||
verbose_proxy_logger.debug(
|
||||
f"received response from updating prisma client. response={response}"
|
||||
)
|
||||
return {
|
||||
"deleted_customers": response,
|
||||
"message": "Successfully deleted customers with ids: "
|
||||
+ str(data.user_ids),
|
||||
}
|
||||
else:
|
||||
raise ValueError(f"user_id is required, passed user_id = {data.user_ids}")
|
||||
|
||||
# update based on remaining passed in values
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.error(
|
||||
"litellm.proxy.proxy_server.delete_end_user(): Exception occured - {}".format(
|
||||
str(e)
|
||||
)
|
||||
)
|
||||
verbose_proxy_logger.debug(traceback.format_exc())
|
||||
if isinstance(e, HTTPException):
|
||||
raise ProxyException(
|
||||
message=getattr(e, "detail", f"Internal Server Error({str(e)})"),
|
||||
type="internal_error",
|
||||
param=getattr(e, "param", "None"),
|
||||
code=getattr(e, "status_code", status.HTTP_500_INTERNAL_SERVER_ERROR),
|
||||
)
|
||||
elif isinstance(e, ProxyException):
|
||||
raise e
|
||||
raise ProxyException(
|
||||
message="Internal Server Error, " + str(e),
|
||||
type="internal_error",
|
||||
param=getattr(e, "param", "None"),
|
||||
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
)
|
||||
pass
|
||||
|
||||
|
||||
@router.get(
|
||||
"/customer/list",
|
||||
tags=["Customer Management"],
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
response_model=List[LiteLLM_EndUserTable],
|
||||
)
|
||||
@router.get(
|
||||
"/end_user/list",
|
||||
tags=["Customer Management"],
|
||||
include_in_schema=False,
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
)
|
||||
async def list_end_user(
|
||||
http_request: Request,
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
):
|
||||
"""
|
||||
[Admin-only] List all available customers
|
||||
|
||||
Example curl:
|
||||
```
|
||||
curl --location --request GET 'http://0.0.0.0:4000/customer/list' \
|
||||
--header 'Authorization: Bearer sk-1234'
|
||||
```
|
||||
|
||||
"""
|
||||
from litellm.proxy.proxy_server import prisma_client
|
||||
|
||||
if (
|
||||
user_api_key_dict.user_role != LitellmUserRoles.PROXY_ADMIN
|
||||
and user_api_key_dict.user_role != LitellmUserRoles.PROXY_ADMIN_VIEW_ONLY
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail={
|
||||
"error": "Admin-only endpoint. Your user role={}".format(
|
||||
user_api_key_dict.user_role
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
if prisma_client is None:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={"error": CommonProxyErrors.db_not_connected_error.value},
|
||||
)
|
||||
|
||||
response = await prisma_client.db.litellm_endusertable.find_many(
|
||||
include={"litellm_budget_table": True}
|
||||
)
|
||||
|
||||
returned_response: List[LiteLLM_EndUserTable] = []
|
||||
for item in response:
|
||||
returned_response.append(LiteLLM_EndUserTable(**item.model_dump()))
|
||||
return returned_response
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,869 @@
|
||||
"""
|
||||
Allow proxy admin to add/update/delete models in the db
|
||||
|
||||
Currently most endpoints are in `proxy_server.py`, but those should be moved here over time.
|
||||
|
||||
Endpoints here:
|
||||
|
||||
model/{model_id}/update - PATCH endpoint for model update.
|
||||
"""
|
||||
|
||||
#### MODEL MANAGEMENT ####
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import uuid
|
||||
from typing import Dict, List, Literal, Optional, Union, cast
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request, status
|
||||
from pydantic import BaseModel
|
||||
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.constants import LITELLM_PROXY_ADMIN_NAME
|
||||
from litellm.proxy._types import (
|
||||
CommonProxyErrors,
|
||||
LiteLLM_ProxyModelTable,
|
||||
LiteLLM_TeamTable,
|
||||
LitellmTableNames,
|
||||
LitellmUserRoles,
|
||||
ModelInfoDelete,
|
||||
PrismaCompatibleUpdateDBModel,
|
||||
ProxyErrorTypes,
|
||||
ProxyException,
|
||||
TeamModelAddRequest,
|
||||
UpdateTeamRequest,
|
||||
UserAPIKeyAuth,
|
||||
)
|
||||
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
|
||||
from litellm.proxy.common_utils.encrypt_decrypt_utils import encrypt_value_helper
|
||||
from litellm.proxy.management_endpoints.common_utils import _is_user_team_admin
|
||||
from litellm.proxy.management_endpoints.team_endpoints import (
|
||||
team_model_add,
|
||||
update_team,
|
||||
)
|
||||
from litellm.proxy.management_helpers.audit_logs import create_object_audit_log
|
||||
from litellm.proxy.utils import PrismaClient
|
||||
from litellm.types.router import (
|
||||
Deployment,
|
||||
DeploymentTypedDict,
|
||||
LiteLLMParamsTypedDict,
|
||||
updateDeployment,
|
||||
)
|
||||
from litellm.utils import get_utc_datetime
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
async def get_db_model(
|
||||
model_id: str, prisma_client: PrismaClient
|
||||
) -> Optional[Deployment]:
|
||||
db_model = cast(
|
||||
Optional[BaseModel],
|
||||
await prisma_client.db.litellm_proxymodeltable.find_unique(
|
||||
where={"model_id": model_id}
|
||||
),
|
||||
)
|
||||
|
||||
if not db_model:
|
||||
return None
|
||||
|
||||
deployment_pydantic_obj = Deployment(**db_model.model_dump(exclude_none=True))
|
||||
return deployment_pydantic_obj
|
||||
|
||||
|
||||
def update_db_model(
|
||||
db_model: Deployment, updated_patch: updateDeployment
|
||||
) -> PrismaCompatibleUpdateDBModel:
|
||||
merged_deployment_dict = DeploymentTypedDict(
|
||||
model_name=db_model.model_name,
|
||||
litellm_params=LiteLLMParamsTypedDict(
|
||||
**db_model.litellm_params.model_dump(exclude_none=True) # type: ignore
|
||||
),
|
||||
)
|
||||
# update model name
|
||||
if updated_patch.model_name:
|
||||
merged_deployment_dict["model_name"] = updated_patch.model_name
|
||||
|
||||
# update litellm params
|
||||
if updated_patch.litellm_params:
|
||||
# Encrypt any sensitive values
|
||||
encrypted_params = {
|
||||
k: encrypt_value_helper(v)
|
||||
for k, v in updated_patch.litellm_params.model_dump(
|
||||
exclude_none=True
|
||||
).items()
|
||||
}
|
||||
|
||||
merged_deployment_dict["litellm_params"].update(encrypted_params) # type: ignore
|
||||
|
||||
# update model info
|
||||
if updated_patch.model_info:
|
||||
if "model_info" not in merged_deployment_dict:
|
||||
merged_deployment_dict["model_info"] = {}
|
||||
merged_deployment_dict["model_info"].update(
|
||||
updated_patch.model_info.model_dump(exclude_none=True)
|
||||
)
|
||||
|
||||
# convert to prisma compatible format
|
||||
|
||||
prisma_compatible_model_dict = PrismaCompatibleUpdateDBModel()
|
||||
if "model_name" in merged_deployment_dict:
|
||||
prisma_compatible_model_dict["model_name"] = merged_deployment_dict[
|
||||
"model_name"
|
||||
]
|
||||
|
||||
if "litellm_params" in merged_deployment_dict:
|
||||
prisma_compatible_model_dict["litellm_params"] = json.dumps(
|
||||
merged_deployment_dict["litellm_params"]
|
||||
)
|
||||
|
||||
if "model_info" in merged_deployment_dict:
|
||||
prisma_compatible_model_dict["model_info"] = json.dumps(
|
||||
merged_deployment_dict["model_info"]
|
||||
)
|
||||
return prisma_compatible_model_dict
|
||||
|
||||
|
||||
@router.patch(
|
||||
"/model/{model_id}/update",
|
||||
tags=["model management"],
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
)
|
||||
async def patch_model(
|
||||
model_id: str, # Get model_id from path parameter
|
||||
patch_data: updateDeployment, # Create a specific schema for PATCH operations
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
):
|
||||
"""
|
||||
PATCH Endpoint for partial model updates.
|
||||
|
||||
Only updates the fields specified in the request while preserving other existing values.
|
||||
Follows proper PATCH semantics by only modifying provided fields.
|
||||
|
||||
Args:
|
||||
model_id: The ID of the model to update
|
||||
patch_data: The fields to update and their new values
|
||||
user_api_key_dict: User authentication information
|
||||
|
||||
Returns:
|
||||
Updated model information
|
||||
|
||||
Raises:
|
||||
ProxyException: For various error conditions including authentication and database errors
|
||||
"""
|
||||
from litellm.proxy.proxy_server import (
|
||||
litellm_proxy_admin_name,
|
||||
llm_router,
|
||||
prisma_client,
|
||||
store_model_in_db,
|
||||
)
|
||||
|
||||
try:
|
||||
if prisma_client is None:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail={"error": CommonProxyErrors.db_not_connected_error.value},
|
||||
)
|
||||
|
||||
# Verify model exists and is stored in DB
|
||||
if not store_model_in_db:
|
||||
raise ProxyException(
|
||||
message="Model updates only supported for DB-stored models",
|
||||
type=ProxyErrorTypes.validation_error.value,
|
||||
code=status.HTTP_400_BAD_REQUEST,
|
||||
param=None,
|
||||
)
|
||||
|
||||
# Fetch existing model
|
||||
db_model = await get_db_model(model_id=model_id, prisma_client=prisma_client)
|
||||
|
||||
if db_model is None:
|
||||
# Check if model exists in config but not DB
|
||||
if llm_router and llm_router.get_deployment(model_id=model_id) is not None:
|
||||
raise ProxyException(
|
||||
message="Cannot edit config-based model. Store model in DB via /model/new first.",
|
||||
type=ProxyErrorTypes.validation_error.value,
|
||||
code=status.HTTP_400_BAD_REQUEST,
|
||||
param=None,
|
||||
)
|
||||
raise ProxyException(
|
||||
message=f"Model {model_id} not found on proxy.",
|
||||
type=ProxyErrorTypes.not_found_error,
|
||||
code=status.HTTP_404_NOT_FOUND,
|
||||
param=None,
|
||||
)
|
||||
|
||||
# Create update dictionary only for provided fields
|
||||
update_data = update_db_model(db_model=db_model, updated_patch=patch_data)
|
||||
|
||||
# Add metadata about update
|
||||
update_data["updated_by"] = (
|
||||
user_api_key_dict.user_id or litellm_proxy_admin_name
|
||||
)
|
||||
update_data["updated_at"] = cast(str, get_utc_datetime())
|
||||
|
||||
# Perform partial update
|
||||
updated_model = await prisma_client.db.litellm_proxymodeltable.update(
|
||||
where={"model_id": model_id},
|
||||
data=update_data,
|
||||
)
|
||||
|
||||
return updated_model
|
||||
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.exception(f"Error in patch_model: {str(e)}")
|
||||
|
||||
if isinstance(e, (HTTPException, ProxyException)):
|
||||
raise e
|
||||
|
||||
raise ProxyException(
|
||||
message=f"Error updating model: {str(e)}",
|
||||
type=ProxyErrorTypes.internal_server_error,
|
||||
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
param=None,
|
||||
)
|
||||
|
||||
|
||||
################################# Helper Functions #################################
|
||||
####################################################################################
|
||||
####################################################################################
|
||||
####################################################################################
|
||||
|
||||
|
||||
async def _add_model_to_db(
|
||||
model_params: Deployment,
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
prisma_client: PrismaClient,
|
||||
new_encryption_key: Optional[str] = None,
|
||||
should_create_model_in_db: bool = True,
|
||||
) -> Optional[LiteLLM_ProxyModelTable]:
|
||||
# encrypt litellm params #
|
||||
_litellm_params_dict = model_params.litellm_params.dict(exclude_none=True)
|
||||
_orignal_litellm_model_name = model_params.litellm_params.model
|
||||
for k, v in _litellm_params_dict.items():
|
||||
encrypted_value = encrypt_value_helper(
|
||||
value=v, new_encryption_key=new_encryption_key
|
||||
)
|
||||
model_params.litellm_params[k] = encrypted_value
|
||||
_data: dict = {
|
||||
"model_id": model_params.model_info.id,
|
||||
"model_name": model_params.model_name,
|
||||
"litellm_params": model_params.litellm_params.model_dump_json(exclude_none=True), # type: ignore
|
||||
"model_info": model_params.model_info.model_dump_json( # type: ignore
|
||||
exclude_none=True
|
||||
),
|
||||
"created_by": user_api_key_dict.user_id or LITELLM_PROXY_ADMIN_NAME,
|
||||
"updated_by": user_api_key_dict.user_id or LITELLM_PROXY_ADMIN_NAME,
|
||||
}
|
||||
if model_params.model_info.id is not None:
|
||||
_data["model_id"] = model_params.model_info.id
|
||||
if should_create_model_in_db:
|
||||
model_response = await prisma_client.db.litellm_proxymodeltable.create(
|
||||
data=_data # type: ignore
|
||||
)
|
||||
else:
|
||||
model_response = LiteLLM_ProxyModelTable(**_data)
|
||||
return model_response
|
||||
|
||||
|
||||
async def _add_team_model_to_db(
|
||||
model_params: Deployment,
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
prisma_client: PrismaClient,
|
||||
) -> Optional[LiteLLM_ProxyModelTable]:
|
||||
"""
|
||||
If 'team_id' is provided,
|
||||
|
||||
- generate a unique 'model_name' for the model (e.g. 'model_name_{team_id}_{uuid})
|
||||
- store the model in the db with the unique 'model_name'
|
||||
- store a team model alias mapping {"model_name": "model_name_{team_id}_{uuid}"}
|
||||
"""
|
||||
_team_id = model_params.model_info.team_id
|
||||
if _team_id is None:
|
||||
return None
|
||||
original_model_name = model_params.model_name
|
||||
if original_model_name:
|
||||
model_params.model_info.team_public_model_name = original_model_name
|
||||
|
||||
unique_model_name = f"model_name_{_team_id}_{uuid.uuid4()}"
|
||||
|
||||
model_params.model_name = unique_model_name
|
||||
|
||||
## CREATE MODEL IN DB ##
|
||||
model_response = await _add_model_to_db(
|
||||
model_params=model_params,
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
prisma_client=prisma_client,
|
||||
)
|
||||
|
||||
## CREATE MODEL ALIAS IN DB ##
|
||||
await update_team(
|
||||
data=UpdateTeamRequest(
|
||||
team_id=_team_id,
|
||||
model_aliases={original_model_name: unique_model_name},
|
||||
),
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
http_request=Request(scope={"type": "http"}),
|
||||
)
|
||||
|
||||
# add model to team object
|
||||
await team_model_add(
|
||||
data=TeamModelAddRequest(
|
||||
team_id=_team_id,
|
||||
models=[original_model_name],
|
||||
),
|
||||
http_request=Request(scope={"type": "http"}),
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
)
|
||||
|
||||
return model_response
|
||||
|
||||
|
||||
class ModelManagementAuthChecks:
|
||||
"""
|
||||
Common auth checks for model management endpoints
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def can_user_make_team_model_call(
|
||||
team_id: str,
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
team_obj: Optional[LiteLLM_TeamTable] = None,
|
||||
premium_user: bool = False,
|
||||
) -> Literal[True]:
|
||||
if premium_user is False:
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail={"error": CommonProxyErrors.not_premium_user.value},
|
||||
)
|
||||
if (
|
||||
user_api_key_dict.user_role
|
||||
and user_api_key_dict.user_role == LitellmUserRoles.PROXY_ADMIN
|
||||
):
|
||||
return True
|
||||
elif team_obj is None or not _is_user_team_admin(
|
||||
user_api_key_dict=user_api_key_dict, team_obj=team_obj
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail={
|
||||
"error": "Team ID={} does not match the API key's team ID={}, OR you are not the admin for this team. Check `/user/info` to verify your team admin status.".format(
|
||||
team_id, user_api_key_dict.team_id
|
||||
)
|
||||
},
|
||||
)
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
async def allow_team_model_action(
|
||||
model_params: Union[Deployment, updateDeployment],
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
prisma_client: PrismaClient,
|
||||
premium_user: bool,
|
||||
) -> Literal[True]:
|
||||
if model_params.model_info is None or model_params.model_info.team_id is None:
|
||||
return True
|
||||
if model_params.model_info.team_id is not None and premium_user is not True:
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail={"error": CommonProxyErrors.not_premium_user.value},
|
||||
)
|
||||
|
||||
_existing_team_row = await prisma_client.db.litellm_teamtable.find_unique(
|
||||
where={"team_id": model_params.model_info.team_id}
|
||||
)
|
||||
|
||||
if _existing_team_row is None:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={
|
||||
"error": "Team id={} does not exist in db".format(
|
||||
model_params.model_info.team_id
|
||||
)
|
||||
},
|
||||
)
|
||||
existing_team_row = LiteLLM_TeamTable(**_existing_team_row.model_dump())
|
||||
|
||||
ModelManagementAuthChecks.can_user_make_team_model_call(
|
||||
team_id=model_params.model_info.team_id,
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
team_obj=existing_team_row,
|
||||
premium_user=premium_user,
|
||||
)
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
async def can_user_make_model_call(
|
||||
model_params: Deployment,
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
prisma_client: PrismaClient,
|
||||
premium_user: bool,
|
||||
) -> Literal[True]:
|
||||
## Check team model auth
|
||||
if (
|
||||
model_params.model_info is not None
|
||||
and model_params.model_info.team_id is not None
|
||||
):
|
||||
team_obj_row = await prisma_client.db.litellm_teamtable.find_unique(
|
||||
where={"team_id": model_params.model_info.team_id}
|
||||
)
|
||||
if team_obj_row is None:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={
|
||||
"error": "Team id={} does not exist in db".format(
|
||||
model_params.model_info.team_id
|
||||
)
|
||||
},
|
||||
)
|
||||
team_obj = LiteLLM_TeamTable(**team_obj_row.model_dump())
|
||||
|
||||
return ModelManagementAuthChecks.can_user_make_team_model_call(
|
||||
team_id=model_params.model_info.team_id,
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
team_obj=team_obj,
|
||||
premium_user=premium_user,
|
||||
)
|
||||
## Check non-team model auth
|
||||
elif user_api_key_dict.user_role != LitellmUserRoles.PROXY_ADMIN:
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail={
|
||||
"error": "User does not have permission to make this model call. Your role={}. You can only make model calls if you are a PROXY_ADMIN or if you are a team admin, by specifying a team_id in the model_info.".format(
|
||||
user_api_key_dict.user_role
|
||||
)
|
||||
},
|
||||
)
|
||||
else:
|
||||
return True
|
||||
|
||||
return True
|
||||
|
||||
|
||||
#### [BETA] - This is a beta endpoint, format might change based on user feedback. - https://github.com/BerriAI/litellm/issues/964
|
||||
@router.post(
|
||||
"/model/delete",
|
||||
description="Allows deleting models in the model list in the config.yaml",
|
||||
tags=["model management"],
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
)
|
||||
async def delete_model(
|
||||
model_info: ModelInfoDelete,
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
):
|
||||
from litellm.proxy.proxy_server import llm_router
|
||||
|
||||
try:
|
||||
"""
|
||||
[BETA] - This is a beta endpoint, format might change based on user feedback. - https://github.com/BerriAI/litellm/issues/964
|
||||
|
||||
- Check if id in db
|
||||
- Delete
|
||||
"""
|
||||
|
||||
from litellm.proxy.proxy_server import (
|
||||
llm_router,
|
||||
premium_user,
|
||||
prisma_client,
|
||||
store_model_in_db,
|
||||
)
|
||||
|
||||
if prisma_client is None:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail={
|
||||
"error": "No DB Connected. Here's how to do it - https://docs.litellm.ai/docs/proxy/virtual_keys"
|
||||
},
|
||||
)
|
||||
|
||||
model_in_db = await prisma_client.db.litellm_proxymodeltable.find_unique(
|
||||
where={"model_id": model_info.id}
|
||||
)
|
||||
if model_in_db is None:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={"error": f"Model with id={model_info.id} not found in db"},
|
||||
)
|
||||
|
||||
model_params = Deployment(**model_in_db.model_dump())
|
||||
await ModelManagementAuthChecks.can_user_make_model_call(
|
||||
model_params=model_params,
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
prisma_client=prisma_client,
|
||||
premium_user=premium_user,
|
||||
)
|
||||
|
||||
# update DB
|
||||
if store_model_in_db is True:
|
||||
"""
|
||||
- store model_list in db
|
||||
- store keys separately
|
||||
"""
|
||||
# encrypt litellm params #
|
||||
result = await prisma_client.db.litellm_proxymodeltable.delete(
|
||||
where={"model_id": model_info.id}
|
||||
)
|
||||
|
||||
if result is None:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={"error": f"Model with id={model_info.id} not found in db"},
|
||||
)
|
||||
|
||||
## DELETE FROM ROUTER ##
|
||||
if llm_router is not None:
|
||||
llm_router.delete_deployment(id=model_info.id)
|
||||
|
||||
## CREATE AUDIT LOG ##
|
||||
asyncio.create_task(
|
||||
create_object_audit_log(
|
||||
object_id=model_info.id,
|
||||
action="deleted",
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
table_name=LitellmTableNames.PROXY_MODEL_TABLE_NAME,
|
||||
before_value=result.model_dump_json(exclude_none=True),
|
||||
after_value=None,
|
||||
litellm_changed_by=user_api_key_dict.user_id,
|
||||
litellm_proxy_admin_name=LITELLM_PROXY_ADMIN_NAME,
|
||||
)
|
||||
)
|
||||
return {"message": f"Model: {result.model_id} deleted successfully"}
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail={
|
||||
"error": "Set `'STORE_MODEL_IN_DB='True'` in your env to enable this feature."
|
||||
},
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.exception(
|
||||
f"Failed to delete model. Due to error - {str(e)}"
|
||||
)
|
||||
if isinstance(e, HTTPException):
|
||||
raise ProxyException(
|
||||
message=getattr(e, "detail", f"Authentication Error({str(e)})"),
|
||||
type=ProxyErrorTypes.auth_error,
|
||||
param=getattr(e, "param", "None"),
|
||||
code=getattr(e, "status_code", status.HTTP_400_BAD_REQUEST),
|
||||
)
|
||||
elif isinstance(e, ProxyException):
|
||||
raise e
|
||||
raise ProxyException(
|
||||
message="Authentication Error, " + str(e),
|
||||
type=ProxyErrorTypes.auth_error,
|
||||
param=getattr(e, "param", "None"),
|
||||
code=status.HTTP_400_BAD_REQUEST,
|
||||
)
|
||||
|
||||
|
||||
#### [BETA] - This is a beta endpoint, format might change based on user feedback. - https://github.com/BerriAI/litellm/issues/964
|
||||
@router.post(
|
||||
"/model/new",
|
||||
description="Allows adding new models to the model list in the config.yaml",
|
||||
tags=["model management"],
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
)
|
||||
async def add_new_model(
|
||||
model_params: Deployment,
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
):
|
||||
from litellm.proxy.proxy_server import (
|
||||
general_settings,
|
||||
premium_user,
|
||||
prisma_client,
|
||||
proxy_config,
|
||||
proxy_logging_obj,
|
||||
store_model_in_db,
|
||||
)
|
||||
|
||||
try:
|
||||
if prisma_client is None:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail={
|
||||
"error": "No DB Connected. Here's how to do it - https://docs.litellm.ai/docs/proxy/virtual_keys"
|
||||
},
|
||||
)
|
||||
|
||||
## Auth check
|
||||
await ModelManagementAuthChecks.can_user_make_model_call(
|
||||
model_params=model_params,
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
prisma_client=prisma_client,
|
||||
premium_user=premium_user,
|
||||
)
|
||||
|
||||
model_response: Optional[LiteLLM_ProxyModelTable] = None
|
||||
# update DB
|
||||
if store_model_in_db is True:
|
||||
"""
|
||||
- store model_list in db
|
||||
- store keys separately
|
||||
"""
|
||||
|
||||
try:
|
||||
_original_litellm_model_name = model_params.model_name
|
||||
if model_params.model_info.team_id is None:
|
||||
model_response = await _add_model_to_db(
|
||||
model_params=model_params,
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
prisma_client=prisma_client,
|
||||
)
|
||||
else:
|
||||
model_response = await _add_team_model_to_db(
|
||||
model_params=model_params,
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
prisma_client=prisma_client,
|
||||
)
|
||||
await proxy_config.add_deployment(
|
||||
prisma_client=prisma_client, proxy_logging_obj=proxy_logging_obj
|
||||
)
|
||||
# don't let failed slack alert block the /model/new response
|
||||
_alerting = general_settings.get("alerting", []) or []
|
||||
if "slack" in _alerting:
|
||||
# send notification - new model added
|
||||
await proxy_logging_obj.slack_alerting_instance.model_added_alert(
|
||||
model_name=model_params.model_name,
|
||||
litellm_model_name=_original_litellm_model_name,
|
||||
passed_model_info=model_params.model_info,
|
||||
)
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.exception(f"Exception in add_new_model: {e}")
|
||||
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail={
|
||||
"error": "Set `'STORE_MODEL_IN_DB='True'` in your env to enable this feature."
|
||||
},
|
||||
)
|
||||
|
||||
if model_response is None:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail={
|
||||
"error": "Failed to add model to db. Check your server logs for more details."
|
||||
},
|
||||
)
|
||||
|
||||
## CREATE AUDIT LOG ##
|
||||
asyncio.create_task(
|
||||
create_object_audit_log(
|
||||
object_id=model_response.model_id,
|
||||
action="created",
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
table_name=LitellmTableNames.PROXY_MODEL_TABLE_NAME,
|
||||
before_value=None,
|
||||
after_value=(
|
||||
model_response.model_dump_json(exclude_none=True)
|
||||
if isinstance(model_response, BaseModel)
|
||||
else None
|
||||
),
|
||||
litellm_changed_by=user_api_key_dict.user_id,
|
||||
litellm_proxy_admin_name=LITELLM_PROXY_ADMIN_NAME,
|
||||
)
|
||||
)
|
||||
|
||||
return model_response
|
||||
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.exception(
|
||||
"litellm.proxy.proxy_server.add_new_model(): Exception occured - {}".format(
|
||||
str(e)
|
||||
)
|
||||
)
|
||||
if isinstance(e, HTTPException):
|
||||
raise ProxyException(
|
||||
message=getattr(e, "detail", f"Authentication Error({str(e)})"),
|
||||
type=ProxyErrorTypes.auth_error,
|
||||
param=getattr(e, "param", "None"),
|
||||
code=getattr(e, "status_code", status.HTTP_400_BAD_REQUEST),
|
||||
)
|
||||
elif isinstance(e, ProxyException):
|
||||
raise e
|
||||
raise ProxyException(
|
||||
message="Authentication Error, " + str(e),
|
||||
type=ProxyErrorTypes.auth_error,
|
||||
param=getattr(e, "param", "None"),
|
||||
code=status.HTTP_400_BAD_REQUEST,
|
||||
)
|
||||
|
||||
|
||||
#### MODEL MANAGEMENT ####
|
||||
@router.post(
|
||||
"/model/update",
|
||||
description="Edit existing model params",
|
||||
tags=["model management"],
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
)
|
||||
async def update_model(
|
||||
model_params: updateDeployment,
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
):
|
||||
"""
|
||||
Old endpoint for model update. Makes a PUT request.
|
||||
|
||||
Use `/model/{model_id}/update` to PATCH the stored model in db.
|
||||
"""
|
||||
from litellm.proxy.proxy_server import (
|
||||
LITELLM_PROXY_ADMIN_NAME,
|
||||
llm_router,
|
||||
premium_user,
|
||||
prisma_client,
|
||||
store_model_in_db,
|
||||
)
|
||||
|
||||
try:
|
||||
if prisma_client is None:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail={
|
||||
"error": "No DB Connected. Here's how to do it - https://docs.litellm.ai/docs/proxy/virtual_keys"
|
||||
},
|
||||
)
|
||||
|
||||
_model_id = None
|
||||
_model_info = getattr(model_params, "model_info", None)
|
||||
if _model_info is None:
|
||||
raise Exception("model_info not provided")
|
||||
|
||||
_model_id = _model_info.id
|
||||
if _model_id is None:
|
||||
raise Exception("model_info.id not provided")
|
||||
|
||||
_existing_litellm_params = (
|
||||
await prisma_client.db.litellm_proxymodeltable.find_unique(
|
||||
where={"model_id": _model_id}
|
||||
)
|
||||
)
|
||||
|
||||
if _existing_litellm_params is None:
|
||||
if (
|
||||
llm_router is not None
|
||||
and llm_router.get_deployment(model_id=_model_id) is not None
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={
|
||||
"error": "Can't edit model. Model in config. Store model in db via `/model/new`. to edit."
|
||||
},
|
||||
)
|
||||
else:
|
||||
raise Exception("model not found")
|
||||
deployment = Deployment(**_existing_litellm_params.model_dump())
|
||||
|
||||
await ModelManagementAuthChecks.can_user_make_model_call(
|
||||
model_params=deployment,
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
prisma_client=prisma_client,
|
||||
premium_user=premium_user,
|
||||
)
|
||||
|
||||
# update DB
|
||||
if store_model_in_db is True:
|
||||
_existing_litellm_params_dict = dict(
|
||||
_existing_litellm_params.litellm_params
|
||||
)
|
||||
|
||||
if model_params.litellm_params is None:
|
||||
raise Exception("litellm_params not provided")
|
||||
|
||||
_new_litellm_params_dict = model_params.litellm_params.dict(
|
||||
exclude_none=True
|
||||
)
|
||||
|
||||
### ENCRYPT PARAMS ###
|
||||
for k, v in _new_litellm_params_dict.items():
|
||||
encrypted_value = encrypt_value_helper(value=v)
|
||||
model_params.litellm_params[k] = encrypted_value
|
||||
|
||||
### MERGE WITH EXISTING DATA ###
|
||||
merged_dictionary = {}
|
||||
_mp = model_params.litellm_params.dict()
|
||||
|
||||
for key, value in _mp.items():
|
||||
if value is not None:
|
||||
merged_dictionary[key] = value
|
||||
elif (
|
||||
key in _existing_litellm_params_dict
|
||||
and _existing_litellm_params_dict[key] is not None
|
||||
):
|
||||
merged_dictionary[key] = _existing_litellm_params_dict[key]
|
||||
else:
|
||||
pass
|
||||
|
||||
_data: dict = {
|
||||
"litellm_params": json.dumps(merged_dictionary), # type: ignore
|
||||
"updated_by": user_api_key_dict.user_id or LITELLM_PROXY_ADMIN_NAME,
|
||||
}
|
||||
model_response = await prisma_client.db.litellm_proxymodeltable.update(
|
||||
where={"model_id": _model_id},
|
||||
data=_data, # type: ignore
|
||||
)
|
||||
|
||||
## CREATE AUDIT LOG ##
|
||||
asyncio.create_task(
|
||||
create_object_audit_log(
|
||||
object_id=_model_id,
|
||||
action="updated",
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
table_name=LitellmTableNames.PROXY_MODEL_TABLE_NAME,
|
||||
before_value=(
|
||||
_existing_litellm_params.model_dump_json(exclude_none=True)
|
||||
if isinstance(_existing_litellm_params, BaseModel)
|
||||
else None
|
||||
),
|
||||
after_value=(
|
||||
model_response.model_dump_json(exclude_none=True)
|
||||
if isinstance(model_response, BaseModel)
|
||||
else None
|
||||
),
|
||||
litellm_changed_by=user_api_key_dict.user_id,
|
||||
litellm_proxy_admin_name=LITELLM_PROXY_ADMIN_NAME,
|
||||
)
|
||||
)
|
||||
|
||||
return model_response
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.exception(
|
||||
"litellm.proxy.proxy_server.update_model(): Exception occured - {}".format(
|
||||
str(e)
|
||||
)
|
||||
)
|
||||
if isinstance(e, HTTPException):
|
||||
raise ProxyException(
|
||||
message=getattr(e, "detail", f"Authentication Error({str(e)})"),
|
||||
type=ProxyErrorTypes.auth_error,
|
||||
param=getattr(e, "param", "None"),
|
||||
code=getattr(e, "status_code", status.HTTP_400_BAD_REQUEST),
|
||||
)
|
||||
elif isinstance(e, ProxyException):
|
||||
raise e
|
||||
raise ProxyException(
|
||||
message="Authentication Error, " + str(e),
|
||||
type=ProxyErrorTypes.auth_error,
|
||||
param=getattr(e, "param", "None"),
|
||||
code=status.HTTP_400_BAD_REQUEST,
|
||||
)
|
||||
|
||||
|
||||
def _deduplicate_litellm_router_models(models: List[Dict]) -> List[Dict]:
|
||||
"""
|
||||
Deduplicate models based on their model_info.id field.
|
||||
Returns a list of unique models keeping only the first occurrence of each model ID.
|
||||
|
||||
Args:
|
||||
models: List of model dictionaries containing model_info
|
||||
|
||||
Returns:
|
||||
List of deduplicated model dictionaries
|
||||
"""
|
||||
seen_ids = set()
|
||||
unique_models = []
|
||||
for model in models:
|
||||
model_id = model.get("model_info", {}).get("id", None)
|
||||
if model_id is not None and model_id not in seen_ids:
|
||||
unique_models.append(model)
|
||||
seen_ids.add(model_id)
|
||||
return unique_models
|
||||
@@ -0,0 +1,821 @@
|
||||
"""
|
||||
Endpoints for /organization operations
|
||||
|
||||
/organization/new
|
||||
/organization/update
|
||||
/organization/delete
|
||||
/organization/member_add
|
||||
/organization/info
|
||||
/organization/list
|
||||
"""
|
||||
|
||||
#### ORGANIZATION MANAGEMENT ####
|
||||
|
||||
import uuid
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request, status
|
||||
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.proxy._types import *
|
||||
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
|
||||
from litellm.proxy.management_endpoints.budget_management_endpoints import (
|
||||
new_budget,
|
||||
update_budget,
|
||||
)
|
||||
from litellm.proxy.management_helpers.utils import (
|
||||
get_new_internal_user_defaults,
|
||||
management_endpoint_wrapper,
|
||||
)
|
||||
from litellm.proxy.utils import PrismaClient
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.post(
|
||||
"/organization/new",
|
||||
tags=["organization management"],
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
response_model=NewOrganizationResponse,
|
||||
)
|
||||
async def new_organization(
|
||||
data: NewOrganizationRequest,
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
):
|
||||
"""
|
||||
Allow orgs to own teams
|
||||
|
||||
Set org level budgets + model access.
|
||||
|
||||
Only admins can create orgs.
|
||||
|
||||
# Parameters
|
||||
|
||||
- organization_alias: *str* - The name of the organization.
|
||||
- models: *List* - The models the organization has access to.
|
||||
- budget_id: *Optional[str]* - The id for a budget (tpm/rpm/max budget) for the organization.
|
||||
### IF NO BUDGET ID - CREATE ONE WITH THESE PARAMS ###
|
||||
- max_budget: *Optional[float]* - Max budget for org
|
||||
- tpm_limit: *Optional[int]* - Max tpm limit for org
|
||||
- rpm_limit: *Optional[int]* - Max rpm limit for org
|
||||
- max_parallel_requests: *Optional[int]* - [Not Implemented Yet] Max parallel requests for org
|
||||
- soft_budget: *Optional[float]* - [Not Implemented Yet] Get a slack alert when this soft budget is reached. Don't block requests.
|
||||
- model_max_budget: *Optional[dict]* - Max budget for a specific model
|
||||
- budget_duration: *Optional[str]* - Frequency of reseting org budget
|
||||
- metadata: *Optional[dict]* - Metadata for organization, store information for organization. Example metadata - {"extra_info": "some info"}
|
||||
- blocked: *bool* - Flag indicating if the org is blocked or not - will stop all calls from keys with this org_id.
|
||||
- tags: *Optional[List[str]]* - Tags for [tracking spend](https://litellm.vercel.app/docs/proxy/enterprise#tracking-spend-for-custom-tags) and/or doing [tag-based routing](https://litellm.vercel.app/docs/proxy/tag_routing).
|
||||
- organization_id: *Optional[str]* - The organization id of the team. Default is None. Create via `/organization/new`.
|
||||
- model_aliases: Optional[dict] - Model aliases for the team. [Docs](https://docs.litellm.ai/docs/proxy/team_based_routing#create-team-with-model-alias)
|
||||
|
||||
Case 1: Create new org **without** a budget_id
|
||||
|
||||
```bash
|
||||
curl --location 'http://0.0.0.0:4000/organization/new' \
|
||||
|
||||
--header 'Authorization: Bearer sk-1234' \
|
||||
|
||||
--header 'Content-Type: application/json' \
|
||||
|
||||
--data '{
|
||||
"organization_alias": "my-secret-org",
|
||||
"models": ["model1", "model2"],
|
||||
"max_budget": 100
|
||||
}'
|
||||
|
||||
|
||||
```
|
||||
|
||||
Case 2: Create new org **with** a budget_id
|
||||
|
||||
```bash
|
||||
curl --location 'http://0.0.0.0:4000/organization/new' \
|
||||
|
||||
--header 'Authorization: Bearer sk-1234' \
|
||||
|
||||
--header 'Content-Type: application/json' \
|
||||
|
||||
--data '{
|
||||
"organization_alias": "my-secret-org",
|
||||
"models": ["model1", "model2"],
|
||||
"budget_id": "428eeaa8-f3ac-4e85-a8fb-7dc8d7aa8689"
|
||||
}'
|
||||
```
|
||||
"""
|
||||
|
||||
from litellm.proxy.proxy_server import litellm_proxy_admin_name, prisma_client
|
||||
|
||||
if prisma_client is None:
|
||||
raise HTTPException(status_code=500, detail={"error": "No db connected"})
|
||||
|
||||
if (
|
||||
user_api_key_dict.user_role is None
|
||||
or user_api_key_dict.user_role != LitellmUserRoles.PROXY_ADMIN
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail={
|
||||
"error": f"Only admins can create orgs. Your role is = {user_api_key_dict.user_role}"
|
||||
},
|
||||
)
|
||||
|
||||
if data.budget_id is None:
|
||||
"""
|
||||
Every organization needs a budget attached.
|
||||
|
||||
If none provided, create one based on provided values
|
||||
"""
|
||||
budget_params = LiteLLM_BudgetTable.model_fields.keys()
|
||||
|
||||
# Only include Budget Params when creating an entry in litellm_budgettable
|
||||
_json_data = data.json(exclude_none=True)
|
||||
_budget_data = {k: v for k, v in _json_data.items() if k in budget_params}
|
||||
budget_row = LiteLLM_BudgetTable(**_budget_data)
|
||||
|
||||
new_budget = prisma_client.jsonify_object(budget_row.json(exclude_none=True))
|
||||
|
||||
_budget = await prisma_client.db.litellm_budgettable.create(
|
||||
data={
|
||||
**new_budget, # type: ignore
|
||||
"created_by": user_api_key_dict.user_id or litellm_proxy_admin_name,
|
||||
"updated_by": user_api_key_dict.user_id or litellm_proxy_admin_name,
|
||||
}
|
||||
) # type: ignore
|
||||
|
||||
data.budget_id = _budget.budget_id
|
||||
|
||||
"""
|
||||
Ensure only models that user has access to, are given to org
|
||||
"""
|
||||
if len(user_api_key_dict.models) == 0: # user has access to all models
|
||||
pass
|
||||
else:
|
||||
if len(data.models) == 0:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={
|
||||
"error": "User not allowed to give access to all models. Select models you want org to have access to."
|
||||
},
|
||||
)
|
||||
for m in data.models:
|
||||
if m not in user_api_key_dict.models:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={
|
||||
"error": f"User not allowed to give access to model={m}. Models you have access to = {user_api_key_dict.models}"
|
||||
},
|
||||
)
|
||||
|
||||
organization_row = LiteLLM_OrganizationTable(
|
||||
**data.json(exclude_none=True),
|
||||
created_by=user_api_key_dict.user_id or litellm_proxy_admin_name,
|
||||
updated_by=user_api_key_dict.user_id or litellm_proxy_admin_name,
|
||||
)
|
||||
new_organization_row = prisma_client.jsonify_object(
|
||||
organization_row.json(exclude_none=True)
|
||||
)
|
||||
verbose_proxy_logger.info(
|
||||
f"new_organization_row: {json.dumps(new_organization_row, indent=2)}"
|
||||
)
|
||||
response = await prisma_client.db.litellm_organizationtable.create(
|
||||
data={
|
||||
**new_organization_row, # type: ignore
|
||||
}
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
|
||||
@router.patch(
|
||||
"/organization/update",
|
||||
tags=["organization management"],
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
response_model=LiteLLM_OrganizationTableWithMembers,
|
||||
)
|
||||
async def update_organization(
|
||||
data: LiteLLM_OrganizationTableUpdate,
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
):
|
||||
"""
|
||||
Update an organization
|
||||
"""
|
||||
from litellm.proxy.proxy_server import prisma_client
|
||||
|
||||
if prisma_client is None:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail={"error": CommonProxyErrors.db_not_connected_error.value},
|
||||
)
|
||||
|
||||
if user_api_key_dict.user_id is None:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={
|
||||
"error": "Cannot associate a user_id to this action. Check `/key/info` to validate if 'user_id' is set."
|
||||
},
|
||||
)
|
||||
|
||||
if data.updated_by is None:
|
||||
data.updated_by = user_api_key_dict.user_id
|
||||
|
||||
updated_organization_row = prisma_client.jsonify_object(
|
||||
data.model_dump(exclude_none=True)
|
||||
)
|
||||
|
||||
response = await prisma_client.db.litellm_organizationtable.update(
|
||||
where={"organization_id": data.organization_id},
|
||||
data=updated_organization_row,
|
||||
include={"members": True, "teams": True, "litellm_budget_table": True},
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
|
||||
@router.delete(
|
||||
"/organization/delete",
|
||||
tags=["organization management"],
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
response_model=List[LiteLLM_OrganizationTableWithMembers],
|
||||
)
|
||||
async def delete_organization(
|
||||
data: DeleteOrganizationRequest,
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
):
|
||||
"""
|
||||
Delete an organization
|
||||
|
||||
# Parameters:
|
||||
|
||||
- organization_ids: List[str] - The organization ids to delete.
|
||||
"""
|
||||
from litellm.proxy.proxy_server import prisma_client
|
||||
|
||||
if prisma_client is None:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail={"error": CommonProxyErrors.db_not_connected_error.value},
|
||||
)
|
||||
|
||||
if user_api_key_dict.user_role != LitellmUserRoles.PROXY_ADMIN:
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail={"error": "Only proxy admins can delete organizations"},
|
||||
)
|
||||
|
||||
deleted_orgs = []
|
||||
for organization_id in data.organization_ids:
|
||||
# delete all teams in the organization
|
||||
await prisma_client.db.litellm_teamtable.delete_many(
|
||||
where={"organization_id": organization_id}
|
||||
)
|
||||
# delete all members in the organization
|
||||
await prisma_client.db.litellm_organizationmembership.delete_many(
|
||||
where={"organization_id": organization_id}
|
||||
)
|
||||
# delete all keys in the organization
|
||||
await prisma_client.db.litellm_verificationtoken.delete_many(
|
||||
where={"organization_id": organization_id}
|
||||
)
|
||||
# delete the organization
|
||||
deleted_org = await prisma_client.db.litellm_organizationtable.delete(
|
||||
where={"organization_id": organization_id},
|
||||
include={"members": True, "teams": True, "litellm_budget_table": True},
|
||||
)
|
||||
if deleted_org is None:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail={"error": f"Organization={organization_id} not found"},
|
||||
)
|
||||
deleted_orgs.append(deleted_org)
|
||||
|
||||
return deleted_orgs
|
||||
|
||||
|
||||
@router.get(
|
||||
"/organization/list",
|
||||
tags=["organization management"],
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
response_model=List[LiteLLM_OrganizationTableWithMembers],
|
||||
)
|
||||
async def list_organization(
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
):
|
||||
"""
|
||||
```
|
||||
curl --location --request GET 'http://0.0.0.0:4000/organization/list' \
|
||||
--header 'Authorization: Bearer sk-1234'
|
||||
```
|
||||
"""
|
||||
from litellm.proxy.proxy_server import prisma_client
|
||||
|
||||
if prisma_client is None:
|
||||
raise HTTPException(status_code=500, detail={"error": "No db connected"})
|
||||
|
||||
if prisma_client is None:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={"error": CommonProxyErrors.db_not_connected_error.value},
|
||||
)
|
||||
|
||||
# if proxy admin - get all orgs
|
||||
if user_api_key_dict.user_role == LitellmUserRoles.PROXY_ADMIN:
|
||||
response = await prisma_client.db.litellm_organizationtable.find_many(
|
||||
include={"members": True, "teams": True}
|
||||
)
|
||||
# if internal user - get orgs they are a member of
|
||||
else:
|
||||
org_memberships = (
|
||||
await prisma_client.db.litellm_organizationmembership.find_many(
|
||||
where={"user_id": user_api_key_dict.user_id}
|
||||
)
|
||||
)
|
||||
org_objects = await prisma_client.db.litellm_organizationtable.find_many(
|
||||
where={
|
||||
"organization_id": {
|
||||
"in": [membership.organization_id for membership in org_memberships]
|
||||
}
|
||||
},
|
||||
include={"members": True, "teams": True},
|
||||
)
|
||||
|
||||
response = org_objects
|
||||
|
||||
return response
|
||||
|
||||
|
||||
@router.get(
|
||||
"/organization/info",
|
||||
tags=["organization management"],
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
response_model=LiteLLM_OrganizationTableWithMembers,
|
||||
)
|
||||
async def info_organization(organization_id: str):
|
||||
"""
|
||||
Get the org specific information
|
||||
"""
|
||||
from litellm.proxy.proxy_server import prisma_client
|
||||
|
||||
if prisma_client is None:
|
||||
raise HTTPException(status_code=500, detail={"error": "No db connected"})
|
||||
|
||||
response: Optional[
|
||||
LiteLLM_OrganizationTableWithMembers
|
||||
] = await prisma_client.db.litellm_organizationtable.find_unique(
|
||||
where={"organization_id": organization_id},
|
||||
include={"litellm_budget_table": True, "members": True, "teams": True},
|
||||
)
|
||||
|
||||
if response is None:
|
||||
raise HTTPException(status_code=404, detail={"error": "Organization not found"})
|
||||
|
||||
response_pydantic_obj = LiteLLM_OrganizationTableWithMembers(
|
||||
**response.model_dump()
|
||||
)
|
||||
|
||||
return response_pydantic_obj
|
||||
|
||||
|
||||
@router.post(
|
||||
"/organization/info",
|
||||
tags=["organization management"],
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
)
|
||||
async def deprecated_info_organization(data: OrganizationRequest):
|
||||
"""
|
||||
DEPRECATED: Use GET /organization/info instead
|
||||
"""
|
||||
from litellm.proxy.proxy_server import prisma_client
|
||||
|
||||
if prisma_client is None:
|
||||
raise HTTPException(status_code=500, detail={"error": "No db connected"})
|
||||
|
||||
if len(data.organizations) == 0:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={
|
||||
"error": f"Specify list of organization id's to query. Passed in={data.organizations}"
|
||||
},
|
||||
)
|
||||
response = await prisma_client.db.litellm_organizationtable.find_many(
|
||||
where={"organization_id": {"in": data.organizations}},
|
||||
include={"litellm_budget_table": True},
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
|
||||
@router.post(
|
||||
"/organization/member_add",
|
||||
tags=["organization management"],
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
response_model=OrganizationAddMemberResponse,
|
||||
)
|
||||
@management_endpoint_wrapper
|
||||
async def organization_member_add(
|
||||
data: OrganizationMemberAddRequest,
|
||||
http_request: Request,
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
) -> OrganizationAddMemberResponse:
|
||||
"""
|
||||
[BETA]
|
||||
|
||||
Add new members (either via user_email or user_id) to an organization
|
||||
|
||||
If user doesn't exist, new user row will also be added to User Table
|
||||
|
||||
Only proxy_admin or org_admin of organization, allowed to access this endpoint.
|
||||
|
||||
# Parameters:
|
||||
|
||||
- organization_id: str (required)
|
||||
- member: Union[List[Member], Member] (required)
|
||||
- role: Literal[LitellmUserRoles] (required)
|
||||
- user_id: Optional[str]
|
||||
- user_email: Optional[str]
|
||||
|
||||
Note: Either user_id or user_email must be provided for each member.
|
||||
|
||||
Example:
|
||||
```
|
||||
curl -X POST 'http://0.0.0.0:4000/organization/member_add' \
|
||||
-H 'Authorization: Bearer sk-1234' \
|
||||
-H 'Content-Type: application/json' \
|
||||
-d '{
|
||||
"organization_id": "45e3e396-ee08-4a61-a88e-16b3ce7e0849",
|
||||
"member": {
|
||||
"role": "internal_user",
|
||||
"user_id": "krrish247652@berri.ai"
|
||||
},
|
||||
"max_budget_in_organization": 100.0
|
||||
}'
|
||||
```
|
||||
|
||||
The following is executed in this function:
|
||||
|
||||
1. Check if organization exists
|
||||
2. Creates a new Internal User if the user_id or user_email is not found in LiteLLM_UserTable
|
||||
3. Add Internal User to the `LiteLLM_OrganizationMembership` table
|
||||
"""
|
||||
try:
|
||||
from litellm.proxy.proxy_server import prisma_client
|
||||
|
||||
if prisma_client is None:
|
||||
raise HTTPException(status_code=500, detail={"error": "No db connected"})
|
||||
|
||||
# Check if organization exists
|
||||
existing_organization_row = (
|
||||
await prisma_client.db.litellm_organizationtable.find_unique(
|
||||
where={"organization_id": data.organization_id}
|
||||
)
|
||||
)
|
||||
if existing_organization_row is None:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail={
|
||||
"error": f"Organization not found for organization_id={getattr(data, 'organization_id', None)}"
|
||||
},
|
||||
)
|
||||
|
||||
members: List[OrgMember]
|
||||
if isinstance(data.member, List):
|
||||
members = data.member
|
||||
else:
|
||||
members = [data.member]
|
||||
|
||||
updated_users: List[LiteLLM_UserTable] = []
|
||||
updated_organization_memberships: List[LiteLLM_OrganizationMembershipTable] = []
|
||||
|
||||
for member in members:
|
||||
(
|
||||
updated_user,
|
||||
updated_organization_membership,
|
||||
) = await add_member_to_organization(
|
||||
member=member,
|
||||
organization_id=data.organization_id,
|
||||
prisma_client=prisma_client,
|
||||
)
|
||||
|
||||
updated_users.append(updated_user)
|
||||
updated_organization_memberships.append(updated_organization_membership)
|
||||
|
||||
return OrganizationAddMemberResponse(
|
||||
organization_id=data.organization_id,
|
||||
updated_users=updated_users,
|
||||
updated_organization_memberships=updated_organization_memberships,
|
||||
)
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.exception(f"Error adding member to organization: {e}")
|
||||
if isinstance(e, HTTPException):
|
||||
raise ProxyException(
|
||||
message=getattr(e, "detail", f"Authentication Error({str(e)})"),
|
||||
type=ProxyErrorTypes.auth_error,
|
||||
param=getattr(e, "param", "None"),
|
||||
code=getattr(e, "status_code", status.HTTP_500_INTERNAL_SERVER_ERROR),
|
||||
)
|
||||
elif isinstance(e, ProxyException):
|
||||
raise e
|
||||
raise ProxyException(
|
||||
message="Authentication Error, " + str(e),
|
||||
type=ProxyErrorTypes.auth_error,
|
||||
param=getattr(e, "param", "None"),
|
||||
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
)
|
||||
|
||||
|
||||
async def find_member_if_email(
|
||||
user_email: str, prisma_client: PrismaClient
|
||||
) -> LiteLLM_UserTable:
|
||||
"""
|
||||
Find a member if the user_email is in LiteLLM_UserTable
|
||||
"""
|
||||
|
||||
try:
|
||||
existing_user_email_row: BaseModel = (
|
||||
await prisma_client.db.litellm_usertable.find_unique(
|
||||
where={"user_email": user_email}
|
||||
)
|
||||
)
|
||||
except Exception:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={
|
||||
"error": f"Unique user not found for user_email={user_email}. Potential duplicate OR non-existent user_email in LiteLLM_UserTable. Use 'user_id' instead."
|
||||
},
|
||||
)
|
||||
existing_user_email_row_pydantic = LiteLLM_UserTable(
|
||||
**existing_user_email_row.model_dump()
|
||||
)
|
||||
return existing_user_email_row_pydantic
|
||||
|
||||
|
||||
@router.patch(
|
||||
"/organization/member_update",
|
||||
tags=["organization management"],
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
response_model=LiteLLM_OrganizationMembershipTable,
|
||||
)
|
||||
@management_endpoint_wrapper
|
||||
async def organization_member_update(
|
||||
data: OrganizationMemberUpdateRequest,
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
):
|
||||
"""
|
||||
Update a member's role in an organization
|
||||
"""
|
||||
try:
|
||||
from litellm.proxy.proxy_server import prisma_client
|
||||
|
||||
if prisma_client is None:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail={"error": CommonProxyErrors.db_not_connected_error.value},
|
||||
)
|
||||
|
||||
# Check if organization exists
|
||||
existing_organization_row = (
|
||||
await prisma_client.db.litellm_organizationtable.find_unique(
|
||||
where={"organization_id": data.organization_id}
|
||||
)
|
||||
)
|
||||
if existing_organization_row is None:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={
|
||||
"error": f"Organization not found for organization_id={getattr(data, 'organization_id', None)}"
|
||||
},
|
||||
)
|
||||
|
||||
# Check if member exists in organization
|
||||
if data.user_email is not None and data.user_id is None:
|
||||
existing_user_email_row = await find_member_if_email(
|
||||
data.user_email, prisma_client
|
||||
)
|
||||
data.user_id = existing_user_email_row.user_id
|
||||
|
||||
try:
|
||||
existing_organization_membership = (
|
||||
await prisma_client.db.litellm_organizationmembership.find_unique(
|
||||
where={
|
||||
"user_id_organization_id": {
|
||||
"user_id": data.user_id,
|
||||
"organization_id": data.organization_id,
|
||||
}
|
||||
}
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={
|
||||
"error": f"Error finding organization membership for user_id={data.user_id} in organization={data.organization_id}: {e}"
|
||||
},
|
||||
)
|
||||
if existing_organization_membership is None:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail={
|
||||
"error": f"Member not found in organization for user_id={data.user_id}"
|
||||
},
|
||||
)
|
||||
|
||||
# Update member role
|
||||
if data.role is not None:
|
||||
await prisma_client.db.litellm_organizationmembership.update(
|
||||
where={
|
||||
"user_id_organization_id": {
|
||||
"user_id": data.user_id,
|
||||
"organization_id": data.organization_id,
|
||||
}
|
||||
},
|
||||
data={"user_role": data.role},
|
||||
)
|
||||
if data.max_budget_in_organization is not None:
|
||||
# if budget_id is None, create a new budget
|
||||
budget_id = existing_organization_membership.budget_id or str(uuid.uuid4())
|
||||
if existing_organization_membership.budget_id is None:
|
||||
new_budget_obj = BudgetNewRequest(
|
||||
budget_id=budget_id, max_budget=data.max_budget_in_organization
|
||||
)
|
||||
await new_budget(
|
||||
budget_obj=new_budget_obj, user_api_key_dict=user_api_key_dict
|
||||
)
|
||||
else:
|
||||
# update budget table with new max_budget
|
||||
await update_budget(
|
||||
budget_obj=BudgetNewRequest(
|
||||
budget_id=budget_id, max_budget=data.max_budget_in_organization
|
||||
),
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
)
|
||||
|
||||
# update organization membership with new budget_id
|
||||
await prisma_client.db.litellm_organizationmembership.update(
|
||||
where={
|
||||
"user_id_organization_id": {
|
||||
"user_id": data.user_id,
|
||||
"organization_id": data.organization_id,
|
||||
}
|
||||
},
|
||||
data={"budget_id": budget_id},
|
||||
)
|
||||
final_organization_membership: Optional[
|
||||
BaseModel
|
||||
] = await prisma_client.db.litellm_organizationmembership.find_unique(
|
||||
where={
|
||||
"user_id_organization_id": {
|
||||
"user_id": data.user_id,
|
||||
"organization_id": data.organization_id,
|
||||
}
|
||||
},
|
||||
include={"litellm_budget_table": True},
|
||||
)
|
||||
|
||||
if final_organization_membership is None:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={
|
||||
"error": f"Member not found in organization={data.organization_id} for user_id={data.user_id}"
|
||||
},
|
||||
)
|
||||
|
||||
final_organization_membership_pydantic = LiteLLM_OrganizationMembershipTable(
|
||||
**final_organization_membership.model_dump(exclude_none=True)
|
||||
)
|
||||
return final_organization_membership_pydantic
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.exception(f"Error updating member in organization: {e}")
|
||||
raise e
|
||||
|
||||
|
||||
@router.delete(
|
||||
"/organization/member_delete",
|
||||
tags=["organization management"],
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
)
|
||||
async def organization_member_delete(
|
||||
data: OrganizationMemberDeleteRequest,
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
):
|
||||
"""
|
||||
Delete a member from an organization
|
||||
"""
|
||||
try:
|
||||
from litellm.proxy.proxy_server import prisma_client
|
||||
|
||||
if prisma_client is None:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail={"error": CommonProxyErrors.db_not_connected_error.value},
|
||||
)
|
||||
|
||||
if data.user_email is not None and data.user_id is None:
|
||||
existing_user_email_row = await find_member_if_email(
|
||||
data.user_email, prisma_client
|
||||
)
|
||||
data.user_id = existing_user_email_row.user_id
|
||||
|
||||
member_to_delete = await prisma_client.db.litellm_organizationmembership.delete(
|
||||
where={
|
||||
"user_id_organization_id": {
|
||||
"user_id": data.user_id,
|
||||
"organization_id": data.organization_id,
|
||||
}
|
||||
}
|
||||
)
|
||||
return member_to_delete
|
||||
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.exception(f"Error deleting member from organization: {e}")
|
||||
raise e
|
||||
|
||||
|
||||
async def add_member_to_organization(
|
||||
member: OrgMember,
|
||||
organization_id: str,
|
||||
prisma_client: PrismaClient,
|
||||
) -> Tuple[LiteLLM_UserTable, LiteLLM_OrganizationMembershipTable]:
|
||||
"""
|
||||
Add a member to an organization
|
||||
|
||||
- Checks if member.user_id or member.user_email is in LiteLLM_UserTable
|
||||
- If not found, create a new user in LiteLLM_UserTable
|
||||
- Add user to organization in LiteLLM_OrganizationMembership
|
||||
"""
|
||||
|
||||
try:
|
||||
user_object: Optional[LiteLLM_UserTable] = None
|
||||
existing_user_id_row = None
|
||||
existing_user_email_row = None
|
||||
## Check if user exists in LiteLLM_UserTable - user exists - either the user_id or user_email is in LiteLLM_UserTable
|
||||
if member.user_id is not None:
|
||||
existing_user_id_row = await prisma_client.db.litellm_usertable.find_unique(
|
||||
where={"user_id": member.user_id}
|
||||
)
|
||||
|
||||
if existing_user_id_row is None and member.user_email is not None:
|
||||
try:
|
||||
existing_user_email_row = (
|
||||
await prisma_client.db.litellm_usertable.find_unique(
|
||||
where={"user_email": member.user_email}
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
raise ValueError(
|
||||
f"Potential NON-Existent or Duplicate user email in DB: Error finding a unique instance of user_email={member.user_email} in LiteLLM_UserTable.: {e}"
|
||||
)
|
||||
|
||||
## If user does not exist, create a new user
|
||||
if existing_user_id_row is None and existing_user_email_row is None:
|
||||
# Create a new user - since user does not exist
|
||||
user_id: str = member.user_id or str(uuid.uuid4())
|
||||
new_user_defaults = get_new_internal_user_defaults(
|
||||
user_id=user_id,
|
||||
user_email=member.user_email,
|
||||
)
|
||||
|
||||
_returned_user = await prisma_client.insert_data(data=new_user_defaults, table_name="user") # type: ignore
|
||||
if _returned_user is not None:
|
||||
user_object = LiteLLM_UserTable(**_returned_user.model_dump())
|
||||
elif existing_user_email_row is not None and len(existing_user_email_row) > 1:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={
|
||||
"error": "Multiple users with this email found in db. Please use 'user_id' instead."
|
||||
},
|
||||
)
|
||||
elif existing_user_email_row is not None:
|
||||
user_object = LiteLLM_UserTable(**existing_user_email_row.model_dump())
|
||||
elif existing_user_id_row is not None:
|
||||
user_object = LiteLLM_UserTable(**existing_user_id_row.model_dump())
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail={
|
||||
"error": f"User not found for user_id={member.user_id} and user_email={member.user_email}"
|
||||
},
|
||||
)
|
||||
|
||||
if user_object is None:
|
||||
raise ValueError(
|
||||
f"User does not exist in LiteLLM_UserTable. user_id={member.user_id} and user_email={member.user_email}"
|
||||
)
|
||||
|
||||
# Add user to organization
|
||||
_organization_membership = (
|
||||
await prisma_client.db.litellm_organizationmembership.create(
|
||||
data={
|
||||
"organization_id": organization_id,
|
||||
"user_id": user_object.user_id,
|
||||
"user_role": member.role,
|
||||
}
|
||||
)
|
||||
)
|
||||
organization_membership = LiteLLM_OrganizationMembershipTable(
|
||||
**_organization_membership.model_dump()
|
||||
)
|
||||
return user_object, organization_membership
|
||||
|
||||
except Exception as e:
|
||||
raise ValueError(
|
||||
f"Error adding member={member} to organization={organization_id}: {e}"
|
||||
)
|
||||
@@ -0,0 +1,118 @@
|
||||
# SCIM v2 Integration for LiteLLM Proxy
|
||||
|
||||
This module provides SCIM v2 (System for Cross-domain Identity Management) endpoints for LiteLLM Proxy, allowing identity providers to manage users and teams (groups) within the LiteLLM ecosystem.
|
||||
|
||||
## Overview
|
||||
|
||||
SCIM is an open standard designed to simplify user management across different systems. This implementation allows compatible identity providers (like Okta, Azure AD, OneLogin, etc.) to automatically provision and deprovision users and groups in LiteLLM Proxy.
|
||||
|
||||
## Endpoints
|
||||
|
||||
The SCIM v2 API follows the standard specification with the following base URL:
|
||||
|
||||
```
|
||||
/scim/v2
|
||||
```
|
||||
|
||||
### User Management
|
||||
|
||||
| Endpoint | Method | Description |
|
||||
|----------|--------|-------------|
|
||||
| `/Users` | GET | List all users with pagination support |
|
||||
| `/Users/{user_id}` | GET | Get a specific user by ID |
|
||||
| `/Users` | POST | Create a new user |
|
||||
| `/Users/{user_id}` | PUT | Update an existing user |
|
||||
| `/Users/{user_id}` | DELETE | Delete a user |
|
||||
|
||||
### Group Management
|
||||
|
||||
| Endpoint | Method | Description |
|
||||
|----------|--------|-------------|
|
||||
| `/Groups` | GET | List all groups with pagination support |
|
||||
| `/Groups/{group_id}` | GET | Get a specific group by ID |
|
||||
| `/Groups` | POST | Create a new group |
|
||||
| `/Groups/{group_id}` | PUT | Update an existing group |
|
||||
| `/Groups/{group_id}` | DELETE | Delete a group |
|
||||
|
||||
## SCIM Schema
|
||||
|
||||
This implementation follows the standard SCIM v2 schema with the following mappings:
|
||||
|
||||
### Users
|
||||
|
||||
- SCIM User ID → LiteLLM `user_id`
|
||||
- SCIM User Email → LiteLLM `user_email`
|
||||
- SCIM User Group Memberships → LiteLLM User-Team relationships
|
||||
|
||||
### Groups
|
||||
|
||||
- SCIM Group ID → LiteLLM `team_id`
|
||||
- SCIM Group Display Name → LiteLLM `team_alias`
|
||||
- SCIM Group Members → LiteLLM Team members list
|
||||
|
||||
## Configuration
|
||||
|
||||
To enable SCIM in your identity provider, use the full URL to the SCIM endpoint:
|
||||
|
||||
```
|
||||
https://your-litellm-proxy-url/scim/v2
|
||||
```
|
||||
|
||||
Most identity providers will require authentication. You should use a valid LiteLLM API key with administrative privileges.
|
||||
|
||||
## Features
|
||||
|
||||
- Full CRUD operations for users and groups
|
||||
- Pagination support
|
||||
- Basic filtering support
|
||||
- Automatic synchronization of user-team relationships
|
||||
- Proper status codes and error handling per SCIM specification
|
||||
|
||||
|
||||
## Example Usage
|
||||
|
||||
### Listing Users
|
||||
|
||||
```
|
||||
GET /scim/v2/Users?startIndex=1&count=10
|
||||
```
|
||||
|
||||
### Creating a User
|
||||
|
||||
```json
|
||||
POST /scim/v2/Users
|
||||
{
|
||||
"schemas": ["urn:ietf:params:scim:schemas:core:2.0:User"],
|
||||
"userName": "john.doe@example.com",
|
||||
"active": true,
|
||||
"emails": [
|
||||
{
|
||||
"value": "john.doe@example.com",
|
||||
"primary": true
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
### Adding a User to Groups
|
||||
|
||||
```json
|
||||
PUT /scim/v2/Users/{user_id}
|
||||
{
|
||||
"schemas": ["urn:ietf:params:scim:schemas:core:2.0:User"],
|
||||
"userName": "john.doe@example.com",
|
||||
"active": true,
|
||||
"emails": [
|
||||
{
|
||||
"value": "john.doe@example.com",
|
||||
"primary": true
|
||||
}
|
||||
],
|
||||
"groups": [
|
||||
{
|
||||
"value": "team-123",
|
||||
"display": "Engineering Team"
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,154 @@
|
||||
from typing import List, Union
|
||||
|
||||
from litellm.proxy._types import (
|
||||
LiteLLM_TeamTable,
|
||||
LiteLLM_UserTable,
|
||||
Member,
|
||||
NewUserResponse,
|
||||
)
|
||||
from litellm.types.proxy.management_endpoints.scim_v2 import *
|
||||
|
||||
|
||||
class ScimTransformations:
|
||||
DEFAULT_SCIM_NAME = "Unknown User"
|
||||
DEFAULT_SCIM_FAMILY_NAME = "Unknown Family Name"
|
||||
DEFAULT_SCIM_DISPLAY_NAME = "Unknown Display Name"
|
||||
DEFAULT_SCIM_MEMBER_VALUE = "Unknown Member Value"
|
||||
|
||||
@staticmethod
|
||||
async def transform_litellm_user_to_scim_user(
|
||||
user: Union[LiteLLM_UserTable, NewUserResponse],
|
||||
) -> SCIMUser:
|
||||
from litellm.proxy.proxy_server import prisma_client
|
||||
|
||||
if prisma_client is None:
|
||||
raise HTTPException(
|
||||
status_code=500, detail={"error": "No database connected"}
|
||||
)
|
||||
|
||||
# Get user's teams/groups
|
||||
groups = []
|
||||
for team_id in user.teams or []:
|
||||
team = await prisma_client.db.litellm_teamtable.find_unique(
|
||||
where={"team_id": team_id}
|
||||
)
|
||||
if team:
|
||||
team_alias = getattr(team, "team_alias", team.team_id)
|
||||
groups.append(SCIMUserGroup(value=team.team_id, display=team_alias))
|
||||
|
||||
user_created_at = user.created_at.isoformat() if user.created_at else None
|
||||
user_updated_at = user.updated_at.isoformat() if user.updated_at else None
|
||||
|
||||
emails = []
|
||||
if user.user_email:
|
||||
emails.append(SCIMUserEmail(value=user.user_email, primary=True))
|
||||
|
||||
return SCIMUser(
|
||||
schemas=["urn:ietf:params:scim:schemas:core:2.0:User"],
|
||||
id=user.user_id,
|
||||
userName=ScimTransformations._get_scim_user_name(user),
|
||||
displayName=ScimTransformations._get_scim_user_name(user),
|
||||
name=SCIMUserName(
|
||||
familyName=ScimTransformations._get_scim_family_name(user),
|
||||
givenName=ScimTransformations._get_scim_given_name(user),
|
||||
),
|
||||
emails=emails,
|
||||
groups=groups,
|
||||
active=True,
|
||||
meta={
|
||||
"resourceType": "User",
|
||||
"created": user_created_at,
|
||||
"lastModified": user_updated_at,
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _get_scim_user_name(user: Union[LiteLLM_UserTable, NewUserResponse]) -> str:
|
||||
"""
|
||||
SCIM requires a display name with length > 0
|
||||
|
||||
We use the same userName and displayName for SCIM users
|
||||
"""
|
||||
if user.user_email and len(user.user_email) > 0:
|
||||
return user.user_email
|
||||
return ScimTransformations.DEFAULT_SCIM_DISPLAY_NAME
|
||||
|
||||
@staticmethod
|
||||
def _get_scim_family_name(user: Union[LiteLLM_UserTable, NewUserResponse]) -> str:
|
||||
"""
|
||||
SCIM requires a family name with length > 0
|
||||
"""
|
||||
metadata = user.metadata or {}
|
||||
if "scim_metadata" in metadata:
|
||||
scim_metadata: LiteLLM_UserScimMetadata = LiteLLM_UserScimMetadata(
|
||||
**metadata["scim_metadata"]
|
||||
)
|
||||
if scim_metadata.familyName and len(scim_metadata.familyName) > 0:
|
||||
return scim_metadata.familyName
|
||||
|
||||
if user.user_alias and len(user.user_alias) > 0:
|
||||
return user.user_alias
|
||||
return ScimTransformations.DEFAULT_SCIM_FAMILY_NAME
|
||||
|
||||
@staticmethod
|
||||
def _get_scim_given_name(user: Union[LiteLLM_UserTable, NewUserResponse]) -> str:
|
||||
"""
|
||||
SCIM requires a given name with length > 0
|
||||
"""
|
||||
metadata = user.metadata or {}
|
||||
if "scim_metadata" in metadata:
|
||||
scim_metadata: LiteLLM_UserScimMetadata = LiteLLM_UserScimMetadata(
|
||||
**metadata["scim_metadata"]
|
||||
)
|
||||
if scim_metadata.givenName and len(scim_metadata.givenName) > 0:
|
||||
return scim_metadata.givenName
|
||||
|
||||
if user.user_alias and len(user.user_alias) > 0:
|
||||
return user.user_alias or ScimTransformations.DEFAULT_SCIM_NAME
|
||||
return ScimTransformations.DEFAULT_SCIM_NAME
|
||||
|
||||
@staticmethod
|
||||
async def transform_litellm_team_to_scim_group(
|
||||
team: Union[LiteLLM_TeamTable, dict],
|
||||
) -> SCIMGroup:
|
||||
from litellm.proxy.proxy_server import prisma_client
|
||||
|
||||
if prisma_client is None:
|
||||
raise HTTPException(
|
||||
status_code=500, detail={"error": "No database connected"}
|
||||
)
|
||||
|
||||
if isinstance(team, dict):
|
||||
team = LiteLLM_TeamTable(**team)
|
||||
|
||||
# Get team members
|
||||
scim_members: List[SCIMMember] = []
|
||||
for member in team.members_with_roles or []:
|
||||
scim_members.append(
|
||||
SCIMMember(
|
||||
value=ScimTransformations._get_scim_member_value(member),
|
||||
display=member.user_email,
|
||||
)
|
||||
)
|
||||
|
||||
team_alias = getattr(team, "team_alias", team.team_id)
|
||||
team_created_at = team.created_at.isoformat() if team.created_at else None
|
||||
team_updated_at = team.updated_at.isoformat() if team.updated_at else None
|
||||
|
||||
return SCIMGroup(
|
||||
schemas=["urn:ietf:params:scim:schemas:core:2.0:Group"],
|
||||
id=team.team_id,
|
||||
displayName=team_alias,
|
||||
members=scim_members,
|
||||
meta={
|
||||
"resourceType": "Group",
|
||||
"created": team_created_at,
|
||||
"lastModified": team_updated_at,
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _get_scim_member_value(member: Member) -> str:
|
||||
if member.user_email:
|
||||
return member.user_email
|
||||
return ScimTransformations.DEFAULT_SCIM_MEMBER_VALUE
|
||||
@@ -0,0 +1,757 @@
|
||||
"""
|
||||
SCIM v2 Endpoints for LiteLLM Proxy using Internal User/Team Management
|
||||
|
||||
"""
|
||||
|
||||
import uuid
|
||||
from typing import List, Optional
|
||||
|
||||
from fastapi import (
|
||||
APIRouter,
|
||||
Body,
|
||||
Depends,
|
||||
HTTPException,
|
||||
Path,
|
||||
Query,
|
||||
Request,
|
||||
Response,
|
||||
)
|
||||
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.proxy._types import (
|
||||
LiteLLM_UserTable,
|
||||
LitellmUserRoles,
|
||||
Member,
|
||||
NewTeamRequest,
|
||||
NewUserRequest,
|
||||
UserAPIKeyAuth,
|
||||
)
|
||||
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
|
||||
from litellm.proxy.management_endpoints.internal_user_endpoints import new_user
|
||||
from litellm.proxy.management_endpoints.scim.scim_transformations import (
|
||||
ScimTransformations,
|
||||
)
|
||||
from litellm.proxy.management_endpoints.team_endpoints import new_team
|
||||
from litellm.types.proxy.management_endpoints.scim_v2 import *
|
||||
|
||||
scim_router = APIRouter(
|
||||
prefix="/scim/v2",
|
||||
tags=["SCIM v2"],
|
||||
)
|
||||
|
||||
|
||||
# Dependency to set the correct SCIM Content-Type
|
||||
async def set_scim_content_type(response: Response):
|
||||
"""Sets the Content-Type header to application/scim+json"""
|
||||
# Check if content type is already application/json, only override in that case
|
||||
# Avoids overriding for non-JSON responses or already correct types if they were set manually
|
||||
response.headers["Content-Type"] = "application/scim+json"
|
||||
|
||||
|
||||
# User Endpoints
|
||||
@scim_router.get(
|
||||
"/Users",
|
||||
response_model=SCIMListResponse,
|
||||
status_code=200,
|
||||
dependencies=[Depends(user_api_key_auth), Depends(set_scim_content_type)],
|
||||
)
|
||||
async def get_users(
|
||||
startIndex: int = Query(1, ge=1),
|
||||
count: int = Query(10, ge=1, le=100),
|
||||
filter: Optional[str] = Query(None),
|
||||
):
|
||||
"""
|
||||
Get a list of users according to SCIM v2 protocol
|
||||
"""
|
||||
from litellm.proxy.proxy_server import prisma_client
|
||||
|
||||
if prisma_client is None:
|
||||
raise HTTPException(status_code=500, detail={"error": "No database connected"})
|
||||
|
||||
try:
|
||||
# Parse filter if provided (basic support)
|
||||
where_conditions = {}
|
||||
if filter:
|
||||
# Very basic filter support - only handling userName eq and emails.value eq
|
||||
if "userName eq" in filter:
|
||||
user_id = filter.split("userName eq ")[1].strip("\"'")
|
||||
where_conditions["user_id"] = user_id
|
||||
elif "emails.value eq" in filter:
|
||||
email = filter.split("emails.value eq ")[1].strip("\"'")
|
||||
where_conditions["user_email"] = email
|
||||
|
||||
# Get users from database
|
||||
users: List[LiteLLM_UserTable] = (
|
||||
await prisma_client.db.litellm_usertable.find_many(
|
||||
where=where_conditions,
|
||||
skip=(startIndex - 1),
|
||||
take=count,
|
||||
order={"created_at": "desc"},
|
||||
)
|
||||
)
|
||||
|
||||
# Get total count for pagination
|
||||
total_count = await prisma_client.db.litellm_usertable.count(
|
||||
where=where_conditions
|
||||
)
|
||||
|
||||
# Convert to SCIM format
|
||||
scim_users: List[SCIMUser] = []
|
||||
for user in users:
|
||||
scim_user = await ScimTransformations.transform_litellm_user_to_scim_user(
|
||||
user=user
|
||||
)
|
||||
scim_users.append(scim_user)
|
||||
|
||||
return SCIMListResponse(
|
||||
totalResults=total_count,
|
||||
startIndex=startIndex,
|
||||
itemsPerPage=min(count, len(scim_users)),
|
||||
Resources=scim_users,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=500, detail={"error": f"Error retrieving users: {str(e)}"}
|
||||
)
|
||||
|
||||
|
||||
@scim_router.get(
|
||||
"/Users/{user_id}",
|
||||
response_model=SCIMUser,
|
||||
status_code=200,
|
||||
dependencies=[Depends(user_api_key_auth), Depends(set_scim_content_type)],
|
||||
)
|
||||
async def get_user(
|
||||
user_id: str = Path(..., title="User ID"),
|
||||
):
|
||||
"""
|
||||
Get a single user by ID according to SCIM v2 protocol
|
||||
"""
|
||||
from litellm.proxy.proxy_server import prisma_client
|
||||
|
||||
if prisma_client is None:
|
||||
raise HTTPException(status_code=500, detail={"error": "No database connected"})
|
||||
|
||||
try:
|
||||
user = await prisma_client.db.litellm_usertable.find_unique(
|
||||
where={"user_id": user_id}
|
||||
)
|
||||
|
||||
if not user:
|
||||
raise HTTPException(
|
||||
status_code=404, detail={"error": f"User not found with ID: {user_id}"}
|
||||
)
|
||||
|
||||
# Convert to SCIM format
|
||||
scim_user = await ScimTransformations.transform_litellm_user_to_scim_user(user)
|
||||
return scim_user
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=500, detail={"error": f"Error retrieving user: {str(e)}"}
|
||||
)
|
||||
|
||||
|
||||
@scim_router.post(
|
||||
"/Users",
|
||||
response_model=SCIMUser,
|
||||
status_code=201,
|
||||
dependencies=[Depends(user_api_key_auth), Depends(set_scim_content_type)],
|
||||
)
|
||||
async def create_user(
|
||||
user: SCIMUser = Body(...),
|
||||
):
|
||||
"""
|
||||
Create a user according to SCIM v2 protocol
|
||||
"""
|
||||
from litellm.proxy.proxy_server import prisma_client
|
||||
|
||||
if prisma_client is None:
|
||||
raise HTTPException(status_code=500, detail={"error": "No database connected"})
|
||||
|
||||
try:
|
||||
verbose_proxy_logger.debug("SCIM CREATE USER request: %s", user)
|
||||
# Extract email from SCIM user
|
||||
user_email = None
|
||||
if user.emails and len(user.emails) > 0:
|
||||
user_email = user.emails[0].value
|
||||
|
||||
# Check if user already exists
|
||||
existing_user = None
|
||||
if user.userName:
|
||||
existing_user = await prisma_client.db.litellm_usertable.find_unique(
|
||||
where={"user_id": user.userName}
|
||||
)
|
||||
|
||||
if existing_user:
|
||||
raise HTTPException(
|
||||
status_code=409,
|
||||
detail={"error": f"User already exists with username: {user.userName}"},
|
||||
)
|
||||
|
||||
# Create user in database
|
||||
user_id = user.userName or str(uuid.uuid4())
|
||||
created_user = await new_user(
|
||||
data=NewUserRequest(
|
||||
user_id=user_id,
|
||||
user_email=user_email,
|
||||
user_alias=user.name.givenName,
|
||||
teams=[group.value for group in user.groups] if user.groups else None,
|
||||
metadata={
|
||||
"scim_metadata": LiteLLM_UserScimMetadata(
|
||||
givenName=user.name.givenName,
|
||||
familyName=user.name.familyName,
|
||||
).model_dump()
|
||||
},
|
||||
auto_create_key=False,
|
||||
),
|
||||
)
|
||||
scim_user = await ScimTransformations.transform_litellm_user_to_scim_user(
|
||||
user=created_user
|
||||
)
|
||||
return scim_user
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=500, detail={"error": f"Error creating user: {str(e)}"}
|
||||
)
|
||||
|
||||
|
||||
@scim_router.put(
|
||||
"/Users/{user_id}",
|
||||
response_model=SCIMUser,
|
||||
status_code=200,
|
||||
dependencies=[Depends(user_api_key_auth), Depends(set_scim_content_type)],
|
||||
)
|
||||
async def update_user(
|
||||
user_id: str = Path(..., title="User ID"),
|
||||
user: SCIMUser = Body(...),
|
||||
):
|
||||
"""
|
||||
Update a user according to SCIM v2 protocol
|
||||
"""
|
||||
from litellm.proxy.proxy_server import prisma_client
|
||||
|
||||
if prisma_client is None:
|
||||
raise HTTPException(status_code=500, detail={"error": "No database connected"})
|
||||
try:
|
||||
return None
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=500, detail={"error": f"Error updating user: {str(e)}"}
|
||||
)
|
||||
|
||||
|
||||
@scim_router.delete(
|
||||
"/Users/{user_id}",
|
||||
status_code=204,
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
)
|
||||
async def delete_user(
|
||||
user_id: str = Path(..., title="User ID"),
|
||||
):
|
||||
"""
|
||||
Delete a user according to SCIM v2 protocol
|
||||
"""
|
||||
from litellm.proxy.proxy_server import prisma_client
|
||||
|
||||
if prisma_client is None:
|
||||
raise HTTPException(status_code=500, detail={"error": "No database connected"})
|
||||
|
||||
try:
|
||||
# Check if user exists
|
||||
existing_user = await prisma_client.db.litellm_usertable.find_unique(
|
||||
where={"user_id": user_id}
|
||||
)
|
||||
|
||||
if not existing_user:
|
||||
raise HTTPException(
|
||||
status_code=404, detail={"error": f"User not found with ID: {user_id}"}
|
||||
)
|
||||
|
||||
# Get teams user belongs to
|
||||
teams = []
|
||||
if existing_user.teams:
|
||||
for team_id in existing_user.teams:
|
||||
team = await prisma_client.db.litellm_teamtable.find_unique(
|
||||
where={"team_id": team_id}
|
||||
)
|
||||
if team:
|
||||
teams.append(team)
|
||||
|
||||
# Remove user from all teams
|
||||
for team in teams:
|
||||
current_members = team.members or []
|
||||
if user_id in current_members:
|
||||
new_members = [m for m in current_members if m != user_id]
|
||||
await prisma_client.db.litellm_teamtable.update(
|
||||
where={"team_id": team.team_id}, data={"members": new_members}
|
||||
)
|
||||
|
||||
# Delete user
|
||||
await prisma_client.db.litellm_usertable.delete(where={"user_id": user_id})
|
||||
|
||||
return Response(status_code=204)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=500, detail={"error": f"Error deleting user: {str(e)}"}
|
||||
)
|
||||
|
||||
|
||||
@scim_router.patch(
|
||||
"/Users/{user_id}",
|
||||
response_model=SCIMUser,
|
||||
status_code=200,
|
||||
dependencies=[Depends(user_api_key_auth), Depends(set_scim_content_type)],
|
||||
)
|
||||
async def patch_user(
|
||||
user_id: str = Path(..., title="User ID"),
|
||||
patch_ops: SCIMPatchOp = Body(...),
|
||||
):
|
||||
"""
|
||||
Patch a user according to SCIM v2 protocol
|
||||
"""
|
||||
from litellm.proxy.proxy_server import prisma_client
|
||||
|
||||
if prisma_client is None:
|
||||
raise HTTPException(status_code=500, detail={"error": "No database connected"})
|
||||
|
||||
verbose_proxy_logger.debug("SCIM PATCH USER request: %s", patch_ops)
|
||||
|
||||
try:
|
||||
# Check if user exists
|
||||
existing_user = await prisma_client.db.litellm_usertable.find_unique(
|
||||
where={"user_id": user_id}
|
||||
)
|
||||
|
||||
if not existing_user:
|
||||
raise HTTPException(
|
||||
status_code=404, detail={"error": f"User not found with ID: {user_id}"}
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=500, detail={"error": f"Error patching user: {str(e)}"}
|
||||
)
|
||||
|
||||
|
||||
# Group Endpoints
|
||||
@scim_router.get(
|
||||
"/Groups",
|
||||
response_model=SCIMListResponse,
|
||||
status_code=200,
|
||||
dependencies=[Depends(user_api_key_auth), Depends(set_scim_content_type)],
|
||||
)
|
||||
async def get_groups(
|
||||
startIndex: int = Query(1, ge=1),
|
||||
count: int = Query(10, ge=1, le=100),
|
||||
filter: Optional[str] = Query(None),
|
||||
):
|
||||
"""
|
||||
Get a list of groups according to SCIM v2 protocol
|
||||
"""
|
||||
from litellm.proxy.proxy_server import prisma_client
|
||||
|
||||
if prisma_client is None:
|
||||
raise HTTPException(status_code=500, detail={"error": "No database connected"})
|
||||
|
||||
try:
|
||||
# Parse filter if provided (basic support)
|
||||
where_conditions = {}
|
||||
if filter:
|
||||
# Very basic filter support - only handling displayName eq
|
||||
if "displayName eq" in filter:
|
||||
team_alias = filter.split("displayName eq ")[1].strip("\"'")
|
||||
where_conditions["team_alias"] = team_alias
|
||||
|
||||
# Get teams from database
|
||||
teams = await prisma_client.db.litellm_teamtable.find_many(
|
||||
where=where_conditions,
|
||||
skip=(startIndex - 1),
|
||||
take=count,
|
||||
order={"created_at": "desc"},
|
||||
)
|
||||
|
||||
# Get total count for pagination
|
||||
total_count = await prisma_client.db.litellm_teamtable.count(
|
||||
where=where_conditions
|
||||
)
|
||||
|
||||
# Convert to SCIM format
|
||||
scim_groups = []
|
||||
for team in teams:
|
||||
# Get team members
|
||||
members = []
|
||||
for member_id in team.members or []:
|
||||
member = await prisma_client.db.litellm_usertable.find_unique(
|
||||
where={"user_id": member_id}
|
||||
)
|
||||
if member:
|
||||
display_name = member.user_email or member.user_id
|
||||
members.append(
|
||||
SCIMMember(value=member.user_id, display=display_name)
|
||||
)
|
||||
|
||||
team_alias = getattr(team, "team_alias", team.team_id)
|
||||
team_created_at = team.created_at.isoformat() if team.created_at else None
|
||||
team_updated_at = team.updated_at.isoformat() if team.updated_at else None
|
||||
|
||||
scim_group = SCIMGroup(
|
||||
schemas=["urn:ietf:params:scim:schemas:core:2.0:Group"],
|
||||
id=team.team_id,
|
||||
displayName=team_alias,
|
||||
members=members,
|
||||
meta={
|
||||
"resourceType": "Group",
|
||||
"created": team_created_at,
|
||||
"lastModified": team_updated_at,
|
||||
},
|
||||
)
|
||||
scim_groups.append(scim_group)
|
||||
|
||||
return SCIMListResponse(
|
||||
totalResults=total_count,
|
||||
startIndex=startIndex,
|
||||
itemsPerPage=min(count, len(scim_groups)),
|
||||
Resources=scim_groups,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=500, detail={"error": f"Error retrieving groups: {str(e)}"}
|
||||
)
|
||||
|
||||
|
||||
@scim_router.get(
|
||||
"/Groups/{group_id}",
|
||||
response_model=SCIMGroup,
|
||||
status_code=200,
|
||||
dependencies=[Depends(user_api_key_auth), Depends(set_scim_content_type)],
|
||||
)
|
||||
async def get_group(
|
||||
group_id: str = Path(..., title="Group ID"),
|
||||
):
|
||||
"""
|
||||
Get a single group by ID according to SCIM v2 protocol
|
||||
"""
|
||||
from litellm.proxy.proxy_server import prisma_client
|
||||
|
||||
if prisma_client is None:
|
||||
raise HTTPException(status_code=500, detail={"error": "No database connected"})
|
||||
|
||||
try:
|
||||
team = await prisma_client.db.litellm_teamtable.find_unique(
|
||||
where={"team_id": group_id}
|
||||
)
|
||||
|
||||
if not team:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail={"error": f"Group not found with ID: {group_id}"},
|
||||
)
|
||||
|
||||
scim_group = await ScimTransformations.transform_litellm_team_to_scim_group(
|
||||
team
|
||||
)
|
||||
return scim_group
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=500, detail={"error": f"Error retrieving group: {str(e)}"}
|
||||
)
|
||||
|
||||
|
||||
@scim_router.post(
|
||||
"/Groups",
|
||||
response_model=SCIMGroup,
|
||||
status_code=201,
|
||||
dependencies=[Depends(user_api_key_auth), Depends(set_scim_content_type)],
|
||||
)
|
||||
async def create_group(
|
||||
group: SCIMGroup = Body(...),
|
||||
):
|
||||
"""
|
||||
Create a group according to SCIM v2 protocol
|
||||
"""
|
||||
from litellm.proxy.proxy_server import prisma_client
|
||||
|
||||
if prisma_client is None:
|
||||
raise HTTPException(status_code=500, detail={"error": "No database connected"})
|
||||
|
||||
try:
|
||||
# Generate ID if not provided
|
||||
team_id = group.id or str(uuid.uuid4())
|
||||
|
||||
# Check if team already exists
|
||||
existing_team = await prisma_client.db.litellm_teamtable.find_unique(
|
||||
where={"team_id": team_id}
|
||||
)
|
||||
|
||||
if existing_team:
|
||||
raise HTTPException(
|
||||
status_code=409,
|
||||
detail={"error": f"Group already exists with ID: {team_id}"},
|
||||
)
|
||||
|
||||
# Extract members
|
||||
members_with_roles: List[Member] = []
|
||||
if group.members:
|
||||
for member in group.members:
|
||||
# Check if user exists
|
||||
user = await prisma_client.db.litellm_usertable.find_unique(
|
||||
where={"user_id": member.value}
|
||||
)
|
||||
if user:
|
||||
members_with_roles.append(Member(user_id=member.value, role="user"))
|
||||
|
||||
# Create team in database
|
||||
created_team = await new_team(
|
||||
data=NewTeamRequest(
|
||||
team_id=team_id,
|
||||
team_alias=group.displayName,
|
||||
members_with_roles=members_with_roles,
|
||||
),
|
||||
http_request=Request(scope={"type": "http", "path": "/scim/v2/Groups"}),
|
||||
user_api_key_dict=UserAPIKeyAuth(user_role=LitellmUserRoles.PROXY_ADMIN),
|
||||
)
|
||||
|
||||
scim_group = await ScimTransformations.transform_litellm_team_to_scim_group(
|
||||
created_team
|
||||
)
|
||||
return scim_group
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=500, detail={"error": f"Error creating group: {str(e)}"}
|
||||
)
|
||||
|
||||
|
||||
@scim_router.put(
|
||||
"/Groups/{group_id}",
|
||||
response_model=SCIMGroup,
|
||||
status_code=200,
|
||||
dependencies=[Depends(user_api_key_auth), Depends(set_scim_content_type)],
|
||||
)
|
||||
async def update_group(
|
||||
group_id: str = Path(..., title="Group ID"),
|
||||
group: SCIMGroup = Body(...),
|
||||
):
|
||||
"""
|
||||
Update a group according to SCIM v2 protocol
|
||||
"""
|
||||
from litellm.proxy.proxy_server import prisma_client
|
||||
|
||||
if prisma_client is None:
|
||||
raise HTTPException(status_code=500, detail={"error": "No database connected"})
|
||||
|
||||
try:
|
||||
# Check if team exists
|
||||
existing_team = await prisma_client.db.litellm_teamtable.find_unique(
|
||||
where={"team_id": group_id}
|
||||
)
|
||||
|
||||
if not existing_team:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail={"error": f"Group not found with ID: {group_id}"},
|
||||
)
|
||||
|
||||
# Extract members
|
||||
member_ids = []
|
||||
if group.members:
|
||||
for member in group.members:
|
||||
# Check if user exists
|
||||
user = await prisma_client.db.litellm_usertable.find_unique(
|
||||
where={"user_id": member.value}
|
||||
)
|
||||
if user:
|
||||
member_ids.append(member.value)
|
||||
|
||||
# Update team in database
|
||||
existing_metadata = existing_team.metadata if existing_team.metadata else {}
|
||||
updated_team = await prisma_client.db.litellm_teamtable.update(
|
||||
where={"team_id": group_id},
|
||||
data={
|
||||
"team_alias": group.displayName,
|
||||
"members": member_ids,
|
||||
"metadata": {**existing_metadata, "scim_data": group.model_dump()},
|
||||
},
|
||||
)
|
||||
|
||||
# Handle user-team relationships
|
||||
current_members = existing_team.members or []
|
||||
|
||||
# Add new members to team
|
||||
for member_id in member_ids:
|
||||
if member_id not in current_members:
|
||||
user = await prisma_client.db.litellm_usertable.find_unique(
|
||||
where={"user_id": member_id}
|
||||
)
|
||||
if user:
|
||||
current_user_teams = user.teams or []
|
||||
if group_id not in current_user_teams:
|
||||
await prisma_client.db.litellm_usertable.update(
|
||||
where={"user_id": member_id},
|
||||
data={"teams": {"push": group_id}},
|
||||
)
|
||||
|
||||
# Remove former members from team
|
||||
for member_id in current_members:
|
||||
if member_id not in member_ids:
|
||||
user = await prisma_client.db.litellm_usertable.find_unique(
|
||||
where={"user_id": member_id}
|
||||
)
|
||||
if user:
|
||||
current_user_teams = user.teams or []
|
||||
if group_id in current_user_teams:
|
||||
new_teams = [t for t in current_user_teams if t != group_id]
|
||||
await prisma_client.db.litellm_usertable.update(
|
||||
where={"user_id": member_id}, data={"teams": new_teams}
|
||||
)
|
||||
|
||||
# Get updated members for response
|
||||
members = []
|
||||
for member_id in member_ids:
|
||||
user = await prisma_client.db.litellm_usertable.find_unique(
|
||||
where={"user_id": member_id}
|
||||
)
|
||||
if user:
|
||||
display_name = user.user_email or user.user_id
|
||||
members.append(SCIMMember(value=user.user_id, display=display_name))
|
||||
|
||||
team_created_at = (
|
||||
updated_team.created_at.isoformat() if updated_team.created_at else None
|
||||
)
|
||||
team_updated_at = (
|
||||
updated_team.updated_at.isoformat() if updated_team.updated_at else None
|
||||
)
|
||||
|
||||
return SCIMGroup(
|
||||
schemas=["urn:ietf:params:scim:schemas:core:2.0:Group"],
|
||||
id=group_id,
|
||||
displayName=updated_team.team_alias or group_id,
|
||||
members=members,
|
||||
meta={
|
||||
"resourceType": "Group",
|
||||
"created": team_created_at,
|
||||
"lastModified": team_updated_at,
|
||||
},
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=500, detail={"error": f"Error updating group: {str(e)}"}
|
||||
)
|
||||
|
||||
|
||||
@scim_router.delete(
|
||||
"/Groups/{group_id}",
|
||||
status_code=204,
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
)
|
||||
async def delete_group(
|
||||
group_id: str = Path(..., title="Group ID"),
|
||||
):
|
||||
"""
|
||||
Delete a group according to SCIM v2 protocol
|
||||
"""
|
||||
from litellm.proxy.proxy_server import prisma_client
|
||||
|
||||
if prisma_client is None:
|
||||
raise HTTPException(status_code=500, detail={"error": "No database connected"})
|
||||
|
||||
try:
|
||||
# Check if team exists
|
||||
existing_team = await prisma_client.db.litellm_teamtable.find_unique(
|
||||
where={"team_id": group_id}
|
||||
)
|
||||
|
||||
if not existing_team:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail={"error": f"Group not found with ID: {group_id}"},
|
||||
)
|
||||
|
||||
# For each member, remove this team from their teams list
|
||||
for member_id in existing_team.members or []:
|
||||
user = await prisma_client.db.litellm_usertable.find_unique(
|
||||
where={"user_id": member_id}
|
||||
)
|
||||
if user:
|
||||
current_teams = user.teams or []
|
||||
if group_id in current_teams:
|
||||
new_teams = [t for t in current_teams if t != group_id]
|
||||
await prisma_client.db.litellm_usertable.update(
|
||||
where={"user_id": member_id}, data={"teams": new_teams}
|
||||
)
|
||||
|
||||
# Delete team
|
||||
await prisma_client.db.litellm_teamtable.delete(where={"team_id": group_id})
|
||||
|
||||
return Response(status_code=204)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=500, detail={"error": f"Error deleting group: {str(e)}"}
|
||||
)
|
||||
|
||||
|
||||
@scim_router.patch(
|
||||
"/Groups/{group_id}",
|
||||
response_model=SCIMGroup,
|
||||
status_code=200,
|
||||
dependencies=[Depends(user_api_key_auth), Depends(set_scim_content_type)],
|
||||
)
|
||||
async def patch_group(
|
||||
group_id: str = Path(..., title="Group ID"),
|
||||
patch_ops: SCIMPatchOp = Body(...),
|
||||
):
|
||||
"""
|
||||
Patch a group according to SCIM v2 protocol
|
||||
"""
|
||||
from litellm.proxy.proxy_server import prisma_client
|
||||
|
||||
if prisma_client is None:
|
||||
raise HTTPException(status_code=500, detail={"error": "No database connected"})
|
||||
|
||||
verbose_proxy_logger.debug("SCIM PATCH GROUP request: %s", patch_ops)
|
||||
|
||||
try:
|
||||
# Check if group exists
|
||||
existing_team = await prisma_client.db.litellm_teamtable.find_unique(
|
||||
where={"team_id": group_id}
|
||||
)
|
||||
|
||||
if not existing_team:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail={"error": f"Group not found with ID: {group_id}"},
|
||||
)
|
||||
return None
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=500, detail={"error": f"Error patching group: {str(e)}"}
|
||||
)
|
||||
@@ -0,0 +1,22 @@
|
||||
from litellm.proxy._types import LitellmUserRoles
|
||||
|
||||
|
||||
def check_is_admin_only_access(ui_access_mode: str) -> bool:
|
||||
"""Checks ui access mode is admin_only"""
|
||||
return ui_access_mode == "admin_only"
|
||||
|
||||
|
||||
def has_admin_ui_access(user_role: str) -> bool:
|
||||
"""
|
||||
Check if the user has admin access to the UI.
|
||||
|
||||
Returns:
|
||||
bool: True if user is 'proxy_admin' or 'proxy_admin_view_only', False otherwise.
|
||||
"""
|
||||
|
||||
if (
|
||||
user_role != LitellmUserRoles.PROXY_ADMIN.value
|
||||
and user_role != LitellmUserRoles.PROXY_ADMIN_VIEW_ONLY.value
|
||||
):
|
||||
return False
|
||||
return True
|
||||
@@ -0,0 +1,436 @@
|
||||
"""
|
||||
TAG MANAGEMENT
|
||||
|
||||
All /tag management endpoints
|
||||
|
||||
/tag/new
|
||||
/tag/info
|
||||
/tag/update
|
||||
/tag/delete
|
||||
/tag/list
|
||||
"""
|
||||
|
||||
import datetime
|
||||
import json
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.litellm_core_utils.safe_json_dumps import safe_dumps
|
||||
from litellm.proxy._types import UserAPIKeyAuth
|
||||
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
|
||||
from litellm.proxy.management_endpoints.common_daily_activity import (
|
||||
SpendAnalyticsPaginatedResponse,
|
||||
get_daily_activity,
|
||||
)
|
||||
from litellm.types.tag_management import (
|
||||
LiteLLM_DailyTagSpendTable,
|
||||
TagConfig,
|
||||
TagDeleteRequest,
|
||||
TagInfoRequest,
|
||||
TagNewRequest,
|
||||
TagUpdateRequest,
|
||||
)
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
async def _get_model_names(prisma_client, model_ids: list) -> Dict[str, str]:
|
||||
"""Helper function to get model names from model IDs"""
|
||||
try:
|
||||
models = await prisma_client.db.litellm_proxymodeltable.find_many(
|
||||
where={"model_id": {"in": model_ids}}
|
||||
)
|
||||
return {model.model_id: model.model_name for model in models}
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.error(f"Error getting model names: {str(e)}")
|
||||
return {}
|
||||
|
||||
|
||||
async def _get_tags_config(prisma_client) -> Dict[str, TagConfig]:
|
||||
"""Helper function to get tags config from db"""
|
||||
try:
|
||||
tags_config = await prisma_client.db.litellm_config.find_unique(
|
||||
where={"param_name": "tags_config"}
|
||||
)
|
||||
if tags_config is None:
|
||||
return {}
|
||||
# Convert from JSON if needed
|
||||
if isinstance(tags_config.param_value, str):
|
||||
config_dict = json.loads(tags_config.param_value)
|
||||
else:
|
||||
config_dict = tags_config.param_value or {}
|
||||
|
||||
# For each tag, get the model names
|
||||
for tag_name, tag_config in config_dict.items():
|
||||
if isinstance(tag_config, dict) and tag_config.get("models"):
|
||||
model_info = await _get_model_names(prisma_client, tag_config["models"])
|
||||
tag_config["model_info"] = model_info
|
||||
|
||||
return config_dict
|
||||
except Exception:
|
||||
return {}
|
||||
|
||||
|
||||
async def _save_tags_config(prisma_client, tags_config: Dict[str, TagConfig]):
|
||||
"""Helper function to save tags config to db"""
|
||||
try:
|
||||
verbose_proxy_logger.debug(f"Saving tags config: {tags_config}")
|
||||
# Convert TagConfig objects to dictionaries
|
||||
tags_config_dict = {}
|
||||
for name, tag in tags_config.items():
|
||||
if isinstance(tag, TagConfig):
|
||||
tag_dict = tag.model_dump()
|
||||
# Remove model_info before saving as it will be dynamically generated
|
||||
if "model_info" in tag_dict:
|
||||
del tag_dict["model_info"]
|
||||
tags_config_dict[name] = tag_dict
|
||||
else:
|
||||
# If it's already a dict, remove model_info
|
||||
tag_copy = tag.copy()
|
||||
if "model_info" in tag_copy:
|
||||
del tag_copy["model_info"]
|
||||
tags_config_dict[name] = tag_copy
|
||||
|
||||
json_tags_config = json.dumps(tags_config_dict, default=str)
|
||||
verbose_proxy_logger.debug(f"JSON tags config: {json_tags_config}")
|
||||
await prisma_client.db.litellm_config.upsert(
|
||||
where={"param_name": "tags_config"},
|
||||
data={
|
||||
"create": {
|
||||
"param_name": "tags_config",
|
||||
"param_value": json_tags_config,
|
||||
},
|
||||
"update": {"param_value": json_tags_config},
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Error saving tags config: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/tag/new",
|
||||
tags=["tag management"],
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
)
|
||||
async def new_tag(
|
||||
tag: TagNewRequest,
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
):
|
||||
"""
|
||||
Create a new tag.
|
||||
|
||||
Parameters:
|
||||
- name: str - The name of the tag
|
||||
- description: Optional[str] - Description of what this tag represents
|
||||
- models: List[str] - List of LLM models allowed for this tag
|
||||
"""
|
||||
from litellm.proxy.proxy_server import prisma_client
|
||||
|
||||
if prisma_client is None:
|
||||
raise HTTPException(status_code=500, detail="Database not connected")
|
||||
try:
|
||||
# Get existing tags config
|
||||
tags_config = await _get_tags_config(prisma_client)
|
||||
|
||||
# Check if tag already exists
|
||||
if tag.name in tags_config:
|
||||
raise HTTPException(
|
||||
status_code=400, detail=f"Tag {tag.name} already exists"
|
||||
)
|
||||
|
||||
# Add new tag
|
||||
tags_config[tag.name] = TagConfig(
|
||||
name=tag.name,
|
||||
description=tag.description,
|
||||
models=tag.models,
|
||||
created_at=str(datetime.datetime.now()),
|
||||
updated_at=str(datetime.datetime.now()),
|
||||
created_by=user_api_key_dict.user_id,
|
||||
)
|
||||
|
||||
# Save updated config
|
||||
await _save_tags_config(
|
||||
prisma_client=prisma_client,
|
||||
tags_config=tags_config,
|
||||
)
|
||||
|
||||
# Update models with new tag
|
||||
if tag.models:
|
||||
for model_id in tag.models:
|
||||
await _add_tag_to_deployment(
|
||||
model_id=model_id,
|
||||
tag=tag.name,
|
||||
)
|
||||
|
||||
# Get model names for response
|
||||
model_info = await _get_model_names(prisma_client, tag.models or [])
|
||||
tags_config[tag.name].model_info = model_info
|
||||
|
||||
return {
|
||||
"message": f"Tag {tag.name} created successfully",
|
||||
"tag": tags_config[tag.name],
|
||||
}
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.exception(f"Error creating tag: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
async def _add_tag_to_deployment(model_id: str, tag: str):
|
||||
"""Helper function to add tag to deployment"""
|
||||
from litellm.proxy.proxy_server import prisma_client
|
||||
|
||||
if prisma_client is None:
|
||||
raise HTTPException(status_code=500, detail="Database not connected")
|
||||
|
||||
deployment = await prisma_client.db.litellm_proxymodeltable.find_unique(
|
||||
where={"model_id": model_id}
|
||||
)
|
||||
if deployment is None:
|
||||
raise HTTPException(status_code=404, detail=f"Deployment {model_id} not found")
|
||||
|
||||
litellm_params = deployment.litellm_params
|
||||
if "tags" not in litellm_params:
|
||||
litellm_params["tags"] = []
|
||||
litellm_params["tags"].append(tag)
|
||||
await prisma_client.db.litellm_proxymodeltable.update(
|
||||
where={"model_id": model_id},
|
||||
data={"litellm_params": safe_dumps(litellm_params)},
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/tag/update",
|
||||
tags=["tag management"],
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
)
|
||||
async def update_tag(
|
||||
tag: TagUpdateRequest,
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
):
|
||||
"""
|
||||
Update an existing tag.
|
||||
|
||||
Parameters:
|
||||
- name: str - The name of the tag to update
|
||||
- description: Optional[str] - Updated description
|
||||
- models: List[str] - Updated list of allowed LLM models
|
||||
"""
|
||||
from litellm.proxy.proxy_server import prisma_client
|
||||
|
||||
if prisma_client is None:
|
||||
raise HTTPException(status_code=500, detail="Database not connected")
|
||||
|
||||
try:
|
||||
# Get existing tags config
|
||||
tags_config = await _get_tags_config(prisma_client)
|
||||
|
||||
# Check if tag exists
|
||||
if tag.name not in tags_config:
|
||||
raise HTTPException(status_code=404, detail=f"Tag {tag.name} not found")
|
||||
|
||||
# Update tag
|
||||
tag_config_dict = dict(tags_config[tag.name])
|
||||
tag_config_dict.update(
|
||||
{
|
||||
"description": tag.description,
|
||||
"models": tag.models,
|
||||
"updated_at": str(datetime.datetime.now()),
|
||||
"updated_by": user_api_key_dict.user_id,
|
||||
}
|
||||
)
|
||||
tags_config[tag.name] = TagConfig(**tag_config_dict)
|
||||
|
||||
# Save updated config
|
||||
await _save_tags_config(prisma_client, tags_config)
|
||||
|
||||
# Get model names for response
|
||||
model_info = await _get_model_names(prisma_client, tag.models or [])
|
||||
tags_config[tag.name].model_info = model_info
|
||||
|
||||
return {
|
||||
"message": f"Tag {tag.name} updated successfully",
|
||||
"tag": tags_config[tag.name],
|
||||
}
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.exception(f"Error updating tag: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.post(
|
||||
"/tag/info",
|
||||
tags=["tag management"],
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
)
|
||||
async def info_tag(
|
||||
data: TagInfoRequest,
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
):
|
||||
"""
|
||||
Get information about specific tags.
|
||||
|
||||
Parameters:
|
||||
- names: List[str] - List of tag names to get information for
|
||||
"""
|
||||
from litellm.proxy.proxy_server import prisma_client
|
||||
|
||||
if prisma_client is None:
|
||||
raise HTTPException(status_code=500, detail="Database not connected")
|
||||
|
||||
try:
|
||||
tags_config = await _get_tags_config(prisma_client)
|
||||
|
||||
# Filter tags based on requested names
|
||||
requested_tags = {name: tags_config.get(name) for name in data.names}
|
||||
|
||||
# Check if any requested tags don't exist
|
||||
missing_tags = [name for name in data.names if name not in tags_config]
|
||||
if missing_tags:
|
||||
raise HTTPException(
|
||||
status_code=404, detail=f"Tags not found: {missing_tags}"
|
||||
)
|
||||
|
||||
return requested_tags
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.get(
|
||||
"/tag/list",
|
||||
tags=["tag management"],
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
response_model=List[TagConfig],
|
||||
)
|
||||
async def list_tags(
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
):
|
||||
"""
|
||||
List all available tags.
|
||||
"""
|
||||
from litellm.proxy.proxy_server import prisma_client
|
||||
|
||||
if prisma_client is None:
|
||||
raise HTTPException(status_code=500, detail="Database not connected")
|
||||
|
||||
try:
|
||||
## QUERY STORED TAGS ##
|
||||
tags_config = await _get_tags_config(prisma_client)
|
||||
list_of_tags = list(tags_config.values())
|
||||
|
||||
## QUERY DYNAMIC TAGS ##
|
||||
dynamic_tags = await prisma_client.db.litellm_dailytagspend.find_many(
|
||||
distinct=["tag"],
|
||||
)
|
||||
|
||||
dynamic_tags_list = [
|
||||
LiteLLM_DailyTagSpendTable(**dynamic_tag.model_dump())
|
||||
for dynamic_tag in dynamic_tags
|
||||
]
|
||||
|
||||
dynamic_tag_config = [
|
||||
TagConfig(
|
||||
name=tag.tag,
|
||||
description="This is just a spend tag that was passed dynamically in a request. It does not control any LLM models.",
|
||||
models=None,
|
||||
created_at=tag.created_at.isoformat(),
|
||||
updated_at=tag.updated_at.isoformat(),
|
||||
)
|
||||
for tag in dynamic_tags_list
|
||||
if tag.tag not in tags_config
|
||||
]
|
||||
|
||||
return list_of_tags + dynamic_tag_config
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.post(
|
||||
"/tag/delete",
|
||||
tags=["tag management"],
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
)
|
||||
async def delete_tag(
|
||||
data: TagDeleteRequest,
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
):
|
||||
"""
|
||||
Delete a tag.
|
||||
|
||||
Parameters:
|
||||
- name: str - The name of the tag to delete
|
||||
"""
|
||||
from litellm.proxy.proxy_server import prisma_client
|
||||
|
||||
if prisma_client is None:
|
||||
raise HTTPException(status_code=500, detail="Database not connected")
|
||||
|
||||
try:
|
||||
# Get existing tags config
|
||||
tags_config = await _get_tags_config(prisma_client)
|
||||
|
||||
# Check if tag exists
|
||||
if data.name not in tags_config:
|
||||
raise HTTPException(status_code=404, detail=f"Tag {data.name} not found")
|
||||
|
||||
# Delete tag
|
||||
del tags_config[data.name]
|
||||
|
||||
# Save updated config
|
||||
await _save_tags_config(prisma_client, tags_config)
|
||||
|
||||
return {"message": f"Tag {data.name} deleted successfully"}
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.get(
|
||||
"/tag/daily/activity",
|
||||
response_model=SpendAnalyticsPaginatedResponse,
|
||||
tags=["tag management"],
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
)
|
||||
async def get_tag_daily_activity(
|
||||
tags: Optional[str] = None,
|
||||
start_date: Optional[str] = None,
|
||||
end_date: Optional[str] = None,
|
||||
model: Optional[str] = None,
|
||||
api_key: Optional[str] = None,
|
||||
page: int = 1,
|
||||
page_size: int = 10,
|
||||
):
|
||||
"""
|
||||
Get daily activity for specific tags or all tags.
|
||||
|
||||
Args:
|
||||
tags (Optional[str]): Comma-separated list of tags to filter by. If not provided, returns data for all tags.
|
||||
start_date (Optional[str]): Start date for the activity period (YYYY-MM-DD).
|
||||
end_date (Optional[str]): End date for the activity period (YYYY-MM-DD).
|
||||
model (Optional[str]): Filter by model name.
|
||||
api_key (Optional[str]): Filter by API key.
|
||||
page (int): Page number for pagination.
|
||||
page_size (int): Number of items per page.
|
||||
|
||||
Returns:
|
||||
SpendAnalyticsPaginatedResponse: Paginated response containing daily activity data.
|
||||
"""
|
||||
from litellm.proxy.proxy_server import prisma_client
|
||||
|
||||
# Convert comma-separated tags string to list if provided
|
||||
tag_list = tags.split(",") if tags else None
|
||||
|
||||
return await get_daily_activity(
|
||||
prisma_client=prisma_client,
|
||||
table_name="litellm_dailytagspend",
|
||||
entity_id_field="tag",
|
||||
entity_id=tag_list,
|
||||
entity_metadata_field=None,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
model=model,
|
||||
api_key=api_key,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
)
|
||||
@@ -0,0 +1,383 @@
|
||||
"""
|
||||
Endpoints to control callbacks per team
|
||||
|
||||
Use this when each team should control its own callbacks
|
||||
"""
|
||||
|
||||
import json
|
||||
import traceback
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, Header, HTTPException, Request, status
|
||||
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.proxy._types import (
|
||||
AddTeamCallback,
|
||||
ProxyErrorTypes,
|
||||
ProxyException,
|
||||
TeamCallbackMetadata,
|
||||
UserAPIKeyAuth,
|
||||
)
|
||||
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
|
||||
from litellm.proxy.management_helpers.utils import management_endpoint_wrapper
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.post(
|
||||
"/team/{team_id:path}/callback",
|
||||
tags=["team management"],
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
)
|
||||
@management_endpoint_wrapper
|
||||
async def add_team_callbacks(
|
||||
data: AddTeamCallback,
|
||||
http_request: Request,
|
||||
team_id: str,
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
litellm_changed_by: Optional[str] = Header(
|
||||
None,
|
||||
description="The litellm-changed-by header enables tracking of actions performed by authorized users on behalf of other users, providing an audit trail for accountability",
|
||||
),
|
||||
):
|
||||
"""
|
||||
Add a success/failure callback to a team
|
||||
|
||||
Use this if if you want different teams to have different success/failure callbacks
|
||||
|
||||
Parameters:
|
||||
- callback_name (Literal["langfuse", "langsmith", "gcs"], required): The name of the callback to add
|
||||
- callback_type (Literal["success", "failure", "success_and_failure"], required): The type of callback to add. One of:
|
||||
- "success": Callback for successful LLM calls
|
||||
- "failure": Callback for failed LLM calls
|
||||
- "success_and_failure": Callback for both successful and failed LLM calls
|
||||
- callback_vars (StandardCallbackDynamicParams, required): A dictionary of variables to pass to the callback
|
||||
- langfuse_public_key: The public key for the Langfuse callback
|
||||
- langfuse_secret_key: The secret key for the Langfuse callback
|
||||
- langfuse_secret: The secret for the Langfuse callback
|
||||
- langfuse_host: The host for the Langfuse callback
|
||||
- gcs_bucket_name: The name of the GCS bucket
|
||||
- gcs_path_service_account: The path to the GCS service account
|
||||
- langsmith_api_key: The API key for the Langsmith callback
|
||||
- langsmith_project: The project for the Langsmith callback
|
||||
- langsmith_base_url: The base URL for the Langsmith callback
|
||||
|
||||
Example curl:
|
||||
```
|
||||
curl -X POST 'http:/localhost:4000/team/dbe2f686-a686-4896-864a-4c3924458709/callback' \
|
||||
-H 'Content-Type: application/json' \
|
||||
-H 'Authorization: Bearer sk-1234' \
|
||||
-d '{
|
||||
"callback_name": "langfuse",
|
||||
"callback_type": "success",
|
||||
"callback_vars": {"langfuse_public_key": "pk-lf-xxxx1", "langfuse_secret_key": "sk-xxxxx"}
|
||||
|
||||
}'
|
||||
```
|
||||
|
||||
This means for the team where team_id = dbe2f686-a686-4896-864a-4c3924458709, all LLM calls will be logged to langfuse using the public key pk-lf-xxxx1 and the secret key sk-xxxxx
|
||||
|
||||
"""
|
||||
try:
|
||||
from litellm.proxy.proxy_server import prisma_client
|
||||
|
||||
if prisma_client is None:
|
||||
raise HTTPException(status_code=500, detail={"error": "No db connected"})
|
||||
|
||||
# Check if team_id exists already
|
||||
_existing_team = await prisma_client.get_data(
|
||||
team_id=team_id, table_name="team", query_type="find_unique"
|
||||
)
|
||||
if _existing_team is None:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={
|
||||
"error": f"Team id = {team_id} does not exist. Please use a different team id."
|
||||
},
|
||||
)
|
||||
|
||||
# store team callback settings in metadata
|
||||
team_metadata = _existing_team.metadata
|
||||
team_callback_settings = team_metadata.get("callback_settings", {})
|
||||
# expect callback settings to be
|
||||
team_callback_settings_obj = TeamCallbackMetadata(**team_callback_settings)
|
||||
if data.callback_type == "success":
|
||||
if team_callback_settings_obj.success_callback is None:
|
||||
team_callback_settings_obj.success_callback = []
|
||||
|
||||
if data.callback_name in team_callback_settings_obj.success_callback:
|
||||
raise ProxyException(
|
||||
message=f"callback_name = {data.callback_name} already exists in failure_callback, for team_id = {team_id}. \n Existing failure_callback = {team_callback_settings_obj.success_callback}",
|
||||
code=status.HTTP_400_BAD_REQUEST,
|
||||
type=ProxyErrorTypes.bad_request_error,
|
||||
param="callback_name",
|
||||
)
|
||||
|
||||
team_callback_settings_obj.success_callback.append(data.callback_name)
|
||||
elif data.callback_type == "failure":
|
||||
if team_callback_settings_obj.failure_callback is None:
|
||||
team_callback_settings_obj.failure_callback = []
|
||||
|
||||
if data.callback_name in team_callback_settings_obj.failure_callback:
|
||||
raise ProxyException(
|
||||
message=f"callback_name = {data.callback_name} already exists in failure_callback, for team_id = {team_id}. \n Existing failure_callback = {team_callback_settings_obj.failure_callback}",
|
||||
code=status.HTTP_400_BAD_REQUEST,
|
||||
type=ProxyErrorTypes.bad_request_error,
|
||||
param="callback_name",
|
||||
)
|
||||
team_callback_settings_obj.failure_callback.append(data.callback_name)
|
||||
elif data.callback_type == "success_and_failure":
|
||||
if team_callback_settings_obj.success_callback is None:
|
||||
team_callback_settings_obj.success_callback = []
|
||||
if team_callback_settings_obj.failure_callback is None:
|
||||
team_callback_settings_obj.failure_callback = []
|
||||
if data.callback_name in team_callback_settings_obj.success_callback:
|
||||
raise ProxyException(
|
||||
message=f"callback_name = {data.callback_name} already exists in success_callback, for team_id = {team_id}. \n Existing success_callback = {team_callback_settings_obj.success_callback}",
|
||||
code=status.HTTP_400_BAD_REQUEST,
|
||||
type=ProxyErrorTypes.bad_request_error,
|
||||
param="callback_name",
|
||||
)
|
||||
|
||||
if data.callback_name in team_callback_settings_obj.failure_callback:
|
||||
raise ProxyException(
|
||||
message=f"callback_name = {data.callback_name} already exists in failure_callback, for team_id = {team_id}. \n Existing failure_callback = {team_callback_settings_obj.failure_callback}",
|
||||
code=status.HTTP_400_BAD_REQUEST,
|
||||
type=ProxyErrorTypes.bad_request_error,
|
||||
param="callback_name",
|
||||
)
|
||||
|
||||
team_callback_settings_obj.success_callback.append(data.callback_name)
|
||||
team_callback_settings_obj.failure_callback.append(data.callback_name)
|
||||
for var, value in data.callback_vars.items():
|
||||
if team_callback_settings_obj.callback_vars is None:
|
||||
team_callback_settings_obj.callback_vars = {}
|
||||
team_callback_settings_obj.callback_vars[var] = value
|
||||
|
||||
team_callback_settings_obj_dict = team_callback_settings_obj.model_dump()
|
||||
|
||||
team_metadata["callback_settings"] = team_callback_settings_obj_dict
|
||||
team_metadata_json = json.dumps(team_metadata) # update team_metadata
|
||||
|
||||
new_team_row = await prisma_client.db.litellm_teamtable.update(
|
||||
where={"team_id": team_id}, data={"metadata": team_metadata_json} # type: ignore
|
||||
)
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"data": new_team_row,
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.error(
|
||||
"litellm.proxy.proxy_server.add_team_callbacks(): Exception occured - {}".format(
|
||||
str(e)
|
||||
)
|
||||
)
|
||||
verbose_proxy_logger.debug(traceback.format_exc())
|
||||
if isinstance(e, HTTPException):
|
||||
raise ProxyException(
|
||||
message=getattr(e, "detail", f"Internal Server Error({str(e)})"),
|
||||
type=ProxyErrorTypes.internal_server_error.value,
|
||||
param=getattr(e, "param", "None"),
|
||||
code=getattr(e, "status_code", status.HTTP_500_INTERNAL_SERVER_ERROR),
|
||||
)
|
||||
elif isinstance(e, ProxyException):
|
||||
raise e
|
||||
raise ProxyException(
|
||||
message="Internal Server Error, " + str(e),
|
||||
type=ProxyErrorTypes.internal_server_error.value,
|
||||
param=getattr(e, "param", "None"),
|
||||
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/team/{team_id}/disable_logging",
|
||||
tags=["team management"],
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
)
|
||||
@management_endpoint_wrapper
|
||||
async def disable_team_logging(
|
||||
http_request: Request,
|
||||
team_id: str,
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
):
|
||||
"""
|
||||
Disable all logging callbacks for a team
|
||||
|
||||
Parameters:
|
||||
- team_id (str, required): The unique identifier for the team
|
||||
|
||||
Example curl:
|
||||
```
|
||||
curl -X POST 'http://localhost:4000/team/dbe2f686-a686-4896-864a-4c3924458709/disable_logging' \
|
||||
-H 'Authorization: Bearer sk-1234'
|
||||
```
|
||||
|
||||
|
||||
"""
|
||||
try:
|
||||
from litellm.proxy.proxy_server import prisma_client
|
||||
|
||||
if prisma_client is None:
|
||||
raise HTTPException(status_code=500, detail={"error": "No db connected"})
|
||||
|
||||
# Check if team exists
|
||||
_existing_team = await prisma_client.get_data(
|
||||
team_id=team_id, table_name="team", query_type="find_unique"
|
||||
)
|
||||
if _existing_team is None:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail={"error": f"Team id = {team_id} does not exist."},
|
||||
)
|
||||
|
||||
# Update team metadata to disable logging
|
||||
team_metadata = _existing_team.metadata
|
||||
team_callback_settings = team_metadata.get("callback_settings", {})
|
||||
team_callback_settings_obj = TeamCallbackMetadata(**team_callback_settings)
|
||||
|
||||
# Reset callbacks
|
||||
team_callback_settings_obj.success_callback = []
|
||||
team_callback_settings_obj.failure_callback = []
|
||||
|
||||
# Update metadata
|
||||
team_metadata["callback_settings"] = team_callback_settings_obj.model_dump()
|
||||
team_metadata_json = json.dumps(team_metadata)
|
||||
|
||||
# Update team in database
|
||||
updated_team = await prisma_client.db.litellm_teamtable.update(
|
||||
where={"team_id": team_id}, data={"metadata": team_metadata_json} # type: ignore
|
||||
)
|
||||
|
||||
if updated_team is None:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail={
|
||||
"error": f"Team id = {team_id} does not exist. Error updating team logging"
|
||||
},
|
||||
)
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"message": f"Logging disabled for team {team_id}",
|
||||
"data": {
|
||||
"team_id": updated_team.team_id,
|
||||
"success_callbacks": [],
|
||||
"failure_callbacks": [],
|
||||
},
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.error(
|
||||
f"litellm.proxy.proxy_server.disable_team_logging(): Exception occurred - {str(e)}"
|
||||
)
|
||||
verbose_proxy_logger.debug(traceback.format_exc())
|
||||
if isinstance(e, HTTPException):
|
||||
raise ProxyException(
|
||||
message=getattr(e, "detail", f"Internal Server Error({str(e)})"),
|
||||
type=ProxyErrorTypes.internal_server_error.value,
|
||||
param=getattr(e, "param", "None"),
|
||||
code=getattr(e, "status_code", status.HTTP_500_INTERNAL_SERVER_ERROR),
|
||||
)
|
||||
elif isinstance(e, ProxyException):
|
||||
raise e
|
||||
raise ProxyException(
|
||||
message="Internal Server Error, " + str(e),
|
||||
type=ProxyErrorTypes.internal_server_error.value,
|
||||
param=getattr(e, "param", "None"),
|
||||
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/team/{team_id:path}/callback",
|
||||
tags=["team management"],
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
)
|
||||
@management_endpoint_wrapper
|
||||
async def get_team_callbacks(
|
||||
http_request: Request,
|
||||
team_id: str,
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
):
|
||||
"""
|
||||
Get the success/failure callbacks and variables for a team
|
||||
|
||||
Parameters:
|
||||
- team_id (str, required): The unique identifier for the team
|
||||
|
||||
Example curl:
|
||||
```
|
||||
curl -X GET 'http://localhost:4000/team/dbe2f686-a686-4896-864a-4c3924458709/callback' \
|
||||
-H 'Authorization: Bearer sk-1234'
|
||||
```
|
||||
|
||||
This will return the callback settings for the team with id dbe2f686-a686-4896-864a-4c3924458709
|
||||
|
||||
Returns {
|
||||
"status": "success",
|
||||
"data": {
|
||||
"team_id": team_id,
|
||||
"success_callbacks": team_callback_settings_obj.success_callback,
|
||||
"failure_callbacks": team_callback_settings_obj.failure_callback,
|
||||
"callback_vars": team_callback_settings_obj.callback_vars,
|
||||
},
|
||||
}
|
||||
"""
|
||||
try:
|
||||
from litellm.proxy.proxy_server import prisma_client
|
||||
|
||||
if prisma_client is None:
|
||||
raise HTTPException(status_code=500, detail={"error": "No db connected"})
|
||||
|
||||
# Check if team_id exists
|
||||
_existing_team = await prisma_client.get_data(
|
||||
team_id=team_id, table_name="team", query_type="find_unique"
|
||||
)
|
||||
if _existing_team is None:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail={"error": f"Team id = {team_id} does not exist."},
|
||||
)
|
||||
|
||||
# Retrieve team callback settings from metadata
|
||||
team_metadata = _existing_team.metadata
|
||||
team_callback_settings = team_metadata.get("callback_settings", {})
|
||||
|
||||
# Convert to TeamCallbackMetadata object for consistent structure
|
||||
team_callback_settings_obj = TeamCallbackMetadata(**team_callback_settings)
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"data": {
|
||||
"team_id": team_id,
|
||||
"success_callbacks": team_callback_settings_obj.success_callback,
|
||||
"failure_callbacks": team_callback_settings_obj.failure_callback,
|
||||
"callback_vars": team_callback_settings_obj.callback_vars,
|
||||
},
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.error(
|
||||
"litellm.proxy.proxy_server.get_team_callbacks(): Exception occurred - {}".format(
|
||||
str(e)
|
||||
)
|
||||
)
|
||||
verbose_proxy_logger.debug(traceback.format_exc())
|
||||
if isinstance(e, HTTPException):
|
||||
raise ProxyException(
|
||||
message=getattr(e, "detail", f"Internal Server Error({str(e)})"),
|
||||
type=ProxyErrorTypes.internal_server_error.value,
|
||||
param=getattr(e, "param", "None"),
|
||||
code=getattr(e, "status_code", status.HTTP_500_INTERNAL_SERVER_ERROR),
|
||||
)
|
||||
elif isinstance(e, ProxyException):
|
||||
raise e
|
||||
raise ProxyException(
|
||||
message="Internal Server Error, " + str(e),
|
||||
type=ProxyErrorTypes.internal_server_error.value,
|
||||
param=getattr(e, "param", "None"),
|
||||
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
)
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,13 @@
|
||||
"""
|
||||
Types for the management endpoints
|
||||
|
||||
Might include fastapi/proxy requirements.txt related imports
|
||||
"""
|
||||
|
||||
from typing import List
|
||||
|
||||
from fastapi_sso.sso.base import OpenID
|
||||
|
||||
|
||||
class CustomOpenID(OpenID):
|
||||
team_ids: List[str]
|
||||
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user