structure saas with tools
This commit is contained in:
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
1460
.venv/lib/python3.10/site-packages/litellm/proxy/auth/auth_checks.py
Normal file
1460
.venv/lib/python3.10/site-packages/litellm/proxy/auth/auth_checks.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,163 @@
|
||||
"""
|
||||
Auth Checks for Organizations
|
||||
"""
|
||||
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
from fastapi import status
|
||||
|
||||
from litellm.proxy._types import *
|
||||
|
||||
|
||||
def organization_role_based_access_check(
|
||||
request_body: dict,
|
||||
user_object: Optional[LiteLLM_UserTable],
|
||||
route: str,
|
||||
):
|
||||
"""
|
||||
Role based access control checks only run if a user is part of an Organization
|
||||
|
||||
Organization Checks:
|
||||
ONLY RUN IF user_object.organization_memberships is not None
|
||||
|
||||
1. Only Proxy Admins can access /organization/new
|
||||
2. IF route is a LiteLLMRoutes.org_admin_only_routes, then check if user is an Org Admin for that organization
|
||||
|
||||
"""
|
||||
|
||||
if user_object is None:
|
||||
return
|
||||
|
||||
passed_organization_id: Optional[str] = request_body.get("organization_id", None)
|
||||
|
||||
if route == "/organization/new":
|
||||
if user_object.user_role != LitellmUserRoles.PROXY_ADMIN.value:
|
||||
raise ProxyException(
|
||||
message=f"Only proxy admins can create new organizations. You are {user_object.user_role}",
|
||||
type=ProxyErrorTypes.auth_error.value,
|
||||
param="user_role",
|
||||
code=status.HTTP_401_UNAUTHORIZED,
|
||||
)
|
||||
|
||||
if user_object.user_role == LitellmUserRoles.PROXY_ADMIN.value:
|
||||
return
|
||||
|
||||
# Checks if route is an Org Admin Only Route
|
||||
if route in LiteLLMRoutes.org_admin_only_routes.value:
|
||||
(
|
||||
_user_organizations,
|
||||
_user_organization_role_mapping,
|
||||
) = get_user_organization_info(user_object)
|
||||
|
||||
if user_object.organization_memberships is None:
|
||||
raise ProxyException(
|
||||
message=f"Tried to access route={route} but you are not a member of any organization. Please contact the proxy admin to request access.",
|
||||
type=ProxyErrorTypes.auth_error.value,
|
||||
param="organization_id",
|
||||
code=status.HTTP_401_UNAUTHORIZED,
|
||||
)
|
||||
|
||||
if passed_organization_id is None:
|
||||
raise ProxyException(
|
||||
message="Passed organization_id is None, please pass an organization_id in your request",
|
||||
type=ProxyErrorTypes.auth_error.value,
|
||||
param="organization_id",
|
||||
code=status.HTTP_401_UNAUTHORIZED,
|
||||
)
|
||||
|
||||
user_role: Optional[LitellmUserRoles] = _user_organization_role_mapping.get(
|
||||
passed_organization_id
|
||||
)
|
||||
if user_role is None:
|
||||
raise ProxyException(
|
||||
message=f"You do not have a role within the selected organization. Passed organization_id: {passed_organization_id}. Please contact the organization admin to request access.",
|
||||
type=ProxyErrorTypes.auth_error.value,
|
||||
param="organization_id",
|
||||
code=status.HTTP_401_UNAUTHORIZED,
|
||||
)
|
||||
|
||||
if user_role != LitellmUserRoles.ORG_ADMIN.value:
|
||||
raise ProxyException(
|
||||
message=f"You do not have the required role to perform {route} in Organization {passed_organization_id}. Your role is {user_role} in Organization {passed_organization_id}",
|
||||
type=ProxyErrorTypes.auth_error.value,
|
||||
param="user_role",
|
||||
code=status.HTTP_401_UNAUTHORIZED,
|
||||
)
|
||||
elif route == "/team/new":
|
||||
# if user is part of multiple teams, then they need to specify the organization_id
|
||||
(
|
||||
_user_organizations,
|
||||
_user_organization_role_mapping,
|
||||
) = get_user_organization_info(user_object)
|
||||
if (
|
||||
user_object.organization_memberships is not None
|
||||
and len(user_object.organization_memberships) > 0
|
||||
):
|
||||
if passed_organization_id is None:
|
||||
raise ProxyException(
|
||||
message=f"Passed organization_id is None, please specify the organization_id in your request. You are part of multiple organizations: {_user_organizations}",
|
||||
type=ProxyErrorTypes.auth_error.value,
|
||||
param="organization_id",
|
||||
code=status.HTTP_401_UNAUTHORIZED,
|
||||
)
|
||||
|
||||
_user_role_in_passed_org = _user_organization_role_mapping.get(
|
||||
passed_organization_id
|
||||
)
|
||||
if _user_role_in_passed_org != LitellmUserRoles.ORG_ADMIN.value:
|
||||
raise ProxyException(
|
||||
message=f"You do not have the required role to call {route}. Your role is {_user_role_in_passed_org} in Organization {passed_organization_id}",
|
||||
type=ProxyErrorTypes.auth_error.value,
|
||||
param="user_role",
|
||||
code=status.HTTP_401_UNAUTHORIZED,
|
||||
)
|
||||
|
||||
|
||||
def get_user_organization_info(
|
||||
user_object: LiteLLM_UserTable,
|
||||
) -> Tuple[List[str], Dict[str, Optional[LitellmUserRoles]]]:
|
||||
"""
|
||||
Helper function to extract user organization information.
|
||||
|
||||
Args:
|
||||
user_object (LiteLLM_UserTable): The user object containing organization memberships.
|
||||
|
||||
Returns:
|
||||
Tuple[List[str], Dict[str, Optional[LitellmUserRoles]]]: A tuple containing:
|
||||
- List of organization IDs the user is a member of
|
||||
- Dictionary mapping organization IDs to user roles
|
||||
"""
|
||||
_user_organizations: List[str] = []
|
||||
_user_organization_role_mapping: Dict[str, Optional[LitellmUserRoles]] = {}
|
||||
|
||||
if user_object.organization_memberships is not None:
|
||||
for _membership in user_object.organization_memberships:
|
||||
if _membership.organization_id is not None:
|
||||
_user_organizations.append(_membership.organization_id)
|
||||
_user_organization_role_mapping[_membership.organization_id] = _membership.user_role # type: ignore
|
||||
|
||||
return _user_organizations, _user_organization_role_mapping
|
||||
|
||||
|
||||
def _user_is_org_admin(
|
||||
request_data: dict,
|
||||
user_object: Optional[LiteLLM_UserTable] = None,
|
||||
) -> bool:
|
||||
"""
|
||||
Helper function to check if user is an org admin for the passed organization_id
|
||||
"""
|
||||
if request_data.get("organization_id", None) is None:
|
||||
return False
|
||||
|
||||
if user_object is None:
|
||||
return False
|
||||
|
||||
if user_object.organization_memberships is None:
|
||||
return False
|
||||
|
||||
for _membership in user_object.organization_memberships:
|
||||
if _membership.organization_id == request_data.get("organization_id", None):
|
||||
if _membership.user_role == LitellmUserRoles.ORG_ADMIN.value:
|
||||
return True
|
||||
|
||||
return False
|
||||
@@ -0,0 +1,124 @@
|
||||
"""
|
||||
Handles Authentication Errors
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from typing import TYPE_CHECKING, Any, Optional, Union
|
||||
|
||||
from fastapi import HTTPException, Request, status
|
||||
|
||||
import litellm
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.proxy._types import ProxyErrorTypes, ProxyException, UserAPIKeyAuth
|
||||
from litellm.proxy.auth.auth_utils import _get_request_ip_address
|
||||
from litellm.proxy.db.exception_handler import PrismaDBExceptionHandler
|
||||
from litellm.types.services import ServiceTypes
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from opentelemetry.trace import Span as _Span
|
||||
|
||||
Span = Union[_Span, Any]
|
||||
else:
|
||||
Span = Any
|
||||
|
||||
|
||||
class UserAPIKeyAuthExceptionHandler:
|
||||
@staticmethod
|
||||
async def _handle_authentication_error(
|
||||
e: Exception,
|
||||
request: Request,
|
||||
request_data: dict,
|
||||
route: str,
|
||||
parent_otel_span: Optional[Span],
|
||||
api_key: str,
|
||||
) -> UserAPIKeyAuth:
|
||||
"""
|
||||
Handles Connection Errors when reading a Virtual Key from LiteLLM DB
|
||||
Use this if you don't want failed DB queries to block LLM API reqiests
|
||||
|
||||
Reliability scenarios this covers:
|
||||
- DB is down and having an outage
|
||||
- Unable to read / recover a key from the DB
|
||||
|
||||
Returns:
|
||||
- UserAPIKeyAuth: If general_settings.allow_requests_on_db_unavailable is True
|
||||
|
||||
Raises:
|
||||
- Orignal Exception in all other cases
|
||||
"""
|
||||
from litellm.proxy.proxy_server import (
|
||||
general_settings,
|
||||
litellm_proxy_admin_name,
|
||||
proxy_logging_obj,
|
||||
)
|
||||
|
||||
if (
|
||||
PrismaDBExceptionHandler.should_allow_request_on_db_unavailable()
|
||||
and PrismaDBExceptionHandler.is_database_connection_error(e)
|
||||
):
|
||||
# log this as a DB failure on prometheus
|
||||
proxy_logging_obj.service_logging_obj.service_failure_hook(
|
||||
service=ServiceTypes.DB,
|
||||
call_type="get_key_object",
|
||||
error=e,
|
||||
duration=0.0,
|
||||
)
|
||||
|
||||
return UserAPIKeyAuth(
|
||||
key_name="failed-to-connect-to-db",
|
||||
token="failed-to-connect-to-db",
|
||||
user_id=litellm_proxy_admin_name,
|
||||
request_route=route,
|
||||
)
|
||||
else:
|
||||
# raise the exception to the caller
|
||||
requester_ip = _get_request_ip_address(
|
||||
request=request,
|
||||
use_x_forwarded_for=general_settings.get("use_x_forwarded_for", False),
|
||||
)
|
||||
verbose_proxy_logger.exception(
|
||||
"litellm.proxy.proxy_server.user_api_key_auth(): Exception occured - {}\nRequester IP Address:{}".format(
|
||||
str(e),
|
||||
requester_ip,
|
||||
),
|
||||
extra={"requester_ip": requester_ip},
|
||||
)
|
||||
|
||||
# Log this exception to OTEL, Datadog etc
|
||||
user_api_key_dict = UserAPIKeyAuth(
|
||||
parent_otel_span=parent_otel_span,
|
||||
api_key=api_key,
|
||||
request_route=route,
|
||||
)
|
||||
asyncio.create_task(
|
||||
proxy_logging_obj.post_call_failure_hook(
|
||||
request_data=request_data,
|
||||
original_exception=e,
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
error_type=ProxyErrorTypes.auth_error,
|
||||
route=route,
|
||||
)
|
||||
)
|
||||
|
||||
if isinstance(e, litellm.BudgetExceededError):
|
||||
raise ProxyException(
|
||||
message=e.message,
|
||||
type=ProxyErrorTypes.budget_exceeded,
|
||||
param=None,
|
||||
code=400,
|
||||
)
|
||||
if isinstance(e, HTTPException):
|
||||
raise ProxyException(
|
||||
message=getattr(e, "detail", f"Authentication Error({str(e)})"),
|
||||
type=ProxyErrorTypes.auth_error,
|
||||
param=getattr(e, "param", "None"),
|
||||
code=getattr(e, "status_code", status.HTTP_401_UNAUTHORIZED),
|
||||
)
|
||||
elif isinstance(e, ProxyException):
|
||||
raise e
|
||||
raise ProxyException(
|
||||
message="Authentication Error, " + str(e),
|
||||
type=ProxyErrorTypes.auth_error,
|
||||
param=getattr(e, "param", "None"),
|
||||
code=status.HTTP_401_UNAUTHORIZED,
|
||||
)
|
||||
@@ -0,0 +1,513 @@
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
from typing import Any, List, Optional, Tuple
|
||||
|
||||
from fastapi import HTTPException, Request, status
|
||||
|
||||
from litellm import Router, provider_list
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.proxy._types import *
|
||||
from litellm.types.router import CONFIGURABLE_CLIENTSIDE_AUTH_PARAMS
|
||||
|
||||
|
||||
def _get_request_ip_address(
|
||||
request: Request, use_x_forwarded_for: Optional[bool] = False
|
||||
) -> Optional[str]:
|
||||
client_ip = None
|
||||
if use_x_forwarded_for is True and "x-forwarded-for" in request.headers:
|
||||
client_ip = request.headers["x-forwarded-for"]
|
||||
elif request.client is not None:
|
||||
client_ip = request.client.host
|
||||
else:
|
||||
client_ip = ""
|
||||
|
||||
return client_ip
|
||||
|
||||
|
||||
def _check_valid_ip(
|
||||
allowed_ips: Optional[List[str]],
|
||||
request: Request,
|
||||
use_x_forwarded_for: Optional[bool] = False,
|
||||
) -> Tuple[bool, Optional[str]]:
|
||||
"""
|
||||
Returns if ip is allowed or not
|
||||
"""
|
||||
if allowed_ips is None: # if not set, assume true
|
||||
return True, None
|
||||
|
||||
# if general_settings.get("use_x_forwarded_for") is True then use x-forwarded-for
|
||||
client_ip = _get_request_ip_address(
|
||||
request=request, use_x_forwarded_for=use_x_forwarded_for
|
||||
)
|
||||
|
||||
# Check if IP address is allowed
|
||||
if client_ip not in allowed_ips:
|
||||
return False, client_ip
|
||||
|
||||
return True, client_ip
|
||||
|
||||
|
||||
def check_complete_credentials(request_body: dict) -> bool:
|
||||
"""
|
||||
if 'api_base' in request body. Check if complete credentials given. Prevent malicious attacks.
|
||||
"""
|
||||
given_model: Optional[str] = None
|
||||
|
||||
given_model = request_body.get("model")
|
||||
if given_model is None:
|
||||
return False
|
||||
|
||||
if (
|
||||
"sagemaker" in given_model
|
||||
or "bedrock" in given_model
|
||||
or "vertex_ai" in given_model
|
||||
or "vertex_ai_beta" in given_model
|
||||
):
|
||||
# complex credentials - easier to make a malicious request
|
||||
return False
|
||||
|
||||
if "api_key" in request_body:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def check_regex_or_str_match(request_body_value: Any, regex_str: str) -> bool:
|
||||
"""
|
||||
Check if request_body_value matches the regex_str or is equal to param
|
||||
"""
|
||||
if re.match(regex_str, request_body_value) or regex_str == request_body_value:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _is_param_allowed(
|
||||
param: str,
|
||||
request_body_value: Any,
|
||||
configurable_clientside_auth_params: CONFIGURABLE_CLIENTSIDE_AUTH_PARAMS,
|
||||
) -> bool:
|
||||
"""
|
||||
Check if param is a str or dict and if request_body_value is in the list of allowed values
|
||||
"""
|
||||
if configurable_clientside_auth_params is None:
|
||||
return False
|
||||
|
||||
for item in configurable_clientside_auth_params:
|
||||
if isinstance(item, str) and param == item:
|
||||
return True
|
||||
elif isinstance(item, Dict):
|
||||
if param == "api_base" and check_regex_or_str_match(
|
||||
request_body_value=request_body_value,
|
||||
regex_str=item["api_base"],
|
||||
): # assume param is a regex
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def _allow_model_level_clientside_configurable_parameters(
|
||||
model: str, param: str, request_body_value: Any, llm_router: Optional[Router]
|
||||
) -> bool:
|
||||
"""
|
||||
Check if model is allowed to use configurable client-side params
|
||||
- get matching model
|
||||
- check if 'clientside_configurable_parameters' is set for model
|
||||
-
|
||||
"""
|
||||
if llm_router is None:
|
||||
return False
|
||||
# check if model is set
|
||||
model_info = llm_router.get_model_group_info(model_group=model)
|
||||
if model_info is None:
|
||||
# check if wildcard model is set
|
||||
if model.split("/", 1)[0] in provider_list:
|
||||
model_info = llm_router.get_model_group_info(
|
||||
model_group=model.split("/", 1)[0]
|
||||
)
|
||||
|
||||
if model_info is None:
|
||||
return False
|
||||
|
||||
if model_info is None or model_info.configurable_clientside_auth_params is None:
|
||||
return False
|
||||
|
||||
return _is_param_allowed(
|
||||
param=param,
|
||||
request_body_value=request_body_value,
|
||||
configurable_clientside_auth_params=model_info.configurable_clientside_auth_params,
|
||||
)
|
||||
|
||||
|
||||
def is_request_body_safe(
|
||||
request_body: dict, general_settings: dict, llm_router: Optional[Router], model: str
|
||||
) -> bool:
|
||||
"""
|
||||
Check if the request body is safe.
|
||||
|
||||
A malicious user can set the api_base to their own domain and invoke POST /chat/completions to intercept and steal the OpenAI API key.
|
||||
Relevant issue: https://huntr.com/bounties/4001e1a2-7b7a-4776-a3ae-e6692ec3d997
|
||||
"""
|
||||
banned_params = ["api_base", "base_url"]
|
||||
|
||||
for param in banned_params:
|
||||
if (
|
||||
param in request_body
|
||||
and not check_complete_credentials( # allow client-credentials to be passed to proxy
|
||||
request_body=request_body
|
||||
)
|
||||
):
|
||||
if general_settings.get("allow_client_side_credentials") is True:
|
||||
return True
|
||||
elif (
|
||||
_allow_model_level_clientside_configurable_parameters(
|
||||
model=model,
|
||||
param=param,
|
||||
request_body_value=request_body[param],
|
||||
llm_router=llm_router,
|
||||
)
|
||||
is True
|
||||
):
|
||||
return True
|
||||
raise ValueError(
|
||||
f"Rejected Request: {param} is not allowed in request body. "
|
||||
"Enable with `general_settings::allow_client_side_credentials` on proxy config.yaml. "
|
||||
"Relevant Issue: https://huntr.com/bounties/4001e1a2-7b7a-4776-a3ae-e6692ec3d997",
|
||||
)
|
||||
|
||||
return True
|
||||
|
||||
|
||||
async def pre_db_read_auth_checks(
|
||||
request: Request,
|
||||
request_data: dict,
|
||||
route: str,
|
||||
):
|
||||
"""
|
||||
1. Checks if request size is under max_request_size_mb (if set)
|
||||
2. Check if request body is safe (example user has not set api_base in request body)
|
||||
3. Check if IP address is allowed (if set)
|
||||
4. Check if request route is an allowed route on the proxy (if set)
|
||||
|
||||
Returns:
|
||||
- True
|
||||
|
||||
Raises:
|
||||
- HTTPException if request fails initial auth checks
|
||||
"""
|
||||
from litellm.proxy.proxy_server import general_settings, llm_router, premium_user
|
||||
|
||||
# Check 1. request size
|
||||
await check_if_request_size_is_safe(request=request)
|
||||
|
||||
# Check 2. Request body is safe
|
||||
is_request_body_safe(
|
||||
request_body=request_data,
|
||||
general_settings=general_settings,
|
||||
llm_router=llm_router,
|
||||
model=request_data.get(
|
||||
"model", ""
|
||||
), # [TODO] use model passed in url as well (azure openai routes)
|
||||
)
|
||||
|
||||
# Check 3. Check if IP address is allowed
|
||||
is_valid_ip, passed_in_ip = _check_valid_ip(
|
||||
allowed_ips=general_settings.get("allowed_ips", None),
|
||||
use_x_forwarded_for=general_settings.get("use_x_forwarded_for", False),
|
||||
request=request,
|
||||
)
|
||||
|
||||
if not is_valid_ip:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail=f"Access forbidden: IP address {passed_in_ip} not allowed.",
|
||||
)
|
||||
|
||||
# Check 4. Check if request route is an allowed route on the proxy
|
||||
if "allowed_routes" in general_settings:
|
||||
_allowed_routes = general_settings["allowed_routes"]
|
||||
if premium_user is not True:
|
||||
verbose_proxy_logger.error(
|
||||
f"Trying to set allowed_routes. This is an Enterprise feature. {CommonProxyErrors.not_premium_user.value}"
|
||||
)
|
||||
if route not in _allowed_routes:
|
||||
verbose_proxy_logger.error(
|
||||
f"Route {route} not in allowed_routes={_allowed_routes}"
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail=f"Access forbidden: Route {route} not allowed",
|
||||
)
|
||||
|
||||
|
||||
def route_in_additonal_public_routes(current_route: str):
|
||||
"""
|
||||
Helper to check if the user defined public_routes on config.yaml
|
||||
|
||||
Parameters:
|
||||
- current_route: str - the route the user is trying to call
|
||||
|
||||
Returns:
|
||||
- bool - True if the route is defined in public_routes
|
||||
- bool - False if the route is not defined in public_routes
|
||||
|
||||
|
||||
In order to use this the litellm config.yaml should have the following in general_settings:
|
||||
|
||||
```yaml
|
||||
general_settings:
|
||||
master_key: sk-1234
|
||||
public_routes: ["LiteLLMRoutes.public_routes", "/spend/calculate"]
|
||||
```
|
||||
"""
|
||||
|
||||
# check if user is premium_user - if not do nothing
|
||||
from litellm.proxy.proxy_server import general_settings, premium_user
|
||||
|
||||
try:
|
||||
if premium_user is not True:
|
||||
return False
|
||||
# check if this is defined on the config
|
||||
if general_settings is None:
|
||||
return False
|
||||
|
||||
routes_defined = general_settings.get("public_routes", [])
|
||||
if current_route in routes_defined:
|
||||
return True
|
||||
|
||||
return False
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.error(f"route_in_additonal_public_routes: {str(e)}")
|
||||
return False
|
||||
|
||||
|
||||
def get_request_route(request: Request) -> str:
|
||||
"""
|
||||
Helper to get the route from the request
|
||||
|
||||
remove base url from path if set e.g. `/genai/chat/completions` -> `/chat/completions
|
||||
"""
|
||||
try:
|
||||
if hasattr(request, "base_url") and request.url.path.startswith(
|
||||
request.base_url.path
|
||||
):
|
||||
# remove base_url from path
|
||||
return request.url.path[len(request.base_url.path) - 1 :]
|
||||
else:
|
||||
return request.url.path
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.debug(
|
||||
f"error on get_request_route: {str(e)}, defaulting to request.url.path={request.url.path}"
|
||||
)
|
||||
return request.url.path
|
||||
|
||||
|
||||
async def check_if_request_size_is_safe(request: Request) -> bool:
|
||||
"""
|
||||
Enterprise Only:
|
||||
- Checks if the request size is within the limit
|
||||
|
||||
Args:
|
||||
request (Request): The incoming request.
|
||||
|
||||
Returns:
|
||||
bool: True if the request size is within the limit
|
||||
|
||||
Raises:
|
||||
ProxyException: If the request size is too large
|
||||
|
||||
"""
|
||||
from litellm.proxy.proxy_server import general_settings, premium_user
|
||||
|
||||
max_request_size_mb = general_settings.get("max_request_size_mb", None)
|
||||
|
||||
if max_request_size_mb is not None:
|
||||
# Check if premium user
|
||||
if premium_user is not True:
|
||||
verbose_proxy_logger.warning(
|
||||
f"using max_request_size_mb - not checking - this is an enterprise only feature. {CommonProxyErrors.not_premium_user.value}"
|
||||
)
|
||||
return True
|
||||
|
||||
# Get the request body
|
||||
content_length = request.headers.get("content-length")
|
||||
|
||||
if content_length:
|
||||
header_size = int(content_length)
|
||||
header_size_mb = bytes_to_mb(bytes_value=header_size)
|
||||
verbose_proxy_logger.debug(
|
||||
f"content_length request size in MB={header_size_mb}"
|
||||
)
|
||||
|
||||
if header_size_mb > max_request_size_mb:
|
||||
raise ProxyException(
|
||||
message=f"Request size is too large. Request size is {header_size_mb} MB. Max size is {max_request_size_mb} MB",
|
||||
type=ProxyErrorTypes.bad_request_error.value,
|
||||
code=400,
|
||||
param="content-length",
|
||||
)
|
||||
else:
|
||||
# If Content-Length is not available, read the body
|
||||
body = await request.body()
|
||||
body_size = len(body)
|
||||
request_size_mb = bytes_to_mb(bytes_value=body_size)
|
||||
|
||||
verbose_proxy_logger.debug(
|
||||
f"request body request size in MB={request_size_mb}"
|
||||
)
|
||||
if request_size_mb > max_request_size_mb:
|
||||
raise ProxyException(
|
||||
message=f"Request size is too large. Request size is {request_size_mb} MB. Max size is {max_request_size_mb} MB",
|
||||
type=ProxyErrorTypes.bad_request_error.value,
|
||||
code=400,
|
||||
param="content-length",
|
||||
)
|
||||
|
||||
return True
|
||||
|
||||
|
||||
async def check_response_size_is_safe(response: Any) -> bool:
|
||||
"""
|
||||
Enterprise Only:
|
||||
- Checks if the response size is within the limit
|
||||
|
||||
Args:
|
||||
response (Any): The response to check.
|
||||
|
||||
Returns:
|
||||
bool: True if the response size is within the limit
|
||||
|
||||
Raises:
|
||||
ProxyException: If the response size is too large
|
||||
|
||||
"""
|
||||
|
||||
from litellm.proxy.proxy_server import general_settings, premium_user
|
||||
|
||||
max_response_size_mb = general_settings.get("max_response_size_mb", None)
|
||||
if max_response_size_mb is not None:
|
||||
# Check if premium user
|
||||
if premium_user is not True:
|
||||
verbose_proxy_logger.warning(
|
||||
f"using max_response_size_mb - not checking - this is an enterprise only feature. {CommonProxyErrors.not_premium_user.value}"
|
||||
)
|
||||
return True
|
||||
|
||||
response_size_mb = bytes_to_mb(bytes_value=sys.getsizeof(response))
|
||||
verbose_proxy_logger.debug(f"response size in MB={response_size_mb}")
|
||||
if response_size_mb > max_response_size_mb:
|
||||
raise ProxyException(
|
||||
message=f"Response size is too large. Response size is {response_size_mb} MB. Max size is {max_response_size_mb} MB",
|
||||
type=ProxyErrorTypes.bad_request_error.value,
|
||||
code=400,
|
||||
param="content-length",
|
||||
)
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def bytes_to_mb(bytes_value: int):
|
||||
"""
|
||||
Helper to convert bytes to MB
|
||||
"""
|
||||
return bytes_value / (1024 * 1024)
|
||||
|
||||
|
||||
# helpers used by parallel request limiter to handle model rpm/tpm limits for a given api key
|
||||
def get_key_model_rpm_limit(
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
) -> Optional[Dict[str, int]]:
|
||||
if user_api_key_dict.metadata:
|
||||
if "model_rpm_limit" in user_api_key_dict.metadata:
|
||||
return user_api_key_dict.metadata["model_rpm_limit"]
|
||||
elif user_api_key_dict.model_max_budget:
|
||||
model_rpm_limit: Dict[str, Any] = {}
|
||||
for model, budget in user_api_key_dict.model_max_budget.items():
|
||||
if "rpm_limit" in budget and budget["rpm_limit"] is not None:
|
||||
model_rpm_limit[model] = budget["rpm_limit"]
|
||||
return model_rpm_limit
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def get_key_model_tpm_limit(
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
) -> Optional[Dict[str, int]]:
|
||||
if user_api_key_dict.metadata:
|
||||
if "model_tpm_limit" in user_api_key_dict.metadata:
|
||||
return user_api_key_dict.metadata["model_tpm_limit"]
|
||||
elif user_api_key_dict.model_max_budget:
|
||||
if "tpm_limit" in user_api_key_dict.model_max_budget:
|
||||
return user_api_key_dict.model_max_budget["tpm_limit"]
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def is_pass_through_provider_route(route: str) -> bool:
|
||||
PROVIDER_SPECIFIC_PASS_THROUGH_ROUTES = [
|
||||
"vertex-ai",
|
||||
]
|
||||
|
||||
# check if any of the prefixes are in the route
|
||||
for prefix in PROVIDER_SPECIFIC_PASS_THROUGH_ROUTES:
|
||||
if prefix in route:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def should_run_auth_on_pass_through_provider_route(route: str) -> bool:
|
||||
"""
|
||||
Use this to decide if the rest of the LiteLLM Virtual Key auth checks should run on /vertex-ai/{endpoint} routes
|
||||
Use this to decide if the rest of the LiteLLM Virtual Key auth checks should run on provider pass through routes
|
||||
ex /vertex-ai/{endpoint} routes
|
||||
Run virtual key auth if the following is try:
|
||||
- User is premium_user
|
||||
- User has enabled litellm_setting.use_client_credentials_pass_through_routes
|
||||
"""
|
||||
from litellm.proxy.proxy_server import general_settings, premium_user
|
||||
|
||||
if premium_user is not True:
|
||||
return False
|
||||
|
||||
# premium use has opted into using client credentials
|
||||
if (
|
||||
general_settings.get("use_client_credentials_pass_through_routes", False)
|
||||
is True
|
||||
):
|
||||
return False
|
||||
|
||||
# only enabled for LiteLLM Enterprise
|
||||
return True
|
||||
|
||||
|
||||
def _has_user_setup_sso():
|
||||
"""
|
||||
Check if the user has set up single sign-on (SSO) by verifying the presence of Microsoft client ID, Google client ID or generic client ID and UI username environment variables.
|
||||
Returns a boolean indicating whether SSO has been set up.
|
||||
"""
|
||||
microsoft_client_id = os.getenv("MICROSOFT_CLIENT_ID", None)
|
||||
google_client_id = os.getenv("GOOGLE_CLIENT_ID", None)
|
||||
generic_client_id = os.getenv("GENERIC_CLIENT_ID", None)
|
||||
|
||||
sso_setup = (
|
||||
(microsoft_client_id is not None)
|
||||
or (google_client_id is not None)
|
||||
or (generic_client_id is not None)
|
||||
)
|
||||
|
||||
return sso_setup
|
||||
|
||||
|
||||
def get_end_user_id_from_request_body(request_body: dict) -> Optional[str]:
|
||||
# openai - check 'user'
|
||||
if "user" in request_body and request_body["user"] is not None:
|
||||
return str(request_body["user"])
|
||||
# anthropic - check 'litellm_metadata'
|
||||
end_user_id = request_body.get("litellm_metadata", {}).get("user", None)
|
||||
if end_user_id:
|
||||
return str(end_user_id)
|
||||
metadata = request_body.get("metadata")
|
||||
if metadata and "user_id" in metadata and metadata["user_id"] is not None:
|
||||
return str(metadata["user_id"])
|
||||
return None
|
||||
@@ -0,0 +1,998 @@
|
||||
"""
|
||||
Supports using JWT's for authenticating into the proxy.
|
||||
|
||||
Currently only supports admin.
|
||||
|
||||
JWT token must have 'litellm_proxy_admin' in scope.
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
from typing import Any, List, Literal, Optional, Set, Tuple, cast
|
||||
|
||||
from cryptography import x509
|
||||
from cryptography.hazmat.backends import default_backend
|
||||
from cryptography.hazmat.primitives import serialization
|
||||
from fastapi import HTTPException
|
||||
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.caching.caching import DualCache
|
||||
from litellm.litellm_core_utils.dot_notation_indexing import get_nested_value
|
||||
from litellm.llms.custom_httpx.httpx_handler import HTTPHandler
|
||||
from litellm.proxy._types import (
|
||||
RBAC_ROLES,
|
||||
JWKKeyValue,
|
||||
JWTAuthBuilderResult,
|
||||
JWTKeyItem,
|
||||
LiteLLM_EndUserTable,
|
||||
LiteLLM_JWTAuth,
|
||||
LiteLLM_OrganizationTable,
|
||||
LiteLLM_TeamTable,
|
||||
LiteLLM_UserTable,
|
||||
LitellmUserRoles,
|
||||
ScopeMapping,
|
||||
Span,
|
||||
)
|
||||
from litellm.proxy.auth.auth_checks import can_team_access_model
|
||||
from litellm.proxy.utils import PrismaClient, ProxyLogging
|
||||
|
||||
from .auth_checks import (
|
||||
_allowed_routes_check,
|
||||
allowed_routes_check,
|
||||
get_actual_routes,
|
||||
get_end_user_object,
|
||||
get_org_object,
|
||||
get_role_based_models,
|
||||
get_role_based_routes,
|
||||
get_team_object,
|
||||
get_user_object,
|
||||
)
|
||||
|
||||
|
||||
class JWTHandler:
|
||||
"""
|
||||
- treat the sub id passed in as the user id
|
||||
- return an error if id making request doesn't exist in proxy user table
|
||||
- track spend against the user id
|
||||
- if role="litellm_proxy_user" -> allow making calls + info. Can not edit budgets
|
||||
"""
|
||||
|
||||
prisma_client: Optional[PrismaClient]
|
||||
user_api_key_cache: DualCache
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
) -> None:
|
||||
self.http_handler = HTTPHandler()
|
||||
self.leeway = 0
|
||||
|
||||
def update_environment(
|
||||
self,
|
||||
prisma_client: Optional[PrismaClient],
|
||||
user_api_key_cache: DualCache,
|
||||
litellm_jwtauth: LiteLLM_JWTAuth,
|
||||
leeway: int = 0,
|
||||
) -> None:
|
||||
self.prisma_client = prisma_client
|
||||
self.user_api_key_cache = user_api_key_cache
|
||||
self.litellm_jwtauth = litellm_jwtauth
|
||||
self.leeway = leeway
|
||||
|
||||
def is_jwt(self, token: str):
|
||||
parts = token.split(".")
|
||||
return len(parts) == 3
|
||||
|
||||
def _rbac_role_from_role_mapping(self, token: dict) -> Optional[RBAC_ROLES]:
|
||||
"""
|
||||
Returns the RBAC role the token 'belongs' to based on role mappings.
|
||||
|
||||
Args:
|
||||
token (dict): The JWT token containing role information
|
||||
|
||||
Returns:
|
||||
Optional[RBAC_ROLES]: The mapped internal RBAC role if a mapping exists,
|
||||
None otherwise
|
||||
|
||||
Note:
|
||||
The function handles both single string roles and lists of roles from the JWT.
|
||||
If multiple mappings match the JWT roles, the first matching mapping is returned.
|
||||
"""
|
||||
if self.litellm_jwtauth.role_mappings is None:
|
||||
return None
|
||||
|
||||
jwt_role = self.get_jwt_role(token=token, default_value=None)
|
||||
if not jwt_role:
|
||||
return None
|
||||
|
||||
jwt_role_set = set(jwt_role)
|
||||
|
||||
for role_mapping in self.litellm_jwtauth.role_mappings:
|
||||
# Check if the mapping role matches any of the JWT roles
|
||||
if role_mapping.role in jwt_role_set:
|
||||
return role_mapping.internal_role
|
||||
|
||||
return None
|
||||
|
||||
def get_rbac_role(self, token: dict) -> Optional[RBAC_ROLES]:
|
||||
"""
|
||||
Returns the RBAC role the token 'belongs' to.
|
||||
|
||||
RBAC roles allowed to make requests:
|
||||
- PROXY_ADMIN: can make requests to all routes
|
||||
- TEAM: can make requests to routes associated with a team
|
||||
- INTERNAL_USER: can make requests to routes associated with a user
|
||||
|
||||
Resolves: https://github.com/BerriAI/litellm/issues/6793
|
||||
|
||||
Returns:
|
||||
- PROXY_ADMIN: if token is admin
|
||||
- TEAM: if token is associated with a team
|
||||
- INTERNAL_USER: if token is associated with a user
|
||||
- None: if token is not associated with a team or user
|
||||
"""
|
||||
scopes = self.get_scopes(token=token)
|
||||
is_admin = self.is_admin(scopes=scopes)
|
||||
user_roles = self.get_user_roles(token=token, default_value=None)
|
||||
|
||||
if is_admin:
|
||||
return LitellmUserRoles.PROXY_ADMIN
|
||||
elif self.get_team_id(token=token, default_value=None) is not None:
|
||||
return LitellmUserRoles.TEAM
|
||||
elif self.get_user_id(token=token, default_value=None) is not None:
|
||||
return LitellmUserRoles.INTERNAL_USER
|
||||
elif user_roles is not None and self.is_allowed_user_role(
|
||||
user_roles=user_roles
|
||||
):
|
||||
return LitellmUserRoles.INTERNAL_USER
|
||||
elif rbac_role := self._rbac_role_from_role_mapping(token=token):
|
||||
return rbac_role
|
||||
|
||||
return None
|
||||
|
||||
def is_admin(self, scopes: list) -> bool:
|
||||
if self.litellm_jwtauth.admin_jwt_scope in scopes:
|
||||
return True
|
||||
return False
|
||||
|
||||
def get_team_ids_from_jwt(self, token: dict) -> List[str]:
|
||||
if (
|
||||
self.litellm_jwtauth.team_ids_jwt_field is not None
|
||||
and token.get(self.litellm_jwtauth.team_ids_jwt_field) is not None
|
||||
):
|
||||
return token[self.litellm_jwtauth.team_ids_jwt_field]
|
||||
return []
|
||||
|
||||
def get_end_user_id(
|
||||
self, token: dict, default_value: Optional[str]
|
||||
) -> Optional[str]:
|
||||
try:
|
||||
if self.litellm_jwtauth.end_user_id_jwt_field is not None:
|
||||
user_id = token[self.litellm_jwtauth.end_user_id_jwt_field]
|
||||
else:
|
||||
user_id = None
|
||||
except KeyError:
|
||||
user_id = default_value
|
||||
|
||||
return user_id
|
||||
|
||||
def is_required_team_id(self) -> bool:
|
||||
"""
|
||||
Returns:
|
||||
- True: if 'team_id_jwt_field' is set
|
||||
- False: if not
|
||||
"""
|
||||
if self.litellm_jwtauth.team_id_jwt_field is None:
|
||||
return False
|
||||
return True
|
||||
|
||||
def is_enforced_email_domain(self) -> bool:
|
||||
"""
|
||||
Returns:
|
||||
- True: if 'user_allowed_email_domain' is set
|
||||
- False: if 'user_allowed_email_domain' is None
|
||||
"""
|
||||
|
||||
if self.litellm_jwtauth.user_allowed_email_domain is not None and isinstance(
|
||||
self.litellm_jwtauth.user_allowed_email_domain, str
|
||||
):
|
||||
return True
|
||||
return False
|
||||
|
||||
def get_team_id(self, token: dict, default_value: Optional[str]) -> Optional[str]:
|
||||
try:
|
||||
if self.litellm_jwtauth.team_id_jwt_field is not None:
|
||||
team_id = token[self.litellm_jwtauth.team_id_jwt_field]
|
||||
elif self.litellm_jwtauth.team_id_default is not None:
|
||||
team_id = self.litellm_jwtauth.team_id_default
|
||||
else:
|
||||
team_id = None
|
||||
except KeyError:
|
||||
team_id = default_value
|
||||
return team_id
|
||||
|
||||
def is_upsert_user_id(self, valid_user_email: Optional[bool] = None) -> bool:
|
||||
"""
|
||||
Returns:
|
||||
- True: if 'user_id_upsert' is set AND valid_user_email is not False
|
||||
- False: if not
|
||||
"""
|
||||
if valid_user_email is False:
|
||||
return False
|
||||
return self.litellm_jwtauth.user_id_upsert
|
||||
|
||||
def get_user_id(self, token: dict, default_value: Optional[str]) -> Optional[str]:
|
||||
try:
|
||||
if self.litellm_jwtauth.user_id_jwt_field is not None:
|
||||
user_id = token[self.litellm_jwtauth.user_id_jwt_field]
|
||||
else:
|
||||
user_id = default_value
|
||||
except KeyError:
|
||||
user_id = default_value
|
||||
return user_id
|
||||
|
||||
def get_user_roles(
|
||||
self, token: dict, default_value: Optional[List[str]]
|
||||
) -> Optional[List[str]]:
|
||||
"""
|
||||
Returns the user role from the token.
|
||||
|
||||
Set via 'user_roles_jwt_field' in the config.
|
||||
"""
|
||||
try:
|
||||
if self.litellm_jwtauth.user_roles_jwt_field is not None:
|
||||
user_roles = get_nested_value(
|
||||
data=token,
|
||||
key_path=self.litellm_jwtauth.user_roles_jwt_field,
|
||||
default=default_value,
|
||||
)
|
||||
else:
|
||||
user_roles = default_value
|
||||
except KeyError:
|
||||
user_roles = default_value
|
||||
return user_roles
|
||||
|
||||
def get_jwt_role(
|
||||
self, token: dict, default_value: Optional[List[str]]
|
||||
) -> Optional[List[str]]:
|
||||
"""
|
||||
Generic implementation of `get_user_roles` that can be used for both user and team roles.
|
||||
|
||||
Returns the jwt role from the token.
|
||||
|
||||
Set via 'roles_jwt_field' in the config.
|
||||
"""
|
||||
try:
|
||||
if self.litellm_jwtauth.roles_jwt_field is not None:
|
||||
user_roles = get_nested_value(
|
||||
data=token,
|
||||
key_path=self.litellm_jwtauth.roles_jwt_field,
|
||||
default=default_value,
|
||||
)
|
||||
else:
|
||||
user_roles = default_value
|
||||
except KeyError:
|
||||
user_roles = default_value
|
||||
return user_roles
|
||||
|
||||
def is_allowed_user_role(self, user_roles: Optional[List[str]]) -> bool:
|
||||
"""
|
||||
Returns the user role from the token.
|
||||
|
||||
Set via 'user_allowed_roles' in the config.
|
||||
"""
|
||||
if (
|
||||
user_roles is not None
|
||||
and self.litellm_jwtauth.user_allowed_roles is not None
|
||||
and any(
|
||||
role in self.litellm_jwtauth.user_allowed_roles for role in user_roles
|
||||
)
|
||||
):
|
||||
return True
|
||||
return False
|
||||
|
||||
def get_user_email(
|
||||
self, token: dict, default_value: Optional[str]
|
||||
) -> Optional[str]:
|
||||
try:
|
||||
if self.litellm_jwtauth.user_email_jwt_field is not None:
|
||||
user_email = token[self.litellm_jwtauth.user_email_jwt_field]
|
||||
else:
|
||||
user_email = None
|
||||
except KeyError:
|
||||
user_email = default_value
|
||||
return user_email
|
||||
|
||||
def get_object_id(self, token: dict, default_value: Optional[str]) -> Optional[str]:
|
||||
try:
|
||||
if self.litellm_jwtauth.object_id_jwt_field is not None:
|
||||
object_id = token[self.litellm_jwtauth.object_id_jwt_field]
|
||||
else:
|
||||
object_id = default_value
|
||||
except KeyError:
|
||||
object_id = default_value
|
||||
return object_id
|
||||
|
||||
def get_org_id(self, token: dict, default_value: Optional[str]) -> Optional[str]:
|
||||
try:
|
||||
if self.litellm_jwtauth.org_id_jwt_field is not None:
|
||||
org_id = token[self.litellm_jwtauth.org_id_jwt_field]
|
||||
else:
|
||||
org_id = None
|
||||
except KeyError:
|
||||
org_id = default_value
|
||||
return org_id
|
||||
|
||||
def get_scopes(self, token: dict) -> List[str]:
|
||||
try:
|
||||
if isinstance(token["scope"], str):
|
||||
# Assuming the scopes are stored in 'scope' claim and are space-separated
|
||||
scopes = token["scope"].split()
|
||||
elif isinstance(token["scope"], list):
|
||||
scopes = token["scope"]
|
||||
else:
|
||||
raise Exception(
|
||||
f"Unmapped scope type - {type(token['scope'])}. Supported types - list, str."
|
||||
)
|
||||
except KeyError:
|
||||
scopes = []
|
||||
return scopes
|
||||
|
||||
async def get_public_key(self, kid: Optional[str]) -> dict:
|
||||
keys_url = os.getenv("JWT_PUBLIC_KEY_URL")
|
||||
|
||||
if keys_url is None:
|
||||
raise Exception("Missing JWT Public Key URL from environment.")
|
||||
|
||||
keys_url_list = [url.strip() for url in keys_url.split(",")]
|
||||
|
||||
for key_url in keys_url_list:
|
||||
cache_key = f"litellm_jwt_auth_keys_{key_url}"
|
||||
|
||||
cached_keys = await self.user_api_key_cache.async_get_cache(cache_key)
|
||||
|
||||
if cached_keys is None:
|
||||
response = await self.http_handler.get(key_url)
|
||||
|
||||
response_json = response.json()
|
||||
if "keys" in response_json:
|
||||
keys: JWKKeyValue = response.json()["keys"]
|
||||
else:
|
||||
keys = response_json
|
||||
|
||||
await self.user_api_key_cache.async_set_cache(
|
||||
key=cache_key,
|
||||
value=keys,
|
||||
ttl=self.litellm_jwtauth.public_key_ttl, # cache for 10 mins
|
||||
)
|
||||
else:
|
||||
keys = cached_keys
|
||||
|
||||
public_key = self.parse_keys(keys=keys, kid=kid)
|
||||
if public_key is not None:
|
||||
return cast(dict, public_key)
|
||||
|
||||
raise Exception(
|
||||
f"No matching public key found. keys={keys_url_list}, kid={kid}"
|
||||
)
|
||||
|
||||
def parse_keys(self, keys: JWKKeyValue, kid: Optional[str]) -> Optional[JWTKeyItem]:
|
||||
public_key: Optional[JWTKeyItem] = None
|
||||
if len(keys) == 1:
|
||||
if isinstance(keys, dict) and (keys.get("kid", None) == kid or kid is None):
|
||||
public_key = keys
|
||||
elif isinstance(keys, list) and (
|
||||
keys[0].get("kid", None) == kid or kid is None
|
||||
):
|
||||
public_key = keys[0]
|
||||
elif len(keys) > 1:
|
||||
for key in keys:
|
||||
if isinstance(key, dict):
|
||||
key_kid = key.get("kid", None)
|
||||
else:
|
||||
key_kid = None
|
||||
if (
|
||||
kid is not None
|
||||
and isinstance(key, dict)
|
||||
and key_kid is not None
|
||||
and key_kid == kid
|
||||
):
|
||||
public_key = key
|
||||
|
||||
return public_key
|
||||
|
||||
def is_allowed_domain(self, user_email: str) -> bool:
|
||||
if self.litellm_jwtauth.user_allowed_email_domain is None:
|
||||
return True
|
||||
|
||||
email_domain = user_email.split("@")[-1] # Extract domain from email
|
||||
if email_domain == self.litellm_jwtauth.user_allowed_email_domain:
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
async def auth_jwt(self, token: str) -> dict:
|
||||
# Supported algos: https://pyjwt.readthedocs.io/en/stable/algorithms.html
|
||||
# "Warning: Make sure not to mix symmetric and asymmetric algorithms that interpret
|
||||
# the key in different ways (e.g. HS* and RS*)."
|
||||
algorithms = ["RS256", "RS384", "RS512", "PS256", "PS384", "PS512"]
|
||||
|
||||
audience = os.getenv("JWT_AUDIENCE")
|
||||
decode_options = None
|
||||
if audience is None:
|
||||
decode_options = {"verify_aud": False}
|
||||
|
||||
import jwt
|
||||
from jwt.algorithms import RSAAlgorithm
|
||||
|
||||
header = jwt.get_unverified_header(token)
|
||||
|
||||
verbose_proxy_logger.debug("header: %s", header)
|
||||
|
||||
kid = header.get("kid", None)
|
||||
|
||||
public_key = await self.get_public_key(kid=kid)
|
||||
|
||||
if public_key is not None and isinstance(public_key, dict):
|
||||
jwk = {}
|
||||
if "kty" in public_key:
|
||||
jwk["kty"] = public_key["kty"]
|
||||
if "kid" in public_key:
|
||||
jwk["kid"] = public_key["kid"]
|
||||
if "n" in public_key:
|
||||
jwk["n"] = public_key["n"]
|
||||
if "e" in public_key:
|
||||
jwk["e"] = public_key["e"]
|
||||
|
||||
public_key_rsa = RSAAlgorithm.from_jwk(json.dumps(jwk))
|
||||
|
||||
try:
|
||||
# decode the token using the public key
|
||||
payload = jwt.decode(
|
||||
token,
|
||||
public_key_rsa, # type: ignore
|
||||
algorithms=algorithms,
|
||||
options=decode_options,
|
||||
audience=audience,
|
||||
leeway=self.leeway, # allow testing of expired tokens
|
||||
)
|
||||
return payload
|
||||
|
||||
except jwt.ExpiredSignatureError:
|
||||
# the token is expired, do something to refresh it
|
||||
raise Exception("Token Expired")
|
||||
except Exception as e:
|
||||
raise Exception(f"Validation fails: {str(e)}")
|
||||
elif public_key is not None and isinstance(public_key, str):
|
||||
try:
|
||||
cert = x509.load_pem_x509_certificate(
|
||||
public_key.encode(), default_backend()
|
||||
)
|
||||
|
||||
# Extract public key
|
||||
key = cert.public_key().public_bytes(
|
||||
serialization.Encoding.PEM,
|
||||
serialization.PublicFormat.SubjectPublicKeyInfo,
|
||||
)
|
||||
|
||||
# decode the token using the public key
|
||||
payload = jwt.decode(
|
||||
token,
|
||||
key,
|
||||
algorithms=algorithms,
|
||||
audience=audience,
|
||||
options=decode_options,
|
||||
)
|
||||
return payload
|
||||
|
||||
except jwt.ExpiredSignatureError:
|
||||
# the token is expired, do something to refresh it
|
||||
raise Exception("Token Expired")
|
||||
except Exception as e:
|
||||
raise Exception(f"Validation fails: {str(e)}")
|
||||
|
||||
raise Exception("Invalid JWT Submitted")
|
||||
|
||||
async def close(self):
|
||||
await self.http_handler.close()
|
||||
|
||||
|
||||
class JWTAuthManager:
|
||||
"""Manages JWT authentication and authorization operations"""
|
||||
|
||||
@staticmethod
|
||||
def can_rbac_role_call_route(
|
||||
rbac_role: RBAC_ROLES,
|
||||
general_settings: dict,
|
||||
route: str,
|
||||
) -> Literal[True]:
|
||||
"""
|
||||
Checks if user is allowed to access the route, based on their role.
|
||||
"""
|
||||
role_based_routes = get_role_based_routes(
|
||||
rbac_role=rbac_role, general_settings=general_settings
|
||||
)
|
||||
|
||||
if role_based_routes is None or route is None:
|
||||
return True
|
||||
|
||||
is_allowed = _allowed_routes_check(
|
||||
user_route=route,
|
||||
allowed_routes=role_based_routes,
|
||||
)
|
||||
|
||||
if not is_allowed:
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail=f"Role={rbac_role} not allowed to call route={route}. Allowed routes={role_based_routes}",
|
||||
)
|
||||
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def can_rbac_role_call_model(
|
||||
rbac_role: RBAC_ROLES,
|
||||
general_settings: dict,
|
||||
model: Optional[str],
|
||||
) -> Literal[True]:
|
||||
"""
|
||||
Checks if user is allowed to access the model, based on their role.
|
||||
"""
|
||||
role_based_models = get_role_based_models(
|
||||
rbac_role=rbac_role, general_settings=general_settings
|
||||
)
|
||||
if role_based_models is None or model is None:
|
||||
return True
|
||||
|
||||
if model not in role_based_models:
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail=f"Role={rbac_role} not allowed to call model={model}. Allowed models={role_based_models}",
|
||||
)
|
||||
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def check_scope_based_access(
|
||||
scope_mappings: List[ScopeMapping],
|
||||
scopes: List[str],
|
||||
request_data: dict,
|
||||
general_settings: dict,
|
||||
) -> None:
|
||||
"""
|
||||
Check if scope allows access to the requested model
|
||||
"""
|
||||
if not scope_mappings:
|
||||
return None
|
||||
|
||||
allowed_models = []
|
||||
for sm in scope_mappings:
|
||||
if sm.scope in scopes and sm.models:
|
||||
allowed_models.extend(sm.models)
|
||||
|
||||
requested_model = request_data.get("model")
|
||||
|
||||
if not requested_model:
|
||||
return None
|
||||
|
||||
if requested_model not in allowed_models:
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail={
|
||||
"error": "model={} not allowed. Allowed_models={}".format(
|
||||
requested_model, allowed_models
|
||||
)
|
||||
},
|
||||
)
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
async def check_rbac_role(
|
||||
jwt_handler: JWTHandler,
|
||||
jwt_valid_token: dict,
|
||||
general_settings: dict,
|
||||
request_data: dict,
|
||||
route: str,
|
||||
rbac_role: Optional[RBAC_ROLES],
|
||||
) -> None:
|
||||
"""Validate RBAC role and model access permissions"""
|
||||
if jwt_handler.litellm_jwtauth.enforce_rbac is True:
|
||||
if rbac_role is None:
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail="Unmatched token passed in. enforce_rbac is set to True. Token must belong to a proxy admin, team, or user.",
|
||||
)
|
||||
JWTAuthManager.can_rbac_role_call_model(
|
||||
rbac_role=rbac_role,
|
||||
general_settings=general_settings,
|
||||
model=request_data.get("model"),
|
||||
)
|
||||
JWTAuthManager.can_rbac_role_call_route(
|
||||
rbac_role=rbac_role,
|
||||
general_settings=general_settings,
|
||||
route=route,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def check_admin_access(
|
||||
jwt_handler: JWTHandler,
|
||||
scopes: list,
|
||||
route: str,
|
||||
user_id: Optional[str],
|
||||
org_id: Optional[str],
|
||||
api_key: str,
|
||||
) -> Optional[JWTAuthBuilderResult]:
|
||||
"""Check admin status and route access permissions"""
|
||||
if not jwt_handler.is_admin(scopes=scopes):
|
||||
return None
|
||||
|
||||
is_allowed = allowed_routes_check(
|
||||
user_role=LitellmUserRoles.PROXY_ADMIN,
|
||||
user_route=route,
|
||||
litellm_proxy_roles=jwt_handler.litellm_jwtauth,
|
||||
)
|
||||
if not is_allowed:
|
||||
allowed_routes: List[Any] = jwt_handler.litellm_jwtauth.admin_allowed_routes
|
||||
actual_routes = get_actual_routes(allowed_routes=allowed_routes)
|
||||
raise Exception(
|
||||
f"Admin not allowed to access this route. Route={route}, Allowed Routes={actual_routes}"
|
||||
)
|
||||
|
||||
return JWTAuthBuilderResult(
|
||||
is_proxy_admin=True,
|
||||
team_object=None,
|
||||
user_object=None,
|
||||
end_user_object=None,
|
||||
org_object=None,
|
||||
token=api_key,
|
||||
team_id=None,
|
||||
user_id=user_id,
|
||||
end_user_id=None,
|
||||
org_id=org_id,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def find_and_validate_specific_team_id(
|
||||
jwt_handler: JWTHandler,
|
||||
jwt_valid_token: dict,
|
||||
prisma_client: Optional[PrismaClient],
|
||||
user_api_key_cache: DualCache,
|
||||
parent_otel_span: Optional[Span],
|
||||
proxy_logging_obj: ProxyLogging,
|
||||
) -> Tuple[Optional[str], Optional[LiteLLM_TeamTable]]:
|
||||
"""Find and validate specific team ID"""
|
||||
individual_team_id = jwt_handler.get_team_id(
|
||||
token=jwt_valid_token, default_value=None
|
||||
)
|
||||
|
||||
if not individual_team_id and jwt_handler.is_required_team_id() is True:
|
||||
raise Exception(
|
||||
f"No team id found in token. Checked team_id field '{jwt_handler.litellm_jwtauth.team_id_jwt_field}'"
|
||||
)
|
||||
|
||||
## VALIDATE TEAM OBJECT ###
|
||||
team_object: Optional[LiteLLM_TeamTable] = None
|
||||
if individual_team_id:
|
||||
team_object = await get_team_object(
|
||||
team_id=individual_team_id,
|
||||
prisma_client=prisma_client,
|
||||
user_api_key_cache=user_api_key_cache,
|
||||
parent_otel_span=parent_otel_span,
|
||||
proxy_logging_obj=proxy_logging_obj,
|
||||
team_id_upsert=jwt_handler.litellm_jwtauth.team_id_upsert,
|
||||
)
|
||||
|
||||
return individual_team_id, team_object
|
||||
|
||||
@staticmethod
|
||||
def get_all_team_ids(jwt_handler: JWTHandler, jwt_valid_token: dict) -> Set[str]:
|
||||
"""Get combined team IDs from groups and individual team_id"""
|
||||
team_ids_from_groups = jwt_handler.get_team_ids_from_jwt(token=jwt_valid_token)
|
||||
|
||||
all_team_ids = set(team_ids_from_groups)
|
||||
|
||||
return all_team_ids
|
||||
|
||||
@staticmethod
|
||||
async def find_team_with_model_access(
|
||||
team_ids: Set[str],
|
||||
requested_model: Optional[str],
|
||||
route: str,
|
||||
jwt_handler: JWTHandler,
|
||||
prisma_client: Optional[PrismaClient],
|
||||
user_api_key_cache: DualCache,
|
||||
parent_otel_span: Optional[Span],
|
||||
proxy_logging_obj: ProxyLogging,
|
||||
) -> Tuple[Optional[str], Optional[LiteLLM_TeamTable]]:
|
||||
"""Find first team with access to the requested model"""
|
||||
|
||||
if not team_ids:
|
||||
if jwt_handler.litellm_jwtauth.enforce_team_based_model_access:
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail="No teams found in token. `enforce_team_based_model_access` is set to True. Token must belong to a team.",
|
||||
)
|
||||
return None, None
|
||||
|
||||
for team_id in team_ids:
|
||||
try:
|
||||
team_object = await get_team_object(
|
||||
team_id=team_id,
|
||||
prisma_client=prisma_client,
|
||||
user_api_key_cache=user_api_key_cache,
|
||||
parent_otel_span=parent_otel_span,
|
||||
proxy_logging_obj=proxy_logging_obj,
|
||||
)
|
||||
|
||||
if team_object and team_object.models is not None:
|
||||
team_models = team_object.models
|
||||
if isinstance(team_models, list) and (
|
||||
not requested_model
|
||||
or can_team_access_model(
|
||||
model=requested_model,
|
||||
team_object=team_object,
|
||||
llm_router=None,
|
||||
team_model_aliases=None,
|
||||
)
|
||||
):
|
||||
is_allowed = allowed_routes_check(
|
||||
user_role=LitellmUserRoles.TEAM,
|
||||
user_route=route,
|
||||
litellm_proxy_roles=jwt_handler.litellm_jwtauth,
|
||||
)
|
||||
if is_allowed:
|
||||
return team_id, team_object
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
if requested_model:
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail=f"No team has access to the requested model: {requested_model}. Checked teams={team_ids}. Check `/models` to see all available models.",
|
||||
)
|
||||
|
||||
return None, None
|
||||
|
||||
@staticmethod
|
||||
async def get_user_info(
|
||||
jwt_handler: JWTHandler,
|
||||
jwt_valid_token: dict,
|
||||
) -> Tuple[Optional[str], Optional[str], Optional[bool]]:
|
||||
"""Get user email and validation status"""
|
||||
user_email = jwt_handler.get_user_email(
|
||||
token=jwt_valid_token, default_value=None
|
||||
)
|
||||
valid_user_email = None
|
||||
if jwt_handler.is_enforced_email_domain():
|
||||
valid_user_email = (
|
||||
False
|
||||
if user_email is None
|
||||
else jwt_handler.is_allowed_domain(user_email=user_email)
|
||||
)
|
||||
user_id = jwt_handler.get_user_id(
|
||||
token=jwt_valid_token, default_value=user_email
|
||||
)
|
||||
return user_id, user_email, valid_user_email
|
||||
|
||||
@staticmethod
|
||||
async def get_objects(
|
||||
user_id: Optional[str],
|
||||
user_email: Optional[str],
|
||||
org_id: Optional[str],
|
||||
end_user_id: Optional[str],
|
||||
valid_user_email: Optional[bool],
|
||||
jwt_handler: JWTHandler,
|
||||
prisma_client: Optional[PrismaClient],
|
||||
user_api_key_cache: DualCache,
|
||||
parent_otel_span: Optional[Span],
|
||||
proxy_logging_obj: ProxyLogging,
|
||||
) -> Tuple[
|
||||
Optional[LiteLLM_UserTable],
|
||||
Optional[LiteLLM_OrganizationTable],
|
||||
Optional[LiteLLM_EndUserTable],
|
||||
]:
|
||||
"""Get user, org, and end user objects"""
|
||||
org_object: Optional[LiteLLM_OrganizationTable] = None
|
||||
if org_id:
|
||||
org_object = (
|
||||
await get_org_object(
|
||||
org_id=org_id,
|
||||
prisma_client=prisma_client,
|
||||
user_api_key_cache=user_api_key_cache,
|
||||
parent_otel_span=parent_otel_span,
|
||||
proxy_logging_obj=proxy_logging_obj,
|
||||
)
|
||||
if org_id
|
||||
else None
|
||||
)
|
||||
|
||||
user_object: Optional[LiteLLM_UserTable] = None
|
||||
if user_id:
|
||||
user_object = (
|
||||
await get_user_object(
|
||||
user_id=user_id,
|
||||
prisma_client=prisma_client,
|
||||
user_api_key_cache=user_api_key_cache,
|
||||
user_id_upsert=jwt_handler.is_upsert_user_id(
|
||||
valid_user_email=valid_user_email
|
||||
),
|
||||
parent_otel_span=parent_otel_span,
|
||||
proxy_logging_obj=proxy_logging_obj,
|
||||
user_email=user_email,
|
||||
sso_user_id=user_id,
|
||||
)
|
||||
if user_id
|
||||
else None
|
||||
)
|
||||
|
||||
end_user_object: Optional[LiteLLM_EndUserTable] = None
|
||||
if end_user_id:
|
||||
end_user_object = (
|
||||
await get_end_user_object(
|
||||
end_user_id=end_user_id,
|
||||
prisma_client=prisma_client,
|
||||
user_api_key_cache=user_api_key_cache,
|
||||
parent_otel_span=parent_otel_span,
|
||||
proxy_logging_obj=proxy_logging_obj,
|
||||
)
|
||||
if end_user_id
|
||||
else None
|
||||
)
|
||||
|
||||
return user_object, org_object, end_user_object
|
||||
|
||||
@staticmethod
|
||||
def validate_object_id(
|
||||
user_id: Optional[str],
|
||||
team_id: Optional[str],
|
||||
enforce_rbac: bool,
|
||||
is_proxy_admin: bool,
|
||||
) -> Literal[True]:
|
||||
"""If enforce_rbac is true, validate that a valid rbac id is returned for spend tracking"""
|
||||
if enforce_rbac and not is_proxy_admin and not user_id and not team_id:
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail="No user or team id found in token. enforce_rbac is set to True. Token must belong to a proxy admin, team, or user.",
|
||||
)
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
async def auth_builder(
|
||||
api_key: str,
|
||||
jwt_handler: JWTHandler,
|
||||
request_data: dict,
|
||||
general_settings: dict,
|
||||
route: str,
|
||||
prisma_client: Optional[PrismaClient],
|
||||
user_api_key_cache: DualCache,
|
||||
parent_otel_span: Optional[Span],
|
||||
proxy_logging_obj: ProxyLogging,
|
||||
) -> JWTAuthBuilderResult:
|
||||
"""Main authentication and authorization builder"""
|
||||
jwt_valid_token: dict = await jwt_handler.auth_jwt(token=api_key)
|
||||
|
||||
# Check custom validate
|
||||
if jwt_handler.litellm_jwtauth.custom_validate:
|
||||
if not jwt_handler.litellm_jwtauth.custom_validate(jwt_valid_token):
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail="Invalid JWT token",
|
||||
)
|
||||
|
||||
# Check RBAC
|
||||
rbac_role = jwt_handler.get_rbac_role(token=jwt_valid_token)
|
||||
await JWTAuthManager.check_rbac_role(
|
||||
jwt_handler,
|
||||
jwt_valid_token,
|
||||
general_settings,
|
||||
request_data,
|
||||
route,
|
||||
rbac_role,
|
||||
)
|
||||
|
||||
# Check Scope Based Access
|
||||
scopes = jwt_handler.get_scopes(token=jwt_valid_token)
|
||||
if (
|
||||
jwt_handler.litellm_jwtauth.enforce_scope_based_access
|
||||
and jwt_handler.litellm_jwtauth.scope_mappings
|
||||
):
|
||||
JWTAuthManager.check_scope_based_access(
|
||||
scope_mappings=jwt_handler.litellm_jwtauth.scope_mappings,
|
||||
scopes=scopes,
|
||||
request_data=request_data,
|
||||
general_settings=general_settings,
|
||||
)
|
||||
|
||||
object_id = jwt_handler.get_object_id(token=jwt_valid_token, default_value=None)
|
||||
|
||||
# Get basic user info
|
||||
scopes = jwt_handler.get_scopes(token=jwt_valid_token)
|
||||
user_id, user_email, valid_user_email = await JWTAuthManager.get_user_info(
|
||||
jwt_handler, jwt_valid_token
|
||||
)
|
||||
|
||||
# Get IDs
|
||||
org_id = jwt_handler.get_org_id(token=jwt_valid_token, default_value=None)
|
||||
end_user_id = jwt_handler.get_end_user_id(
|
||||
token=jwt_valid_token, default_value=None
|
||||
)
|
||||
team_id: Optional[str] = None
|
||||
team_object: Optional[LiteLLM_TeamTable] = None
|
||||
object_id = jwt_handler.get_object_id(token=jwt_valid_token, default_value=None)
|
||||
|
||||
if rbac_role and object_id:
|
||||
if rbac_role == LitellmUserRoles.TEAM:
|
||||
team_id = object_id
|
||||
elif rbac_role == LitellmUserRoles.INTERNAL_USER:
|
||||
user_id = object_id
|
||||
|
||||
# Check admin access
|
||||
admin_result = await JWTAuthManager.check_admin_access(
|
||||
jwt_handler, scopes, route, user_id, org_id, api_key
|
||||
)
|
||||
if admin_result:
|
||||
return admin_result
|
||||
|
||||
# Get team with model access
|
||||
## SPECIFIC TEAM ID
|
||||
|
||||
if not team_id:
|
||||
(
|
||||
team_id,
|
||||
team_object,
|
||||
) = await JWTAuthManager.find_and_validate_specific_team_id(
|
||||
jwt_handler,
|
||||
jwt_valid_token,
|
||||
prisma_client,
|
||||
user_api_key_cache,
|
||||
parent_otel_span,
|
||||
proxy_logging_obj,
|
||||
)
|
||||
|
||||
if not team_object and not team_id:
|
||||
## CHECK USER GROUP ACCESS
|
||||
all_team_ids = JWTAuthManager.get_all_team_ids(jwt_handler, jwt_valid_token)
|
||||
team_id, team_object = await JWTAuthManager.find_team_with_model_access(
|
||||
team_ids=all_team_ids,
|
||||
requested_model=request_data.get("model"),
|
||||
route=route,
|
||||
jwt_handler=jwt_handler,
|
||||
prisma_client=prisma_client,
|
||||
user_api_key_cache=user_api_key_cache,
|
||||
parent_otel_span=parent_otel_span,
|
||||
proxy_logging_obj=proxy_logging_obj,
|
||||
)
|
||||
|
||||
# Get other objects
|
||||
user_object, org_object, end_user_object = await JWTAuthManager.get_objects(
|
||||
user_id=user_id,
|
||||
user_email=user_email,
|
||||
org_id=org_id,
|
||||
end_user_id=end_user_id,
|
||||
valid_user_email=valid_user_email,
|
||||
jwt_handler=jwt_handler,
|
||||
prisma_client=prisma_client,
|
||||
user_api_key_cache=user_api_key_cache,
|
||||
parent_otel_span=parent_otel_span,
|
||||
proxy_logging_obj=proxy_logging_obj,
|
||||
)
|
||||
|
||||
# Validate that a valid rbac id is returned for spend tracking
|
||||
JWTAuthManager.validate_object_id(
|
||||
user_id=user_id,
|
||||
team_id=team_id,
|
||||
enforce_rbac=general_settings.get("enforce_rbac", False),
|
||||
is_proxy_admin=False,
|
||||
)
|
||||
|
||||
return JWTAuthBuilderResult(
|
||||
is_proxy_admin=False,
|
||||
team_id=team_id,
|
||||
team_object=team_object,
|
||||
user_id=user_id,
|
||||
user_object=user_object,
|
||||
org_id=org_id,
|
||||
org_object=org_object,
|
||||
end_user_id=end_user_id,
|
||||
end_user_object=end_user_object,
|
||||
token=api_key,
|
||||
)
|
||||
@@ -0,0 +1,169 @@
|
||||
# What is this?
|
||||
## If litellm license in env, checks if it's valid
|
||||
import base64
|
||||
import json
|
||||
import os
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
import httpx
|
||||
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.constants import NON_LLM_CONNECTION_TIMEOUT
|
||||
from litellm.llms.custom_httpx.http_handler import HTTPHandler
|
||||
|
||||
|
||||
class LicenseCheck:
|
||||
"""
|
||||
- Check if license in env
|
||||
- Returns if license is valid
|
||||
"""
|
||||
|
||||
base_url = "https://license.litellm.ai"
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.license_str = os.getenv("LITELLM_LICENSE", None)
|
||||
verbose_proxy_logger.debug("License Str value - {}".format(self.license_str))
|
||||
self.http_handler = HTTPHandler(timeout=NON_LLM_CONNECTION_TIMEOUT)
|
||||
self.public_key = None
|
||||
self.read_public_key()
|
||||
|
||||
def read_public_key(self):
|
||||
try:
|
||||
from cryptography.hazmat.primitives import serialization
|
||||
|
||||
# current dir
|
||||
current_dir = os.path.dirname(os.path.realpath(__file__))
|
||||
|
||||
# check if public_key.pem exists
|
||||
_path_to_public_key = os.path.join(current_dir, "public_key.pem")
|
||||
if os.path.exists(_path_to_public_key):
|
||||
with open(_path_to_public_key, "rb") as key_file:
|
||||
self.public_key = serialization.load_pem_public_key(key_file.read())
|
||||
else:
|
||||
self.public_key = None
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.error(f"Error reading public key: {str(e)}")
|
||||
|
||||
def _verify(self, license_str: str) -> bool:
|
||||
verbose_proxy_logger.debug(
|
||||
"litellm.proxy.auth.litellm_license.py::_verify - Checking license against {}/verify_license - {}".format(
|
||||
self.base_url, license_str
|
||||
)
|
||||
)
|
||||
url = "{}/verify_license/{}".format(self.base_url, license_str)
|
||||
|
||||
response: Optional[httpx.Response] = None
|
||||
try: # don't impact user, if call fails
|
||||
num_retries = 3
|
||||
for i in range(num_retries):
|
||||
try:
|
||||
response = self.http_handler.get(url=url)
|
||||
if response is None:
|
||||
raise Exception("No response from license server")
|
||||
response.raise_for_status()
|
||||
except httpx.HTTPStatusError:
|
||||
if i == num_retries - 1:
|
||||
raise
|
||||
|
||||
if response is None:
|
||||
raise Exception("No response from license server")
|
||||
|
||||
response_json = response.json()
|
||||
|
||||
premium = response_json["verify"]
|
||||
|
||||
assert isinstance(premium, bool)
|
||||
|
||||
verbose_proxy_logger.debug(
|
||||
"litellm.proxy.auth.litellm_license.py::_verify - License={} is premium={}".format(
|
||||
license_str, premium
|
||||
)
|
||||
)
|
||||
return premium
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.exception(
|
||||
"litellm.proxy.auth.litellm_license.py::_verify - Unable to verify License={} via api. - {}".format(
|
||||
license_str, str(e)
|
||||
)
|
||||
)
|
||||
return False
|
||||
|
||||
def is_premium(self) -> bool:
|
||||
"""
|
||||
1. verify_license_without_api_request: checks if license was generate using private / public key pair
|
||||
2. _verify: checks if license is valid calling litellm API. This is the old way we were generating/validating license
|
||||
"""
|
||||
try:
|
||||
verbose_proxy_logger.debug(
|
||||
"litellm.proxy.auth.litellm_license.py::is_premium() - ENTERING 'IS_PREMIUM' - LiteLLM License={}".format(
|
||||
self.license_str
|
||||
)
|
||||
)
|
||||
|
||||
if self.license_str is None:
|
||||
self.license_str = os.getenv("LITELLM_LICENSE", None)
|
||||
|
||||
verbose_proxy_logger.debug(
|
||||
"litellm.proxy.auth.litellm_license.py::is_premium() - Updated 'self.license_str' - {}".format(
|
||||
self.license_str
|
||||
)
|
||||
)
|
||||
|
||||
if self.license_str is None:
|
||||
return False
|
||||
elif (
|
||||
self.verify_license_without_api_request(
|
||||
public_key=self.public_key, license_key=self.license_str
|
||||
)
|
||||
is True
|
||||
):
|
||||
return True
|
||||
elif self._verify(license_str=self.license_str) is True:
|
||||
return True
|
||||
return False
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
def verify_license_without_api_request(self, public_key, license_key):
|
||||
try:
|
||||
from cryptography.hazmat.primitives import hashes
|
||||
from cryptography.hazmat.primitives.asymmetric import padding
|
||||
|
||||
# Decode the license key
|
||||
decoded = base64.b64decode(license_key)
|
||||
message, signature = decoded.split(b".", 1)
|
||||
|
||||
# Verify the signature
|
||||
public_key.verify(
|
||||
signature,
|
||||
message,
|
||||
padding.PSS(
|
||||
mgf=padding.MGF1(hashes.SHA256()),
|
||||
salt_length=padding.PSS.MAX_LENGTH,
|
||||
),
|
||||
hashes.SHA256(),
|
||||
)
|
||||
|
||||
# Decode and parse the data
|
||||
license_data = json.loads(message.decode())
|
||||
|
||||
# debug information provided in license data
|
||||
verbose_proxy_logger.debug("License data: %s", license_data)
|
||||
|
||||
# Check expiration date
|
||||
expiration_date = datetime.strptime(
|
||||
license_data["expiration_date"], "%Y-%m-%d"
|
||||
)
|
||||
if expiration_date < datetime.now():
|
||||
return False, "License has expired"
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.debug(
|
||||
"litellm.proxy.auth.litellm_license.py::verify_license_without_api_request - Unable to verify License locally. - {}".format(
|
||||
str(e)
|
||||
)
|
||||
)
|
||||
return False
|
||||
@@ -0,0 +1,224 @@
|
||||
# What is this?
|
||||
## Common checks for /v1/models and `/model/info`
|
||||
from typing import Dict, List, Optional, Set
|
||||
|
||||
import litellm
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.proxy._types import SpecialModelNames, UserAPIKeyAuth
|
||||
from litellm.router import Router
|
||||
from litellm.types.router import LiteLLM_Params
|
||||
from litellm.utils import get_valid_models
|
||||
|
||||
|
||||
def _check_wildcard_routing(model: str) -> bool:
|
||||
"""
|
||||
Returns True if a model is a provider wildcard.
|
||||
|
||||
eg:
|
||||
- anthropic/*
|
||||
- openai/*
|
||||
- *
|
||||
"""
|
||||
if "*" in model:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def get_provider_models(
|
||||
provider: str, litellm_params: Optional[LiteLLM_Params] = None
|
||||
) -> Optional[List[str]]:
|
||||
"""
|
||||
Returns the list of known models by provider
|
||||
"""
|
||||
if provider == "*":
|
||||
return get_valid_models(litellm_params=litellm_params)
|
||||
|
||||
if provider in litellm.models_by_provider:
|
||||
provider_models = get_valid_models(
|
||||
custom_llm_provider=provider, litellm_params=litellm_params
|
||||
)
|
||||
# provider_models = copy.deepcopy(litellm.models_by_provider[provider])
|
||||
for idx, _model in enumerate(provider_models):
|
||||
if provider not in _model:
|
||||
provider_models[idx] = f"{provider}/{_model}"
|
||||
return provider_models
|
||||
return None
|
||||
|
||||
|
||||
def _get_models_from_access_groups(
|
||||
model_access_groups: Dict[str, List[str]],
|
||||
all_models: List[str],
|
||||
) -> List[str]:
|
||||
idx_to_remove = []
|
||||
new_models = []
|
||||
for idx, model in enumerate(all_models):
|
||||
if model in model_access_groups:
|
||||
idx_to_remove.append(idx)
|
||||
new_models.extend(model_access_groups[model])
|
||||
|
||||
for idx in sorted(idx_to_remove, reverse=True):
|
||||
all_models.pop(idx)
|
||||
|
||||
all_models.extend(new_models)
|
||||
return all_models
|
||||
|
||||
|
||||
def get_key_models(
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
proxy_model_list: List[str],
|
||||
model_access_groups: Dict[str, List[str]],
|
||||
) -> List[str]:
|
||||
"""
|
||||
Returns:
|
||||
- List of model name strings
|
||||
- Empty list if no models set
|
||||
- If model_access_groups is provided, only return models that are in the access groups
|
||||
"""
|
||||
all_models: List[str] = []
|
||||
if len(user_api_key_dict.models) > 0:
|
||||
all_models = user_api_key_dict.models
|
||||
if SpecialModelNames.all_team_models.value in all_models:
|
||||
all_models = user_api_key_dict.team_models
|
||||
if SpecialModelNames.all_proxy_models.value in all_models:
|
||||
all_models = proxy_model_list
|
||||
|
||||
all_models = _get_models_from_access_groups(
|
||||
model_access_groups=model_access_groups, all_models=all_models
|
||||
)
|
||||
|
||||
verbose_proxy_logger.debug("ALL KEY MODELS - {}".format(len(all_models)))
|
||||
return all_models
|
||||
|
||||
|
||||
def get_team_models(
|
||||
team_models: List[str],
|
||||
proxy_model_list: List[str],
|
||||
model_access_groups: Dict[str, List[str]],
|
||||
) -> List[str]:
|
||||
"""
|
||||
Returns:
|
||||
- List of model name strings
|
||||
- Empty list if no models set
|
||||
- If model_access_groups is provided, only return models that are in the access groups
|
||||
"""
|
||||
all_models = []
|
||||
if len(team_models) > 0:
|
||||
all_models = team_models
|
||||
if SpecialModelNames.all_team_models.value in all_models:
|
||||
all_models = team_models
|
||||
if SpecialModelNames.all_proxy_models.value in all_models:
|
||||
all_models = proxy_model_list
|
||||
|
||||
all_models = _get_models_from_access_groups(
|
||||
model_access_groups=model_access_groups, all_models=all_models
|
||||
)
|
||||
|
||||
verbose_proxy_logger.debug("ALL TEAM MODELS - {}".format(len(all_models)))
|
||||
return all_models
|
||||
|
||||
|
||||
def get_complete_model_list(
|
||||
key_models: List[str],
|
||||
team_models: List[str],
|
||||
proxy_model_list: List[str],
|
||||
user_model: Optional[str],
|
||||
infer_model_from_keys: Optional[bool],
|
||||
return_wildcard_routes: Optional[bool] = False,
|
||||
llm_router: Optional[Router] = None,
|
||||
) -> List[str]:
|
||||
"""Logic for returning complete model list for a given key + team pair"""
|
||||
|
||||
"""
|
||||
- If key list is empty -> defer to team list
|
||||
- If team list is empty -> defer to proxy model list
|
||||
|
||||
If list contains wildcard -> return known provider models
|
||||
"""
|
||||
unique_models: Set[str] = set()
|
||||
if key_models:
|
||||
unique_models.update(key_models)
|
||||
elif team_models:
|
||||
unique_models.update(team_models)
|
||||
else:
|
||||
unique_models.update(proxy_model_list)
|
||||
|
||||
if user_model:
|
||||
unique_models.add(user_model)
|
||||
|
||||
if infer_model_from_keys:
|
||||
valid_models = get_valid_models()
|
||||
unique_models.update(valid_models)
|
||||
|
||||
all_wildcard_models = _get_wildcard_models(
|
||||
unique_models=unique_models,
|
||||
return_wildcard_routes=return_wildcard_routes,
|
||||
llm_router=llm_router,
|
||||
)
|
||||
|
||||
return list(unique_models) + all_wildcard_models
|
||||
|
||||
|
||||
def get_known_models_from_wildcard(
|
||||
wildcard_model: str, litellm_params: Optional[LiteLLM_Params] = None
|
||||
) -> List[str]:
|
||||
try:
|
||||
provider, model = wildcard_model.split("/", 1)
|
||||
except ValueError: # safely fail
|
||||
return []
|
||||
# get all known provider models
|
||||
wildcard_models = get_provider_models(
|
||||
provider=provider, litellm_params=litellm_params
|
||||
)
|
||||
if wildcard_models is None:
|
||||
return []
|
||||
if model == "*":
|
||||
return wildcard_models or []
|
||||
else:
|
||||
model_prefix = model.replace("*", "")
|
||||
filtered_wildcard_models = [
|
||||
wc_model
|
||||
for wc_model in wildcard_models
|
||||
if wc_model.split("/")[1].startswith(model_prefix)
|
||||
]
|
||||
|
||||
return filtered_wildcard_models
|
||||
|
||||
|
||||
def _get_wildcard_models(
|
||||
unique_models: Set[str],
|
||||
return_wildcard_routes: Optional[bool] = False,
|
||||
llm_router: Optional[Router] = None,
|
||||
) -> List[str]:
|
||||
models_to_remove = set()
|
||||
all_wildcard_models = []
|
||||
for model in unique_models:
|
||||
if _check_wildcard_routing(model=model):
|
||||
if (
|
||||
return_wildcard_routes
|
||||
): # will add the wildcard route to the list eg: anthropic/*.
|
||||
all_wildcard_models.append(model)
|
||||
|
||||
## get litellm params from model
|
||||
if llm_router is not None:
|
||||
model_list = llm_router.get_model_list(model_name=model)
|
||||
if model_list is not None:
|
||||
for router_model in model_list:
|
||||
wildcard_models = get_known_models_from_wildcard(
|
||||
wildcard_model=model,
|
||||
litellm_params=LiteLLM_Params(
|
||||
**router_model["litellm_params"] # type: ignore
|
||||
),
|
||||
)
|
||||
all_wildcard_models.extend(wildcard_models)
|
||||
else:
|
||||
# get all known provider models
|
||||
wildcard_models = get_known_models_from_wildcard(wildcard_model=model)
|
||||
|
||||
if wildcard_models is not None:
|
||||
models_to_remove.add(model)
|
||||
all_wildcard_models.extend(wildcard_models)
|
||||
|
||||
for model in models_to_remove:
|
||||
unique_models.remove(model)
|
||||
|
||||
return all_wildcard_models
|
||||
@@ -0,0 +1,80 @@
|
||||
from litellm.proxy._types import UserAPIKeyAuth
|
||||
|
||||
|
||||
async def check_oauth2_token(token: str) -> UserAPIKeyAuth:
|
||||
"""
|
||||
Makes a request to the token info endpoint to validate the OAuth2 token.
|
||||
|
||||
Args:
|
||||
token (str): The OAuth2 token to validate.
|
||||
|
||||
Returns:
|
||||
Literal[True]: If the token is valid.
|
||||
|
||||
Raises:
|
||||
ValueError: If the token is invalid, the request fails, or the token info endpoint is not set.
|
||||
"""
|
||||
import os
|
||||
|
||||
import httpx
|
||||
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.llms.custom_httpx.http_handler import (
|
||||
get_async_httpx_client,
|
||||
httpxSpecialProvider,
|
||||
)
|
||||
from litellm.proxy._types import CommonProxyErrors
|
||||
from litellm.proxy.proxy_server import premium_user
|
||||
|
||||
if premium_user is not True:
|
||||
raise ValueError(
|
||||
"Oauth2 token validation is only available for premium users"
|
||||
+ CommonProxyErrors.not_premium_user.value
|
||||
)
|
||||
|
||||
verbose_proxy_logger.debug("Oauth2 token validation for token=%s", token)
|
||||
# Get the token info endpoint from environment variable
|
||||
token_info_endpoint = os.getenv("OAUTH_TOKEN_INFO_ENDPOINT")
|
||||
user_id_field_name = os.environ.get("OAUTH_USER_ID_FIELD_NAME", "sub")
|
||||
user_role_field_name = os.environ.get("OAUTH_USER_ROLE_FIELD_NAME", "role")
|
||||
user_team_id_field_name = os.environ.get("OAUTH_USER_TEAM_ID_FIELD_NAME", "team_id")
|
||||
|
||||
if not token_info_endpoint:
|
||||
raise ValueError("OAUTH_TOKEN_INFO_ENDPOINT environment variable is not set")
|
||||
|
||||
client = get_async_httpx_client(llm_provider=httpxSpecialProvider.Oauth2Check)
|
||||
headers = {"Authorization": f"Bearer {token}", "Content-Type": "application/json"}
|
||||
|
||||
try:
|
||||
response = await client.get(token_info_endpoint, headers=headers)
|
||||
|
||||
# if it's a bad token we expect it to raise an HTTPStatusError
|
||||
response.raise_for_status()
|
||||
|
||||
# If we get here, the request was successful
|
||||
data = response.json()
|
||||
|
||||
verbose_proxy_logger.debug(
|
||||
"Oauth2 token validation for token=%s, response from /token/info=%s",
|
||||
token,
|
||||
data,
|
||||
)
|
||||
|
||||
# You might want to add additional checks here based on the response
|
||||
# For example, checking if the token is expired or has the correct scope
|
||||
user_id = data.get(user_id_field_name)
|
||||
user_team_id = data.get(user_team_id_field_name)
|
||||
user_role = data.get(user_role_field_name)
|
||||
|
||||
return UserAPIKeyAuth(
|
||||
api_key=token,
|
||||
team_id=user_team_id,
|
||||
user_id=user_id,
|
||||
user_role=user_role,
|
||||
)
|
||||
except httpx.HTTPStatusError as e:
|
||||
# This will catch any 4xx or 5xx errors
|
||||
raise ValueError(f"Oauth 2.0 Token validation failed: {e}")
|
||||
except Exception as e:
|
||||
# This will catch any other errors (like network issues)
|
||||
raise ValueError(f"An error occurred during token validation: {e}")
|
||||
@@ -0,0 +1,45 @@
|
||||
from typing import Any, Dict
|
||||
|
||||
from fastapi import Request
|
||||
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.proxy._types import UserAPIKeyAuth
|
||||
|
||||
|
||||
async def handle_oauth2_proxy_request(request: Request) -> UserAPIKeyAuth:
|
||||
"""
|
||||
Handle request from oauth2 proxy.
|
||||
"""
|
||||
from litellm.proxy.proxy_server import general_settings
|
||||
|
||||
verbose_proxy_logger.debug("Handling oauth2 proxy request")
|
||||
# Define the OAuth2 config mappings
|
||||
oauth2_config_mappings: Dict[str, str] = general_settings.get(
|
||||
"oauth2_config_mappings", None
|
||||
)
|
||||
verbose_proxy_logger.debug(f"Oauth2 config mappings: {oauth2_config_mappings}")
|
||||
|
||||
if not oauth2_config_mappings:
|
||||
raise ValueError("Oauth2 config mappings not found in general_settings")
|
||||
# Initialize a dictionary to store the mapped values
|
||||
auth_data: Dict[str, Any] = {}
|
||||
|
||||
# Extract values from headers based on the mappings
|
||||
for key, header in oauth2_config_mappings.items():
|
||||
value = request.headers.get(header)
|
||||
if value:
|
||||
# Convert max_budget to float if present
|
||||
if key == "max_budget":
|
||||
auth_data[key] = float(value)
|
||||
# Convert models to list if present
|
||||
elif key == "models":
|
||||
auth_data[key] = [model.strip() for model in value.split(",")]
|
||||
else:
|
||||
auth_data[key] = value
|
||||
verbose_proxy_logger.debug(
|
||||
f"Auth data before creating UserAPIKeyAuth object: {auth_data}"
|
||||
)
|
||||
user_api_key_auth = UserAPIKeyAuth(**auth_data)
|
||||
verbose_proxy_logger.debug(f"UserAPIKeyAuth object created: {user_api_key_auth}")
|
||||
# Create and return UserAPIKeyAuth object
|
||||
return user_api_key_auth
|
||||
@@ -0,0 +1,9 @@
|
||||
-----BEGIN PUBLIC KEY-----
|
||||
MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEAwcNBabWBZzrDhFAuA4Fh
|
||||
FhIcA3rF7vrLb8+1yhF2U62AghQp9nStyuJRjxMUuldWgJ1yRJ2s7UffVw5r8DeA
|
||||
dqXPD+w+3LCNwqJGaIKN08QGJXNArM3QtMaN0RTzAyQ4iibN1r6609W5muK9wGp0
|
||||
b1j5+iDUmf0ynItnhvaX6B8Xoaflc3WD/UBdrygLmsU5uR3XC86+/8ILoSZH3HtN
|
||||
6FJmWhlhjS2TR1cKZv8K5D0WuADTFf5MF8jYFR+uORPj5Pe/EJlLGN26Lfn2QnGu
|
||||
XgbPF6nCGwZ0hwH1Xkn3xzGaJ4xBEC761wqp5cHxWSDktHyFKnLbP3jVeegjVIHh
|
||||
pQIDAQAB
|
||||
-----END PUBLIC KEY-----
|
||||
@@ -0,0 +1,187 @@
|
||||
import os
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
import httpx
|
||||
|
||||
|
||||
def init_rds_client(
|
||||
aws_access_key_id: Optional[str] = None,
|
||||
aws_secret_access_key: Optional[str] = None,
|
||||
aws_region_name: Optional[str] = None,
|
||||
aws_session_name: Optional[str] = None,
|
||||
aws_profile_name: Optional[str] = None,
|
||||
aws_role_name: Optional[str] = None,
|
||||
aws_web_identity_token: Optional[str] = None,
|
||||
timeout: Optional[Union[float, httpx.Timeout]] = None,
|
||||
):
|
||||
from litellm.secret_managers.main import get_secret
|
||||
|
||||
# check for custom AWS_REGION_NAME and use it if not passed to init_bedrock_client
|
||||
litellm_aws_region_name = get_secret("AWS_REGION_NAME", None)
|
||||
standard_aws_region_name = get_secret("AWS_REGION", None)
|
||||
## CHECK IS 'os.environ/' passed in
|
||||
# Define the list of parameters to check
|
||||
params_to_check = [
|
||||
aws_access_key_id,
|
||||
aws_secret_access_key,
|
||||
aws_region_name,
|
||||
aws_session_name,
|
||||
aws_profile_name,
|
||||
aws_role_name,
|
||||
aws_web_identity_token,
|
||||
]
|
||||
|
||||
# Iterate over parameters and update if needed
|
||||
for i, param in enumerate(params_to_check):
|
||||
if param and param.startswith("os.environ/"):
|
||||
params_to_check[i] = get_secret(param) # type: ignore
|
||||
# Assign updated values back to parameters
|
||||
(
|
||||
aws_access_key_id,
|
||||
aws_secret_access_key,
|
||||
aws_region_name,
|
||||
aws_session_name,
|
||||
aws_profile_name,
|
||||
aws_role_name,
|
||||
aws_web_identity_token,
|
||||
) = params_to_check
|
||||
|
||||
### SET REGION NAME
|
||||
region_name = aws_region_name
|
||||
if aws_region_name:
|
||||
region_name = aws_region_name
|
||||
elif litellm_aws_region_name:
|
||||
region_name = litellm_aws_region_name
|
||||
elif standard_aws_region_name:
|
||||
region_name = standard_aws_region_name
|
||||
else:
|
||||
raise Exception(
|
||||
"AWS region not set: set AWS_REGION_NAME or AWS_REGION env variable or in .env file",
|
||||
)
|
||||
|
||||
import boto3
|
||||
|
||||
if isinstance(timeout, float):
|
||||
config = boto3.session.Config(connect_timeout=timeout, read_timeout=timeout) # type: ignore
|
||||
elif isinstance(timeout, httpx.Timeout):
|
||||
config = boto3.session.Config( # type: ignore
|
||||
connect_timeout=timeout.connect, read_timeout=timeout.read
|
||||
)
|
||||
else:
|
||||
config = boto3.session.Config() # type: ignore
|
||||
|
||||
### CHECK STS ###
|
||||
if (
|
||||
aws_web_identity_token is not None
|
||||
and aws_role_name is not None
|
||||
and aws_session_name is not None
|
||||
):
|
||||
try:
|
||||
oidc_token = open(aws_web_identity_token).read() # check if filepath
|
||||
except Exception:
|
||||
oidc_token = get_secret(aws_web_identity_token)
|
||||
|
||||
if oidc_token is None:
|
||||
raise Exception(
|
||||
"OIDC token could not be retrieved from secret manager.",
|
||||
)
|
||||
|
||||
sts_client = boto3.client("sts")
|
||||
|
||||
# https://docs.aws.amazon.com/STS/latest/APIReference/API_AssumeRoleWithWebIdentity.html
|
||||
# https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sts/client/assume_role_with_web_identity.html
|
||||
sts_response = sts_client.assume_role_with_web_identity(
|
||||
RoleArn=aws_role_name,
|
||||
RoleSessionName=aws_session_name,
|
||||
WebIdentityToken=oidc_token,
|
||||
DurationSeconds=3600,
|
||||
)
|
||||
|
||||
client = boto3.client(
|
||||
service_name="rds",
|
||||
aws_access_key_id=sts_response["Credentials"]["AccessKeyId"],
|
||||
aws_secret_access_key=sts_response["Credentials"]["SecretAccessKey"],
|
||||
aws_session_token=sts_response["Credentials"]["SessionToken"],
|
||||
region_name=region_name,
|
||||
config=config,
|
||||
)
|
||||
|
||||
elif aws_role_name is not None and aws_session_name is not None:
|
||||
# use sts if role name passed in
|
||||
sts_client = boto3.client(
|
||||
"sts",
|
||||
aws_access_key_id=aws_access_key_id,
|
||||
aws_secret_access_key=aws_secret_access_key,
|
||||
)
|
||||
|
||||
sts_response = sts_client.assume_role(
|
||||
RoleArn=aws_role_name, RoleSessionName=aws_session_name
|
||||
)
|
||||
|
||||
client = boto3.client(
|
||||
service_name="rds",
|
||||
aws_access_key_id=sts_response["Credentials"]["AccessKeyId"],
|
||||
aws_secret_access_key=sts_response["Credentials"]["SecretAccessKey"],
|
||||
aws_session_token=sts_response["Credentials"]["SessionToken"],
|
||||
region_name=region_name,
|
||||
config=config,
|
||||
)
|
||||
elif aws_access_key_id is not None:
|
||||
# uses auth params passed to completion
|
||||
# aws_access_key_id is not None, assume user is trying to auth using litellm.completion
|
||||
|
||||
client = boto3.client(
|
||||
service_name="rds",
|
||||
aws_access_key_id=aws_access_key_id,
|
||||
aws_secret_access_key=aws_secret_access_key,
|
||||
region_name=region_name,
|
||||
config=config,
|
||||
)
|
||||
elif aws_profile_name is not None:
|
||||
# uses auth values from AWS profile usually stored in ~/.aws/credentials
|
||||
|
||||
client = boto3.Session(profile_name=aws_profile_name).client(
|
||||
service_name="rds",
|
||||
region_name=region_name,
|
||||
config=config,
|
||||
)
|
||||
|
||||
else:
|
||||
# aws_access_key_id is None, assume user is trying to auth using env variables
|
||||
# boto3 automatically reads env variables
|
||||
|
||||
client = boto3.client(
|
||||
service_name="rds",
|
||||
region_name=region_name,
|
||||
config=config,
|
||||
)
|
||||
|
||||
return client
|
||||
|
||||
|
||||
def generate_iam_auth_token(
|
||||
db_host, db_port, db_user, client: Optional[Any] = None
|
||||
) -> str:
|
||||
from urllib.parse import quote
|
||||
|
||||
if client is None:
|
||||
boto_client = init_rds_client(
|
||||
aws_region_name=os.getenv("AWS_REGION_NAME"),
|
||||
aws_access_key_id=os.getenv("AWS_ACCESS_KEY_ID"),
|
||||
aws_secret_access_key=os.getenv("AWS_SECRET_ACCESS_KEY"),
|
||||
aws_session_name=os.getenv("AWS_SESSION_NAME"),
|
||||
aws_profile_name=os.getenv("AWS_PROFILE_NAME"),
|
||||
aws_role_name=os.getenv("AWS_ROLE_NAME", os.getenv("AWS_ROLE_ARN")),
|
||||
aws_web_identity_token=os.getenv(
|
||||
"AWS_WEB_IDENTITY_TOKEN", os.getenv("AWS_WEB_IDENTITY_TOKEN_FILE")
|
||||
),
|
||||
)
|
||||
else:
|
||||
boto_client = client
|
||||
|
||||
token = boto_client.generate_db_auth_token(
|
||||
DBHostname=db_host, Port=db_port, DBUsername=db_user
|
||||
)
|
||||
cleaned_token = quote(token, safe="")
|
||||
|
||||
return cleaned_token
|
||||
@@ -0,0 +1,313 @@
|
||||
import re
|
||||
from typing import List, Optional
|
||||
|
||||
from fastapi import HTTPException, Request, status
|
||||
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.proxy._types import (
|
||||
CommonProxyErrors,
|
||||
LiteLLM_UserTable,
|
||||
LiteLLMRoutes,
|
||||
LitellmUserRoles,
|
||||
UserAPIKeyAuth,
|
||||
)
|
||||
|
||||
from .auth_checks_organization import _user_is_org_admin
|
||||
|
||||
|
||||
class RouteChecks:
|
||||
@staticmethod
|
||||
def is_virtual_key_allowed_to_call_route(
|
||||
route: str, valid_token: UserAPIKeyAuth
|
||||
) -> bool:
|
||||
"""
|
||||
Raises Exception if Virtual Key is not allowed to call the route
|
||||
"""
|
||||
|
||||
# Only check if valid_token.allowed_routes is set and is a list with at least one item
|
||||
if valid_token.allowed_routes is None:
|
||||
return True
|
||||
if not isinstance(valid_token.allowed_routes, list):
|
||||
return True
|
||||
if len(valid_token.allowed_routes) == 0:
|
||||
return True
|
||||
|
||||
# explicit check for allowed routes
|
||||
if route in valid_token.allowed_routes:
|
||||
return True
|
||||
|
||||
# check if wildcard pattern is allowed
|
||||
for allowed_route in valid_token.allowed_routes:
|
||||
if RouteChecks._route_matches_wildcard_pattern(
|
||||
route=route, pattern=allowed_route
|
||||
):
|
||||
return True
|
||||
|
||||
raise Exception(
|
||||
f"Virtual key is not allowed to call this route. Only allowed to call routes: {valid_token.allowed_routes}. Tried to call route: {route}"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def non_proxy_admin_allowed_routes_check(
|
||||
user_obj: Optional[LiteLLM_UserTable],
|
||||
_user_role: Optional[LitellmUserRoles],
|
||||
route: str,
|
||||
request: Request,
|
||||
valid_token: UserAPIKeyAuth,
|
||||
request_data: dict,
|
||||
):
|
||||
"""
|
||||
Checks if Non Proxy Admin User is allowed to access the route
|
||||
"""
|
||||
|
||||
# Check user has defined custom admin routes
|
||||
RouteChecks.custom_admin_only_route_check(
|
||||
route=route,
|
||||
)
|
||||
|
||||
if RouteChecks.is_llm_api_route(route=route):
|
||||
pass
|
||||
elif (
|
||||
route in LiteLLMRoutes.info_routes.value
|
||||
): # check if user allowed to call an info route
|
||||
if route == "/key/info":
|
||||
# handled by function itself
|
||||
pass
|
||||
elif route == "/user/info":
|
||||
# check if user can access this route
|
||||
query_params = request.query_params
|
||||
user_id = query_params.get("user_id")
|
||||
verbose_proxy_logger.debug(
|
||||
f"user_id: {user_id} & valid_token.user_id: {valid_token.user_id}"
|
||||
)
|
||||
if user_id and user_id != valid_token.user_id:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="key not allowed to access this user's info. user_id={}, key's user_id={}".format(
|
||||
user_id, valid_token.user_id
|
||||
),
|
||||
)
|
||||
elif route == "/model/info":
|
||||
# /model/info just shows models user has access to
|
||||
pass
|
||||
elif route == "/team/info":
|
||||
pass # handled by function itself
|
||||
elif (
|
||||
route in LiteLLMRoutes.global_spend_tracking_routes.value
|
||||
and getattr(valid_token, "permissions", None) is not None
|
||||
and "get_spend_routes" in getattr(valid_token, "permissions", [])
|
||||
):
|
||||
pass
|
||||
elif _user_role == LitellmUserRoles.PROXY_ADMIN_VIEW_ONLY.value:
|
||||
if RouteChecks.is_llm_api_route(route=route):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail=f"user not allowed to access this OpenAI routes, role= {_user_role}",
|
||||
)
|
||||
if RouteChecks.check_route_access(
|
||||
route=route, allowed_routes=LiteLLMRoutes.management_routes.value
|
||||
):
|
||||
# the Admin Viewer is only allowed to call /user/update for their own user_id and can only update
|
||||
if route == "/user/update":
|
||||
# Check the Request params are valid for PROXY_ADMIN_VIEW_ONLY
|
||||
if request_data is not None and isinstance(request_data, dict):
|
||||
_params_updated = request_data.keys()
|
||||
for param in _params_updated:
|
||||
if param not in ["user_email", "password"]:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail=f"user not allowed to access this route, role= {_user_role}. Trying to access: {route} and updating invalid param: {param}. only user_email and password can be updated",
|
||||
)
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail=f"user not allowed to access this route, role= {_user_role}. Trying to access: {route}",
|
||||
)
|
||||
|
||||
elif (
|
||||
_user_role == LitellmUserRoles.INTERNAL_USER.value
|
||||
and RouteChecks.check_route_access(
|
||||
route=route, allowed_routes=LiteLLMRoutes.internal_user_routes.value
|
||||
)
|
||||
):
|
||||
pass
|
||||
elif _user_is_org_admin(
|
||||
request_data=request_data, user_object=user_obj
|
||||
) and RouteChecks.check_route_access(
|
||||
route=route, allowed_routes=LiteLLMRoutes.org_admin_allowed_routes.value
|
||||
):
|
||||
pass
|
||||
elif (
|
||||
_user_role == LitellmUserRoles.INTERNAL_USER_VIEW_ONLY.value
|
||||
and RouteChecks.check_route_access(
|
||||
route=route,
|
||||
allowed_routes=LiteLLMRoutes.internal_user_view_only_routes.value,
|
||||
)
|
||||
):
|
||||
pass
|
||||
elif RouteChecks.check_route_access(
|
||||
route=route, allowed_routes=LiteLLMRoutes.self_managed_routes.value
|
||||
): # routes that manage their own allowed/disallowed logic
|
||||
pass
|
||||
else:
|
||||
user_role = "unknown"
|
||||
user_id = "unknown"
|
||||
if user_obj is not None:
|
||||
user_role = user_obj.user_role or "unknown"
|
||||
user_id = user_obj.user_id or "unknown"
|
||||
raise Exception(
|
||||
f"Only proxy admin can be used to generate, delete, update info for new keys/users/teams. Route={route}. Your role={user_role}. Your user_id={user_id}"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def custom_admin_only_route_check(route: str):
|
||||
from litellm.proxy.proxy_server import general_settings, premium_user
|
||||
|
||||
if "admin_only_routes" in general_settings:
|
||||
if premium_user is not True:
|
||||
verbose_proxy_logger.error(
|
||||
f"Trying to use 'admin_only_routes' this is an Enterprise only feature. {CommonProxyErrors.not_premium_user.value}"
|
||||
)
|
||||
return
|
||||
if route in general_settings["admin_only_routes"]:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail=f"user not allowed to access this route. Route={route} is an admin only route",
|
||||
)
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def is_llm_api_route(route: str) -> bool:
|
||||
"""
|
||||
Helper to checks if provided route is an OpenAI route
|
||||
|
||||
|
||||
Returns:
|
||||
- True: if route is an OpenAI route
|
||||
- False: if route is not an OpenAI route
|
||||
"""
|
||||
|
||||
if route in LiteLLMRoutes.openai_routes.value:
|
||||
return True
|
||||
|
||||
if route in LiteLLMRoutes.anthropic_routes.value:
|
||||
return True
|
||||
|
||||
# fuzzy match routes like "/v1/threads/thread_49EIN5QF32s4mH20M7GFKdlZ"
|
||||
# Check for routes with placeholders
|
||||
for openai_route in LiteLLMRoutes.openai_routes.value:
|
||||
# Replace placeholders with regex pattern
|
||||
# placeholders are written as "/threads/{thread_id}"
|
||||
if "{" in openai_route:
|
||||
if RouteChecks._route_matches_pattern(
|
||||
route=route, pattern=openai_route
|
||||
):
|
||||
return True
|
||||
|
||||
if RouteChecks._is_azure_openai_route(route=route):
|
||||
return True
|
||||
|
||||
for _llm_passthrough_route in LiteLLMRoutes.mapped_pass_through_routes.value:
|
||||
if _llm_passthrough_route in route:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def _is_azure_openai_route(route: str) -> bool:
|
||||
"""
|
||||
Check if route is a route from AzureOpenAI SDK client
|
||||
|
||||
eg.
|
||||
route='/openai/deployments/vertex_ai/gemini-1.5-flash/chat/completions'
|
||||
"""
|
||||
# Add support for deployment and engine model paths
|
||||
deployment_pattern = r"^/openai/deployments/[^/]+/[^/]+/chat/completions$"
|
||||
engine_pattern = r"^/engines/[^/]+/chat/completions$"
|
||||
|
||||
if re.match(deployment_pattern, route) or re.match(engine_pattern, route):
|
||||
return True
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def _route_matches_pattern(route: str, pattern: str) -> bool:
|
||||
"""
|
||||
Check if route matches the pattern placed in proxy/_types.py
|
||||
|
||||
Example:
|
||||
- pattern: "/threads/{thread_id}"
|
||||
- route: "/threads/thread_49EIN5QF32s4mH20M7GFKdlZ"
|
||||
- returns: True
|
||||
|
||||
|
||||
- pattern: "/key/{token_id}/regenerate"
|
||||
- route: "/key/regenerate/82akk800000000jjsk"
|
||||
- returns: False, pattern is "/key/{token_id}/regenerate"
|
||||
"""
|
||||
pattern = re.sub(r"\{[^}]+\}", r"[^/]+", pattern)
|
||||
# Anchor the pattern to match the entire string
|
||||
pattern = f"^{pattern}$"
|
||||
if re.match(pattern, route):
|
||||
return True
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def _route_matches_wildcard_pattern(route: str, pattern: str) -> bool:
|
||||
"""
|
||||
Check if route matches the wildcard pattern
|
||||
|
||||
eg.
|
||||
|
||||
pattern: "/scim/v2/*"
|
||||
route: "/scim/v2/Users"
|
||||
- returns: True
|
||||
|
||||
pattern: "/scim/v2/*"
|
||||
route: "/chat/completions"
|
||||
- returns: False
|
||||
|
||||
|
||||
pattern: "/scim/v2/*"
|
||||
route: "/scim/v2/Users/123"
|
||||
- returns: True
|
||||
|
||||
"""
|
||||
if pattern.endswith("*"):
|
||||
# Get the prefix (everything before the wildcard)
|
||||
prefix = pattern[:-1]
|
||||
return route.startswith(prefix)
|
||||
else:
|
||||
# If there's no wildcard, the pattern and route should match exactly
|
||||
return route == pattern
|
||||
|
||||
@staticmethod
|
||||
def check_route_access(route: str, allowed_routes: List[str]) -> bool:
|
||||
"""
|
||||
Check if a route has access by checking both exact matches and patterns
|
||||
|
||||
Args:
|
||||
route (str): The route to check
|
||||
allowed_routes (list): List of allowed routes/patterns
|
||||
|
||||
Returns:
|
||||
bool: True if route is allowed, False otherwise
|
||||
"""
|
||||
return route in allowed_routes or any( # Check exact match
|
||||
RouteChecks._route_matches_pattern(route=route, pattern=allowed_route)
|
||||
for allowed_route in allowed_routes
|
||||
) # Check pattern match
|
||||
|
||||
@staticmethod
|
||||
def _is_assistants_api_request(request: Request) -> bool:
|
||||
"""
|
||||
Returns True if `thread` or `assistant` is in the request path
|
||||
|
||||
Args:
|
||||
request (Request): The request object
|
||||
|
||||
Returns:
|
||||
bool: True if `thread` or `assistant` is in the request path, False otherwise
|
||||
"""
|
||||
if "thread" in request.url.path or "assistant" in request.url.path:
|
||||
return True
|
||||
return False
|
||||
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user