new mcp servers format
This commit is contained in:
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -202,6 +202,7 @@ async def new_user(
|
||||
- team_id: Optional[str] - [DEPRECATED PARAM] The team id of the user. Default is None.
|
||||
- duration: Optional[str] - Duration for the key auto-created on `/user/new`. Default is None.
|
||||
- key_alias: Optional[str] - Alias for the key auto-created on `/user/new`. Default is None.
|
||||
- sso_user_id: Optional[str] - The id of the user in the SSO provider.
|
||||
|
||||
Returns:
|
||||
- key: (str) The generated api key for the user
|
||||
|
||||
@@ -1173,6 +1173,7 @@ async def generate_key_helper_fn( # noqa: PLR0915
|
||||
created_by: Optional[str] = None,
|
||||
updated_by: Optional[str] = None,
|
||||
allowed_routes: Optional[list] = None,
|
||||
sso_user_id: Optional[str] = None,
|
||||
):
|
||||
from litellm.proxy.proxy_server import (
|
||||
litellm_proxy_budget_name,
|
||||
@@ -1251,6 +1252,7 @@ async def generate_key_helper_fn( # noqa: PLR0915
|
||||
"budget_duration": budget_duration,
|
||||
"budget_reset_at": reset_at,
|
||||
"allowed_cache_controls": allowed_cache_controls,
|
||||
"sso_user_id": sso_user_id,
|
||||
}
|
||||
if teams is not None:
|
||||
user_data["teams"] = teams
|
||||
@@ -1859,6 +1861,7 @@ async def validate_key_list_check(
|
||||
team_id: Optional[str],
|
||||
organization_id: Optional[str],
|
||||
key_alias: Optional[str],
|
||||
key_hash: Optional[str],
|
||||
prisma_client: PrismaClient,
|
||||
) -> Optional[LiteLLM_UserTable]:
|
||||
if user_api_key_dict.user_role == LitellmUserRoles.PROXY_ADMIN.value:
|
||||
@@ -1922,6 +1925,31 @@ async def validate_key_list_check(
|
||||
param="organization_id",
|
||||
code=status.HTTP_403_FORBIDDEN,
|
||||
)
|
||||
|
||||
if key_hash:
|
||||
try:
|
||||
key_info = await prisma_client.db.litellm_verificationtoken.find_unique(
|
||||
where={"token": key_hash},
|
||||
)
|
||||
except Exception:
|
||||
raise ProxyException(
|
||||
message="Key Hash not found.",
|
||||
type=ProxyErrorTypes.bad_request_error,
|
||||
param="key_hash",
|
||||
code=status.HTTP_403_FORBIDDEN,
|
||||
)
|
||||
can_user_query_key_info = await _can_user_query_key_info(
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
key=key_hash,
|
||||
key_info=key_info,
|
||||
)
|
||||
if not can_user_query_key_info:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="You are not allowed to access this key's info. Your role={}".format(
|
||||
user_api_key_dict.user_role
|
||||
),
|
||||
)
|
||||
return complete_user_info
|
||||
|
||||
|
||||
@@ -1970,6 +1998,7 @@ async def list_keys(
|
||||
organization_id: Optional[str] = Query(
|
||||
None, description="Filter keys by organization ID"
|
||||
),
|
||||
key_hash: Optional[str] = Query(None, description="Filter keys by key hash"),
|
||||
key_alias: Optional[str] = Query(None, description="Filter keys by key alias"),
|
||||
return_full_object: bool = Query(False, description="Return full key object"),
|
||||
include_team_keys: bool = Query(
|
||||
@@ -2002,6 +2031,7 @@ async def list_keys(
|
||||
team_id=team_id,
|
||||
organization_id=organization_id,
|
||||
key_alias=key_alias,
|
||||
key_hash=key_hash,
|
||||
prisma_client=prisma_client,
|
||||
)
|
||||
|
||||
@@ -2027,6 +2057,7 @@ async def list_keys(
|
||||
user_id=user_id,
|
||||
team_id=team_id,
|
||||
key_alias=key_alias,
|
||||
key_hash=key_hash,
|
||||
return_full_object=return_full_object,
|
||||
organization_id=organization_id,
|
||||
admin_team_ids=admin_team_ids,
|
||||
@@ -2063,6 +2094,7 @@ async def _list_key_helper(
|
||||
team_id: Optional[str],
|
||||
organization_id: Optional[str],
|
||||
key_alias: Optional[str],
|
||||
key_hash: Optional[str],
|
||||
exclude_team_id: Optional[str] = None,
|
||||
return_full_object: bool = False,
|
||||
admin_team_ids: Optional[
|
||||
@@ -2109,6 +2141,8 @@ async def _list_key_helper(
|
||||
user_condition["team_id"] = {"not": exclude_team_id}
|
||||
if organization_id and isinstance(organization_id, str):
|
||||
user_condition["organization_id"] = organization_id
|
||||
if key_hash and isinstance(key_hash, str):
|
||||
user_condition["token"] = key_hash
|
||||
|
||||
if user_condition:
|
||||
or_conditions.append(user_condition)
|
||||
|
||||
Binary file not shown.
Binary file not shown.
@@ -14,7 +14,7 @@ import json
|
||||
import traceback
|
||||
import uuid
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import List, Optional, Tuple, Union, cast
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union, cast
|
||||
|
||||
import fastapi
|
||||
from fastapi import APIRouter, Depends, Header, HTTPException, Request, status
|
||||
@@ -85,6 +85,7 @@ from litellm.types.proxy.management_endpoints.common_daily_activity import (
|
||||
)
|
||||
from litellm.types.proxy.management_endpoints.team_endpoints import (
|
||||
GetTeamMemberPermissionsResponse,
|
||||
TeamListResponse,
|
||||
UpdateTeamMemberPermissionsRequest,
|
||||
)
|
||||
|
||||
@@ -537,6 +538,7 @@ async def update_team(
|
||||
detail={"error": "Team doesn't exist. Got={}".format(team_row)},
|
||||
)
|
||||
|
||||
verbose_proxy_logger.info("Successfully updated team - %s, info", team_row.team_id)
|
||||
await _cache_team_object(
|
||||
team_id=team_row.team_id,
|
||||
team_table=LiteLLM_TeamTableCachedObj(**team_row.model_dump()),
|
||||
@@ -1553,6 +1555,150 @@ async def list_available_teams(
|
||||
return available_teams_correct_type
|
||||
|
||||
|
||||
@router.get(
|
||||
"/v2/team/list",
|
||||
tags=["team management"],
|
||||
response_model=TeamListResponse,
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
)
|
||||
async def list_team_v2(
|
||||
http_request: Request,
|
||||
user_id: Optional[str] = fastapi.Query(
|
||||
default=None, description="Only return teams which this 'user_id' belongs to"
|
||||
),
|
||||
organization_id: Optional[str] = fastapi.Query(
|
||||
default=None,
|
||||
description="Only return teams which this 'organization_id' belongs to",
|
||||
),
|
||||
team_id: Optional[str] = fastapi.Query(
|
||||
default=None, description="Only return teams which this 'team_id' belongs to"
|
||||
),
|
||||
team_alias: Optional[str] = fastapi.Query(
|
||||
default=None,
|
||||
description="Only return teams which this 'team_alias' belongs to. Supports partial matching.",
|
||||
),
|
||||
page: int = fastapi.Query(
|
||||
default=1, description="Page number for pagination", ge=1
|
||||
),
|
||||
page_size: int = fastapi.Query(
|
||||
default=10, description="Number of teams per page", ge=1, le=100
|
||||
),
|
||||
sort_by: Optional[str] = fastapi.Query(
|
||||
default=None,
|
||||
description="Column to sort by (e.g. 'team_id', 'team_alias', 'created_at')",
|
||||
),
|
||||
sort_order: str = fastapi.Query(
|
||||
default="asc", description="Sort order ('asc' or 'desc')"
|
||||
),
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
):
|
||||
"""
|
||||
Get a paginated list of teams with filtering and sorting options.
|
||||
|
||||
Parameters:
|
||||
user_id: Optional[str]
|
||||
Only return teams which this user belongs to
|
||||
organization_id: Optional[str]
|
||||
Only return teams which belong to this organization
|
||||
team_id: Optional[str]
|
||||
Filter teams by exact team_id match
|
||||
team_alias: Optional[str]
|
||||
Filter teams by partial team_alias match
|
||||
page: int
|
||||
The page number to return
|
||||
page_size: int
|
||||
The number of items per page
|
||||
sort_by: Optional[str]
|
||||
Column to sort by (e.g. 'team_id', 'team_alias', 'created_at')
|
||||
sort_order: str
|
||||
Sort order ('asc' or 'desc')
|
||||
"""
|
||||
from litellm.proxy.proxy_server import prisma_client
|
||||
|
||||
if prisma_client is None:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail={"error": f"No db connected. prisma client={prisma_client}"},
|
||||
)
|
||||
|
||||
if user_id is None and user_api_key_dict.user_role != LitellmUserRoles.PROXY_ADMIN:
|
||||
user_id = user_api_key_dict.user_id
|
||||
|
||||
# Calculate skip and take for pagination
|
||||
skip = (page - 1) * page_size
|
||||
|
||||
# Build where conditions based on provided parameters
|
||||
where_conditions: Dict[str, Any] = {}
|
||||
|
||||
if team_id:
|
||||
where_conditions["team_id"] = team_id
|
||||
|
||||
if team_alias:
|
||||
where_conditions["team_alias"] = {
|
||||
"contains": team_alias,
|
||||
"mode": "insensitive", # Case-insensitive search
|
||||
}
|
||||
|
||||
if organization_id:
|
||||
where_conditions["organization_id"] = organization_id
|
||||
|
||||
if user_id:
|
||||
try:
|
||||
user_object = await prisma_client.db.litellm_usertable.find_unique(
|
||||
where={"user_id": user_id}
|
||||
)
|
||||
except Exception:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail={"error": f"User not found, passed user_id={user_id}"},
|
||||
)
|
||||
if user_object is None:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail={"error": f"User not found, passed user_id={user_id}"},
|
||||
)
|
||||
user_object_correct_type = LiteLLM_UserTable(**user_object.model_dump())
|
||||
# Find teams where this user is a member by checking members_with_roles array
|
||||
if team_id is None:
|
||||
where_conditions["team_id"] = {"in": user_object_correct_type.teams}
|
||||
elif team_id in user_object_correct_type.teams:
|
||||
where_conditions["team_id"] = team_id
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail={"error": f"User is not a member of team_id={team_id}"},
|
||||
)
|
||||
|
||||
# Build order_by conditions
|
||||
valid_sort_columns = ["team_id", "team_alias", "created_at"]
|
||||
order_by = None
|
||||
if sort_by and sort_by in valid_sort_columns:
|
||||
if sort_order.lower() not in ["asc", "desc"]:
|
||||
sort_order = "asc"
|
||||
order_by = {sort_by: sort_order.lower()}
|
||||
|
||||
# Get teams with pagination
|
||||
teams = await prisma_client.db.litellm_teamtable.find_many(
|
||||
where=where_conditions,
|
||||
skip=skip,
|
||||
take=page_size,
|
||||
order=order_by if order_by else {"created_at": "desc"}, # Default sort
|
||||
)
|
||||
# Get total count for pagination
|
||||
total_count = await prisma_client.db.litellm_teamtable.count(where=where_conditions)
|
||||
|
||||
# Calculate total pages
|
||||
total_pages = -(-total_count // page_size) # Ceiling division
|
||||
|
||||
return {
|
||||
"teams": [team.model_dump() for team in teams] if teams else [],
|
||||
"total": total_count,
|
||||
"page": page,
|
||||
"page_size": page_size,
|
||||
"total_pages": total_pages,
|
||||
}
|
||||
|
||||
|
||||
@router.get(
|
||||
"/team/list", tags=["team management"], dependencies=[Depends(user_api_key_auth)]
|
||||
)
|
||||
|
||||
@@ -19,6 +19,7 @@ from fastapi.responses import RedirectResponse
|
||||
|
||||
import litellm
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.caching import DualCache
|
||||
from litellm.constants import MAX_SPENDLOG_ROWS_TO_QUERY
|
||||
from litellm.llms.custom_httpx.http_handler import (
|
||||
AsyncHTTPHandler,
|
||||
@@ -26,6 +27,7 @@ from litellm.llms.custom_httpx.http_handler import (
|
||||
httpxSpecialProvider,
|
||||
)
|
||||
from litellm.proxy._types import (
|
||||
CommonProxyErrors,
|
||||
LiteLLM_UserTable,
|
||||
LitellmUserRoles,
|
||||
Member,
|
||||
@@ -38,7 +40,7 @@ from litellm.proxy._types import (
|
||||
TeamMemberAddRequest,
|
||||
UserAPIKeyAuth,
|
||||
)
|
||||
from litellm.proxy.auth.auth_checks import get_user_object
|
||||
from litellm.proxy.auth.auth_checks import ExperimentalUIJWTToken, get_user_object
|
||||
from litellm.proxy.auth.auth_utils import _has_user_setup_sso
|
||||
from litellm.proxy.auth.handle_jwt import JWTHandler
|
||||
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
|
||||
@@ -57,8 +59,8 @@ from litellm.proxy.management_endpoints.sso_helper_utils import (
|
||||
)
|
||||
from litellm.proxy.management_endpoints.team_endpoints import new_team, team_member_add
|
||||
from litellm.proxy.management_endpoints.types import CustomOpenID
|
||||
from litellm.proxy.utils import PrismaClient
|
||||
from litellm.secret_managers.main import str_to_bool
|
||||
from litellm.proxy.utils import PrismaClient, ProxyLogging
|
||||
from litellm.secret_managers.main import get_secret_bool, str_to_bool
|
||||
from litellm.types.proxy.management_endpoints.ui_sso import *
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -124,6 +126,7 @@ async def google_login(request: Request): # noqa: PLR0915
|
||||
)
|
||||
is True
|
||||
):
|
||||
verbose_proxy_logger.info(f"Redirecting to SSO login for {redirect_url}")
|
||||
return await SSOAuthenticationHandler.get_sso_login_redirect(
|
||||
redirect_url=redirect_url,
|
||||
microsoft_client_id=microsoft_client_id,
|
||||
@@ -319,9 +322,113 @@ def get_disabled_non_admin_personal_key_creation():
|
||||
return bool("proxy_admin" in allowed_user_roles)
|
||||
|
||||
|
||||
async def get_existing_user_info_from_db(
|
||||
user_id: Optional[str],
|
||||
user_email: Optional[str],
|
||||
prisma_client: PrismaClient,
|
||||
user_api_key_cache: DualCache,
|
||||
proxy_logging_obj: ProxyLogging,
|
||||
) -> Optional[LiteLLM_UserTable]:
|
||||
try:
|
||||
user_info = await get_user_object(
|
||||
user_id=user_id,
|
||||
user_email=user_email,
|
||||
prisma_client=prisma_client,
|
||||
user_api_key_cache=user_api_key_cache,
|
||||
user_id_upsert=False,
|
||||
parent_otel_span=None,
|
||||
proxy_logging_obj=proxy_logging_obj,
|
||||
sso_user_id=user_id,
|
||||
)
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.debug(f"Error getting user object: {e}")
|
||||
user_info = None
|
||||
|
||||
return user_info
|
||||
|
||||
|
||||
async def get_user_info_from_db(
|
||||
result: Union[CustomOpenID, OpenID, dict],
|
||||
prisma_client: PrismaClient,
|
||||
user_api_key_cache: DualCache,
|
||||
proxy_logging_obj: ProxyLogging,
|
||||
user_email: Optional[str],
|
||||
user_defined_values: Optional[SSOUserDefinedValues],
|
||||
) -> Optional[Union[LiteLLM_UserTable, NewUserResponse]]:
|
||||
try:
|
||||
user_info: Optional[Union[LiteLLM_UserTable, NewUserResponse]] = (
|
||||
await get_existing_user_info_from_db(
|
||||
user_id=cast(
|
||||
Optional[str],
|
||||
(
|
||||
getattr(result, "id", None)
|
||||
if not isinstance(result, dict)
|
||||
else result.get("id", None)
|
||||
),
|
||||
),
|
||||
user_email=cast(
|
||||
Optional[str],
|
||||
(
|
||||
getattr(result, "email", None)
|
||||
if not isinstance(result, dict)
|
||||
else result.get("email", None)
|
||||
),
|
||||
),
|
||||
prisma_client=prisma_client,
|
||||
user_api_key_cache=user_api_key_cache,
|
||||
proxy_logging_obj=proxy_logging_obj,
|
||||
)
|
||||
)
|
||||
verbose_proxy_logger.debug(
|
||||
f"user_info: {user_info}; litellm.default_internal_user_params: {litellm.default_internal_user_params}"
|
||||
)
|
||||
|
||||
# Upsert SSO User to LiteLLM DB
|
||||
|
||||
if user_info is None:
|
||||
user_info = await SSOAuthenticationHandler.upsert_sso_user(
|
||||
result=result,
|
||||
user_info=user_info,
|
||||
user_email=user_email,
|
||||
user_defined_values=user_defined_values,
|
||||
prisma_client=prisma_client,
|
||||
)
|
||||
|
||||
await SSOAuthenticationHandler.add_user_to_teams_from_sso_response(
|
||||
result=result,
|
||||
user_info=user_info,
|
||||
)
|
||||
|
||||
return user_info
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.exception(
|
||||
f"[Non-Blocking] Error trying to add sso user to db: {e}"
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def apply_user_info_values_to_sso_user_defined_values(
|
||||
user_info: Optional[Union[LiteLLM_UserTable, NewUserResponse]],
|
||||
user_defined_values: Optional[SSOUserDefinedValues],
|
||||
) -> Optional[SSOUserDefinedValues]:
|
||||
if user_defined_values is None:
|
||||
return None
|
||||
if user_info is not None and user_info.user_id is not None:
|
||||
user_defined_values["user_id"] = user_info.user_id
|
||||
|
||||
if user_info is None or user_info.user_role is None:
|
||||
user_defined_values["user_role"] = LitellmUserRoles.INTERNAL_USER_VIEW_ONLY
|
||||
else:
|
||||
user_defined_values["user_role"] = user_info.user_role
|
||||
|
||||
return user_defined_values
|
||||
|
||||
|
||||
@router.get("/sso/callback", tags=["experimental"], include_in_schema=False)
|
||||
async def auth_callback(request: Request): # noqa: PLR0915
|
||||
"""Verify login"""
|
||||
verbose_proxy_logger.info("Starting SSO callback")
|
||||
from litellm.proxy.management_endpoints.key_management_endpoints import (
|
||||
generate_key_helper_fn,
|
||||
)
|
||||
@@ -337,6 +444,11 @@ async def auth_callback(request: Request): # noqa: PLR0915
|
||||
user_custom_sso,
|
||||
)
|
||||
|
||||
if prisma_client is None:
|
||||
raise HTTPException(
|
||||
status_code=500, detail=CommonProxyErrors.db_not_connected_error.value
|
||||
)
|
||||
|
||||
microsoft_client_id = os.getenv("MICROSOFT_CLIENT_ID", None)
|
||||
google_client_id = os.getenv("GOOGLE_CLIENT_ID", None)
|
||||
generic_client_id = os.getenv("GENERIC_CLIENT_ID", None)
|
||||
@@ -374,8 +486,15 @@ async def auth_callback(request: Request): # noqa: PLR0915
|
||||
generic_client_id=generic_client_id,
|
||||
redirect_url=redirect_url,
|
||||
)
|
||||
|
||||
if result is None:
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail="Result not returned by SSO provider.",
|
||||
)
|
||||
|
||||
# User is Authe'd in - generate key for the UI to access Proxy
|
||||
verbose_proxy_logger.debug(f"SSO callback result: {result}")
|
||||
verbose_proxy_logger.info(f"SSO callback result: {result}")
|
||||
user_email: Optional[str] = getattr(result, "email", None)
|
||||
user_id: Optional[str] = getattr(result, "id", None) if result is not None else None
|
||||
|
||||
@@ -441,51 +560,18 @@ async def auth_callback(request: Request): # noqa: PLR0915
|
||||
budget_duration=internal_user_budget_duration,
|
||||
)
|
||||
|
||||
_user_id_from_sso = user_id
|
||||
user_role = None
|
||||
try:
|
||||
if prisma_client is not None:
|
||||
try:
|
||||
user_info = await get_user_object(
|
||||
user_id=user_id,
|
||||
user_email=user_email,
|
||||
prisma_client=prisma_client,
|
||||
user_api_key_cache=user_api_key_cache,
|
||||
user_id_upsert=False,
|
||||
parent_otel_span=None,
|
||||
proxy_logging_obj=proxy_logging_obj,
|
||||
sso_user_id=user_id,
|
||||
)
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.debug(f"Error getting user object: {e}")
|
||||
user_info = None
|
||||
user_info = await get_user_info_from_db(
|
||||
result=result,
|
||||
prisma_client=prisma_client,
|
||||
user_api_key_cache=user_api_key_cache,
|
||||
proxy_logging_obj=proxy_logging_obj,
|
||||
user_email=user_email,
|
||||
user_defined_values=user_defined_values,
|
||||
)
|
||||
|
||||
verbose_proxy_logger.debug(
|
||||
f"user_info: {user_info}; litellm.default_internal_user_params: {litellm.default_internal_user_params}"
|
||||
)
|
||||
|
||||
# Upsert SSO User to LiteLLM DB
|
||||
user_info = await SSOAuthenticationHandler.upsert_sso_user(
|
||||
result=result,
|
||||
user_info=user_info,
|
||||
user_email=user_email,
|
||||
user_defined_values=user_defined_values,
|
||||
prisma_client=prisma_client,
|
||||
)
|
||||
if user_info and user_info.user_role is not None:
|
||||
user_role = user_info.user_role
|
||||
else:
|
||||
user_role = LitellmUserRoles.INTERNAL_USER_VIEW_ONLY
|
||||
|
||||
await SSOAuthenticationHandler.add_user_to_teams_from_sso_response(
|
||||
result=result,
|
||||
user_info=user_info,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.debug(
|
||||
f"[Non-Blocking] Error trying to add sso user to db: {e}"
|
||||
)
|
||||
user_defined_values = apply_user_info_values_to_sso_user_defined_values(
|
||||
user_info=user_info, user_defined_values=user_defined_values
|
||||
)
|
||||
|
||||
if user_defined_values is None:
|
||||
raise Exception(
|
||||
@@ -507,7 +593,10 @@ async def auth_callback(request: Request): # noqa: PLR0915
|
||||
user_id = response["user_id"] # type: ignore
|
||||
|
||||
litellm_dashboard_ui = "/ui/"
|
||||
user_role = user_role or LitellmUserRoles.INTERNAL_USER_VIEW_ONLY.value
|
||||
user_role = (
|
||||
user_defined_values["user_role"]
|
||||
or LitellmUserRoles.INTERNAL_USER_VIEW_ONLY.value
|
||||
)
|
||||
if (
|
||||
os.getenv("PROXY_ADMIN_ID", None) is not None
|
||||
and os.environ["PROXY_ADMIN_ID"] == user_id
|
||||
@@ -536,6 +625,30 @@ async def auth_callback(request: Request): # noqa: PLR0915
|
||||
|
||||
import jwt
|
||||
|
||||
if get_secret_bool("EXPERIMENTAL_UI_LOGIN"):
|
||||
_user_info: Optional[LiteLLM_UserTable] = None
|
||||
if (
|
||||
user_defined_values is not None
|
||||
and user_defined_values["user_id"] is not None
|
||||
):
|
||||
_user_info = LiteLLM_UserTable(
|
||||
user_id=user_defined_values["user_id"],
|
||||
user_role=user_defined_values["user_role"] or user_role,
|
||||
models=[],
|
||||
max_budget=litellm.max_ui_session_budget,
|
||||
)
|
||||
if _user_info is None:
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail={
|
||||
"error": "User Information is required for experimental UI login"
|
||||
},
|
||||
)
|
||||
|
||||
key = ExperimentalUIJWTToken.get_experimental_ui_login_jwt_auth_token(
|
||||
_user_info
|
||||
)
|
||||
|
||||
jwt_token = jwt.encode( # type: ignore
|
||||
{
|
||||
"user_id": user_id,
|
||||
@@ -552,8 +665,10 @@ async def auth_callback(request: Request): # noqa: PLR0915
|
||||
master_key,
|
||||
algorithm="HS256",
|
||||
)
|
||||
verbose_proxy_logger.info(f"user_id: {user_id}; jwt_token: {jwt_token}")
|
||||
if user_id is not None and isinstance(user_id, str):
|
||||
litellm_dashboard_ui += "?login=success"
|
||||
verbose_proxy_logger.info(f"Redirecting to {litellm_dashboard_ui}")
|
||||
redirect_response = RedirectResponse(url=litellm_dashboard_ui, status_code=303)
|
||||
redirect_response.set_cookie(key="token", value=jwt_token, secure=True)
|
||||
return redirect_response
|
||||
@@ -592,9 +707,9 @@ async def insert_sso_user(
|
||||
if user_defined_values.get("max_budget") is None:
|
||||
user_defined_values["max_budget"] = litellm.max_internal_user_budget
|
||||
if user_defined_values.get("budget_duration") is None:
|
||||
user_defined_values[
|
||||
"budget_duration"
|
||||
] = litellm.internal_user_budget_duration
|
||||
user_defined_values["budget_duration"] = (
|
||||
litellm.internal_user_budget_duration
|
||||
)
|
||||
|
||||
if user_defined_values["user_role"] is None:
|
||||
user_defined_values["user_role"] = LitellmUserRoles.INTERNAL_USER_VIEW_ONLY
|
||||
@@ -605,6 +720,8 @@ async def insert_sso_user(
|
||||
user_role=user_defined_values["user_role"], # type: ignore
|
||||
max_budget=user_defined_values["max_budget"],
|
||||
budget_duration=user_defined_values["budget_duration"],
|
||||
sso_user_id=user_defined_values["user_id"],
|
||||
auto_create_key=False,
|
||||
)
|
||||
|
||||
if result_openid:
|
||||
@@ -787,9 +904,9 @@ class SSOAuthenticationHandler:
|
||||
if state:
|
||||
redirect_params["state"] = state
|
||||
elif "okta" in generic_authorization_endpoint:
|
||||
redirect_params[
|
||||
"state"
|
||||
] = uuid.uuid4().hex # set state param for okta - required
|
||||
redirect_params["state"] = (
|
||||
uuid.uuid4().hex
|
||||
) # set state param for okta - required
|
||||
return await generic_sso.get_login_redirect(**redirect_params) # type: ignore
|
||||
raise ValueError(
|
||||
"Unknown SSO provider. Please setup SSO with client IDs https://docs.litellm.ai/docs/proxy/admin_ui_sso"
|
||||
@@ -1034,9 +1151,9 @@ class MicrosoftSSOHandler:
|
||||
|
||||
# if user is trying to get the raw sso response for debugging, return the raw sso response
|
||||
if return_raw_sso_response:
|
||||
original_msft_result[
|
||||
MicrosoftSSOHandler.GRAPH_API_RESPONSE_KEY
|
||||
] = user_team_ids
|
||||
original_msft_result[MicrosoftSSOHandler.GRAPH_API_RESPONSE_KEY] = (
|
||||
user_team_ids
|
||||
)
|
||||
return original_msft_result or {}
|
||||
|
||||
result = MicrosoftSSOHandler.openid_from_response(
|
||||
@@ -1104,9 +1221,9 @@ class MicrosoftSSOHandler:
|
||||
|
||||
# Fetch user membership from Microsoft Graph API
|
||||
all_group_ids = []
|
||||
next_link: Optional[
|
||||
str
|
||||
] = MicrosoftSSOHandler.graph_api_user_groups_endpoint
|
||||
next_link: Optional[str] = (
|
||||
MicrosoftSSOHandler.graph_api_user_groups_endpoint
|
||||
)
|
||||
auth_headers = {"Authorization": f"Bearer {access_token}"}
|
||||
page_count = 0
|
||||
|
||||
|
||||
Reference in New Issue
Block a user