structure saas with tools
This commit is contained in:
@@ -0,0 +1,3 @@
|
||||
## Supported Secret Managers to read credentials from
|
||||
|
||||
Example read OPENAI_API_KEY, AZURE_API_KEY from a secret manager
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,143 @@
|
||||
"""
|
||||
This is a file for the AWS Secret Manager Integration
|
||||
|
||||
Relevant issue: https://github.com/BerriAI/litellm/issues/1883
|
||||
|
||||
Requires:
|
||||
* `os.environ["AWS_REGION_NAME"],
|
||||
* `pip install boto3>=1.28.57`
|
||||
"""
|
||||
|
||||
import ast
|
||||
import base64
|
||||
import os
|
||||
import re
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import litellm
|
||||
from litellm.proxy._types import KeyManagementSystem
|
||||
|
||||
|
||||
def validate_environment():
|
||||
if "AWS_REGION_NAME" not in os.environ:
|
||||
raise ValueError("Missing required environment variable - AWS_REGION_NAME")
|
||||
|
||||
|
||||
def load_aws_kms(use_aws_kms: Optional[bool]):
|
||||
if use_aws_kms is None or use_aws_kms is False:
|
||||
return
|
||||
try:
|
||||
import boto3
|
||||
|
||||
validate_environment()
|
||||
|
||||
# Create a Secrets Manager client
|
||||
kms_client = boto3.client("kms", region_name=os.getenv("AWS_REGION_NAME"))
|
||||
|
||||
litellm.secret_manager_client = kms_client
|
||||
litellm._key_management_system = KeyManagementSystem.AWS_KMS
|
||||
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
|
||||
class AWSKeyManagementService_V2:
|
||||
"""
|
||||
V2 Clean Class for decrypting keys from AWS KeyManagementService
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.validate_environment()
|
||||
self.kms_client = self.load_aws_kms(use_aws_kms=True)
|
||||
|
||||
def validate_environment(
|
||||
self,
|
||||
):
|
||||
if "AWS_REGION_NAME" not in os.environ:
|
||||
raise ValueError("Missing required environment variable - AWS_REGION_NAME")
|
||||
|
||||
## CHECK IF LICENSE IN ENV ## - premium feature
|
||||
is_litellm_license_in_env: bool = False
|
||||
|
||||
if os.getenv("LITELLM_LICENSE", None) is not None:
|
||||
is_litellm_license_in_env = True
|
||||
elif os.getenv("LITELLM_SECRET_AWS_KMS_LITELLM_LICENSE", None) is not None:
|
||||
is_litellm_license_in_env = True
|
||||
if is_litellm_license_in_env is False:
|
||||
raise ValueError(
|
||||
"AWSKeyManagementService V2 is an Enterprise Feature. Please add a valid LITELLM_LICENSE to your envionment."
|
||||
)
|
||||
|
||||
def load_aws_kms(self, use_aws_kms: Optional[bool]):
|
||||
if use_aws_kms is None or use_aws_kms is False:
|
||||
return
|
||||
try:
|
||||
import boto3
|
||||
|
||||
validate_environment()
|
||||
|
||||
# Create a Secrets Manager client
|
||||
kms_client = boto3.client("kms", region_name=os.getenv("AWS_REGION_NAME"))
|
||||
|
||||
return kms_client
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
def decrypt_value(self, secret_name: str) -> Any:
|
||||
if self.kms_client is None:
|
||||
raise ValueError("kms_client is None")
|
||||
encrypted_value = os.getenv(secret_name, None)
|
||||
if encrypted_value is None:
|
||||
raise Exception(
|
||||
"AWS KMS - Encrypted Value of Key={} is None".format(secret_name)
|
||||
)
|
||||
if isinstance(encrypted_value, str) and encrypted_value.startswith("aws_kms/"):
|
||||
encrypted_value = encrypted_value.replace("aws_kms/", "")
|
||||
|
||||
# Decode the base64 encoded ciphertext
|
||||
ciphertext_blob = base64.b64decode(encrypted_value)
|
||||
|
||||
# Set up the parameters for the decrypt call
|
||||
params = {"CiphertextBlob": ciphertext_blob}
|
||||
# Perform the decryption
|
||||
response = self.kms_client.decrypt(**params)
|
||||
|
||||
# Extract and decode the plaintext
|
||||
plaintext = response["Plaintext"]
|
||||
secret = plaintext.decode("utf-8")
|
||||
if isinstance(secret, str):
|
||||
secret = secret.strip()
|
||||
try:
|
||||
secret_value_as_bool = ast.literal_eval(secret)
|
||||
if isinstance(secret_value_as_bool, bool):
|
||||
return secret_value_as_bool
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return secret
|
||||
|
||||
|
||||
"""
|
||||
- look for all values in the env with `aws_kms/<hashed_key>`
|
||||
- decrypt keys
|
||||
- rewrite env var with decrypted key (). Note: this environment variable will only be available to the current process and any child processes spawned from it. Once the Python script ends, the environment variable will not persist.
|
||||
"""
|
||||
|
||||
|
||||
def decrypt_env_var() -> Dict[str, Any]:
|
||||
# setup client class
|
||||
aws_kms = AWSKeyManagementService_V2()
|
||||
# iterate through env - for `aws_kms/`
|
||||
new_values = {}
|
||||
for k, v in os.environ.items():
|
||||
if (
|
||||
k is not None
|
||||
and isinstance(k, str)
|
||||
and k.lower().startswith("litellm_secret_aws_kms")
|
||||
) or (v is not None and isinstance(v, str) and v.startswith("aws_kms/")):
|
||||
decrypted_value = aws_kms.decrypt_value(secret_name=k)
|
||||
# reset env var
|
||||
k = re.sub("litellm_secret_aws_kms_", "", k, flags=re.IGNORECASE)
|
||||
new_values[k] = decrypted_value
|
||||
|
||||
return new_values
|
||||
@@ -0,0 +1,360 @@
|
||||
"""
|
||||
This is a file for the AWS Secret Manager Integration
|
||||
|
||||
Handles Async Operations for:
|
||||
- Read Secret
|
||||
- Write Secret
|
||||
- Delete Secret
|
||||
|
||||
Relevant issue: https://github.com/BerriAI/litellm/issues/1883
|
||||
|
||||
Requires:
|
||||
* `os.environ["AWS_REGION_NAME"],
|
||||
* `pip install boto3>=1.28.57`
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
import httpx
|
||||
|
||||
import litellm
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.llms.bedrock.base_aws_llm import BaseAWSLLM
|
||||
from litellm.llms.custom_httpx.http_handler import (
|
||||
_get_httpx_client,
|
||||
get_async_httpx_client,
|
||||
)
|
||||
from litellm.proxy._types import KeyManagementSystem
|
||||
from litellm.types.llms.custom_http import httpxSpecialProvider
|
||||
|
||||
from .base_secret_manager import BaseSecretManager
|
||||
|
||||
|
||||
class AWSSecretsManagerV2(BaseAWSLLM, BaseSecretManager):
|
||||
def __init__(self, **kwargs):
|
||||
BaseSecretManager.__init__(self, **kwargs)
|
||||
BaseAWSLLM.__init__(self, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def validate_environment(cls):
|
||||
if "AWS_REGION_NAME" not in os.environ:
|
||||
raise ValueError("Missing required environment variable - AWS_REGION_NAME")
|
||||
|
||||
@classmethod
|
||||
def load_aws_secret_manager(cls, use_aws_secret_manager: Optional[bool]):
|
||||
"""
|
||||
Initialize AWSSecretsManagerV2 and sets litellm.secret_manager_client = AWSSecretsManagerV2() and litellm._key_management_system = KeyManagementSystem.AWS_SECRET_MANAGER
|
||||
"""
|
||||
if use_aws_secret_manager is None or use_aws_secret_manager is False:
|
||||
return
|
||||
try:
|
||||
cls.validate_environment()
|
||||
litellm.secret_manager_client = cls()
|
||||
litellm._key_management_system = KeyManagementSystem.AWS_SECRET_MANAGER
|
||||
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
async def async_read_secret(
|
||||
self,
|
||||
secret_name: str,
|
||||
optional_params: Optional[dict] = None,
|
||||
timeout: Optional[Union[float, httpx.Timeout]] = None,
|
||||
primary_secret_name: Optional[str] = None,
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
Async function to read a secret from AWS Secrets Manager
|
||||
|
||||
Returns:
|
||||
str: Secret value
|
||||
Raises:
|
||||
ValueError: If the secret is not found or an HTTP error occurs
|
||||
"""
|
||||
if primary_secret_name:
|
||||
return await self.async_read_secret_from_primary_secret(
|
||||
secret_name=secret_name, primary_secret_name=primary_secret_name
|
||||
)
|
||||
|
||||
endpoint_url, headers, body = self._prepare_request(
|
||||
action="GetSecretValue",
|
||||
secret_name=secret_name,
|
||||
optional_params=optional_params,
|
||||
)
|
||||
|
||||
async_client = get_async_httpx_client(
|
||||
llm_provider=httpxSpecialProvider.SecretManager,
|
||||
params={"timeout": timeout},
|
||||
)
|
||||
|
||||
try:
|
||||
response = await async_client.post(
|
||||
url=endpoint_url, headers=headers, data=body.decode("utf-8")
|
||||
)
|
||||
response.raise_for_status()
|
||||
return response.json()["SecretString"]
|
||||
except httpx.TimeoutException:
|
||||
raise ValueError("Timeout error occurred")
|
||||
except Exception as e:
|
||||
verbose_logger.exception(
|
||||
"Error reading secret='%s' from AWS Secrets Manager: %s",
|
||||
secret_name,
|
||||
str(e),
|
||||
)
|
||||
return None
|
||||
|
||||
def sync_read_secret(
|
||||
self,
|
||||
secret_name: str,
|
||||
optional_params: Optional[dict] = None,
|
||||
timeout: Optional[Union[float, httpx.Timeout]] = None,
|
||||
primary_secret_name: Optional[str] = None,
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
Sync function to read a secret from AWS Secrets Manager
|
||||
|
||||
Done for backwards compatibility with existing codebase, since get_secret is a sync function
|
||||
"""
|
||||
# self._prepare_request uses these env vars, we cannot read them from AWS Secrets Manager. If we do we'd get stuck in an infinite loop
|
||||
if secret_name in [
|
||||
"AWS_ACCESS_KEY_ID",
|
||||
"AWS_SECRET_ACCESS_KEY",
|
||||
"AWS_REGION_NAME",
|
||||
"AWS_REGION",
|
||||
"AWS_BEDROCK_RUNTIME_ENDPOINT",
|
||||
]:
|
||||
return os.getenv(secret_name)
|
||||
|
||||
if primary_secret_name:
|
||||
return self.sync_read_secret_from_primary_secret(
|
||||
secret_name=secret_name, primary_secret_name=primary_secret_name
|
||||
)
|
||||
|
||||
endpoint_url, headers, body = self._prepare_request(
|
||||
action="GetSecretValue",
|
||||
secret_name=secret_name,
|
||||
optional_params=optional_params,
|
||||
)
|
||||
|
||||
sync_client = _get_httpx_client(
|
||||
params={"timeout": timeout},
|
||||
)
|
||||
|
||||
try:
|
||||
response = sync_client.post(
|
||||
url=endpoint_url, headers=headers, data=body.decode("utf-8")
|
||||
)
|
||||
return response.json()["SecretString"]
|
||||
except httpx.TimeoutException:
|
||||
raise ValueError("Timeout error occurred")
|
||||
except httpx.HTTPStatusError as e:
|
||||
verbose_logger.exception(
|
||||
"Error reading secret='%s' from AWS Secrets Manager: %s, %s",
|
||||
secret_name,
|
||||
str(e.response.text),
|
||||
str(e.response.status_code),
|
||||
)
|
||||
except Exception as e:
|
||||
verbose_logger.exception(
|
||||
"Error reading secret='%s' from AWS Secrets Manager: %s",
|
||||
secret_name,
|
||||
str(e),
|
||||
)
|
||||
return None
|
||||
|
||||
def _parse_primary_secret(self, primary_secret_json_str: Optional[str]) -> dict:
|
||||
"""
|
||||
Parse the primary secret JSON string into a dictionary
|
||||
|
||||
Args:
|
||||
primary_secret_json_str: JSON string containing key-value pairs
|
||||
|
||||
Returns:
|
||||
Dictionary of key-value pairs from the primary secret
|
||||
"""
|
||||
return json.loads(primary_secret_json_str or "{}")
|
||||
|
||||
def sync_read_secret_from_primary_secret(
|
||||
self, secret_name: str, primary_secret_name: str
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
Read a secret from the primary secret
|
||||
"""
|
||||
primary_secret_json_str = self.sync_read_secret(secret_name=primary_secret_name)
|
||||
primary_secret_kv_pairs = self._parse_primary_secret(primary_secret_json_str)
|
||||
return primary_secret_kv_pairs.get(secret_name)
|
||||
|
||||
async def async_read_secret_from_primary_secret(
|
||||
self, secret_name: str, primary_secret_name: str
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
Read a secret from the primary secret
|
||||
"""
|
||||
primary_secret_json_str = await self.async_read_secret(
|
||||
secret_name=primary_secret_name
|
||||
)
|
||||
primary_secret_kv_pairs = self._parse_primary_secret(primary_secret_json_str)
|
||||
return primary_secret_kv_pairs.get(secret_name)
|
||||
|
||||
async def async_write_secret(
|
||||
self,
|
||||
secret_name: str,
|
||||
secret_value: str,
|
||||
description: Optional[str] = None,
|
||||
optional_params: Optional[dict] = None,
|
||||
timeout: Optional[Union[float, httpx.Timeout]] = None,
|
||||
) -> dict:
|
||||
"""
|
||||
Async function to write a secret to AWS Secrets Manager
|
||||
|
||||
Args:
|
||||
secret_name: Name of the secret
|
||||
secret_value: Value to store (can be a JSON string)
|
||||
description: Optional description for the secret
|
||||
optional_params: Additional AWS parameters
|
||||
timeout: Request timeout
|
||||
"""
|
||||
import uuid
|
||||
|
||||
# Prepare the request data
|
||||
data = {"Name": secret_name, "SecretString": secret_value}
|
||||
if description:
|
||||
data["Description"] = description
|
||||
|
||||
data["ClientRequestToken"] = str(uuid.uuid4())
|
||||
|
||||
endpoint_url, headers, body = self._prepare_request(
|
||||
action="CreateSecret",
|
||||
secret_name=secret_name,
|
||||
secret_value=secret_value,
|
||||
optional_params=optional_params,
|
||||
request_data=data, # Pass the complete request data
|
||||
)
|
||||
|
||||
async_client = get_async_httpx_client(
|
||||
llm_provider=httpxSpecialProvider.SecretManager,
|
||||
params={"timeout": timeout},
|
||||
)
|
||||
|
||||
try:
|
||||
response = await async_client.post(
|
||||
url=endpoint_url, headers=headers, data=body.decode("utf-8")
|
||||
)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
except httpx.HTTPStatusError as err:
|
||||
raise ValueError(f"HTTP error occurred: {err.response.text}")
|
||||
except httpx.TimeoutException:
|
||||
raise ValueError("Timeout error occurred")
|
||||
|
||||
async def async_delete_secret(
|
||||
self,
|
||||
secret_name: str,
|
||||
recovery_window_in_days: Optional[int] = 7,
|
||||
optional_params: Optional[dict] = None,
|
||||
timeout: Optional[Union[float, httpx.Timeout]] = None,
|
||||
) -> dict:
|
||||
"""
|
||||
Async function to delete a secret from AWS Secrets Manager
|
||||
|
||||
Args:
|
||||
secret_name: Name of the secret to delete
|
||||
recovery_window_in_days: Number of days before permanent deletion (default: 7)
|
||||
optional_params: Additional AWS parameters
|
||||
timeout: Request timeout
|
||||
|
||||
Returns:
|
||||
dict: Response from AWS Secrets Manager containing deletion details
|
||||
"""
|
||||
# Prepare the request data
|
||||
data = {
|
||||
"SecretId": secret_name,
|
||||
"RecoveryWindowInDays": recovery_window_in_days,
|
||||
}
|
||||
|
||||
endpoint_url, headers, body = self._prepare_request(
|
||||
action="DeleteSecret",
|
||||
secret_name=secret_name,
|
||||
optional_params=optional_params,
|
||||
request_data=data,
|
||||
)
|
||||
|
||||
async_client = get_async_httpx_client(
|
||||
llm_provider=httpxSpecialProvider.SecretManager,
|
||||
params={"timeout": timeout},
|
||||
)
|
||||
|
||||
try:
|
||||
response = await async_client.post(
|
||||
url=endpoint_url, headers=headers, data=body.decode("utf-8")
|
||||
)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
except httpx.HTTPStatusError as err:
|
||||
raise ValueError(f"HTTP error occurred: {err.response.text}")
|
||||
except httpx.TimeoutException:
|
||||
raise ValueError("Timeout error occurred")
|
||||
|
||||
def _prepare_request(
|
||||
self,
|
||||
action: str, # "GetSecretValue" or "PutSecretValue"
|
||||
secret_name: str,
|
||||
secret_value: Optional[str] = None,
|
||||
optional_params: Optional[dict] = None,
|
||||
request_data: Optional[dict] = None,
|
||||
) -> tuple[str, Any, bytes]:
|
||||
"""Prepare the AWS Secrets Manager request"""
|
||||
try:
|
||||
from botocore.auth import SigV4Auth
|
||||
from botocore.awsrequest import AWSRequest
|
||||
except ImportError:
|
||||
raise ImportError("Missing boto3 to call bedrock. Run 'pip install boto3'.")
|
||||
optional_params = optional_params or {}
|
||||
boto3_credentials_info = self._get_boto_credentials_from_optional_params(
|
||||
optional_params
|
||||
)
|
||||
|
||||
# Get endpoint
|
||||
_, endpoint_url = self.get_runtime_endpoint(
|
||||
api_base=None,
|
||||
aws_bedrock_runtime_endpoint=boto3_credentials_info.aws_bedrock_runtime_endpoint,
|
||||
aws_region_name=boto3_credentials_info.aws_region_name,
|
||||
)
|
||||
endpoint_url = endpoint_url.replace("bedrock-runtime", "secretsmanager")
|
||||
|
||||
# Use provided request_data if available, otherwise build default data
|
||||
if request_data:
|
||||
data = request_data
|
||||
else:
|
||||
data = {"SecretId": secret_name}
|
||||
if secret_value and action == "PutSecretValue":
|
||||
data["SecretString"] = secret_value
|
||||
|
||||
body = json.dumps(data).encode("utf-8")
|
||||
headers = {
|
||||
"Content-Type": "application/x-amz-json-1.1",
|
||||
"X-Amz-Target": f"secretsmanager.{action}",
|
||||
}
|
||||
|
||||
# Sign request
|
||||
request = AWSRequest(
|
||||
method="POST", url=endpoint_url, data=body, headers=headers
|
||||
)
|
||||
SigV4Auth(
|
||||
boto3_credentials_info.credentials,
|
||||
"secretsmanager",
|
||||
boto3_credentials_info.aws_region_name,
|
||||
).add_auth(request)
|
||||
prepped = request.prepare()
|
||||
|
||||
return endpoint_url, prepped.headers, body
|
||||
|
||||
|
||||
# if __name__ == "__main__":
|
||||
# print("loading aws secret manager v2")
|
||||
# aws_secret_manager_v2 = AWSSecretsManagerV2()
|
||||
|
||||
# print("writing secret to aws secret manager v2")
|
||||
# asyncio.run(aws_secret_manager_v2.async_write_secret(secret_name="test_secret_3", secret_value="test_value_2"))
|
||||
# print("reading secret from aws secret manager v2")
|
||||
@@ -0,0 +1,176 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, Optional, Union
|
||||
|
||||
import httpx
|
||||
|
||||
from litellm import verbose_logger
|
||||
|
||||
|
||||
class BaseSecretManager(ABC):
|
||||
"""
|
||||
Abstract base class for secret management implementations.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def async_read_secret(
|
||||
self,
|
||||
secret_name: str,
|
||||
optional_params: Optional[dict] = None,
|
||||
timeout: Optional[Union[float, httpx.Timeout]] = None,
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
Asynchronously read a secret from the secret manager.
|
||||
|
||||
Args:
|
||||
secret_name (str): Name/path of the secret to read
|
||||
optional_params (Optional[dict]): Additional parameters specific to the secret manager
|
||||
timeout (Optional[Union[float, httpx.Timeout]]): Request timeout
|
||||
|
||||
Returns:
|
||||
Optional[str]: The secret value if found, None otherwise
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def sync_read_secret(
|
||||
self,
|
||||
secret_name: str,
|
||||
optional_params: Optional[dict] = None,
|
||||
timeout: Optional[Union[float, httpx.Timeout]] = None,
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
Synchronously read a secret from the secret manager.
|
||||
|
||||
Args:
|
||||
secret_name (str): Name/path of the secret to read
|
||||
optional_params (Optional[dict]): Additional parameters specific to the secret manager
|
||||
timeout (Optional[Union[float, httpx.Timeout]]): Request timeout
|
||||
|
||||
Returns:
|
||||
Optional[str]: The secret value if found, None otherwise
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def async_write_secret(
|
||||
self,
|
||||
secret_name: str,
|
||||
secret_value: str,
|
||||
description: Optional[str] = None,
|
||||
optional_params: Optional[dict] = None,
|
||||
timeout: Optional[Union[float, httpx.Timeout]] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Asynchronously write a secret to the secret manager.
|
||||
|
||||
Args:
|
||||
secret_name (str): Name/path of the secret to write
|
||||
secret_value (str): Value to store
|
||||
description (Optional[str]): Description of the secret. Some secret managers allow storing a description with the secret.
|
||||
optional_params (Optional[dict]): Additional parameters specific to the secret manager
|
||||
timeout (Optional[Union[float, httpx.Timeout]]): Request timeout
|
||||
Returns:
|
||||
Dict[str, Any]: Response from the secret manager containing write operation details
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def async_delete_secret(
|
||||
self,
|
||||
secret_name: str,
|
||||
recovery_window_in_days: Optional[int] = 7,
|
||||
optional_params: Optional[dict] = None,
|
||||
timeout: Optional[Union[float, httpx.Timeout]] = None,
|
||||
) -> dict:
|
||||
"""
|
||||
Async function to delete a secret from the secret manager
|
||||
|
||||
Args:
|
||||
secret_name: Name of the secret to delete
|
||||
recovery_window_in_days: Number of days before permanent deletion (default: 7)
|
||||
optional_params: Additional parameters specific to the secret manager
|
||||
timeout: Request timeout
|
||||
|
||||
Returns:
|
||||
dict: Response from the secret manager containing deletion details
|
||||
"""
|
||||
pass
|
||||
|
||||
async def async_rotate_secret(
|
||||
self,
|
||||
current_secret_name: str,
|
||||
new_secret_name: str,
|
||||
new_secret_value: str,
|
||||
optional_params: Optional[dict] = None,
|
||||
timeout: Optional[Union[float, httpx.Timeout]] = None,
|
||||
) -> dict:
|
||||
"""
|
||||
Async function to rotate a secret by creating a new one and deleting the old one.
|
||||
This allows for both value and name changes during rotation.
|
||||
|
||||
Args:
|
||||
current_secret_name: Current name of the secret
|
||||
new_secret_name: New name for the secret
|
||||
new_secret_value: New value for the secret
|
||||
optional_params: Additional AWS parameters
|
||||
timeout: Request timeout
|
||||
|
||||
Returns:
|
||||
dict: Response containing the new secret details
|
||||
|
||||
Raises:
|
||||
ValueError: If the secret doesn't exist or if there's an HTTP error
|
||||
"""
|
||||
try:
|
||||
# First verify the old secret exists
|
||||
old_secret = await self.async_read_secret(
|
||||
secret_name=current_secret_name,
|
||||
optional_params=optional_params,
|
||||
timeout=timeout,
|
||||
)
|
||||
|
||||
if old_secret is None:
|
||||
raise ValueError(f"Current secret {current_secret_name} not found")
|
||||
|
||||
# Create new secret with new name and value
|
||||
create_response = await self.async_write_secret(
|
||||
secret_name=new_secret_name,
|
||||
secret_value=new_secret_value,
|
||||
description=f"Rotated from {current_secret_name}",
|
||||
optional_params=optional_params,
|
||||
timeout=timeout,
|
||||
)
|
||||
|
||||
# Verify new secret was created successfully
|
||||
new_secret = await self.async_read_secret(
|
||||
secret_name=new_secret_name,
|
||||
optional_params=optional_params,
|
||||
timeout=timeout,
|
||||
)
|
||||
|
||||
if new_secret is None:
|
||||
raise ValueError(f"Failed to verify new secret {new_secret_name}")
|
||||
|
||||
# If everything is successful, delete the old secret
|
||||
await self.async_delete_secret(
|
||||
secret_name=current_secret_name,
|
||||
recovery_window_in_days=7, # Keep for recovery if needed
|
||||
optional_params=optional_params,
|
||||
timeout=timeout,
|
||||
)
|
||||
|
||||
return create_response
|
||||
|
||||
except httpx.HTTPStatusError as err:
|
||||
verbose_logger.exception(
|
||||
"Error rotating secret in AWS Secrets Manager: %s",
|
||||
str(err.response.text),
|
||||
)
|
||||
raise ValueError(f"HTTP error occurred: {err.response.text}")
|
||||
except httpx.TimeoutException:
|
||||
raise ValueError("Timeout error occurred")
|
||||
except Exception as e:
|
||||
verbose_logger.exception(
|
||||
"Error rotating secret in AWS Secrets Manager: %s", str(e)
|
||||
)
|
||||
raise
|
||||
@@ -0,0 +1,38 @@
|
||||
import os
|
||||
from typing import Callable
|
||||
|
||||
|
||||
def get_azure_ad_token_provider() -> Callable[[], str]:
|
||||
"""
|
||||
Get Azure AD token provider based on Service Principal with Secret workflow.
|
||||
|
||||
Based on: https://github.com/openai/openai-python/blob/main/examples/azure_ad.py
|
||||
See Also:
|
||||
https://learn.microsoft.com/en-us/python/api/overview/azure/identity-readme?view=azure-python#service-principal-with-secret;
|
||||
https://learn.microsoft.com/en-us/python/api/azure-identity/azure.identity.clientsecretcredential?view=azure-python.
|
||||
|
||||
Returns:
|
||||
Callable that returns a temporary authentication token.
|
||||
"""
|
||||
import azure.identity as identity
|
||||
from azure.identity import get_bearer_token_provider
|
||||
|
||||
azure_scope = os.environ.get(
|
||||
"AZURE_SCOPE", "https://cognitiveservices.azure.com/.default"
|
||||
)
|
||||
cred = os.environ.get("AZURE_CREDENTIAL", "ClientSecretCredential")
|
||||
|
||||
cred_cls = getattr(identity, cred)
|
||||
# ClientSecretCredential, DefaultAzureCredential, AzureCliCredential
|
||||
if cred == "ClientSecretCredential":
|
||||
credential = cred_cls(
|
||||
client_id=os.environ["AZURE_CLIENT_ID"],
|
||||
client_secret=os.environ["AZURE_CLIENT_SECRET"],
|
||||
tenant_id=os.environ["AZURE_TENANT_ID"],
|
||||
)
|
||||
elif cred == "ManagedIdentityCredential":
|
||||
credential = cred_cls(client_id=os.environ["AZURE_CLIENT_ID"])
|
||||
else:
|
||||
credential = cred_cls()
|
||||
|
||||
return get_bearer_token_provider(credential, azure_scope)
|
||||
@@ -0,0 +1,43 @@
|
||||
"""
|
||||
This is a file for the Google KMS integration
|
||||
|
||||
Relevant issue: https://github.com/BerriAI/litellm/issues/1235
|
||||
|
||||
Requires:
|
||||
* `os.environ["GOOGLE_APPLICATION_CREDENTIALS"], os.environ["GOOGLE_KMS_RESOURCE_NAME"]`
|
||||
* `pip install google-cloud-kms`
|
||||
"""
|
||||
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
import litellm
|
||||
from litellm.proxy._types import KeyManagementSystem
|
||||
|
||||
|
||||
def validate_environment():
|
||||
if "GOOGLE_APPLICATION_CREDENTIALS" not in os.environ:
|
||||
raise ValueError(
|
||||
"Missing required environment variable - GOOGLE_APPLICATION_CREDENTIALS"
|
||||
)
|
||||
if "GOOGLE_KMS_RESOURCE_NAME" not in os.environ:
|
||||
raise ValueError(
|
||||
"Missing required environment variable - GOOGLE_KMS_RESOURCE_NAME"
|
||||
)
|
||||
|
||||
|
||||
def load_google_kms(use_google_kms: Optional[bool]):
|
||||
if use_google_kms is None or use_google_kms is False:
|
||||
return
|
||||
try:
|
||||
from google.cloud import kms_v1 # type: ignore
|
||||
|
||||
validate_environment()
|
||||
|
||||
# Create the KMS client
|
||||
client = kms_v1.KeyManagementServiceClient()
|
||||
litellm.secret_manager_client = client
|
||||
litellm._key_management_system = KeyManagementSystem.GOOGLE_KMS
|
||||
litellm._google_kms_resource_name = os.getenv("GOOGLE_KMS_RESOURCE_NAME")
|
||||
except Exception as e:
|
||||
raise e
|
||||
@@ -0,0 +1,117 @@
|
||||
import base64
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
import litellm
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.caching.caching import InMemoryCache
|
||||
from litellm.constants import SECRET_MANAGER_REFRESH_INTERVAL
|
||||
from litellm.integrations.gcs_bucket.gcs_bucket_base import GCSBucketBase
|
||||
from litellm.llms.custom_httpx.http_handler import _get_httpx_client
|
||||
from litellm.proxy._types import CommonProxyErrors, KeyManagementSystem
|
||||
|
||||
|
||||
class GoogleSecretManager(GCSBucketBase):
|
||||
def __init__(
|
||||
self,
|
||||
refresh_interval: Optional[int] = SECRET_MANAGER_REFRESH_INTERVAL,
|
||||
always_read_secret_manager: Optional[bool] = False,
|
||||
) -> None:
|
||||
"""
|
||||
Args:
|
||||
refresh_interval (int, optional): The refresh interval in seconds. Defaults to 86400. (24 hours)
|
||||
always_read_secret_manager (bool, optional): Whether to always read from the secret manager. Defaults to False. Since we do want to cache values
|
||||
"""
|
||||
from litellm.proxy.proxy_server import premium_user
|
||||
|
||||
if premium_user is not True:
|
||||
raise ValueError(
|
||||
f"Google Secret Manager requires an Enterprise License {CommonProxyErrors.not_premium_user.value}"
|
||||
)
|
||||
super().__init__()
|
||||
self.PROJECT_ID = os.environ.get("GOOGLE_SECRET_MANAGER_PROJECT_ID", None)
|
||||
if self.PROJECT_ID is None:
|
||||
raise ValueError(
|
||||
"Google Secret Manager requires a project ID, please set 'GOOGLE_SECRET_MANAGER_PROJECT_ID' in your .env"
|
||||
)
|
||||
self.sync_httpx_client = _get_httpx_client()
|
||||
litellm.secret_manager_client = self
|
||||
litellm._key_management_system = KeyManagementSystem.GOOGLE_SECRET_MANAGER
|
||||
_refresh_interval = os.environ.get(
|
||||
"GOOGLE_SECRET_MANAGER_REFRESH_INTERVAL", refresh_interval
|
||||
)
|
||||
_refresh_interval = (
|
||||
int(_refresh_interval) if _refresh_interval else refresh_interval
|
||||
)
|
||||
self.cache = InMemoryCache(
|
||||
default_ttl=_refresh_interval
|
||||
) # store in memory for 1 day
|
||||
|
||||
_always_read_secret_manager = os.environ.get(
|
||||
"GOOGLE_SECRET_MANAGER_ALWAYS_READ_SECRET_MANAGER",
|
||||
)
|
||||
if (
|
||||
_always_read_secret_manager
|
||||
and _always_read_secret_manager.lower() == "true"
|
||||
):
|
||||
self.always_read_secret_manager = True
|
||||
else:
|
||||
# by default this should be False, we want to use in memory caching for this. It's a bad idea to fetch from secret manager for all requests
|
||||
self.always_read_secret_manager = always_read_secret_manager or False
|
||||
|
||||
def get_secret_from_google_secret_manager(self, secret_name: str) -> Optional[str]:
|
||||
"""
|
||||
Retrieve a secret from Google Secret Manager or cache.
|
||||
|
||||
Args:
|
||||
secret_name (str): The name of the secret.
|
||||
|
||||
Returns:
|
||||
str: The secret value if successful, None otherwise.
|
||||
"""
|
||||
if self.always_read_secret_manager is not True:
|
||||
cached_secret = self.cache.get_cache(secret_name)
|
||||
if cached_secret is not None:
|
||||
return cached_secret
|
||||
if secret_name in self.cache.cache_dict:
|
||||
return cached_secret
|
||||
|
||||
_secret_name = (
|
||||
f"projects/{self.PROJECT_ID}/secrets/{secret_name}/versions/latest"
|
||||
)
|
||||
headers = self.sync_construct_request_headers()
|
||||
url = f"https://secretmanager.googleapis.com/v1/{_secret_name}:access"
|
||||
|
||||
# Send the GET request to retrieve the secret
|
||||
response = self.sync_httpx_client.get(url=url, headers=headers)
|
||||
|
||||
if response.status_code != 200:
|
||||
verbose_logger.error(
|
||||
"Google Secret Manager retrieval error: %s", str(response.text)
|
||||
)
|
||||
self.cache.set_cache(
|
||||
secret_name, None
|
||||
) # Cache that the secret was not found
|
||||
raise ValueError(
|
||||
f"secret {secret_name} not found in Google Secret Manager. Error: {response.text}"
|
||||
)
|
||||
|
||||
verbose_logger.debug(
|
||||
"Google Secret Manager retrieval response status code: %s",
|
||||
response.status_code,
|
||||
)
|
||||
|
||||
# Parse the JSON response and return the secret value
|
||||
secret_data = response.json()
|
||||
_base64_encoded_value = secret_data.get("payload", {}).get("data")
|
||||
|
||||
# decode the base64 encoded value
|
||||
if _base64_encoded_value is not None:
|
||||
_decoded_value = base64.b64decode(_base64_encoded_value).decode("utf-8")
|
||||
self.cache.set_cache(
|
||||
secret_name, _decoded_value
|
||||
) # Cache the retrieved secret
|
||||
return _decoded_value
|
||||
|
||||
self.cache.set_cache(secret_name, None) # Cache that the secret was not found
|
||||
raise ValueError(f"secret {secret_name} not found in Google Secret Manager")
|
||||
@@ -0,0 +1,325 @@
|
||||
import os
|
||||
from typing import Any, Dict, Optional, Union
|
||||
|
||||
import httpx
|
||||
|
||||
import litellm
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.caching import InMemoryCache
|
||||
from litellm.constants import SECRET_MANAGER_REFRESH_INTERVAL
|
||||
from litellm.llms.custom_httpx.http_handler import (
|
||||
_get_httpx_client,
|
||||
get_async_httpx_client,
|
||||
httpxSpecialProvider,
|
||||
)
|
||||
from litellm.proxy._types import KeyManagementSystem
|
||||
|
||||
from .base_secret_manager import BaseSecretManager
|
||||
|
||||
|
||||
class HashicorpSecretManager(BaseSecretManager):
|
||||
def __init__(self):
|
||||
from litellm.proxy.proxy_server import CommonProxyErrors, premium_user
|
||||
|
||||
# Vault-specific config
|
||||
self.vault_addr = os.getenv("HCP_VAULT_ADDR", "http://127.0.0.1:8200")
|
||||
self.vault_token = os.getenv("HCP_VAULT_TOKEN", "")
|
||||
# If your KV engine is mounted somewhere other than "secret", adjust here:
|
||||
self.vault_namespace = os.getenv("HCP_VAULT_NAMESPACE", None)
|
||||
|
||||
# Optional config for TLS cert auth
|
||||
self.tls_cert_path = os.getenv("HCP_VAULT_CLIENT_CERT", "")
|
||||
self.tls_key_path = os.getenv("HCP_VAULT_CLIENT_KEY", "")
|
||||
self.vault_cert_role = os.getenv("HCP_VAULT_CERT_ROLE", None)
|
||||
|
||||
# Validate environment
|
||||
if not self.vault_token:
|
||||
raise ValueError(
|
||||
"Missing Vault token. Please set HCP_VAULT_TOKEN in your environment."
|
||||
)
|
||||
|
||||
litellm.secret_manager_client = self
|
||||
litellm._key_management_system = KeyManagementSystem.HASHICORP_VAULT
|
||||
_refresh_interval = os.environ.get(
|
||||
"HCP_VAULT_REFRESH_INTERVAL", SECRET_MANAGER_REFRESH_INTERVAL
|
||||
)
|
||||
_refresh_interval = (
|
||||
int(_refresh_interval)
|
||||
if _refresh_interval
|
||||
else SECRET_MANAGER_REFRESH_INTERVAL
|
||||
)
|
||||
self.cache = InMemoryCache(
|
||||
default_ttl=_refresh_interval
|
||||
) # store in memory for 1 day
|
||||
|
||||
if premium_user is not True:
|
||||
raise ValueError(
|
||||
f"Hashicorp secret manager is only available for premium users. {CommonProxyErrors.not_premium_user.value}"
|
||||
)
|
||||
|
||||
def _auth_via_tls_cert(self) -> str:
|
||||
"""
|
||||
Ref: https://developer.hashicorp.com/vault/api-docs/auth/cert
|
||||
|
||||
Request:
|
||||
```
|
||||
curl \
|
||||
--request POST \
|
||||
--cacert vault-ca.pem \
|
||||
--cert cert.pem \
|
||||
--key key.pem \
|
||||
--header "X-Vault-Namespace: mynamespace/" \
|
||||
--data '{"name": "my-cert-role"}' \
|
||||
https://127.0.0.1:8200/v1/auth/cert/login
|
||||
```
|
||||
|
||||
Response:
|
||||
```
|
||||
{
|
||||
"auth": {
|
||||
"client_token": "cf95f87d-f95b-47ff-b1f5-ba7bff850425",
|
||||
"policies": ["web", "stage"],
|
||||
"lease_duration": 3600,
|
||||
"renewable": true
|
||||
}
|
||||
}
|
||||
```
|
||||
"""
|
||||
verbose_logger.debug("Using TLS cert auth for Hashicorp Vault")
|
||||
# Vault endpoint for cert-based login, e.g. '/v1/auth/cert/login'
|
||||
login_url = f"{self.vault_addr}/v1/auth/cert/login"
|
||||
|
||||
# Include your Vault namespace in the header if you're using namespaces.
|
||||
# E.g. self.vault_namespace = 'mynamespace/'
|
||||
# If you only have root namespace, you can omit this header entirely.
|
||||
headers = {}
|
||||
if hasattr(self, "vault_namespace") and self.vault_namespace:
|
||||
headers["X-Vault-Namespace"] = self.vault_namespace
|
||||
try:
|
||||
# We use the client cert and key for mutual TLS
|
||||
client = httpx.Client(cert=(self.tls_cert_path, self.tls_key_path))
|
||||
resp = client.post(
|
||||
login_url,
|
||||
headers=headers,
|
||||
json=self._get_tls_cert_auth_body(),
|
||||
)
|
||||
resp.raise_for_status()
|
||||
token = resp.json()["auth"]["client_token"]
|
||||
_lease_duration = resp.json()["auth"]["lease_duration"]
|
||||
verbose_logger.info("Successfully obtained Vault token via TLS cert auth.")
|
||||
self.cache.set_cache(
|
||||
key="hcp_vault_token", value=token, ttl=_lease_duration
|
||||
)
|
||||
return token
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Could not authenticate to Vault via TLS cert: {e}")
|
||||
|
||||
def _get_tls_cert_auth_body(self) -> dict:
|
||||
return {"name": self.vault_cert_role}
|
||||
|
||||
def get_url(self, secret_name: str) -> str:
|
||||
_url = f"{self.vault_addr}/v1/"
|
||||
if self.vault_namespace:
|
||||
_url += f"{self.vault_namespace}/"
|
||||
_url += f"secret/data/{secret_name}"
|
||||
return _url
|
||||
|
||||
def _get_request_headers(self) -> dict:
|
||||
if self.tls_cert_path and self.tls_key_path:
|
||||
return {"X-Vault-Token": self._auth_via_tls_cert()}
|
||||
return {"X-Vault-Token": self.vault_token}
|
||||
|
||||
async def async_read_secret(
|
||||
self,
|
||||
secret_name: str,
|
||||
optional_params: Optional[dict] = None,
|
||||
timeout: Optional[Union[float, httpx.Timeout]] = None,
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
Reads a secret from Vault KV v2 using an async HTTPX client.
|
||||
secret_name is just the path inside the KV mount (e.g., 'myapp/config').
|
||||
Returns the entire data dict from data.data, or None on failure.
|
||||
"""
|
||||
if self.cache.get_cache(secret_name) is not None:
|
||||
return self.cache.get_cache(secret_name)
|
||||
async_client = get_async_httpx_client(
|
||||
llm_provider=httpxSpecialProvider.SecretManager,
|
||||
)
|
||||
try:
|
||||
# For KV v2: /v1/<mount>/data/<path>
|
||||
# Example: http://127.0.0.1:8200/v1/secret/data/myapp/config
|
||||
_url = self.get_url(secret_name)
|
||||
url = _url
|
||||
|
||||
response = await async_client.get(url, headers=self._get_request_headers())
|
||||
response.raise_for_status()
|
||||
|
||||
# For KV v2, the secret is in response.json()["data"]["data"]
|
||||
json_resp = response.json()
|
||||
_value = self._get_secret_value_from_json_response(json_resp)
|
||||
self.cache.set_cache(secret_name, _value)
|
||||
return _value
|
||||
|
||||
except Exception as e:
|
||||
verbose_logger.exception(f"Error reading secret from Hashicorp Vault: {e}")
|
||||
return None
|
||||
|
||||
def sync_read_secret(
|
||||
self,
|
||||
secret_name: str,
|
||||
optional_params: Optional[dict] = None,
|
||||
timeout: Optional[Union[float, httpx.Timeout]] = None,
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
Reads a secret from Vault KV v2 using a sync HTTPX client.
|
||||
secret_name is just the path inside the KV mount (e.g., 'myapp/config').
|
||||
Returns the entire data dict from data.data, or None on failure.
|
||||
"""
|
||||
if self.cache.get_cache(secret_name) is not None:
|
||||
return self.cache.get_cache(secret_name)
|
||||
sync_client = _get_httpx_client()
|
||||
try:
|
||||
# For KV v2: /v1/<mount>/data/<path>
|
||||
url = self.get_url(secret_name)
|
||||
|
||||
response = sync_client.get(url, headers=self._get_request_headers())
|
||||
response.raise_for_status()
|
||||
|
||||
# For KV v2, the secret is in response.json()["data"]["data"]
|
||||
json_resp = response.json()
|
||||
_value = self._get_secret_value_from_json_response(json_resp)
|
||||
self.cache.set_cache(secret_name, _value)
|
||||
return _value
|
||||
|
||||
except Exception as e:
|
||||
verbose_logger.exception(f"Error reading secret from Hashicorp Vault: {e}")
|
||||
return None
|
||||
|
||||
async def async_write_secret(
|
||||
self,
|
||||
secret_name: str,
|
||||
secret_value: str,
|
||||
description: Optional[str] = None,
|
||||
optional_params: Optional[dict] = None,
|
||||
timeout: Optional[Union[float, httpx.Timeout]] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Writes a secret to Vault KV v2 using an async HTTPX client.
|
||||
|
||||
Args:
|
||||
secret_name: Path inside the KV mount (e.g., 'myapp/config')
|
||||
secret_value: Value to store
|
||||
description: Optional description for the secret
|
||||
optional_params: Additional parameters to include in the secret data
|
||||
timeout: Request timeout
|
||||
|
||||
Returns:
|
||||
dict: Response containing status and details of the operation
|
||||
"""
|
||||
async_client = get_async_httpx_client(
|
||||
llm_provider=httpxSpecialProvider.SecretManager,
|
||||
params={"timeout": timeout},
|
||||
)
|
||||
|
||||
try:
|
||||
url = self.get_url(secret_name)
|
||||
|
||||
# Prepare the secret data
|
||||
data = {"data": {"key": secret_value}}
|
||||
|
||||
if description:
|
||||
data["data"]["description"] = description
|
||||
|
||||
response = await async_client.post(
|
||||
url=url, headers=self._get_request_headers(), json=data
|
||||
)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
except Exception as e:
|
||||
verbose_logger.exception(f"Error writing secret to Hashicorp Vault: {e}")
|
||||
return {"status": "error", "message": str(e)}
|
||||
|
||||
async def async_rotate_secret(
|
||||
self,
|
||||
current_secret_name: str,
|
||||
new_secret_name: str,
|
||||
new_secret_value: str,
|
||||
optional_params: Dict | None = None,
|
||||
timeout: float | httpx.Timeout | None = None,
|
||||
) -> Dict:
|
||||
raise NotImplementedError("Hashicorp does not support secret rotation")
|
||||
|
||||
async def async_delete_secret(
|
||||
self,
|
||||
secret_name: str,
|
||||
recovery_window_in_days: Optional[int] = 7,
|
||||
optional_params: Optional[dict] = None,
|
||||
timeout: Optional[Union[float, httpx.Timeout]] = None,
|
||||
) -> dict:
|
||||
"""
|
||||
Async function to delete a secret from Hashicorp Vault.
|
||||
In KV v2, this marks the latest version of the secret as deleted.
|
||||
|
||||
Args:
|
||||
secret_name: Name of the secret to delete
|
||||
recovery_window_in_days: Not used for Vault (Vault handles this internally)
|
||||
optional_params: Additional parameters specific to the secret manager
|
||||
timeout: Request timeout
|
||||
|
||||
Returns:
|
||||
dict: Response containing status and details of the operation
|
||||
"""
|
||||
async_client = get_async_httpx_client(
|
||||
llm_provider=httpxSpecialProvider.SecretManager,
|
||||
params={"timeout": timeout},
|
||||
)
|
||||
|
||||
try:
|
||||
# For KV v2 delete: /v1/<mount>/data/<path>
|
||||
url = self.get_url(secret_name)
|
||||
|
||||
response = await async_client.delete(
|
||||
url=url, headers=self._get_request_headers()
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
# Clear the cache for this secret
|
||||
self.cache.delete_cache(secret_name)
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"message": f"Secret {secret_name} deleted successfully",
|
||||
}
|
||||
except Exception as e:
|
||||
verbose_logger.exception(f"Error deleting secret from Hashicorp Vault: {e}")
|
||||
return {"status": "error", "message": str(e)}
|
||||
|
||||
def _get_secret_value_from_json_response(
|
||||
self, json_resp: Optional[dict]
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
Get the secret value from the JSON response
|
||||
|
||||
Json response from hashicorp vault is of the form:
|
||||
|
||||
{
|
||||
"request_id":"036ba77c-018b-31dd-047b-323bcd0cd332",
|
||||
"lease_id":"",
|
||||
"renewable":false,
|
||||
"lease_duration":0,
|
||||
"data":
|
||||
{"data":
|
||||
{"key":"Vault Is The Way"},
|
||||
"metadata":{"created_time":"2025-01-01T22:13:50.93942388Z","custom_metadata":null,"deletion_time":"","destroyed":false,"version":1}
|
||||
},
|
||||
"wrap_info":null,
|
||||
"warnings":null,
|
||||
"auth":null,
|
||||
"mount_type":"kv"
|
||||
}
|
||||
|
||||
Note: LiteLLM assumes that all secrets are stored as under the key "key"
|
||||
"""
|
||||
if json_resp is None:
|
||||
return None
|
||||
return json_resp.get("data", {}).get("data", {}).get("key", None)
|
||||
@@ -0,0 +1,354 @@
|
||||
import ast
|
||||
import base64
|
||||
import binascii
|
||||
import os
|
||||
import traceback
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
import httpx
|
||||
|
||||
import litellm
|
||||
from litellm._logging import print_verbose, verbose_logger
|
||||
from litellm.caching.caching import DualCache
|
||||
from litellm.llms.custom_httpx.http_handler import HTTPHandler
|
||||
from litellm.proxy._types import KeyManagementSystem
|
||||
|
||||
oidc_cache = DualCache()
|
||||
|
||||
|
||||
######### Secret Manager ############################
|
||||
# checks if user has passed in a secret manager client
|
||||
# if passed in then checks the secret there
|
||||
def _is_base64(s):
|
||||
try:
|
||||
return base64.b64encode(base64.b64decode(s)).decode() == s
|
||||
except binascii.Error:
|
||||
return False
|
||||
|
||||
|
||||
def str_to_bool(value: Optional[str]) -> Optional[bool]:
|
||||
"""
|
||||
Converts a string to a boolean if it's a recognized boolean string.
|
||||
Returns None if the string is not a recognized boolean value.
|
||||
|
||||
:param value: The string to be checked.
|
||||
:return: True or False if the string is a recognized boolean, otherwise None.
|
||||
"""
|
||||
if value is None:
|
||||
return None
|
||||
|
||||
true_values = {"true"}
|
||||
false_values = {"false"}
|
||||
|
||||
value_lower = value.strip().lower()
|
||||
|
||||
if value_lower in true_values:
|
||||
return True
|
||||
elif value_lower in false_values:
|
||||
return False
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
def get_secret_str(
|
||||
secret_name: str,
|
||||
default_value: Optional[Union[str, bool]] = None,
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
Guarantees response from 'get_secret' is either string or none. Used for fixing linting errors.
|
||||
"""
|
||||
value = get_secret(secret_name=secret_name, default_value=default_value)
|
||||
if value is not None and not isinstance(value, str):
|
||||
return None
|
||||
|
||||
return value
|
||||
|
||||
|
||||
def get_secret_bool(
|
||||
secret_name: str,
|
||||
default_value: Optional[bool] = None,
|
||||
) -> Optional[bool]:
|
||||
"""
|
||||
Guarantees response from 'get_secret' is either boolean or none. Used for fixing linting errors.
|
||||
|
||||
Args:
|
||||
secret_name: The name of the secret to get.
|
||||
default_value: The default value to return if the secret is not found.
|
||||
|
||||
Returns:
|
||||
The secret value as a boolean or None if the secret is not found.
|
||||
"""
|
||||
_secret_value = get_secret(secret_name, default_value)
|
||||
if _secret_value is None:
|
||||
return None
|
||||
elif isinstance(_secret_value, bool):
|
||||
return _secret_value
|
||||
else:
|
||||
return str_to_bool(_secret_value)
|
||||
|
||||
|
||||
def get_secret( # noqa: PLR0915
|
||||
secret_name: str,
|
||||
default_value: Optional[Union[str, bool]] = None,
|
||||
):
|
||||
key_management_system = litellm._key_management_system
|
||||
key_management_settings = litellm._key_management_settings
|
||||
secret = None
|
||||
|
||||
if secret_name.startswith("os.environ/"):
|
||||
secret_name = secret_name.replace("os.environ/", "")
|
||||
|
||||
# Example: oidc/google/https://bedrock-runtime.us-east-1.amazonaws.com/model/stability.stable-diffusion-xl-v1/invoke
|
||||
if secret_name.startswith("oidc/"):
|
||||
secret_name_split = secret_name.replace("oidc/", "")
|
||||
oidc_provider, oidc_aud = secret_name_split.split("/", 1)
|
||||
# TODO: Add caching for HTTP requests
|
||||
if oidc_provider == "google":
|
||||
oidc_token = oidc_cache.get_cache(key=secret_name)
|
||||
if oidc_token is not None:
|
||||
return oidc_token
|
||||
|
||||
oidc_client = HTTPHandler(timeout=httpx.Timeout(timeout=600.0, connect=5.0))
|
||||
# https://cloud.google.com/compute/docs/instances/verifying-instance-identity#request_signature
|
||||
response = oidc_client.get(
|
||||
"http://metadata.google.internal/computeMetadata/v1/instance/service-accounts/default/identity",
|
||||
params={"audience": oidc_aud},
|
||||
headers={"Metadata-Flavor": "Google"},
|
||||
)
|
||||
if response.status_code == 200:
|
||||
oidc_token = response.text
|
||||
oidc_cache.set_cache(key=secret_name, value=oidc_token, ttl=3600 - 60)
|
||||
return oidc_token
|
||||
else:
|
||||
raise ValueError("Google OIDC provider failed")
|
||||
elif oidc_provider == "circleci":
|
||||
# https://circleci.com/docs/openid-connect-tokens/
|
||||
env_secret = os.getenv("CIRCLE_OIDC_TOKEN")
|
||||
if env_secret is None:
|
||||
raise ValueError("CIRCLE_OIDC_TOKEN not found in environment")
|
||||
return env_secret
|
||||
elif oidc_provider == "circleci_v2":
|
||||
# https://circleci.com/docs/openid-connect-tokens/
|
||||
env_secret = os.getenv("CIRCLE_OIDC_TOKEN_V2")
|
||||
if env_secret is None:
|
||||
raise ValueError("CIRCLE_OIDC_TOKEN_V2 not found in environment")
|
||||
return env_secret
|
||||
elif oidc_provider == "github":
|
||||
# https://docs.github.com/en/actions/deployment/security-hardening-your-deployments/configuring-openid-connect-in-cloud-providers#using-custom-actions
|
||||
actions_id_token_request_url = os.getenv("ACTIONS_ID_TOKEN_REQUEST_URL")
|
||||
actions_id_token_request_token = os.getenv("ACTIONS_ID_TOKEN_REQUEST_TOKEN")
|
||||
if (
|
||||
actions_id_token_request_url is None
|
||||
or actions_id_token_request_token is None
|
||||
):
|
||||
raise ValueError(
|
||||
"ACTIONS_ID_TOKEN_REQUEST_URL or ACTIONS_ID_TOKEN_REQUEST_TOKEN not found in environment"
|
||||
)
|
||||
|
||||
oidc_token = oidc_cache.get_cache(key=secret_name)
|
||||
if oidc_token is not None:
|
||||
return oidc_token
|
||||
|
||||
oidc_client = HTTPHandler(timeout=httpx.Timeout(timeout=600.0, connect=5.0))
|
||||
response = oidc_client.get(
|
||||
actions_id_token_request_url,
|
||||
params={"audience": oidc_aud},
|
||||
headers={
|
||||
"Authorization": f"Bearer {actions_id_token_request_token}",
|
||||
"Accept": "application/json; api-version=2.0",
|
||||
},
|
||||
)
|
||||
if response.status_code == 200:
|
||||
oidc_token = response.json().get("value", None)
|
||||
oidc_cache.set_cache(key=secret_name, value=oidc_token, ttl=300 - 5)
|
||||
return oidc_token
|
||||
else:
|
||||
raise ValueError("Github OIDC provider failed")
|
||||
elif oidc_provider == "azure":
|
||||
# https://azure.github.io/azure-workload-identity/docs/quick-start.html
|
||||
azure_federated_token_file = os.getenv("AZURE_FEDERATED_TOKEN_FILE")
|
||||
if azure_federated_token_file is None:
|
||||
raise ValueError("AZURE_FEDERATED_TOKEN_FILE not found in environment")
|
||||
with open(azure_federated_token_file, "r") as f:
|
||||
oidc_token = f.read()
|
||||
return oidc_token
|
||||
elif oidc_provider == "file":
|
||||
# Load token from a file
|
||||
with open(oidc_aud, "r") as f:
|
||||
oidc_token = f.read()
|
||||
return oidc_token
|
||||
elif oidc_provider == "env":
|
||||
# Load token directly from an environment variable
|
||||
oidc_token = os.getenv(oidc_aud)
|
||||
if oidc_token is None:
|
||||
raise ValueError(f"Environment variable {oidc_aud} not found")
|
||||
return oidc_token
|
||||
elif oidc_provider == "env_path":
|
||||
# Load token from a file path specified in an environment variable
|
||||
token_file_path = os.getenv(oidc_aud)
|
||||
if token_file_path is None:
|
||||
raise ValueError(f"Environment variable {oidc_aud} not found")
|
||||
with open(token_file_path, "r") as f:
|
||||
oidc_token = f.read()
|
||||
return oidc_token
|
||||
else:
|
||||
raise ValueError("Unsupported OIDC provider")
|
||||
|
||||
try:
|
||||
if (
|
||||
_should_read_secret_from_secret_manager()
|
||||
and litellm.secret_manager_client is not None
|
||||
):
|
||||
try:
|
||||
client = litellm.secret_manager_client
|
||||
key_manager = "local"
|
||||
if key_management_system is not None:
|
||||
key_manager = key_management_system.value
|
||||
|
||||
if key_management_settings is not None:
|
||||
if (
|
||||
key_management_settings.hosted_keys is not None
|
||||
and secret_name not in key_management_settings.hosted_keys
|
||||
): # allow user to specify which keys to check in hosted key manager
|
||||
key_manager = "local"
|
||||
|
||||
if (
|
||||
key_manager == KeyManagementSystem.AZURE_KEY_VAULT.value
|
||||
or type(client).__module__ + "." + type(client).__name__
|
||||
== "azure.keyvault.secrets._client.SecretClient"
|
||||
): # support Azure Secret Client - from azure.keyvault.secrets import SecretClient
|
||||
secret = client.get_secret(secret_name).value
|
||||
elif (
|
||||
key_manager == KeyManagementSystem.GOOGLE_KMS.value
|
||||
or client.__class__.__name__ == "KeyManagementServiceClient"
|
||||
):
|
||||
encrypted_secret: Any = os.getenv(secret_name)
|
||||
if encrypted_secret is None:
|
||||
raise ValueError(
|
||||
"Google KMS requires the encrypted secret to be in the environment!"
|
||||
)
|
||||
b64_flag = _is_base64(encrypted_secret)
|
||||
if b64_flag is True: # if passed in as encoded b64 string
|
||||
encrypted_secret = base64.b64decode(encrypted_secret)
|
||||
ciphertext = encrypted_secret
|
||||
else:
|
||||
raise ValueError(
|
||||
"Google KMS requires the encrypted secret to be encoded in base64"
|
||||
) # fix for this vulnerability https://huntr.com/bounties/ae623c2f-b64b-4245-9ed4-f13a0a5824ce
|
||||
response = client.decrypt(
|
||||
request={
|
||||
"name": litellm._google_kms_resource_name,
|
||||
"ciphertext": ciphertext,
|
||||
}
|
||||
)
|
||||
secret = response.plaintext.decode(
|
||||
"utf-8"
|
||||
) # assumes the original value was encoded with utf-8
|
||||
elif key_manager == KeyManagementSystem.AWS_KMS.value:
|
||||
"""
|
||||
Only check the tokens which start with 'aws_kms/'. This prevents latency impact caused by checking all keys.
|
||||
"""
|
||||
encrypted_value = os.getenv(secret_name, None)
|
||||
if encrypted_value is None:
|
||||
raise Exception(
|
||||
"AWS KMS - Encrypted Value of Key={} is None".format(
|
||||
secret_name
|
||||
)
|
||||
)
|
||||
# Decode the base64 encoded ciphertext
|
||||
ciphertext_blob = base64.b64decode(encrypted_value)
|
||||
|
||||
# Set up the parameters for the decrypt call
|
||||
params = {"CiphertextBlob": ciphertext_blob}
|
||||
# Perform the decryption
|
||||
response = client.decrypt(**params)
|
||||
|
||||
# Extract and decode the plaintext
|
||||
plaintext = response["Plaintext"]
|
||||
secret = plaintext.decode("utf-8")
|
||||
if isinstance(secret, str):
|
||||
secret = secret.strip()
|
||||
elif key_manager == KeyManagementSystem.AWS_SECRET_MANAGER.value:
|
||||
from litellm.secret_managers.aws_secret_manager_v2 import (
|
||||
AWSSecretsManagerV2,
|
||||
)
|
||||
|
||||
if isinstance(client, AWSSecretsManagerV2):
|
||||
secret = client.sync_read_secret(
|
||||
secret_name=secret_name,
|
||||
primary_secret_name=key_management_settings.primary_secret_name,
|
||||
)
|
||||
print_verbose(f"get_secret_value_response: {secret}")
|
||||
elif key_manager == KeyManagementSystem.GOOGLE_SECRET_MANAGER.value:
|
||||
try:
|
||||
secret = client.get_secret_from_google_secret_manager(
|
||||
secret_name
|
||||
)
|
||||
print_verbose(f"secret from google secret manager: {secret}")
|
||||
if secret is None:
|
||||
raise ValueError(
|
||||
f"No secret found in Google Secret Manager for {secret_name}"
|
||||
)
|
||||
except Exception as e:
|
||||
print_verbose(f"An error occurred - {str(e)}")
|
||||
raise e
|
||||
elif key_manager == KeyManagementSystem.HASHICORP_VAULT.value:
|
||||
try:
|
||||
secret = client.sync_read_secret(secret_name=secret_name)
|
||||
if secret is None:
|
||||
raise ValueError(
|
||||
f"No secret found in Hashicorp Secret Manager for {secret_name}"
|
||||
)
|
||||
except Exception as e:
|
||||
print_verbose(f"An error occurred - {str(e)}")
|
||||
raise e
|
||||
elif key_manager == "local":
|
||||
secret = os.getenv(secret_name)
|
||||
else: # assume the default is infisicial client
|
||||
secret = client.get_secret(secret_name).secret_value
|
||||
except Exception as e: # check if it's in os.environ
|
||||
verbose_logger.error(
|
||||
f"Defaulting to os.environ value for key={secret_name}. An exception occurred - {str(e)}.\n\n{traceback.format_exc()}"
|
||||
)
|
||||
secret = os.getenv(secret_name)
|
||||
try:
|
||||
if isinstance(secret, str):
|
||||
secret_value_as_bool = ast.literal_eval(secret)
|
||||
if isinstance(secret_value_as_bool, bool):
|
||||
return secret_value_as_bool
|
||||
else:
|
||||
return secret
|
||||
except Exception:
|
||||
return secret
|
||||
else:
|
||||
secret = os.environ.get(secret_name)
|
||||
secret_value_as_bool = str_to_bool(secret) if secret is not None else None
|
||||
if secret_value_as_bool is not None and isinstance(
|
||||
secret_value_as_bool, bool
|
||||
):
|
||||
return secret_value_as_bool
|
||||
else:
|
||||
return secret
|
||||
except Exception as e:
|
||||
if default_value is not None:
|
||||
return default_value
|
||||
else:
|
||||
raise e
|
||||
|
||||
|
||||
def _should_read_secret_from_secret_manager() -> bool:
|
||||
"""
|
||||
Returns True if the secret manager should be used to read the secret, False otherwise
|
||||
|
||||
- If the secret manager client is not set, return False
|
||||
- If the `_key_management_settings` access mode is "read_only" or "read_and_write", return True
|
||||
- Otherwise, return False
|
||||
"""
|
||||
if litellm.secret_manager_client is not None:
|
||||
if litellm._key_management_settings is not None:
|
||||
if (
|
||||
litellm._key_management_settings.access_mode == "read_only"
|
||||
or litellm._key_management_settings.access_mode == "read_and_write"
|
||||
):
|
||||
return True
|
||||
return False
|
||||
Reference in New Issue
Block a user