structure saas with tools
This commit is contained in:
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,627 @@
|
||||
import hashlib
|
||||
import json
|
||||
import os
|
||||
from datetime import datetime
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, cast, get_args
|
||||
|
||||
import httpx
|
||||
from pydantic import BaseModel
|
||||
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.caching.caching import DualCache
|
||||
from litellm.constants import BEDROCK_INVOKE_PROVIDERS_LITERAL, BEDROCK_MAX_POLICY_SIZE
|
||||
from litellm.litellm_core_utils.dd_tracing import tracer
|
||||
from litellm.secret_managers.main import get_secret
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from botocore.awsrequest import AWSPreparedRequest
|
||||
from botocore.credentials import Credentials
|
||||
else:
|
||||
Credentials = Any
|
||||
AWSPreparedRequest = Any
|
||||
|
||||
|
||||
class Boto3CredentialsInfo(BaseModel):
|
||||
credentials: Credentials
|
||||
aws_region_name: str
|
||||
aws_bedrock_runtime_endpoint: Optional[str]
|
||||
|
||||
|
||||
class AwsAuthError(Exception):
|
||||
def __init__(self, status_code, message):
|
||||
self.status_code = status_code
|
||||
self.message = message
|
||||
self.request = httpx.Request(
|
||||
method="POST", url="https://us-west-2.console.aws.amazon.com/bedrock"
|
||||
)
|
||||
self.response = httpx.Response(status_code=status_code, request=self.request)
|
||||
super().__init__(
|
||||
self.message
|
||||
) # Call the base class constructor with the parameters it needs
|
||||
|
||||
|
||||
class BaseAWSLLM:
|
||||
def __init__(self) -> None:
|
||||
self.iam_cache = DualCache()
|
||||
super().__init__()
|
||||
self.aws_authentication_params = [
|
||||
"aws_access_key_id",
|
||||
"aws_secret_access_key",
|
||||
"aws_session_token",
|
||||
"aws_region_name",
|
||||
"aws_session_name",
|
||||
"aws_profile_name",
|
||||
"aws_role_name",
|
||||
"aws_web_identity_token",
|
||||
"aws_sts_endpoint",
|
||||
"aws_bedrock_runtime_endpoint",
|
||||
]
|
||||
|
||||
def get_cache_key(self, credential_args: Dict[str, Optional[str]]) -> str:
|
||||
"""
|
||||
Generate a unique cache key based on the credential arguments.
|
||||
"""
|
||||
# Convert credential arguments to a JSON string and hash it to create a unique key
|
||||
credential_str = json.dumps(credential_args, sort_keys=True)
|
||||
return hashlib.sha256(credential_str.encode()).hexdigest()
|
||||
|
||||
@tracer.wrap()
|
||||
def get_credentials(
|
||||
self,
|
||||
aws_access_key_id: Optional[str] = None,
|
||||
aws_secret_access_key: Optional[str] = None,
|
||||
aws_session_token: Optional[str] = None,
|
||||
aws_region_name: Optional[str] = None,
|
||||
aws_session_name: Optional[str] = None,
|
||||
aws_profile_name: Optional[str] = None,
|
||||
aws_role_name: Optional[str] = None,
|
||||
aws_web_identity_token: Optional[str] = None,
|
||||
aws_sts_endpoint: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
Return a boto3.Credentials object
|
||||
"""
|
||||
## CHECK IS 'os.environ/' passed in
|
||||
params_to_check: List[Optional[str]] = [
|
||||
aws_access_key_id,
|
||||
aws_secret_access_key,
|
||||
aws_session_token,
|
||||
aws_region_name,
|
||||
aws_session_name,
|
||||
aws_profile_name,
|
||||
aws_role_name,
|
||||
aws_web_identity_token,
|
||||
aws_sts_endpoint,
|
||||
]
|
||||
|
||||
# Iterate over parameters and update if needed
|
||||
for i, param in enumerate(params_to_check):
|
||||
if param and param.startswith("os.environ/"):
|
||||
_v = get_secret(param)
|
||||
if _v is not None and isinstance(_v, str):
|
||||
params_to_check[i] = _v
|
||||
elif param is None: # check if uppercase value in env
|
||||
key = self.aws_authentication_params[i]
|
||||
if key.upper() in os.environ:
|
||||
params_to_check[i] = os.getenv(key)
|
||||
|
||||
# Assign updated values back to parameters
|
||||
(
|
||||
aws_access_key_id,
|
||||
aws_secret_access_key,
|
||||
aws_session_token,
|
||||
aws_region_name,
|
||||
aws_session_name,
|
||||
aws_profile_name,
|
||||
aws_role_name,
|
||||
aws_web_identity_token,
|
||||
aws_sts_endpoint,
|
||||
) = params_to_check
|
||||
|
||||
verbose_logger.debug(
|
||||
"in get credentials\n"
|
||||
"aws_access_key_id=%s\n"
|
||||
"aws_secret_access_key=%s\n"
|
||||
"aws_session_token=%s\n"
|
||||
"aws_region_name=%s\n"
|
||||
"aws_session_name=%s\n"
|
||||
"aws_profile_name=%s\n"
|
||||
"aws_role_name=%s\n"
|
||||
"aws_web_identity_token=%s\n"
|
||||
"aws_sts_endpoint=%s",
|
||||
aws_access_key_id,
|
||||
aws_secret_access_key,
|
||||
aws_session_token,
|
||||
aws_region_name,
|
||||
aws_session_name,
|
||||
aws_profile_name,
|
||||
aws_role_name,
|
||||
aws_web_identity_token,
|
||||
aws_sts_endpoint,
|
||||
)
|
||||
|
||||
# create cache key for non-expiring auth flows
|
||||
args = {k: v for k, v in locals().items() if k.startswith("aws_")}
|
||||
|
||||
cache_key = self.get_cache_key(args)
|
||||
_cached_credentials = self.iam_cache.get_cache(cache_key)
|
||||
if _cached_credentials:
|
||||
return _cached_credentials
|
||||
|
||||
#########################################################
|
||||
# Handle diff boto3 auth flows
|
||||
# for each helper
|
||||
# Return:
|
||||
# Credentials - boto3.Credentials
|
||||
# cache ttl - Optional[int]. If None, the credentials are not cached. Some auth flows have no expiry time.
|
||||
#########################################################
|
||||
if (
|
||||
aws_web_identity_token is not None
|
||||
and aws_role_name is not None
|
||||
and aws_session_name is not None
|
||||
):
|
||||
credentials, _cache_ttl = self._auth_with_web_identity_token(
|
||||
aws_web_identity_token=aws_web_identity_token,
|
||||
aws_role_name=aws_role_name,
|
||||
aws_session_name=aws_session_name,
|
||||
aws_region_name=aws_region_name,
|
||||
aws_sts_endpoint=aws_sts_endpoint,
|
||||
)
|
||||
elif aws_role_name is not None and aws_session_name is not None:
|
||||
credentials, _cache_ttl = self._auth_with_aws_role(
|
||||
aws_access_key_id=aws_access_key_id,
|
||||
aws_secret_access_key=aws_secret_access_key,
|
||||
aws_role_name=aws_role_name,
|
||||
aws_session_name=aws_session_name,
|
||||
)
|
||||
|
||||
elif aws_profile_name is not None: ### CHECK SESSION ###
|
||||
credentials, _cache_ttl = self._auth_with_aws_profile(aws_profile_name)
|
||||
elif (
|
||||
aws_access_key_id is not None
|
||||
and aws_secret_access_key is not None
|
||||
and aws_session_token is not None
|
||||
):
|
||||
credentials, _cache_ttl = self._auth_with_aws_session_token(
|
||||
aws_access_key_id=aws_access_key_id,
|
||||
aws_secret_access_key=aws_secret_access_key,
|
||||
aws_session_token=aws_session_token,
|
||||
)
|
||||
elif (
|
||||
aws_access_key_id is not None
|
||||
and aws_secret_access_key is not None
|
||||
and aws_region_name is not None
|
||||
):
|
||||
credentials, _cache_ttl = self._auth_with_access_key_and_secret_key(
|
||||
aws_access_key_id=aws_access_key_id,
|
||||
aws_secret_access_key=aws_secret_access_key,
|
||||
aws_region_name=aws_region_name,
|
||||
)
|
||||
else:
|
||||
credentials, _cache_ttl = self._auth_with_env_vars()
|
||||
|
||||
self.iam_cache.set_cache(cache_key, credentials, ttl=_cache_ttl)
|
||||
return credentials
|
||||
|
||||
def _get_aws_region_from_model_arn(self, model: Optional[str]) -> Optional[str]:
|
||||
try:
|
||||
# First check if the string contains the expected prefix
|
||||
if not isinstance(model, str) or "arn:aws:bedrock" not in model:
|
||||
return None
|
||||
|
||||
# Split the ARN and check if we have enough parts
|
||||
parts = model.split(":")
|
||||
if len(parts) < 4:
|
||||
return None
|
||||
|
||||
# Get the region from the correct position
|
||||
region = parts[3]
|
||||
if not region: # Check if region is empty
|
||||
return None
|
||||
|
||||
return region
|
||||
except Exception:
|
||||
# Catch any unexpected errors and return None
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _get_provider_from_model_path(
|
||||
model_path: str,
|
||||
) -> Optional[BEDROCK_INVOKE_PROVIDERS_LITERAL]:
|
||||
"""
|
||||
Helper function to get the provider from a model path with format: provider/model-name
|
||||
|
||||
Args:
|
||||
model_path (str): The model path (e.g., 'llama/arn:aws:bedrock:us-east-1:086734376398:imported-model/r4c4kewx2s0n' or 'anthropic/model-name')
|
||||
|
||||
Returns:
|
||||
Optional[str]: The provider name, or None if no valid provider found
|
||||
"""
|
||||
parts = model_path.split("/")
|
||||
if len(parts) >= 1:
|
||||
provider = parts[0]
|
||||
if provider in get_args(BEDROCK_INVOKE_PROVIDERS_LITERAL):
|
||||
return cast(BEDROCK_INVOKE_PROVIDERS_LITERAL, provider)
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def get_bedrock_invoke_provider(
|
||||
model: str,
|
||||
) -> Optional[BEDROCK_INVOKE_PROVIDERS_LITERAL]:
|
||||
"""
|
||||
Helper function to get the bedrock provider from the model
|
||||
|
||||
handles 3 scenarions:
|
||||
1. model=invoke/anthropic.claude-3-5-sonnet-20240620-v1:0 -> Returns `anthropic`
|
||||
2. model=anthropic.claude-3-5-sonnet-20240620-v1:0 -> Returns `anthropic`
|
||||
3. model=llama/arn:aws:bedrock:us-east-1:086734376398:imported-model/r4c4kewx2s0n -> Returns `llama`
|
||||
4. model=us.amazon.nova-pro-v1:0 -> Returns `nova`
|
||||
"""
|
||||
if model.startswith("invoke/"):
|
||||
model = model.replace("invoke/", "", 1)
|
||||
|
||||
_split_model = model.split(".")[0]
|
||||
if _split_model in get_args(BEDROCK_INVOKE_PROVIDERS_LITERAL):
|
||||
return cast(BEDROCK_INVOKE_PROVIDERS_LITERAL, _split_model)
|
||||
|
||||
# If not a known provider, check for pattern with two slashes
|
||||
provider = BaseAWSLLM._get_provider_from_model_path(model)
|
||||
if provider is not None:
|
||||
return provider
|
||||
|
||||
# check if provider == "nova"
|
||||
if "nova" in model:
|
||||
return "nova"
|
||||
else:
|
||||
for provider in get_args(BEDROCK_INVOKE_PROVIDERS_LITERAL):
|
||||
if provider in model:
|
||||
return provider
|
||||
return None
|
||||
|
||||
def _get_aws_region_name(
|
||||
self,
|
||||
optional_params: dict,
|
||||
model: Optional[str] = None,
|
||||
model_id: Optional[str] = None,
|
||||
) -> str:
|
||||
"""
|
||||
Get the AWS region name from the environment variables.
|
||||
|
||||
Parameters:
|
||||
optional_params (dict): Optional parameters for the model call
|
||||
model (str): The model name
|
||||
model_id (str): The model ID. This is the ARN of the model, if passed in as a separate param.
|
||||
|
||||
Returns:
|
||||
str: The AWS region name
|
||||
"""
|
||||
aws_region_name = optional_params.get("aws_region_name", None)
|
||||
### SET REGION NAME ###
|
||||
if aws_region_name is None:
|
||||
# check model arn #
|
||||
if model_id is not None:
|
||||
aws_region_name = self._get_aws_region_from_model_arn(model_id)
|
||||
else:
|
||||
aws_region_name = self._get_aws_region_from_model_arn(model)
|
||||
# check env #
|
||||
litellm_aws_region_name = get_secret("AWS_REGION_NAME", None)
|
||||
|
||||
if (
|
||||
aws_region_name is None
|
||||
and litellm_aws_region_name is not None
|
||||
and isinstance(litellm_aws_region_name, str)
|
||||
):
|
||||
aws_region_name = litellm_aws_region_name
|
||||
|
||||
standard_aws_region_name = get_secret("AWS_REGION", None)
|
||||
if (
|
||||
aws_region_name is None
|
||||
and standard_aws_region_name is not None
|
||||
and isinstance(standard_aws_region_name, str)
|
||||
):
|
||||
aws_region_name = standard_aws_region_name
|
||||
|
||||
if aws_region_name is None:
|
||||
aws_region_name = "us-west-2"
|
||||
|
||||
return aws_region_name
|
||||
|
||||
@tracer.wrap()
|
||||
def _auth_with_web_identity_token(
|
||||
self,
|
||||
aws_web_identity_token: str,
|
||||
aws_role_name: str,
|
||||
aws_session_name: str,
|
||||
aws_region_name: Optional[str],
|
||||
aws_sts_endpoint: Optional[str],
|
||||
) -> Tuple[Credentials, Optional[int]]:
|
||||
"""
|
||||
Authenticate with AWS Web Identity Token
|
||||
"""
|
||||
import boto3
|
||||
|
||||
verbose_logger.debug(
|
||||
f"IN Web Identity Token: {aws_web_identity_token} | Role Name: {aws_role_name} | Session Name: {aws_session_name}"
|
||||
)
|
||||
|
||||
if aws_sts_endpoint is None:
|
||||
sts_endpoint = f"https://sts.{aws_region_name}.amazonaws.com"
|
||||
else:
|
||||
sts_endpoint = aws_sts_endpoint
|
||||
|
||||
oidc_token = get_secret(aws_web_identity_token)
|
||||
|
||||
if oidc_token is None:
|
||||
raise AwsAuthError(
|
||||
message="OIDC token could not be retrieved from secret manager.",
|
||||
status_code=401,
|
||||
)
|
||||
|
||||
with tracer.trace("boto3.client(sts)"):
|
||||
sts_client = boto3.client(
|
||||
"sts",
|
||||
region_name=aws_region_name,
|
||||
endpoint_url=sts_endpoint,
|
||||
)
|
||||
|
||||
# https://docs.aws.amazon.com/STS/latest/APIReference/API_AssumeRoleWithWebIdentity.html
|
||||
# https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sts/client/assume_role_with_web_identity.html
|
||||
sts_response = sts_client.assume_role_with_web_identity(
|
||||
RoleArn=aws_role_name,
|
||||
RoleSessionName=aws_session_name,
|
||||
WebIdentityToken=oidc_token,
|
||||
DurationSeconds=3600,
|
||||
Policy='{"Version":"2012-10-17","Statement":[{"Sid":"BedrockLiteLLM","Effect":"Allow","Action":["bedrock:InvokeModel","bedrock:InvokeModelWithResponseStream"],"Resource":"*","Condition":{"Bool":{"aws:SecureTransport":"true"},"StringLike":{"aws:UserAgent":"litellm/*"}}}]}',
|
||||
)
|
||||
|
||||
iam_creds_dict = {
|
||||
"aws_access_key_id": sts_response["Credentials"]["AccessKeyId"],
|
||||
"aws_secret_access_key": sts_response["Credentials"]["SecretAccessKey"],
|
||||
"aws_session_token": sts_response["Credentials"]["SessionToken"],
|
||||
"region_name": aws_region_name,
|
||||
}
|
||||
|
||||
if sts_response["PackedPolicySize"] > BEDROCK_MAX_POLICY_SIZE:
|
||||
verbose_logger.warning(
|
||||
f"The policy size is greater than 75% of the allowed size, PackedPolicySize: {sts_response['PackedPolicySize']}"
|
||||
)
|
||||
|
||||
with tracer.trace("boto3.Session(**iam_creds_dict)"):
|
||||
session = boto3.Session(**iam_creds_dict)
|
||||
|
||||
iam_creds = session.get_credentials()
|
||||
return iam_creds, self._get_default_ttl_for_boto3_credentials()
|
||||
|
||||
@tracer.wrap()
|
||||
def _auth_with_aws_role(
|
||||
self,
|
||||
aws_access_key_id: Optional[str],
|
||||
aws_secret_access_key: Optional[str],
|
||||
aws_role_name: str,
|
||||
aws_session_name: str,
|
||||
) -> Tuple[Credentials, Optional[int]]:
|
||||
"""
|
||||
Authenticate with AWS Role
|
||||
"""
|
||||
import boto3
|
||||
from botocore.credentials import Credentials
|
||||
|
||||
with tracer.trace("boto3.client(sts)"):
|
||||
sts_client = boto3.client(
|
||||
"sts",
|
||||
aws_access_key_id=aws_access_key_id, # [OPTIONAL]
|
||||
aws_secret_access_key=aws_secret_access_key, # [OPTIONAL]
|
||||
)
|
||||
|
||||
sts_response = sts_client.assume_role(
|
||||
RoleArn=aws_role_name, RoleSessionName=aws_session_name
|
||||
)
|
||||
|
||||
# Extract the credentials from the response and convert to Session Credentials
|
||||
sts_credentials = sts_response["Credentials"]
|
||||
credentials = Credentials(
|
||||
access_key=sts_credentials["AccessKeyId"],
|
||||
secret_key=sts_credentials["SecretAccessKey"],
|
||||
token=sts_credentials["SessionToken"],
|
||||
)
|
||||
|
||||
sts_expiry = sts_credentials["Expiration"]
|
||||
# Convert to timezone-aware datetime for comparison
|
||||
current_time = datetime.now(sts_expiry.tzinfo)
|
||||
sts_ttl = (sts_expiry - current_time).total_seconds() - 60
|
||||
return credentials, sts_ttl
|
||||
|
||||
@tracer.wrap()
|
||||
def _auth_with_aws_profile(
|
||||
self, aws_profile_name: str
|
||||
) -> Tuple[Credentials, Optional[int]]:
|
||||
"""
|
||||
Authenticate with AWS profile
|
||||
"""
|
||||
import boto3
|
||||
|
||||
# uses auth values from AWS profile usually stored in ~/.aws/credentials
|
||||
with tracer.trace("boto3.Session(profile_name=aws_profile_name)"):
|
||||
client = boto3.Session(profile_name=aws_profile_name)
|
||||
return client.get_credentials(), None
|
||||
|
||||
@tracer.wrap()
|
||||
def _auth_with_aws_session_token(
|
||||
self,
|
||||
aws_access_key_id: str,
|
||||
aws_secret_access_key: str,
|
||||
aws_session_token: str,
|
||||
) -> Tuple[Credentials, Optional[int]]:
|
||||
"""
|
||||
Authenticate with AWS Session Token
|
||||
"""
|
||||
### CHECK FOR AWS SESSION TOKEN ###
|
||||
from botocore.credentials import Credentials
|
||||
|
||||
credentials = Credentials(
|
||||
access_key=aws_access_key_id,
|
||||
secret_key=aws_secret_access_key,
|
||||
token=aws_session_token,
|
||||
)
|
||||
|
||||
return credentials, None
|
||||
|
||||
@tracer.wrap()
|
||||
def _auth_with_access_key_and_secret_key(
|
||||
self,
|
||||
aws_access_key_id: str,
|
||||
aws_secret_access_key: str,
|
||||
aws_region_name: Optional[str],
|
||||
) -> Tuple[Credentials, Optional[int]]:
|
||||
"""
|
||||
Authenticate with AWS Access Key and Secret Key
|
||||
"""
|
||||
import boto3
|
||||
|
||||
# Check if credentials are already in cache. These credentials have no expiry time.
|
||||
with tracer.trace(
|
||||
"boto3.Session(aws_access_key_id=aws_access_key_id, aws_secret_access_key=aws_secret_access_key, region_name=aws_region_name)"
|
||||
):
|
||||
session = boto3.Session(
|
||||
aws_access_key_id=aws_access_key_id,
|
||||
aws_secret_access_key=aws_secret_access_key,
|
||||
region_name=aws_region_name,
|
||||
)
|
||||
|
||||
credentials = session.get_credentials()
|
||||
return credentials, self._get_default_ttl_for_boto3_credentials()
|
||||
|
||||
@tracer.wrap()
|
||||
def _auth_with_env_vars(self) -> Tuple[Credentials, Optional[int]]:
|
||||
"""
|
||||
Authenticate with AWS Environment Variables
|
||||
"""
|
||||
import boto3
|
||||
|
||||
with tracer.trace("boto3.Session()"):
|
||||
session = boto3.Session()
|
||||
credentials = session.get_credentials()
|
||||
return credentials, None
|
||||
|
||||
@tracer.wrap()
|
||||
def _get_default_ttl_for_boto3_credentials(self) -> int:
|
||||
"""
|
||||
Get the default TTL for boto3 credentials
|
||||
|
||||
Returns `3600-60` which is 59 minutes
|
||||
"""
|
||||
return 3600 - 60
|
||||
|
||||
def get_runtime_endpoint(
|
||||
self,
|
||||
api_base: Optional[str],
|
||||
aws_bedrock_runtime_endpoint: Optional[str],
|
||||
aws_region_name: str,
|
||||
) -> Tuple[str, str]:
|
||||
env_aws_bedrock_runtime_endpoint = get_secret("AWS_BEDROCK_RUNTIME_ENDPOINT")
|
||||
if api_base is not None:
|
||||
endpoint_url = api_base
|
||||
elif aws_bedrock_runtime_endpoint is not None and isinstance(
|
||||
aws_bedrock_runtime_endpoint, str
|
||||
):
|
||||
endpoint_url = aws_bedrock_runtime_endpoint
|
||||
elif env_aws_bedrock_runtime_endpoint and isinstance(
|
||||
env_aws_bedrock_runtime_endpoint, str
|
||||
):
|
||||
endpoint_url = env_aws_bedrock_runtime_endpoint
|
||||
else:
|
||||
endpoint_url = f"https://bedrock-runtime.{aws_region_name}.amazonaws.com"
|
||||
|
||||
# Determine proxy_endpoint_url
|
||||
if env_aws_bedrock_runtime_endpoint and isinstance(
|
||||
env_aws_bedrock_runtime_endpoint, str
|
||||
):
|
||||
proxy_endpoint_url = env_aws_bedrock_runtime_endpoint
|
||||
elif aws_bedrock_runtime_endpoint is not None and isinstance(
|
||||
aws_bedrock_runtime_endpoint, str
|
||||
):
|
||||
proxy_endpoint_url = aws_bedrock_runtime_endpoint
|
||||
else:
|
||||
proxy_endpoint_url = endpoint_url
|
||||
|
||||
return endpoint_url, proxy_endpoint_url
|
||||
|
||||
def _get_boto_credentials_from_optional_params(
|
||||
self, optional_params: dict, model: Optional[str] = None
|
||||
) -> Boto3CredentialsInfo:
|
||||
"""
|
||||
Get boto3 credentials from optional params
|
||||
|
||||
Args:
|
||||
optional_params (dict): Optional parameters for the model call
|
||||
|
||||
Returns:
|
||||
Credentials: Boto3 credentials object
|
||||
"""
|
||||
try:
|
||||
from botocore.credentials import Credentials
|
||||
except ImportError:
|
||||
raise ImportError("Missing boto3 to call bedrock. Run 'pip install boto3'.")
|
||||
## CREDENTIALS ##
|
||||
# pop aws_secret_access_key, aws_access_key_id, aws_region_name from kwargs, since completion calls fail with them
|
||||
aws_secret_access_key = optional_params.pop("aws_secret_access_key", None)
|
||||
aws_access_key_id = optional_params.pop("aws_access_key_id", None)
|
||||
aws_session_token = optional_params.pop("aws_session_token", None)
|
||||
aws_region_name = self._get_aws_region_name(optional_params, model)
|
||||
optional_params.pop("aws_region_name", None)
|
||||
aws_role_name = optional_params.pop("aws_role_name", None)
|
||||
aws_session_name = optional_params.pop("aws_session_name", None)
|
||||
aws_profile_name = optional_params.pop("aws_profile_name", None)
|
||||
aws_web_identity_token = optional_params.pop("aws_web_identity_token", None)
|
||||
aws_sts_endpoint = optional_params.pop("aws_sts_endpoint", None)
|
||||
aws_bedrock_runtime_endpoint = optional_params.pop(
|
||||
"aws_bedrock_runtime_endpoint", None
|
||||
) # https://bedrock-runtime.{region_name}.amazonaws.com
|
||||
|
||||
credentials: Credentials = self.get_credentials(
|
||||
aws_access_key_id=aws_access_key_id,
|
||||
aws_secret_access_key=aws_secret_access_key,
|
||||
aws_session_token=aws_session_token,
|
||||
aws_region_name=aws_region_name,
|
||||
aws_session_name=aws_session_name,
|
||||
aws_profile_name=aws_profile_name,
|
||||
aws_role_name=aws_role_name,
|
||||
aws_web_identity_token=aws_web_identity_token,
|
||||
aws_sts_endpoint=aws_sts_endpoint,
|
||||
)
|
||||
|
||||
return Boto3CredentialsInfo(
|
||||
credentials=credentials,
|
||||
aws_region_name=aws_region_name,
|
||||
aws_bedrock_runtime_endpoint=aws_bedrock_runtime_endpoint,
|
||||
)
|
||||
|
||||
@tracer.wrap()
|
||||
def get_request_headers(
|
||||
self,
|
||||
credentials: Credentials,
|
||||
aws_region_name: str,
|
||||
extra_headers: Optional[dict],
|
||||
endpoint_url: str,
|
||||
data: str,
|
||||
headers: dict,
|
||||
) -> AWSPreparedRequest:
|
||||
try:
|
||||
from botocore.auth import SigV4Auth
|
||||
from botocore.awsrequest import AWSRequest
|
||||
except ImportError:
|
||||
raise ImportError("Missing boto3 to call bedrock. Run 'pip install boto3'.")
|
||||
|
||||
sigv4 = SigV4Auth(credentials, "bedrock", aws_region_name)
|
||||
|
||||
request = AWSRequest(
|
||||
method="POST", url=endpoint_url, data=data, headers=headers
|
||||
)
|
||||
sigv4.add_auth(request)
|
||||
if (
|
||||
extra_headers is not None and "Authorization" in extra_headers
|
||||
): # prevent sigv4 from overwriting the auth header
|
||||
request.headers["Authorization"] = extra_headers["Authorization"]
|
||||
prepped = request.prepare()
|
||||
|
||||
return prepped
|
||||
@@ -0,0 +1,2 @@
|
||||
from .converse_handler import BedrockConverseLLM
|
||||
from .invoke_handler import BedrockLLM
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,466 @@
|
||||
import json
|
||||
import urllib
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
import httpx
|
||||
|
||||
import litellm
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObject
|
||||
from litellm.llms.custom_httpx.http_handler import (
|
||||
AsyncHTTPHandler,
|
||||
HTTPHandler,
|
||||
_get_httpx_client,
|
||||
get_async_httpx_client,
|
||||
)
|
||||
from litellm.types.utils import ModelResponse
|
||||
from litellm.utils import CustomStreamWrapper
|
||||
|
||||
from ..base_aws_llm import BaseAWSLLM, Credentials
|
||||
from ..common_utils import BedrockError
|
||||
from .invoke_handler import AWSEventStreamDecoder, MockResponseIterator, make_call
|
||||
|
||||
|
||||
def make_sync_call(
|
||||
client: Optional[HTTPHandler],
|
||||
api_base: str,
|
||||
headers: dict,
|
||||
data: str,
|
||||
model: str,
|
||||
messages: list,
|
||||
logging_obj: LiteLLMLoggingObject,
|
||||
json_mode: Optional[bool] = False,
|
||||
fake_stream: bool = False,
|
||||
):
|
||||
if client is None:
|
||||
client = _get_httpx_client() # Create a new client if none provided
|
||||
|
||||
response = client.post(
|
||||
api_base,
|
||||
headers=headers,
|
||||
data=data,
|
||||
stream=not fake_stream,
|
||||
logging_obj=logging_obj,
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
raise BedrockError(
|
||||
status_code=response.status_code, message=str(response.read())
|
||||
)
|
||||
|
||||
if fake_stream:
|
||||
model_response: (
|
||||
ModelResponse
|
||||
) = litellm.AmazonConverseConfig()._transform_response(
|
||||
model=model,
|
||||
response=response,
|
||||
model_response=litellm.ModelResponse(),
|
||||
stream=True,
|
||||
logging_obj=logging_obj,
|
||||
optional_params={},
|
||||
api_key="",
|
||||
data=data,
|
||||
messages=messages,
|
||||
encoding=litellm.encoding,
|
||||
) # type: ignore
|
||||
completion_stream: Any = MockResponseIterator(
|
||||
model_response=model_response, json_mode=json_mode
|
||||
)
|
||||
else:
|
||||
decoder = AWSEventStreamDecoder(model=model)
|
||||
completion_stream = decoder.iter_bytes(response.iter_bytes(chunk_size=1024))
|
||||
|
||||
# LOGGING
|
||||
logging_obj.post_call(
|
||||
input=messages,
|
||||
api_key="",
|
||||
original_response="first stream response received",
|
||||
additional_args={"complete_input_dict": data},
|
||||
)
|
||||
|
||||
return completion_stream
|
||||
|
||||
|
||||
class BedrockConverseLLM(BaseAWSLLM):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def encode_model_id(self, model_id: str) -> str:
|
||||
"""
|
||||
Double encode the model ID to ensure it matches the expected double-encoded format.
|
||||
Args:
|
||||
model_id (str): The model ID to encode.
|
||||
Returns:
|
||||
str: The double-encoded model ID.
|
||||
"""
|
||||
return urllib.parse.quote(model_id, safe="") # type: ignore
|
||||
|
||||
async def async_streaming(
|
||||
self,
|
||||
model: str,
|
||||
messages: list,
|
||||
api_base: str,
|
||||
model_response: ModelResponse,
|
||||
timeout: Optional[Union[float, httpx.Timeout]],
|
||||
encoding,
|
||||
logging_obj,
|
||||
stream,
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
credentials: Credentials,
|
||||
logger_fn=None,
|
||||
headers={},
|
||||
client: Optional[AsyncHTTPHandler] = None,
|
||||
fake_stream: bool = False,
|
||||
json_mode: Optional[bool] = False,
|
||||
) -> CustomStreamWrapper:
|
||||
request_data = await litellm.AmazonConverseConfig()._async_transform_request(
|
||||
model=model,
|
||||
messages=messages,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
)
|
||||
data = json.dumps(request_data)
|
||||
|
||||
prepped = self.get_request_headers(
|
||||
credentials=credentials,
|
||||
aws_region_name=litellm_params.get("aws_region_name") or "us-west-2",
|
||||
extra_headers=headers,
|
||||
endpoint_url=api_base,
|
||||
data=data,
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
## LOGGING
|
||||
logging_obj.pre_call(
|
||||
input=messages,
|
||||
api_key="",
|
||||
additional_args={
|
||||
"complete_input_dict": data,
|
||||
"api_base": api_base,
|
||||
"headers": dict(prepped.headers),
|
||||
},
|
||||
)
|
||||
|
||||
completion_stream = await make_call(
|
||||
client=client,
|
||||
api_base=api_base,
|
||||
headers=dict(prepped.headers),
|
||||
data=data,
|
||||
model=model,
|
||||
messages=messages,
|
||||
logging_obj=logging_obj,
|
||||
fake_stream=fake_stream,
|
||||
json_mode=json_mode,
|
||||
)
|
||||
streaming_response = CustomStreamWrapper(
|
||||
completion_stream=completion_stream,
|
||||
model=model,
|
||||
custom_llm_provider="bedrock",
|
||||
logging_obj=logging_obj,
|
||||
)
|
||||
return streaming_response
|
||||
|
||||
async def async_completion(
|
||||
self,
|
||||
model: str,
|
||||
messages: list,
|
||||
api_base: str,
|
||||
model_response: ModelResponse,
|
||||
timeout: Optional[Union[float, httpx.Timeout]],
|
||||
encoding,
|
||||
logging_obj: LiteLLMLoggingObject,
|
||||
stream,
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
credentials: Credentials,
|
||||
logger_fn=None,
|
||||
headers: dict = {},
|
||||
client: Optional[AsyncHTTPHandler] = None,
|
||||
) -> Union[ModelResponse, CustomStreamWrapper]:
|
||||
request_data = await litellm.AmazonConverseConfig()._async_transform_request(
|
||||
model=model,
|
||||
messages=messages,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
)
|
||||
data = json.dumps(request_data)
|
||||
|
||||
prepped = self.get_request_headers(
|
||||
credentials=credentials,
|
||||
aws_region_name=litellm_params.get("aws_region_name") or "us-west-2",
|
||||
extra_headers=headers,
|
||||
endpoint_url=api_base,
|
||||
data=data,
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
## LOGGING
|
||||
logging_obj.pre_call(
|
||||
input=messages,
|
||||
api_key="",
|
||||
additional_args={
|
||||
"complete_input_dict": data,
|
||||
"api_base": api_base,
|
||||
"headers": prepped.headers,
|
||||
},
|
||||
)
|
||||
|
||||
headers = dict(prepped.headers)
|
||||
if client is None or not isinstance(client, AsyncHTTPHandler):
|
||||
_params = {}
|
||||
if timeout is not None:
|
||||
if isinstance(timeout, float) or isinstance(timeout, int):
|
||||
timeout = httpx.Timeout(timeout)
|
||||
_params["timeout"] = timeout
|
||||
client = get_async_httpx_client(
|
||||
params=_params, llm_provider=litellm.LlmProviders.BEDROCK
|
||||
)
|
||||
else:
|
||||
client = client # type: ignore
|
||||
|
||||
try:
|
||||
response = await client.post(
|
||||
url=api_base,
|
||||
headers=headers,
|
||||
data=data,
|
||||
logging_obj=logging_obj,
|
||||
) # type: ignore
|
||||
response.raise_for_status()
|
||||
except httpx.HTTPStatusError as err:
|
||||
error_code = err.response.status_code
|
||||
raise BedrockError(status_code=error_code, message=err.response.text)
|
||||
except httpx.TimeoutException:
|
||||
raise BedrockError(status_code=408, message="Timeout error occurred.")
|
||||
|
||||
return litellm.AmazonConverseConfig()._transform_response(
|
||||
model=model,
|
||||
response=response,
|
||||
model_response=model_response,
|
||||
stream=stream if isinstance(stream, bool) else False,
|
||||
logging_obj=logging_obj,
|
||||
api_key="",
|
||||
data=data,
|
||||
messages=messages,
|
||||
optional_params=optional_params,
|
||||
encoding=encoding,
|
||||
)
|
||||
|
||||
def completion( # noqa: PLR0915
|
||||
self,
|
||||
model: str,
|
||||
messages: list,
|
||||
api_base: Optional[str],
|
||||
custom_prompt_dict: dict,
|
||||
model_response: ModelResponse,
|
||||
encoding,
|
||||
logging_obj: LiteLLMLoggingObject,
|
||||
optional_params: dict,
|
||||
acompletion: bool,
|
||||
timeout: Optional[Union[float, httpx.Timeout]],
|
||||
litellm_params: dict,
|
||||
logger_fn=None,
|
||||
extra_headers: Optional[dict] = None,
|
||||
client: Optional[Union[AsyncHTTPHandler, HTTPHandler]] = None,
|
||||
):
|
||||
## SETUP ##
|
||||
stream = optional_params.pop("stream", None)
|
||||
unencoded_model_id = optional_params.pop("model_id", None)
|
||||
fake_stream = optional_params.pop("fake_stream", False)
|
||||
json_mode = optional_params.get("json_mode", False)
|
||||
if unencoded_model_id is not None:
|
||||
modelId = self.encode_model_id(model_id=unencoded_model_id)
|
||||
else:
|
||||
modelId = self.encode_model_id(model_id=model)
|
||||
|
||||
if stream is True and "ai21" in modelId:
|
||||
fake_stream = True
|
||||
|
||||
### SET REGION NAME ###
|
||||
aws_region_name = self._get_aws_region_name(
|
||||
optional_params=optional_params,
|
||||
model=model,
|
||||
model_id=unencoded_model_id,
|
||||
)
|
||||
|
||||
## CREDENTIALS ##
|
||||
# pop aws_secret_access_key, aws_access_key_id, aws_region_name from kwargs, since completion calls fail with them
|
||||
aws_secret_access_key = optional_params.pop("aws_secret_access_key", None)
|
||||
aws_access_key_id = optional_params.pop("aws_access_key_id", None)
|
||||
aws_session_token = optional_params.pop("aws_session_token", None)
|
||||
aws_role_name = optional_params.pop("aws_role_name", None)
|
||||
aws_session_name = optional_params.pop("aws_session_name", None)
|
||||
aws_profile_name = optional_params.pop("aws_profile_name", None)
|
||||
aws_bedrock_runtime_endpoint = optional_params.pop(
|
||||
"aws_bedrock_runtime_endpoint", None
|
||||
) # https://bedrock-runtime.{region_name}.amazonaws.com
|
||||
aws_web_identity_token = optional_params.pop("aws_web_identity_token", None)
|
||||
aws_sts_endpoint = optional_params.pop("aws_sts_endpoint", None)
|
||||
optional_params.pop("aws_region_name", None)
|
||||
|
||||
litellm_params[
|
||||
"aws_region_name"
|
||||
] = aws_region_name # [DO NOT DELETE] important for async calls
|
||||
|
||||
credentials: Credentials = self.get_credentials(
|
||||
aws_access_key_id=aws_access_key_id,
|
||||
aws_secret_access_key=aws_secret_access_key,
|
||||
aws_session_token=aws_session_token,
|
||||
aws_region_name=aws_region_name,
|
||||
aws_session_name=aws_session_name,
|
||||
aws_profile_name=aws_profile_name,
|
||||
aws_role_name=aws_role_name,
|
||||
aws_web_identity_token=aws_web_identity_token,
|
||||
aws_sts_endpoint=aws_sts_endpoint,
|
||||
)
|
||||
|
||||
### SET RUNTIME ENDPOINT ###
|
||||
endpoint_url, proxy_endpoint_url = self.get_runtime_endpoint(
|
||||
api_base=api_base,
|
||||
aws_bedrock_runtime_endpoint=aws_bedrock_runtime_endpoint,
|
||||
aws_region_name=aws_region_name,
|
||||
)
|
||||
if (stream is not None and stream is True) and not fake_stream:
|
||||
endpoint_url = f"{endpoint_url}/model/{modelId}/converse-stream"
|
||||
proxy_endpoint_url = f"{proxy_endpoint_url}/model/{modelId}/converse-stream"
|
||||
else:
|
||||
endpoint_url = f"{endpoint_url}/model/{modelId}/converse"
|
||||
proxy_endpoint_url = f"{proxy_endpoint_url}/model/{modelId}/converse"
|
||||
|
||||
## COMPLETION CALL
|
||||
headers = {"Content-Type": "application/json"}
|
||||
if extra_headers is not None:
|
||||
headers = {"Content-Type": "application/json", **extra_headers}
|
||||
|
||||
### ROUTING (ASYNC, STREAMING, SYNC)
|
||||
if acompletion:
|
||||
if isinstance(client, HTTPHandler):
|
||||
client = None
|
||||
if stream is True:
|
||||
return self.async_streaming(
|
||||
model=model,
|
||||
messages=messages,
|
||||
api_base=proxy_endpoint_url,
|
||||
model_response=model_response,
|
||||
encoding=encoding,
|
||||
logging_obj=logging_obj,
|
||||
optional_params=optional_params,
|
||||
stream=True,
|
||||
litellm_params=litellm_params,
|
||||
logger_fn=logger_fn,
|
||||
headers=headers,
|
||||
timeout=timeout,
|
||||
client=client,
|
||||
json_mode=json_mode,
|
||||
fake_stream=fake_stream,
|
||||
credentials=credentials,
|
||||
) # type: ignore
|
||||
### ASYNC COMPLETION
|
||||
return self.async_completion(
|
||||
model=model,
|
||||
messages=messages,
|
||||
api_base=proxy_endpoint_url,
|
||||
model_response=model_response,
|
||||
encoding=encoding,
|
||||
logging_obj=logging_obj,
|
||||
optional_params=optional_params,
|
||||
stream=stream, # type: ignore
|
||||
litellm_params=litellm_params,
|
||||
logger_fn=logger_fn,
|
||||
headers=headers,
|
||||
timeout=timeout,
|
||||
client=client,
|
||||
credentials=credentials,
|
||||
) # type: ignore
|
||||
|
||||
## TRANSFORMATION ##
|
||||
|
||||
_data = litellm.AmazonConverseConfig()._transform_request(
|
||||
model=model,
|
||||
messages=messages,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
)
|
||||
data = json.dumps(_data)
|
||||
|
||||
prepped = self.get_request_headers(
|
||||
credentials=credentials,
|
||||
aws_region_name=aws_region_name,
|
||||
extra_headers=extra_headers,
|
||||
endpoint_url=proxy_endpoint_url,
|
||||
data=data,
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
## LOGGING
|
||||
logging_obj.pre_call(
|
||||
input=messages,
|
||||
api_key="",
|
||||
additional_args={
|
||||
"complete_input_dict": data,
|
||||
"api_base": proxy_endpoint_url,
|
||||
"headers": prepped.headers,
|
||||
},
|
||||
)
|
||||
if client is None or isinstance(client, AsyncHTTPHandler):
|
||||
_params = {}
|
||||
if timeout is not None:
|
||||
if isinstance(timeout, float) or isinstance(timeout, int):
|
||||
timeout = httpx.Timeout(timeout)
|
||||
_params["timeout"] = timeout
|
||||
client = _get_httpx_client(_params) # type: ignore
|
||||
else:
|
||||
client = client
|
||||
|
||||
if stream is not None and stream is True:
|
||||
completion_stream = make_sync_call(
|
||||
client=(
|
||||
client
|
||||
if client is not None and isinstance(client, HTTPHandler)
|
||||
else None
|
||||
),
|
||||
api_base=proxy_endpoint_url,
|
||||
headers=prepped.headers, # type: ignore
|
||||
data=data,
|
||||
model=model,
|
||||
messages=messages,
|
||||
logging_obj=logging_obj,
|
||||
json_mode=json_mode,
|
||||
fake_stream=fake_stream,
|
||||
)
|
||||
streaming_response = CustomStreamWrapper(
|
||||
completion_stream=completion_stream,
|
||||
model=model,
|
||||
custom_llm_provider="bedrock",
|
||||
logging_obj=logging_obj,
|
||||
)
|
||||
|
||||
return streaming_response
|
||||
|
||||
### COMPLETION
|
||||
|
||||
try:
|
||||
response = client.post(
|
||||
url=proxy_endpoint_url,
|
||||
headers=prepped.headers,
|
||||
data=data,
|
||||
logging_obj=logging_obj,
|
||||
) # type: ignore
|
||||
response.raise_for_status()
|
||||
except httpx.HTTPStatusError as err:
|
||||
error_code = err.response.status_code
|
||||
raise BedrockError(status_code=error_code, message=err.response.text)
|
||||
except httpx.TimeoutException:
|
||||
raise BedrockError(status_code=408, message="Timeout error occurred.")
|
||||
|
||||
return litellm.AmazonConverseConfig()._transform_response(
|
||||
model=model,
|
||||
response=response,
|
||||
model_response=model_response,
|
||||
stream=stream if isinstance(stream, bool) else False,
|
||||
logging_obj=logging_obj,
|
||||
api_key="",
|
||||
data=data,
|
||||
messages=messages,
|
||||
optional_params=optional_params,
|
||||
encoding=encoding,
|
||||
)
|
||||
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,5 @@
|
||||
"""
|
||||
Uses base_llm_http_handler to call the 'converse like' endpoint.
|
||||
|
||||
Relevant issue: https://github.com/BerriAI/litellm/issues/8085
|
||||
"""
|
||||
@@ -0,0 +1,3 @@
|
||||
"""
|
||||
Uses `converse_transformation.py` to transform the messages to the format required by Bedrock Converse.
|
||||
"""
|
||||
@@ -0,0 +1,852 @@
|
||||
"""
|
||||
Translating between OpenAI's `/chat/completion` format and Amazon's `/converse` format
|
||||
"""
|
||||
|
||||
import copy
|
||||
import time
|
||||
import types
|
||||
from typing import List, Literal, Optional, Tuple, Union, cast, overload
|
||||
|
||||
import httpx
|
||||
|
||||
import litellm
|
||||
from litellm.litellm_core_utils.core_helpers import map_finish_reason
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging
|
||||
from litellm.litellm_core_utils.prompt_templates.factory import (
|
||||
BedrockConverseMessagesProcessor,
|
||||
_bedrock_converse_messages_pt,
|
||||
_bedrock_tools_pt,
|
||||
)
|
||||
from litellm.llms.anthropic.chat.transformation import AnthropicConfig
|
||||
from litellm.llms.base_llm.chat.transformation import BaseConfig, BaseLLMException
|
||||
from litellm.types.llms.bedrock import *
|
||||
from litellm.types.llms.openai import (
|
||||
AllMessageValues,
|
||||
ChatCompletionRedactedThinkingBlock,
|
||||
ChatCompletionResponseMessage,
|
||||
ChatCompletionSystemMessage,
|
||||
ChatCompletionThinkingBlock,
|
||||
ChatCompletionToolCallChunk,
|
||||
ChatCompletionToolCallFunctionChunk,
|
||||
ChatCompletionToolParam,
|
||||
ChatCompletionToolParamFunctionChunk,
|
||||
ChatCompletionUserMessage,
|
||||
OpenAIChatCompletionToolParam,
|
||||
OpenAIMessageContentListBlock,
|
||||
)
|
||||
from litellm.types.utils import ModelResponse, PromptTokensDetailsWrapper, Usage
|
||||
from litellm.utils import add_dummy_tool, has_tool_call_blocks
|
||||
|
||||
from ..common_utils import BedrockError, BedrockModelInfo, get_bedrock_tool_name
|
||||
|
||||
|
||||
class AmazonConverseConfig(BaseConfig):
|
||||
"""
|
||||
Reference - https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_Converse.html
|
||||
#2 - https://docs.aws.amazon.com/bedrock/latest/userguide/conversation-inference.html#conversation-inference-supported-models-features
|
||||
"""
|
||||
|
||||
maxTokens: Optional[int]
|
||||
stopSequences: Optional[List[str]]
|
||||
temperature: Optional[int]
|
||||
topP: Optional[int]
|
||||
topK: Optional[int]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
maxTokens: Optional[int] = None,
|
||||
stopSequences: Optional[List[str]] = None,
|
||||
temperature: Optional[int] = None,
|
||||
topP: Optional[int] = None,
|
||||
topK: Optional[int] = None,
|
||||
) -> None:
|
||||
locals_ = locals().copy()
|
||||
for key, value in locals_.items():
|
||||
if key != "self" and value is not None:
|
||||
setattr(self.__class__, key, value)
|
||||
|
||||
@property
|
||||
def custom_llm_provider(self) -> Optional[str]:
|
||||
return "bedrock_converse"
|
||||
|
||||
@classmethod
|
||||
def get_config_blocks(cls) -> dict:
|
||||
return {
|
||||
"guardrailConfig": GuardrailConfigBlock,
|
||||
"performanceConfig": PerformanceConfigBlock,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def get_config(cls):
|
||||
return {
|
||||
k: v
|
||||
for k, v in cls.__dict__.items()
|
||||
if not k.startswith("__")
|
||||
and not isinstance(
|
||||
v,
|
||||
(
|
||||
types.FunctionType,
|
||||
types.BuiltinFunctionType,
|
||||
classmethod,
|
||||
staticmethod,
|
||||
),
|
||||
)
|
||||
and v is not None
|
||||
}
|
||||
|
||||
def get_supported_openai_params(self, model: str) -> List[str]:
|
||||
supported_params = [
|
||||
"max_tokens",
|
||||
"max_completion_tokens",
|
||||
"stream",
|
||||
"stream_options",
|
||||
"stop",
|
||||
"temperature",
|
||||
"top_p",
|
||||
"extra_headers",
|
||||
"response_format",
|
||||
]
|
||||
|
||||
## Filter out 'cross-region' from model name
|
||||
base_model = BedrockModelInfo.get_base_model(model)
|
||||
|
||||
if (
|
||||
base_model.startswith("anthropic")
|
||||
or base_model.startswith("mistral")
|
||||
or base_model.startswith("cohere")
|
||||
or base_model.startswith("meta.llama3-1")
|
||||
or base_model.startswith("meta.llama3-2")
|
||||
or base_model.startswith("meta.llama3-3")
|
||||
or base_model.startswith("amazon.nova")
|
||||
):
|
||||
supported_params.append("tools")
|
||||
|
||||
if litellm.utils.supports_tool_choice(
|
||||
model=model, custom_llm_provider=self.custom_llm_provider
|
||||
):
|
||||
# only anthropic and mistral support tool choice config. otherwise (E.g. cohere) will fail the call - https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_ToolChoice.html
|
||||
supported_params.append("tool_choice")
|
||||
|
||||
if (
|
||||
"claude-3-7" in model
|
||||
): # [TODO]: move to a 'supports_reasoning_content' param from model cost map
|
||||
supported_params.append("thinking")
|
||||
supported_params.append("reasoning_effort")
|
||||
return supported_params
|
||||
|
||||
def map_tool_choice_values(
|
||||
self, model: str, tool_choice: Union[str, dict], drop_params: bool
|
||||
) -> Optional[ToolChoiceValuesBlock]:
|
||||
if tool_choice == "none":
|
||||
if litellm.drop_params is True or drop_params is True:
|
||||
return None
|
||||
else:
|
||||
raise litellm.utils.UnsupportedParamsError(
|
||||
message="Bedrock doesn't support tool_choice={}. To drop it from the call, set `litellm.drop_params = True.".format(
|
||||
tool_choice
|
||||
),
|
||||
status_code=400,
|
||||
)
|
||||
elif tool_choice == "required":
|
||||
return ToolChoiceValuesBlock(any={})
|
||||
elif tool_choice == "auto":
|
||||
return ToolChoiceValuesBlock(auto={})
|
||||
elif isinstance(tool_choice, dict):
|
||||
# only supported for anthropic + mistral models - https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_ToolChoice.html
|
||||
specific_tool = SpecificToolChoiceBlock(
|
||||
name=tool_choice.get("function", {}).get("name", "")
|
||||
)
|
||||
return ToolChoiceValuesBlock(tool=specific_tool)
|
||||
else:
|
||||
raise litellm.utils.UnsupportedParamsError(
|
||||
message="Bedrock doesn't support tool_choice={}. Supported tool_choice values=['auto', 'required', json object]. To drop it from the call, set `litellm.drop_params = True.".format(
|
||||
tool_choice
|
||||
),
|
||||
status_code=400,
|
||||
)
|
||||
|
||||
def get_supported_image_types(self) -> List[str]:
|
||||
return ["png", "jpeg", "gif", "webp"]
|
||||
|
||||
def get_supported_document_types(self) -> List[str]:
|
||||
return ["pdf", "csv", "doc", "docx", "xls", "xlsx", "html", "txt", "md"]
|
||||
|
||||
def get_all_supported_content_types(self) -> List[str]:
|
||||
return self.get_supported_image_types() + self.get_supported_document_types()
|
||||
|
||||
def _create_json_tool_call_for_response_format(
|
||||
self,
|
||||
json_schema: Optional[dict] = None,
|
||||
schema_name: str = "json_tool_call",
|
||||
description: Optional[str] = None,
|
||||
) -> ChatCompletionToolParam:
|
||||
"""
|
||||
Handles creating a tool call for getting responses in JSON format.
|
||||
|
||||
Args:
|
||||
json_schema (Optional[dict]): The JSON schema the response should be in
|
||||
|
||||
Returns:
|
||||
AnthropicMessagesTool: The tool call to send to Anthropic API to get responses in JSON format
|
||||
"""
|
||||
|
||||
if json_schema is None:
|
||||
# Anthropic raises a 400 BadRequest error if properties is passed as None
|
||||
# see usage with additionalProperties (Example 5) https://github.com/anthropics/anthropic-cookbook/blob/main/tool_use/extracting_structured_json.ipynb
|
||||
_input_schema = {
|
||||
"type": "object",
|
||||
"additionalProperties": True,
|
||||
"properties": {},
|
||||
}
|
||||
else:
|
||||
_input_schema = json_schema
|
||||
|
||||
tool_param_function_chunk = ChatCompletionToolParamFunctionChunk(
|
||||
name=schema_name, parameters=_input_schema
|
||||
)
|
||||
if description:
|
||||
tool_param_function_chunk["description"] = description
|
||||
|
||||
_tool = ChatCompletionToolParam(
|
||||
type="function",
|
||||
function=tool_param_function_chunk,
|
||||
)
|
||||
return _tool
|
||||
|
||||
def _apply_tool_call_transformation(
|
||||
self,
|
||||
tools: List[OpenAIChatCompletionToolParam],
|
||||
model: str,
|
||||
non_default_params: dict,
|
||||
optional_params: dict,
|
||||
):
|
||||
optional_params = self._add_tools_to_optional_params(
|
||||
optional_params=optional_params, tools=tools
|
||||
)
|
||||
|
||||
if (
|
||||
"meta.llama3-3-70b-instruct-v1:0" in model
|
||||
and non_default_params.get("stream", False) is True
|
||||
):
|
||||
optional_params["fake_stream"] = True
|
||||
|
||||
def map_openai_params(
|
||||
self,
|
||||
non_default_params: dict,
|
||||
optional_params: dict,
|
||||
model: str,
|
||||
drop_params: bool,
|
||||
) -> dict:
|
||||
is_thinking_enabled = self.is_thinking_enabled(non_default_params)
|
||||
|
||||
for param, value in non_default_params.items():
|
||||
if param == "response_format" and isinstance(value, dict):
|
||||
ignore_response_format_types = ["text"]
|
||||
if value["type"] in ignore_response_format_types: # value is a no-op
|
||||
continue
|
||||
|
||||
json_schema: Optional[dict] = None
|
||||
schema_name: str = ""
|
||||
description: Optional[str] = None
|
||||
if "response_schema" in value:
|
||||
json_schema = value["response_schema"]
|
||||
schema_name = "json_tool_call"
|
||||
elif "json_schema" in value:
|
||||
json_schema = value["json_schema"]["schema"]
|
||||
schema_name = value["json_schema"]["name"]
|
||||
description = value["json_schema"].get("description")
|
||||
|
||||
if "type" in value and value["type"] == "text":
|
||||
continue
|
||||
|
||||
"""
|
||||
Follow similar approach to anthropic - translate to a single tool call.
|
||||
|
||||
When using tools in this way: - https://docs.anthropic.com/en/docs/build-with-claude/tool-use#json-mode
|
||||
- You usually want to provide a single tool
|
||||
- You should set tool_choice (see Forcing tool use) to instruct the model to explicitly use that tool
|
||||
- Remember that the model will pass the input to the tool, so the name of the tool and description should be from the model’s perspective.
|
||||
"""
|
||||
_tool = self._create_json_tool_call_for_response_format(
|
||||
json_schema=json_schema,
|
||||
schema_name=schema_name if schema_name != "" else "json_tool_call",
|
||||
description=description,
|
||||
)
|
||||
optional_params = self._add_tools_to_optional_params(
|
||||
optional_params=optional_params, tools=[_tool]
|
||||
)
|
||||
if (
|
||||
litellm.utils.supports_tool_choice(
|
||||
model=model, custom_llm_provider=self.custom_llm_provider
|
||||
)
|
||||
and not is_thinking_enabled
|
||||
):
|
||||
optional_params["tool_choice"] = ToolChoiceValuesBlock(
|
||||
tool=SpecificToolChoiceBlock(
|
||||
name=schema_name if schema_name != "" else "json_tool_call"
|
||||
)
|
||||
)
|
||||
optional_params["json_mode"] = True
|
||||
if non_default_params.get("stream", False) is True:
|
||||
optional_params["fake_stream"] = True
|
||||
if param == "max_tokens" or param == "max_completion_tokens":
|
||||
optional_params["maxTokens"] = value
|
||||
if param == "stream":
|
||||
optional_params["stream"] = value
|
||||
if param == "stop":
|
||||
if isinstance(value, str):
|
||||
if len(value) == 0: # converse raises error for empty strings
|
||||
continue
|
||||
value = [value]
|
||||
optional_params["stopSequences"] = value
|
||||
if param == "temperature":
|
||||
optional_params["temperature"] = value
|
||||
if param == "top_p":
|
||||
optional_params["topP"] = value
|
||||
if param == "tools" and isinstance(value, list):
|
||||
self._apply_tool_call_transformation(
|
||||
tools=cast(List[OpenAIChatCompletionToolParam], value),
|
||||
model=model,
|
||||
non_default_params=non_default_params,
|
||||
optional_params=optional_params,
|
||||
)
|
||||
if param == "tool_choice":
|
||||
_tool_choice_value = self.map_tool_choice_values(
|
||||
model=model, tool_choice=value, drop_params=drop_params # type: ignore
|
||||
)
|
||||
if _tool_choice_value is not None:
|
||||
optional_params["tool_choice"] = _tool_choice_value
|
||||
if param == "thinking":
|
||||
optional_params["thinking"] = value
|
||||
elif param == "reasoning_effort" and isinstance(value, str):
|
||||
optional_params["thinking"] = AnthropicConfig._map_reasoning_effort(
|
||||
value
|
||||
)
|
||||
|
||||
self.update_optional_params_with_thinking_tokens(
|
||||
non_default_params=non_default_params, optional_params=optional_params
|
||||
)
|
||||
|
||||
return optional_params
|
||||
|
||||
@overload
|
||||
def _get_cache_point_block(
|
||||
self,
|
||||
message_block: Union[
|
||||
OpenAIMessageContentListBlock,
|
||||
ChatCompletionUserMessage,
|
||||
ChatCompletionSystemMessage,
|
||||
],
|
||||
block_type: Literal["system"],
|
||||
) -> Optional[SystemContentBlock]:
|
||||
pass
|
||||
|
||||
@overload
|
||||
def _get_cache_point_block(
|
||||
self,
|
||||
message_block: Union[
|
||||
OpenAIMessageContentListBlock,
|
||||
ChatCompletionUserMessage,
|
||||
ChatCompletionSystemMessage,
|
||||
],
|
||||
block_type: Literal["content_block"],
|
||||
) -> Optional[ContentBlock]:
|
||||
pass
|
||||
|
||||
def _get_cache_point_block(
|
||||
self,
|
||||
message_block: Union[
|
||||
OpenAIMessageContentListBlock,
|
||||
ChatCompletionUserMessage,
|
||||
ChatCompletionSystemMessage,
|
||||
],
|
||||
block_type: Literal["system", "content_block"],
|
||||
) -> Optional[Union[SystemContentBlock, ContentBlock]]:
|
||||
if message_block.get("cache_control", None) is None:
|
||||
return None
|
||||
if block_type == "system":
|
||||
return SystemContentBlock(cachePoint=CachePointBlock(type="default"))
|
||||
else:
|
||||
return ContentBlock(cachePoint=CachePointBlock(type="default"))
|
||||
|
||||
def _transform_system_message(
|
||||
self, messages: List[AllMessageValues]
|
||||
) -> Tuple[List[AllMessageValues], List[SystemContentBlock]]:
|
||||
system_prompt_indices = []
|
||||
system_content_blocks: List[SystemContentBlock] = []
|
||||
for idx, message in enumerate(messages):
|
||||
if message["role"] == "system":
|
||||
system_prompt_indices.append(idx)
|
||||
if isinstance(message["content"], str) and message["content"]:
|
||||
system_content_blocks.append(
|
||||
SystemContentBlock(text=message["content"])
|
||||
)
|
||||
cache_block = self._get_cache_point_block(
|
||||
message, block_type="system"
|
||||
)
|
||||
if cache_block:
|
||||
system_content_blocks.append(cache_block)
|
||||
elif isinstance(message["content"], list):
|
||||
for m in message["content"]:
|
||||
if m.get("type") == "text" and m.get("text"):
|
||||
system_content_blocks.append(
|
||||
SystemContentBlock(text=m["text"])
|
||||
)
|
||||
cache_block = self._get_cache_point_block(
|
||||
m, block_type="system"
|
||||
)
|
||||
if cache_block:
|
||||
system_content_blocks.append(cache_block)
|
||||
if len(system_prompt_indices) > 0:
|
||||
for idx in reversed(system_prompt_indices):
|
||||
messages.pop(idx)
|
||||
return messages, system_content_blocks
|
||||
|
||||
def _transform_inference_params(self, inference_params: dict) -> InferenceConfig:
|
||||
if "top_k" in inference_params:
|
||||
inference_params["topK"] = inference_params.pop("top_k")
|
||||
return InferenceConfig(**inference_params)
|
||||
|
||||
def _handle_top_k_value(self, model: str, inference_params: dict) -> dict:
|
||||
base_model = BedrockModelInfo.get_base_model(model)
|
||||
|
||||
val_top_k = None
|
||||
if "topK" in inference_params:
|
||||
val_top_k = inference_params.pop("topK")
|
||||
elif "top_k" in inference_params:
|
||||
val_top_k = inference_params.pop("top_k")
|
||||
|
||||
if val_top_k:
|
||||
if base_model.startswith("anthropic"):
|
||||
return {"top_k": val_top_k}
|
||||
if base_model.startswith("amazon.nova"):
|
||||
return {"inferenceConfig": {"topK": val_top_k}}
|
||||
|
||||
return {}
|
||||
|
||||
def _transform_request_helper(
|
||||
self,
|
||||
model: str,
|
||||
system_content_blocks: List[SystemContentBlock],
|
||||
optional_params: dict,
|
||||
messages: Optional[List[AllMessageValues]] = None,
|
||||
) -> CommonRequestObject:
|
||||
## VALIDATE REQUEST
|
||||
"""
|
||||
Bedrock doesn't support tool calling without `tools=` param specified.
|
||||
"""
|
||||
if (
|
||||
"tools" not in optional_params
|
||||
and messages is not None
|
||||
and has_tool_call_blocks(messages)
|
||||
):
|
||||
if litellm.modify_params:
|
||||
optional_params["tools"] = add_dummy_tool(
|
||||
custom_llm_provider="bedrock_converse"
|
||||
)
|
||||
else:
|
||||
raise litellm.UnsupportedParamsError(
|
||||
message="Bedrock doesn't support tool calling without `tools=` param specified. Pass `tools=` param OR set `litellm.modify_params = True` // `litellm_settings::modify_params: True` to add dummy tool to the request.",
|
||||
model="",
|
||||
llm_provider="bedrock",
|
||||
)
|
||||
|
||||
inference_params = copy.deepcopy(optional_params)
|
||||
supported_converse_params = list(
|
||||
AmazonConverseConfig.__annotations__.keys()
|
||||
) + ["top_k"]
|
||||
supported_tool_call_params = ["tools", "tool_choice"]
|
||||
supported_config_params = list(self.get_config_blocks().keys())
|
||||
total_supported_params = (
|
||||
supported_converse_params
|
||||
+ supported_tool_call_params
|
||||
+ supported_config_params
|
||||
)
|
||||
inference_params.pop("json_mode", None) # used for handling json_schema
|
||||
|
||||
# keep supported params in 'inference_params', and set all model-specific params in 'additional_request_params'
|
||||
additional_request_params = {
|
||||
k: v for k, v in inference_params.items() if k not in total_supported_params
|
||||
}
|
||||
inference_params = {
|
||||
k: v for k, v in inference_params.items() if k in total_supported_params
|
||||
}
|
||||
|
||||
# Only set the topK value in for models that support it
|
||||
additional_request_params.update(
|
||||
self._handle_top_k_value(model, inference_params)
|
||||
)
|
||||
|
||||
bedrock_tools: List[ToolBlock] = _bedrock_tools_pt(
|
||||
inference_params.pop("tools", [])
|
||||
)
|
||||
bedrock_tool_config: Optional[ToolConfigBlock] = None
|
||||
if len(bedrock_tools) > 0:
|
||||
tool_choice_values: ToolChoiceValuesBlock = inference_params.pop(
|
||||
"tool_choice", None
|
||||
)
|
||||
bedrock_tool_config = ToolConfigBlock(
|
||||
tools=bedrock_tools,
|
||||
)
|
||||
if tool_choice_values is not None:
|
||||
bedrock_tool_config["toolChoice"] = tool_choice_values
|
||||
|
||||
data: CommonRequestObject = {
|
||||
"additionalModelRequestFields": additional_request_params,
|
||||
"system": system_content_blocks,
|
||||
"inferenceConfig": self._transform_inference_params(
|
||||
inference_params=inference_params
|
||||
),
|
||||
}
|
||||
|
||||
# Handle all config blocks
|
||||
for config_name, config_class in self.get_config_blocks().items():
|
||||
config_value = inference_params.pop(config_name, None)
|
||||
if config_value is not None:
|
||||
data[config_name] = config_class(**config_value) # type: ignore
|
||||
|
||||
# Tool Config
|
||||
if bedrock_tool_config is not None:
|
||||
data["toolConfig"] = bedrock_tool_config
|
||||
|
||||
return data
|
||||
|
||||
async def _async_transform_request(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
) -> RequestObject:
|
||||
messages, system_content_blocks = self._transform_system_message(messages)
|
||||
## TRANSFORMATION ##
|
||||
|
||||
_data: CommonRequestObject = self._transform_request_helper(
|
||||
model=model,
|
||||
system_content_blocks=system_content_blocks,
|
||||
optional_params=optional_params,
|
||||
messages=messages,
|
||||
)
|
||||
|
||||
bedrock_messages = (
|
||||
await BedrockConverseMessagesProcessor._bedrock_converse_messages_pt_async(
|
||||
messages=messages,
|
||||
model=model,
|
||||
llm_provider="bedrock_converse",
|
||||
user_continue_message=litellm_params.pop("user_continue_message", None),
|
||||
)
|
||||
)
|
||||
|
||||
data: RequestObject = {"messages": bedrock_messages, **_data}
|
||||
|
||||
return data
|
||||
|
||||
def transform_request(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
headers: dict,
|
||||
) -> dict:
|
||||
return cast(
|
||||
dict,
|
||||
self._transform_request(
|
||||
model=model,
|
||||
messages=messages,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
),
|
||||
)
|
||||
|
||||
def _transform_request(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
) -> RequestObject:
|
||||
messages, system_content_blocks = self._transform_system_message(messages)
|
||||
|
||||
_data: CommonRequestObject = self._transform_request_helper(
|
||||
model=model,
|
||||
system_content_blocks=system_content_blocks,
|
||||
optional_params=optional_params,
|
||||
messages=messages,
|
||||
)
|
||||
|
||||
## TRANSFORMATION ##
|
||||
bedrock_messages: List[MessageBlock] = _bedrock_converse_messages_pt(
|
||||
messages=messages,
|
||||
model=model,
|
||||
llm_provider="bedrock_converse",
|
||||
user_continue_message=litellm_params.pop("user_continue_message", None),
|
||||
)
|
||||
|
||||
data: RequestObject = {"messages": bedrock_messages, **_data}
|
||||
|
||||
return data
|
||||
|
||||
def transform_response(
|
||||
self,
|
||||
model: str,
|
||||
raw_response: httpx.Response,
|
||||
model_response: ModelResponse,
|
||||
logging_obj: Logging,
|
||||
request_data: dict,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
encoding: Any,
|
||||
api_key: Optional[str] = None,
|
||||
json_mode: Optional[bool] = None,
|
||||
) -> ModelResponse:
|
||||
return self._transform_response(
|
||||
model=model,
|
||||
response=raw_response,
|
||||
model_response=model_response,
|
||||
stream=optional_params.get("stream", False),
|
||||
logging_obj=logging_obj,
|
||||
optional_params=optional_params,
|
||||
api_key=api_key,
|
||||
data=request_data,
|
||||
messages=messages,
|
||||
encoding=encoding,
|
||||
)
|
||||
|
||||
def _transform_reasoning_content(
|
||||
self, reasoning_content_blocks: List[BedrockConverseReasoningContentBlock]
|
||||
) -> str:
|
||||
"""
|
||||
Extract the reasoning text from the reasoning content blocks
|
||||
|
||||
Ensures deepseek reasoning content compatible output.
|
||||
"""
|
||||
reasoning_content_str = ""
|
||||
for block in reasoning_content_blocks:
|
||||
if "reasoningText" in block:
|
||||
reasoning_content_str += block["reasoningText"]["text"]
|
||||
return reasoning_content_str
|
||||
|
||||
def _transform_thinking_blocks(
|
||||
self, thinking_blocks: List[BedrockConverseReasoningContentBlock]
|
||||
) -> List[Union[ChatCompletionThinkingBlock, ChatCompletionRedactedThinkingBlock]]:
|
||||
"""Return a consistent format for thinking blocks between Anthropic and Bedrock."""
|
||||
thinking_blocks_list: List[
|
||||
Union[ChatCompletionThinkingBlock, ChatCompletionRedactedThinkingBlock]
|
||||
] = []
|
||||
for block in thinking_blocks:
|
||||
if "reasoningText" in block:
|
||||
_thinking_block = ChatCompletionThinkingBlock(type="thinking")
|
||||
_text = block["reasoningText"].get("text")
|
||||
_signature = block["reasoningText"].get("signature")
|
||||
if _text is not None:
|
||||
_thinking_block["thinking"] = _text
|
||||
if _signature is not None:
|
||||
_thinking_block["signature"] = _signature
|
||||
thinking_blocks_list.append(_thinking_block)
|
||||
elif "redactedContent" in block:
|
||||
_redacted_block = ChatCompletionRedactedThinkingBlock(
|
||||
type="redacted_thinking", data=block["redactedContent"]
|
||||
)
|
||||
thinking_blocks_list.append(_redacted_block)
|
||||
return thinking_blocks_list
|
||||
|
||||
def _transform_usage(self, usage: ConverseTokenUsageBlock) -> Usage:
|
||||
input_tokens = usage["inputTokens"]
|
||||
output_tokens = usage["outputTokens"]
|
||||
total_tokens = usage["totalTokens"]
|
||||
cache_creation_input_tokens: int = 0
|
||||
cache_read_input_tokens: int = 0
|
||||
|
||||
if "cacheReadInputTokens" in usage:
|
||||
cache_read_input_tokens = usage["cacheReadInputTokens"]
|
||||
input_tokens += cache_read_input_tokens
|
||||
if "cacheWriteInputTokens" in usage:
|
||||
"""
|
||||
Do not increment prompt_tokens with cacheWriteInputTokens
|
||||
"""
|
||||
cache_creation_input_tokens = usage["cacheWriteInputTokens"]
|
||||
|
||||
prompt_tokens_details = PromptTokensDetailsWrapper(
|
||||
cached_tokens=cache_read_input_tokens
|
||||
)
|
||||
openai_usage = Usage(
|
||||
prompt_tokens=input_tokens,
|
||||
completion_tokens=output_tokens,
|
||||
total_tokens=total_tokens,
|
||||
prompt_tokens_details=prompt_tokens_details,
|
||||
cache_creation_input_tokens=cache_creation_input_tokens,
|
||||
cache_read_input_tokens=cache_read_input_tokens,
|
||||
)
|
||||
return openai_usage
|
||||
|
||||
def _transform_response(
|
||||
self,
|
||||
model: str,
|
||||
response: httpx.Response,
|
||||
model_response: ModelResponse,
|
||||
stream: bool,
|
||||
logging_obj: Optional[Logging],
|
||||
optional_params: dict,
|
||||
api_key: Optional[str],
|
||||
data: Union[dict, str],
|
||||
messages: List,
|
||||
encoding,
|
||||
) -> ModelResponse:
|
||||
## LOGGING
|
||||
if logging_obj is not None:
|
||||
logging_obj.post_call(
|
||||
input=messages,
|
||||
api_key=api_key,
|
||||
original_response=response.text,
|
||||
additional_args={"complete_input_dict": data},
|
||||
)
|
||||
|
||||
json_mode: Optional[bool] = optional_params.pop("json_mode", None)
|
||||
## RESPONSE OBJECT
|
||||
try:
|
||||
completion_response = ConverseResponseBlock(**response.json()) # type: ignore
|
||||
except Exception as e:
|
||||
raise BedrockError(
|
||||
message="Received={}, Error converting to valid response block={}. File an issue if litellm error - https://github.com/BerriAI/litellm/issues".format(
|
||||
response.text, str(e)
|
||||
),
|
||||
status_code=422,
|
||||
)
|
||||
|
||||
"""
|
||||
Bedrock Response Object has optional message block
|
||||
|
||||
completion_response["output"].get("message", None)
|
||||
|
||||
A message block looks like this (Example 1):
|
||||
"output": {
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{
|
||||
"text": "Is there anything else you'd like to talk about? Perhaps I can help with some economic questions or provide some information about economic concepts?"
|
||||
}
|
||||
]
|
||||
}
|
||||
},
|
||||
(Example 2):
|
||||
"output": {
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{
|
||||
"toolUse": {
|
||||
"toolUseId": "tooluse_hbTgdi0CSLq_hM4P8csZJA",
|
||||
"name": "top_song",
|
||||
"input": {
|
||||
"sign": "WZPZ"
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
"""
|
||||
message: Optional[MessageBlock] = completion_response["output"]["message"]
|
||||
chat_completion_message: ChatCompletionResponseMessage = {"role": "assistant"}
|
||||
content_str = ""
|
||||
tools: List[ChatCompletionToolCallChunk] = []
|
||||
reasoningContentBlocks: Optional[
|
||||
List[BedrockConverseReasoningContentBlock]
|
||||
] = None
|
||||
|
||||
if message is not None:
|
||||
for idx, content in enumerate(message["content"]):
|
||||
"""
|
||||
- Content is either a tool response or text
|
||||
"""
|
||||
if "text" in content:
|
||||
content_str += content["text"]
|
||||
if "toolUse" in content:
|
||||
## check tool name was formatted by litellm
|
||||
_response_tool_name = content["toolUse"]["name"]
|
||||
response_tool_name = get_bedrock_tool_name(
|
||||
response_tool_name=_response_tool_name
|
||||
)
|
||||
_function_chunk = ChatCompletionToolCallFunctionChunk(
|
||||
name=response_tool_name,
|
||||
arguments=json.dumps(content["toolUse"]["input"]),
|
||||
)
|
||||
|
||||
_tool_response_chunk = ChatCompletionToolCallChunk(
|
||||
id=content["toolUse"]["toolUseId"],
|
||||
type="function",
|
||||
function=_function_chunk,
|
||||
index=idx,
|
||||
)
|
||||
tools.append(_tool_response_chunk)
|
||||
if "reasoningContent" in content:
|
||||
if reasoningContentBlocks is None:
|
||||
reasoningContentBlocks = []
|
||||
reasoningContentBlocks.append(content["reasoningContent"])
|
||||
|
||||
if reasoningContentBlocks is not None:
|
||||
chat_completion_message["provider_specific_fields"] = {
|
||||
"reasoningContentBlocks": reasoningContentBlocks,
|
||||
}
|
||||
chat_completion_message[
|
||||
"reasoning_content"
|
||||
] = self._transform_reasoning_content(reasoningContentBlocks)
|
||||
chat_completion_message[
|
||||
"thinking_blocks"
|
||||
] = self._transform_thinking_blocks(reasoningContentBlocks)
|
||||
chat_completion_message["content"] = content_str
|
||||
if json_mode is True and tools is not None and len(tools) == 1:
|
||||
# to support 'json_schema' logic on bedrock models
|
||||
json_mode_content_str: Optional[str] = tools[0]["function"].get("arguments")
|
||||
if json_mode_content_str is not None:
|
||||
chat_completion_message["content"] = json_mode_content_str
|
||||
else:
|
||||
chat_completion_message["tool_calls"] = tools
|
||||
|
||||
## CALCULATING USAGE - bedrock returns usage in the headers
|
||||
usage = self._transform_usage(completion_response["usage"])
|
||||
|
||||
model_response.choices = [
|
||||
litellm.Choices(
|
||||
finish_reason=map_finish_reason(completion_response["stopReason"]),
|
||||
index=0,
|
||||
message=litellm.Message(**chat_completion_message),
|
||||
)
|
||||
]
|
||||
model_response.created = int(time.time())
|
||||
model_response.model = model
|
||||
|
||||
setattr(model_response, "usage", usage)
|
||||
|
||||
# Add "trace" from Bedrock guardrails - if user has opted in to returning it
|
||||
if "trace" in completion_response:
|
||||
setattr(model_response, "trace", completion_response["trace"])
|
||||
|
||||
return model_response
|
||||
|
||||
def get_error_class(
|
||||
self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers]
|
||||
) -> BaseLLMException:
|
||||
return BedrockError(
|
||||
message=error_message,
|
||||
status_code=status_code,
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
def validate_environment(
|
||||
self,
|
||||
headers: dict,
|
||||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
) -> dict:
|
||||
if api_key:
|
||||
headers["Authorization"] = f"Bearer {api_key}"
|
||||
return headers
|
||||
File diff suppressed because it is too large
Load Diff
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,99 @@
|
||||
import types
|
||||
from typing import List, Optional
|
||||
|
||||
from litellm.llms.base_llm.chat.transformation import BaseConfig
|
||||
from litellm.llms.bedrock.chat.invoke_transformations.base_invoke_transformation import (
|
||||
AmazonInvokeConfig,
|
||||
)
|
||||
|
||||
|
||||
class AmazonAI21Config(AmazonInvokeConfig, BaseConfig):
|
||||
"""
|
||||
Reference: https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=j2-ultra
|
||||
|
||||
Supported Params for the Amazon / AI21 models:
|
||||
|
||||
- `maxTokens` (int32): The maximum number of tokens to generate per result. Optional, default is 16. If no `stopSequences` are given, generation stops after producing `maxTokens`.
|
||||
|
||||
- `temperature` (float): Modifies the distribution from which tokens are sampled. Optional, default is 0.7. A value of 0 essentially disables sampling and results in greedy decoding.
|
||||
|
||||
- `topP` (float): Used for sampling tokens from the corresponding top percentile of probability mass. Optional, default is 1. For instance, a value of 0.9 considers only tokens comprising the top 90% probability mass.
|
||||
|
||||
- `stopSequences` (array of strings): Stops decoding if any of the input strings is generated. Optional.
|
||||
|
||||
- `frequencyPenalty` (object): Placeholder for frequency penalty object.
|
||||
|
||||
- `presencePenalty` (object): Placeholder for presence penalty object.
|
||||
|
||||
- `countPenalty` (object): Placeholder for count penalty object.
|
||||
"""
|
||||
|
||||
maxTokens: Optional[int] = None
|
||||
temperature: Optional[float] = None
|
||||
topP: Optional[float] = None
|
||||
stopSequences: Optional[list] = None
|
||||
frequencePenalty: Optional[dict] = None
|
||||
presencePenalty: Optional[dict] = None
|
||||
countPenalty: Optional[dict] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
maxTokens: Optional[int] = None,
|
||||
temperature: Optional[float] = None,
|
||||
topP: Optional[float] = None,
|
||||
stopSequences: Optional[list] = None,
|
||||
frequencePenalty: Optional[dict] = None,
|
||||
presencePenalty: Optional[dict] = None,
|
||||
countPenalty: Optional[dict] = None,
|
||||
) -> None:
|
||||
locals_ = locals().copy()
|
||||
for key, value in locals_.items():
|
||||
if key != "self" and value is not None:
|
||||
setattr(self.__class__, key, value)
|
||||
|
||||
AmazonInvokeConfig.__init__(self)
|
||||
|
||||
@classmethod
|
||||
def get_config(cls):
|
||||
return {
|
||||
k: v
|
||||
for k, v in cls.__dict__.items()
|
||||
if not k.startswith("__")
|
||||
and not k.startswith("_abc")
|
||||
and not isinstance(
|
||||
v,
|
||||
(
|
||||
types.FunctionType,
|
||||
types.BuiltinFunctionType,
|
||||
classmethod,
|
||||
staticmethod,
|
||||
),
|
||||
)
|
||||
and v is not None
|
||||
}
|
||||
|
||||
def get_supported_openai_params(self, model: str) -> List:
|
||||
return [
|
||||
"max_tokens",
|
||||
"temperature",
|
||||
"top_p",
|
||||
"stream",
|
||||
]
|
||||
|
||||
def map_openai_params(
|
||||
self,
|
||||
non_default_params: dict,
|
||||
optional_params: dict,
|
||||
model: str,
|
||||
drop_params: bool,
|
||||
) -> dict:
|
||||
for k, v in non_default_params.items():
|
||||
if k == "max_tokens":
|
||||
optional_params["maxTokens"] = v
|
||||
if k == "temperature":
|
||||
optional_params["temperature"] = v
|
||||
if k == "top_p":
|
||||
optional_params["topP"] = v
|
||||
if k == "stream":
|
||||
optional_params["stream"] = v
|
||||
return optional_params
|
||||
@@ -0,0 +1,75 @@
|
||||
import types
|
||||
from typing import List, Optional
|
||||
|
||||
from litellm.llms.bedrock.chat.invoke_transformations.base_invoke_transformation import (
|
||||
AmazonInvokeConfig,
|
||||
)
|
||||
from litellm.llms.cohere.chat.transformation import CohereChatConfig
|
||||
|
||||
|
||||
class AmazonCohereConfig(AmazonInvokeConfig, CohereChatConfig):
|
||||
"""
|
||||
Reference: https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=command
|
||||
|
||||
Supported Params for the Amazon / Cohere models:
|
||||
|
||||
- `max_tokens` (integer) max tokens,
|
||||
- `temperature` (float) model temperature,
|
||||
- `return_likelihood` (string) n/a
|
||||
"""
|
||||
|
||||
max_tokens: Optional[int] = None
|
||||
return_likelihood: Optional[str] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_tokens: Optional[int] = None,
|
||||
temperature: Optional[float] = None,
|
||||
return_likelihood: Optional[str] = None,
|
||||
) -> None:
|
||||
locals_ = locals().copy()
|
||||
for key, value in locals_.items():
|
||||
if key != "self" and value is not None:
|
||||
setattr(self.__class__, key, value)
|
||||
|
||||
AmazonInvokeConfig.__init__(self)
|
||||
|
||||
@classmethod
|
||||
def get_config(cls):
|
||||
return {
|
||||
k: v
|
||||
for k, v in cls.__dict__.items()
|
||||
if not k.startswith("__")
|
||||
and not k.startswith("_abc")
|
||||
and not isinstance(
|
||||
v,
|
||||
(
|
||||
types.FunctionType,
|
||||
types.BuiltinFunctionType,
|
||||
classmethod,
|
||||
staticmethod,
|
||||
),
|
||||
)
|
||||
and v is not None
|
||||
}
|
||||
|
||||
def get_supported_openai_params(self, model: str) -> List[str]:
|
||||
supported_params = CohereChatConfig.get_supported_openai_params(
|
||||
self, model=model
|
||||
)
|
||||
return supported_params
|
||||
|
||||
def map_openai_params(
|
||||
self,
|
||||
non_default_params: dict,
|
||||
optional_params: dict,
|
||||
model: str,
|
||||
drop_params: bool,
|
||||
) -> dict:
|
||||
return CohereChatConfig.map_openai_params(
|
||||
self,
|
||||
non_default_params=non_default_params,
|
||||
optional_params=optional_params,
|
||||
model=model,
|
||||
drop_params=drop_params,
|
||||
)
|
||||
@@ -0,0 +1,135 @@
|
||||
from typing import Any, List, Optional, cast
|
||||
|
||||
from httpx import Response
|
||||
|
||||
from litellm import verbose_logger
|
||||
from litellm.litellm_core_utils.llm_response_utils.convert_dict_to_response import (
|
||||
_parse_content_for_reasoning,
|
||||
)
|
||||
from litellm.llms.base_llm.base_model_iterator import BaseModelResponseIterator
|
||||
from litellm.llms.bedrock.chat.invoke_transformations.base_invoke_transformation import (
|
||||
LiteLLMLoggingObj,
|
||||
)
|
||||
from litellm.types.llms.bedrock import AmazonDeepSeekR1StreamingResponse
|
||||
from litellm.types.llms.openai import AllMessageValues
|
||||
from litellm.types.utils import (
|
||||
ChatCompletionUsageBlock,
|
||||
Choices,
|
||||
Delta,
|
||||
Message,
|
||||
ModelResponse,
|
||||
ModelResponseStream,
|
||||
StreamingChoices,
|
||||
)
|
||||
|
||||
from .amazon_llama_transformation import AmazonLlamaConfig
|
||||
|
||||
|
||||
class AmazonDeepSeekR1Config(AmazonLlamaConfig):
|
||||
def transform_response(
|
||||
self,
|
||||
model: str,
|
||||
raw_response: Response,
|
||||
model_response: ModelResponse,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
request_data: dict,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
encoding: Any,
|
||||
api_key: Optional[str] = None,
|
||||
json_mode: Optional[bool] = None,
|
||||
) -> ModelResponse:
|
||||
"""
|
||||
Extract the reasoning content, and return it as a separate field in the response.
|
||||
"""
|
||||
response = super().transform_response(
|
||||
model,
|
||||
raw_response,
|
||||
model_response,
|
||||
logging_obj,
|
||||
request_data,
|
||||
messages,
|
||||
optional_params,
|
||||
litellm_params,
|
||||
encoding,
|
||||
api_key,
|
||||
json_mode,
|
||||
)
|
||||
prompt = cast(Optional[str], request_data.get("prompt"))
|
||||
message_content = cast(
|
||||
Optional[str], cast(Choices, response.choices[0]).message.get("content")
|
||||
)
|
||||
if prompt and prompt.strip().endswith("<think>") and message_content:
|
||||
message_content_with_reasoning_token = "<think>" + message_content
|
||||
reasoning, content = _parse_content_for_reasoning(
|
||||
message_content_with_reasoning_token
|
||||
)
|
||||
provider_specific_fields = (
|
||||
cast(Choices, response.choices[0]).message.provider_specific_fields
|
||||
or {}
|
||||
)
|
||||
if reasoning:
|
||||
provider_specific_fields["reasoning_content"] = reasoning
|
||||
|
||||
message = Message(
|
||||
**{
|
||||
**cast(Choices, response.choices[0]).message.model_dump(),
|
||||
"content": content,
|
||||
"provider_specific_fields": provider_specific_fields,
|
||||
}
|
||||
)
|
||||
cast(Choices, response.choices[0]).message = message
|
||||
return response
|
||||
|
||||
|
||||
class AmazonDeepseekR1ResponseIterator(BaseModelResponseIterator):
|
||||
def __init__(self, streaming_response: Any, sync_stream: bool) -> None:
|
||||
super().__init__(streaming_response=streaming_response, sync_stream=sync_stream)
|
||||
self.has_finished_thinking = False
|
||||
|
||||
def chunk_parser(self, chunk: dict) -> ModelResponseStream:
|
||||
"""
|
||||
Deepseek r1 starts by thinking, then it generates the response.
|
||||
"""
|
||||
try:
|
||||
typed_chunk = AmazonDeepSeekR1StreamingResponse(**chunk) # type: ignore
|
||||
generated_content = typed_chunk["generation"]
|
||||
if generated_content == "</think>" and not self.has_finished_thinking:
|
||||
verbose_logger.debug(
|
||||
"Deepseek r1: </think> received, setting has_finished_thinking to True"
|
||||
)
|
||||
generated_content = ""
|
||||
self.has_finished_thinking = True
|
||||
|
||||
prompt_token_count = typed_chunk.get("prompt_token_count") or 0
|
||||
generation_token_count = typed_chunk.get("generation_token_count") or 0
|
||||
usage = ChatCompletionUsageBlock(
|
||||
prompt_tokens=prompt_token_count,
|
||||
completion_tokens=generation_token_count,
|
||||
total_tokens=prompt_token_count + generation_token_count,
|
||||
)
|
||||
|
||||
return ModelResponseStream(
|
||||
choices=[
|
||||
StreamingChoices(
|
||||
finish_reason=typed_chunk["stop_reason"],
|
||||
delta=Delta(
|
||||
content=(
|
||||
generated_content
|
||||
if self.has_finished_thinking
|
||||
else None
|
||||
),
|
||||
reasoning_content=(
|
||||
generated_content
|
||||
if not self.has_finished_thinking
|
||||
else None
|
||||
),
|
||||
),
|
||||
)
|
||||
],
|
||||
usage=usage,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
raise e
|
||||
@@ -0,0 +1,80 @@
|
||||
import types
|
||||
from typing import List, Optional
|
||||
|
||||
from litellm.llms.base_llm.chat.transformation import BaseConfig
|
||||
from litellm.llms.bedrock.chat.invoke_transformations.base_invoke_transformation import (
|
||||
AmazonInvokeConfig,
|
||||
)
|
||||
|
||||
|
||||
class AmazonLlamaConfig(AmazonInvokeConfig, BaseConfig):
|
||||
"""
|
||||
Reference: https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=meta.llama2-13b-chat-v1
|
||||
|
||||
Supported Params for the Amazon / Meta Llama models:
|
||||
|
||||
- `max_gen_len` (integer) max tokens,
|
||||
- `temperature` (float) temperature for model,
|
||||
- `top_p` (float) top p for model
|
||||
"""
|
||||
|
||||
max_gen_len: Optional[int] = None
|
||||
temperature: Optional[float] = None
|
||||
topP: Optional[float] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
maxTokenCount: Optional[int] = None,
|
||||
temperature: Optional[float] = None,
|
||||
topP: Optional[int] = None,
|
||||
) -> None:
|
||||
locals_ = locals().copy()
|
||||
for key, value in locals_.items():
|
||||
if key != "self" and value is not None:
|
||||
setattr(self.__class__, key, value)
|
||||
AmazonInvokeConfig.__init__(self)
|
||||
|
||||
@classmethod
|
||||
def get_config(cls):
|
||||
return {
|
||||
k: v
|
||||
for k, v in cls.__dict__.items()
|
||||
if not k.startswith("__")
|
||||
and not k.startswith("_abc")
|
||||
and not isinstance(
|
||||
v,
|
||||
(
|
||||
types.FunctionType,
|
||||
types.BuiltinFunctionType,
|
||||
classmethod,
|
||||
staticmethod,
|
||||
),
|
||||
)
|
||||
and v is not None
|
||||
}
|
||||
|
||||
def get_supported_openai_params(self, model: str) -> List:
|
||||
return [
|
||||
"max_tokens",
|
||||
"temperature",
|
||||
"top_p",
|
||||
"stream",
|
||||
]
|
||||
|
||||
def map_openai_params(
|
||||
self,
|
||||
non_default_params: dict,
|
||||
optional_params: dict,
|
||||
model: str,
|
||||
drop_params: bool,
|
||||
) -> dict:
|
||||
for k, v in non_default_params.items():
|
||||
if k == "max_tokens":
|
||||
optional_params["max_gen_len"] = v
|
||||
if k == "temperature":
|
||||
optional_params["temperature"] = v
|
||||
if k == "top_p":
|
||||
optional_params["top_p"] = v
|
||||
if k == "stream":
|
||||
optional_params["stream"] = v
|
||||
return optional_params
|
||||
@@ -0,0 +1,83 @@
|
||||
import types
|
||||
from typing import List, Optional
|
||||
|
||||
from litellm.llms.base_llm.chat.transformation import BaseConfig
|
||||
from litellm.llms.bedrock.chat.invoke_transformations.base_invoke_transformation import (
|
||||
AmazonInvokeConfig,
|
||||
)
|
||||
|
||||
|
||||
class AmazonMistralConfig(AmazonInvokeConfig, BaseConfig):
|
||||
"""
|
||||
Reference: https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-mistral.html
|
||||
Supported Params for the Amazon / Mistral models:
|
||||
|
||||
- `max_tokens` (integer) max tokens,
|
||||
- `temperature` (float) temperature for model,
|
||||
- `top_p` (float) top p for model
|
||||
- `stop` [string] A list of stop sequences that if generated by the model, stops the model from generating further output.
|
||||
- `top_k` (float) top k for model
|
||||
"""
|
||||
|
||||
max_tokens: Optional[int] = None
|
||||
temperature: Optional[float] = None
|
||||
top_p: Optional[float] = None
|
||||
top_k: Optional[float] = None
|
||||
stop: Optional[List[str]] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_tokens: Optional[int] = None,
|
||||
temperature: Optional[float] = None,
|
||||
top_p: Optional[int] = None,
|
||||
top_k: Optional[float] = None,
|
||||
stop: Optional[List[str]] = None,
|
||||
) -> None:
|
||||
locals_ = locals().copy()
|
||||
for key, value in locals_.items():
|
||||
if key != "self" and value is not None:
|
||||
setattr(self.__class__, key, value)
|
||||
|
||||
AmazonInvokeConfig.__init__(self)
|
||||
|
||||
@classmethod
|
||||
def get_config(cls):
|
||||
return {
|
||||
k: v
|
||||
for k, v in cls.__dict__.items()
|
||||
if not k.startswith("__")
|
||||
and not k.startswith("_abc")
|
||||
and not isinstance(
|
||||
v,
|
||||
(
|
||||
types.FunctionType,
|
||||
types.BuiltinFunctionType,
|
||||
classmethod,
|
||||
staticmethod,
|
||||
),
|
||||
)
|
||||
and v is not None
|
||||
}
|
||||
|
||||
def get_supported_openai_params(self, model: str) -> List[str]:
|
||||
return ["max_tokens", "temperature", "top_p", "stop", "stream"]
|
||||
|
||||
def map_openai_params(
|
||||
self,
|
||||
non_default_params: dict,
|
||||
optional_params: dict,
|
||||
model: str,
|
||||
drop_params: bool,
|
||||
) -> dict:
|
||||
for k, v in non_default_params.items():
|
||||
if k == "max_tokens":
|
||||
optional_params["max_tokens"] = v
|
||||
if k == "temperature":
|
||||
optional_params["temperature"] = v
|
||||
if k == "top_p":
|
||||
optional_params["top_p"] = v
|
||||
if k == "stop":
|
||||
optional_params["stop"] = v
|
||||
if k == "stream":
|
||||
optional_params["stream"] = v
|
||||
return optional_params
|
||||
@@ -0,0 +1,121 @@
|
||||
"""
|
||||
Handles transforming requests for `bedrock/invoke/{nova} models`
|
||||
|
||||
Inherits from `AmazonConverseConfig`
|
||||
|
||||
Nova + Invoke API Tutorial: https://docs.aws.amazon.com/nova/latest/userguide/using-invoke-api.html
|
||||
"""
|
||||
|
||||
from typing import Any, List, Optional
|
||||
|
||||
import httpx
|
||||
|
||||
import litellm
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging
|
||||
from litellm.types.llms.bedrock import BedrockInvokeNovaRequest
|
||||
from litellm.types.llms.openai import AllMessageValues
|
||||
from litellm.types.utils import ModelResponse
|
||||
|
||||
from ..converse_transformation import AmazonConverseConfig
|
||||
from .base_invoke_transformation import AmazonInvokeConfig
|
||||
|
||||
|
||||
class AmazonInvokeNovaConfig(AmazonInvokeConfig, AmazonConverseConfig):
|
||||
"""
|
||||
Config for sending `nova` requests to `/bedrock/invoke/`
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def get_supported_openai_params(self, model: str) -> list:
|
||||
return AmazonConverseConfig.get_supported_openai_params(self, model)
|
||||
|
||||
def map_openai_params(
|
||||
self,
|
||||
non_default_params: dict,
|
||||
optional_params: dict,
|
||||
model: str,
|
||||
drop_params: bool,
|
||||
) -> dict:
|
||||
return AmazonConverseConfig.map_openai_params(
|
||||
self, non_default_params, optional_params, model, drop_params
|
||||
)
|
||||
|
||||
def transform_request(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
headers: dict,
|
||||
) -> dict:
|
||||
_transformed_nova_request = AmazonConverseConfig.transform_request(
|
||||
self,
|
||||
model=model,
|
||||
messages=messages,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
headers=headers,
|
||||
)
|
||||
_bedrock_invoke_nova_request = BedrockInvokeNovaRequest(
|
||||
**_transformed_nova_request
|
||||
)
|
||||
self._remove_empty_system_messages(_bedrock_invoke_nova_request)
|
||||
bedrock_invoke_nova_request = self._filter_allowed_fields(
|
||||
_bedrock_invoke_nova_request
|
||||
)
|
||||
return bedrock_invoke_nova_request
|
||||
|
||||
def transform_response(
|
||||
self,
|
||||
model: str,
|
||||
raw_response: httpx.Response,
|
||||
model_response: ModelResponse,
|
||||
logging_obj: Logging,
|
||||
request_data: dict,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
encoding: Any,
|
||||
api_key: Optional[str] = None,
|
||||
json_mode: Optional[bool] = None,
|
||||
) -> litellm.ModelResponse:
|
||||
return AmazonConverseConfig.transform_response(
|
||||
self,
|
||||
model,
|
||||
raw_response,
|
||||
model_response,
|
||||
logging_obj,
|
||||
request_data,
|
||||
messages,
|
||||
optional_params,
|
||||
litellm_params,
|
||||
encoding,
|
||||
api_key,
|
||||
json_mode,
|
||||
)
|
||||
|
||||
def _filter_allowed_fields(
|
||||
self, bedrock_invoke_nova_request: BedrockInvokeNovaRequest
|
||||
) -> dict:
|
||||
"""
|
||||
Filter out fields that are not allowed in the `BedrockInvokeNovaRequest` dataclass.
|
||||
"""
|
||||
allowed_fields = set(BedrockInvokeNovaRequest.__annotations__.keys())
|
||||
return {
|
||||
k: v for k, v in bedrock_invoke_nova_request.items() if k in allowed_fields
|
||||
}
|
||||
|
||||
def _remove_empty_system_messages(
|
||||
self, bedrock_invoke_nova_request: BedrockInvokeNovaRequest
|
||||
) -> None:
|
||||
"""
|
||||
In-place remove empty `system` messages from the request.
|
||||
|
||||
/bedrock/invoke/ does not allow empty `system` messages.
|
||||
"""
|
||||
_system_message = bedrock_invoke_nova_request.get("system", None)
|
||||
if isinstance(_system_message, list) and len(_system_message) == 0:
|
||||
bedrock_invoke_nova_request.pop("system", None)
|
||||
return
|
||||
@@ -0,0 +1,116 @@
|
||||
import re
|
||||
import types
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import litellm
|
||||
from litellm.llms.base_llm.chat.transformation import BaseConfig
|
||||
from litellm.llms.bedrock.chat.invoke_transformations.base_invoke_transformation import (
|
||||
AmazonInvokeConfig,
|
||||
)
|
||||
|
||||
|
||||
class AmazonTitanConfig(AmazonInvokeConfig, BaseConfig):
|
||||
"""
|
||||
Reference: https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=titan-text-express-v1
|
||||
|
||||
Supported Params for the Amazon Titan models:
|
||||
|
||||
- `maxTokenCount` (integer) max tokens,
|
||||
- `stopSequences` (string[]) list of stop sequence strings
|
||||
- `temperature` (float) temperature for model,
|
||||
- `topP` (int) top p for model
|
||||
"""
|
||||
|
||||
maxTokenCount: Optional[int] = None
|
||||
stopSequences: Optional[list] = None
|
||||
temperature: Optional[float] = None
|
||||
topP: Optional[int] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
maxTokenCount: Optional[int] = None,
|
||||
stopSequences: Optional[list] = None,
|
||||
temperature: Optional[float] = None,
|
||||
topP: Optional[int] = None,
|
||||
) -> None:
|
||||
locals_ = locals().copy()
|
||||
for key, value in locals_.items():
|
||||
if key != "self" and value is not None:
|
||||
setattr(self.__class__, key, value)
|
||||
|
||||
AmazonInvokeConfig.__init__(self)
|
||||
|
||||
@classmethod
|
||||
def get_config(cls):
|
||||
return {
|
||||
k: v
|
||||
for k, v in cls.__dict__.items()
|
||||
if not k.startswith("__")
|
||||
and not k.startswith("_abc")
|
||||
and not isinstance(
|
||||
v,
|
||||
(
|
||||
types.FunctionType,
|
||||
types.BuiltinFunctionType,
|
||||
classmethod,
|
||||
staticmethod,
|
||||
),
|
||||
)
|
||||
and v is not None
|
||||
}
|
||||
|
||||
def _map_and_modify_arg(
|
||||
self,
|
||||
supported_params: dict,
|
||||
provider: str,
|
||||
model: str,
|
||||
stop: Union[List[str], str],
|
||||
):
|
||||
"""
|
||||
filter params to fit the required provider format, drop those that don't fit if user sets `litellm.drop_params = True`.
|
||||
"""
|
||||
filtered_stop = None
|
||||
if "stop" in supported_params and litellm.drop_params:
|
||||
if provider == "bedrock" and "amazon" in model:
|
||||
filtered_stop = []
|
||||
if isinstance(stop, list):
|
||||
for s in stop:
|
||||
if re.match(r"^(\|+|User:)$", s):
|
||||
filtered_stop.append(s)
|
||||
if filtered_stop is not None:
|
||||
supported_params["stop"] = filtered_stop
|
||||
|
||||
return supported_params
|
||||
|
||||
def get_supported_openai_params(self, model: str) -> List[str]:
|
||||
return [
|
||||
"max_tokens",
|
||||
"max_completion_tokens",
|
||||
"stop",
|
||||
"temperature",
|
||||
"top_p",
|
||||
"stream",
|
||||
]
|
||||
|
||||
def map_openai_params(
|
||||
self,
|
||||
non_default_params: dict,
|
||||
optional_params: dict,
|
||||
model: str,
|
||||
drop_params: bool,
|
||||
) -> dict:
|
||||
for k, v in non_default_params.items():
|
||||
if k == "max_tokens" or k == "max_completion_tokens":
|
||||
optional_params["maxTokenCount"] = v
|
||||
if k == "temperature":
|
||||
optional_params["temperature"] = v
|
||||
if k == "stop":
|
||||
filtered_stop = self._map_and_modify_arg(
|
||||
{"stop": v}, provider="bedrock", model=model, stop=v
|
||||
)
|
||||
optional_params["stopSequences"] = filtered_stop["stop"]
|
||||
if k == "top_p":
|
||||
optional_params["topP"] = v
|
||||
if k == "stream":
|
||||
optional_params["stream"] = v
|
||||
return optional_params
|
||||
@@ -0,0 +1,90 @@
|
||||
import types
|
||||
from typing import Optional
|
||||
|
||||
import litellm
|
||||
|
||||
from .base_invoke_transformation import AmazonInvokeConfig
|
||||
|
||||
|
||||
class AmazonAnthropicConfig(AmazonInvokeConfig):
|
||||
"""
|
||||
Reference: https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=claude
|
||||
|
||||
Supported Params for the Amazon / Anthropic models:
|
||||
|
||||
- `max_tokens_to_sample` (integer) max tokens,
|
||||
- `temperature` (float) model temperature,
|
||||
- `top_k` (integer) top k,
|
||||
- `top_p` (integer) top p,
|
||||
- `stop_sequences` (string[]) list of stop sequences - e.g. ["\\n\\nHuman:"],
|
||||
- `anthropic_version` (string) version of anthropic for bedrock - e.g. "bedrock-2023-05-31"
|
||||
"""
|
||||
|
||||
max_tokens_to_sample: Optional[int] = litellm.max_tokens
|
||||
stop_sequences: Optional[list] = None
|
||||
temperature: Optional[float] = None
|
||||
top_k: Optional[int] = None
|
||||
top_p: Optional[int] = None
|
||||
anthropic_version: Optional[str] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_tokens_to_sample: Optional[int] = None,
|
||||
stop_sequences: Optional[list] = None,
|
||||
temperature: Optional[float] = None,
|
||||
top_k: Optional[int] = None,
|
||||
top_p: Optional[int] = None,
|
||||
anthropic_version: Optional[str] = None,
|
||||
) -> None:
|
||||
locals_ = locals().copy()
|
||||
for key, value in locals_.items():
|
||||
if key != "self" and value is not None:
|
||||
setattr(self.__class__, key, value)
|
||||
|
||||
@classmethod
|
||||
def get_config(cls):
|
||||
return {
|
||||
k: v
|
||||
for k, v in cls.__dict__.items()
|
||||
if not k.startswith("__")
|
||||
and not isinstance(
|
||||
v,
|
||||
(
|
||||
types.FunctionType,
|
||||
types.BuiltinFunctionType,
|
||||
classmethod,
|
||||
staticmethod,
|
||||
),
|
||||
)
|
||||
and v is not None
|
||||
}
|
||||
|
||||
def get_supported_openai_params(self, model: str):
|
||||
return [
|
||||
"max_tokens",
|
||||
"max_completion_tokens",
|
||||
"temperature",
|
||||
"stop",
|
||||
"top_p",
|
||||
"stream",
|
||||
]
|
||||
|
||||
def map_openai_params(
|
||||
self,
|
||||
non_default_params: dict,
|
||||
optional_params: dict,
|
||||
model: str,
|
||||
drop_params: bool,
|
||||
):
|
||||
for param, value in non_default_params.items():
|
||||
if param == "max_tokens" or param == "max_completion_tokens":
|
||||
optional_params["max_tokens_to_sample"] = value
|
||||
if param == "temperature":
|
||||
optional_params["temperature"] = value
|
||||
if param == "top_p":
|
||||
optional_params["top_p"] = value
|
||||
if param == "stop":
|
||||
optional_params["stop_sequences"] = value
|
||||
if param == "stream" and value is True:
|
||||
optional_params["stream"] = value
|
||||
return optional_params
|
||||
@@ -0,0 +1,100 @@
|
||||
from typing import TYPE_CHECKING, Any, List, Optional
|
||||
|
||||
import httpx
|
||||
|
||||
from litellm.llms.anthropic.chat.transformation import AnthropicConfig
|
||||
from litellm.llms.bedrock.chat.invoke_transformations.base_invoke_transformation import (
|
||||
AmazonInvokeConfig,
|
||||
)
|
||||
from litellm.types.llms.openai import AllMessageValues
|
||||
from litellm.types.utils import ModelResponse
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
|
||||
|
||||
LiteLLMLoggingObj = _LiteLLMLoggingObj
|
||||
else:
|
||||
LiteLLMLoggingObj = Any
|
||||
|
||||
|
||||
class AmazonAnthropicClaude3Config(AmazonInvokeConfig, AnthropicConfig):
|
||||
"""
|
||||
Reference:
|
||||
https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=claude
|
||||
https://docs.anthropic.com/claude/docs/models-overview#model-comparison
|
||||
|
||||
Supported Params for the Amazon / Anthropic Claude 3 models:
|
||||
"""
|
||||
|
||||
anthropic_version: str = "bedrock-2023-05-31"
|
||||
|
||||
def get_supported_openai_params(self, model: str) -> List[str]:
|
||||
return AnthropicConfig.get_supported_openai_params(self, model)
|
||||
|
||||
def map_openai_params(
|
||||
self,
|
||||
non_default_params: dict,
|
||||
optional_params: dict,
|
||||
model: str,
|
||||
drop_params: bool,
|
||||
) -> dict:
|
||||
return AnthropicConfig.map_openai_params(
|
||||
self,
|
||||
non_default_params,
|
||||
optional_params,
|
||||
model,
|
||||
drop_params,
|
||||
)
|
||||
|
||||
def transform_request(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
headers: dict,
|
||||
) -> dict:
|
||||
_anthropic_request = AnthropicConfig.transform_request(
|
||||
self,
|
||||
model=model,
|
||||
messages=messages,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
_anthropic_request.pop("model", None)
|
||||
_anthropic_request.pop("stream", None)
|
||||
if "anthropic_version" not in _anthropic_request:
|
||||
_anthropic_request["anthropic_version"] = self.anthropic_version
|
||||
|
||||
return _anthropic_request
|
||||
|
||||
def transform_response(
|
||||
self,
|
||||
model: str,
|
||||
raw_response: httpx.Response,
|
||||
model_response: ModelResponse,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
request_data: dict,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
encoding: Any,
|
||||
api_key: Optional[str] = None,
|
||||
json_mode: Optional[bool] = None,
|
||||
) -> ModelResponse:
|
||||
return AnthropicConfig.transform_response(
|
||||
self,
|
||||
model=model,
|
||||
raw_response=raw_response,
|
||||
model_response=model_response,
|
||||
logging_obj=logging_obj,
|
||||
request_data=request_data,
|
||||
messages=messages,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
encoding=encoding,
|
||||
api_key=api_key,
|
||||
json_mode=json_mode,
|
||||
)
|
||||
@@ -0,0 +1,679 @@
|
||||
import copy
|
||||
import json
|
||||
import time
|
||||
import urllib.parse
|
||||
from functools import partial
|
||||
from typing import TYPE_CHECKING, Any, List, Optional, Tuple, Union, cast, get_args
|
||||
|
||||
import httpx
|
||||
|
||||
import litellm
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.litellm_core_utils.core_helpers import map_finish_reason
|
||||
from litellm.litellm_core_utils.logging_utils import track_llm_api_timing
|
||||
from litellm.litellm_core_utils.prompt_templates.factory import (
|
||||
cohere_message_pt,
|
||||
custom_prompt,
|
||||
deepseek_r1_pt,
|
||||
prompt_factory,
|
||||
)
|
||||
from litellm.llms.base_llm.chat.transformation import BaseConfig, BaseLLMException
|
||||
from litellm.llms.bedrock.chat.invoke_handler import make_call, make_sync_call
|
||||
from litellm.llms.bedrock.common_utils import BedrockError
|
||||
from litellm.llms.custom_httpx.http_handler import (
|
||||
AsyncHTTPHandler,
|
||||
HTTPHandler,
|
||||
_get_httpx_client,
|
||||
)
|
||||
from litellm.types.llms.openai import AllMessageValues
|
||||
from litellm.types.utils import ModelResponse, Usage
|
||||
from litellm.utils import CustomStreamWrapper
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
|
||||
|
||||
LiteLLMLoggingObj = _LiteLLMLoggingObj
|
||||
else:
|
||||
LiteLLMLoggingObj = Any
|
||||
|
||||
from litellm.llms.bedrock.base_aws_llm import BaseAWSLLM
|
||||
|
||||
|
||||
class AmazonInvokeConfig(BaseConfig, BaseAWSLLM):
|
||||
def __init__(self, **kwargs):
|
||||
BaseConfig.__init__(self, **kwargs)
|
||||
BaseAWSLLM.__init__(self, **kwargs)
|
||||
|
||||
def get_supported_openai_params(self, model: str) -> List[str]:
|
||||
"""
|
||||
This is a base invoke model mapping. For Invoke - define a bedrock provider specific config that extends this class.
|
||||
"""
|
||||
return [
|
||||
"max_tokens",
|
||||
"max_completion_tokens",
|
||||
"stream",
|
||||
]
|
||||
|
||||
def map_openai_params(
|
||||
self,
|
||||
non_default_params: dict,
|
||||
optional_params: dict,
|
||||
model: str,
|
||||
drop_params: bool,
|
||||
) -> dict:
|
||||
"""
|
||||
This is a base invoke model mapping. For Invoke - define a bedrock provider specific config that extends this class.
|
||||
"""
|
||||
for param, value in non_default_params.items():
|
||||
if param == "max_tokens" or param == "max_completion_tokens":
|
||||
optional_params["max_tokens"] = value
|
||||
if param == "stream":
|
||||
optional_params["stream"] = value
|
||||
return optional_params
|
||||
|
||||
def get_complete_url(
|
||||
self,
|
||||
api_base: Optional[str],
|
||||
api_key: Optional[str],
|
||||
model: str,
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
stream: Optional[bool] = None,
|
||||
) -> str:
|
||||
"""
|
||||
Get the complete url for the request
|
||||
"""
|
||||
provider = self.get_bedrock_invoke_provider(model)
|
||||
modelId = self.get_bedrock_model_id(
|
||||
model=model,
|
||||
provider=provider,
|
||||
optional_params=optional_params,
|
||||
)
|
||||
### SET RUNTIME ENDPOINT ###
|
||||
aws_bedrock_runtime_endpoint = optional_params.get(
|
||||
"aws_bedrock_runtime_endpoint", None
|
||||
) # https://bedrock-runtime.{region_name}.amazonaws.com
|
||||
endpoint_url, proxy_endpoint_url = self.get_runtime_endpoint(
|
||||
api_base=api_base,
|
||||
aws_bedrock_runtime_endpoint=aws_bedrock_runtime_endpoint,
|
||||
aws_region_name=self._get_aws_region_name(
|
||||
optional_params=optional_params, model=model
|
||||
),
|
||||
)
|
||||
|
||||
if (stream is not None and stream is True) and provider != "ai21":
|
||||
endpoint_url = f"{endpoint_url}/model/{modelId}/invoke-with-response-stream"
|
||||
proxy_endpoint_url = (
|
||||
f"{proxy_endpoint_url}/model/{modelId}/invoke-with-response-stream"
|
||||
)
|
||||
else:
|
||||
endpoint_url = f"{endpoint_url}/model/{modelId}/invoke"
|
||||
proxy_endpoint_url = f"{proxy_endpoint_url}/model/{modelId}/invoke"
|
||||
|
||||
return endpoint_url
|
||||
|
||||
def sign_request(
|
||||
self,
|
||||
headers: dict,
|
||||
optional_params: dict,
|
||||
request_data: dict,
|
||||
api_base: str,
|
||||
model: Optional[str] = None,
|
||||
stream: Optional[bool] = None,
|
||||
fake_stream: Optional[bool] = None,
|
||||
) -> dict:
|
||||
try:
|
||||
from botocore.auth import SigV4Auth
|
||||
from botocore.awsrequest import AWSRequest
|
||||
from botocore.credentials import Credentials
|
||||
except ImportError:
|
||||
raise ImportError("Missing boto3 to call bedrock. Run 'pip install boto3'.")
|
||||
|
||||
## CREDENTIALS ##
|
||||
# pop aws_secret_access_key, aws_access_key_id, aws_session_token, aws_region_name from kwargs, since completion calls fail with them
|
||||
aws_secret_access_key = optional_params.get("aws_secret_access_key", None)
|
||||
aws_access_key_id = optional_params.get("aws_access_key_id", None)
|
||||
aws_session_token = optional_params.get("aws_session_token", None)
|
||||
aws_role_name = optional_params.get("aws_role_name", None)
|
||||
aws_session_name = optional_params.get("aws_session_name", None)
|
||||
aws_profile_name = optional_params.get("aws_profile_name", None)
|
||||
aws_web_identity_token = optional_params.get("aws_web_identity_token", None)
|
||||
aws_sts_endpoint = optional_params.get("aws_sts_endpoint", None)
|
||||
aws_region_name = self._get_aws_region_name(
|
||||
optional_params=optional_params, model=model
|
||||
)
|
||||
|
||||
credentials: Credentials = self.get_credentials(
|
||||
aws_access_key_id=aws_access_key_id,
|
||||
aws_secret_access_key=aws_secret_access_key,
|
||||
aws_session_token=aws_session_token,
|
||||
aws_region_name=aws_region_name,
|
||||
aws_session_name=aws_session_name,
|
||||
aws_profile_name=aws_profile_name,
|
||||
aws_role_name=aws_role_name,
|
||||
aws_web_identity_token=aws_web_identity_token,
|
||||
aws_sts_endpoint=aws_sts_endpoint,
|
||||
)
|
||||
|
||||
sigv4 = SigV4Auth(credentials, "bedrock", aws_region_name)
|
||||
if headers is not None:
|
||||
headers = {"Content-Type": "application/json", **headers}
|
||||
else:
|
||||
headers = {"Content-Type": "application/json"}
|
||||
|
||||
request = AWSRequest(
|
||||
method="POST",
|
||||
url=api_base,
|
||||
data=json.dumps(request_data),
|
||||
headers=headers,
|
||||
)
|
||||
sigv4.add_auth(request)
|
||||
|
||||
request_headers_dict = dict(request.headers)
|
||||
if (
|
||||
headers is not None and "Authorization" in headers
|
||||
): # prevent sigv4 from overwriting the auth header
|
||||
request_headers_dict["Authorization"] = headers["Authorization"]
|
||||
return request_headers_dict
|
||||
|
||||
def transform_request(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
headers: dict,
|
||||
) -> dict:
|
||||
## SETUP ##
|
||||
stream = optional_params.pop("stream", None)
|
||||
custom_prompt_dict: dict = litellm_params.pop("custom_prompt_dict", None) or {}
|
||||
hf_model_name = litellm_params.get("hf_model_name", None)
|
||||
|
||||
provider = self.get_bedrock_invoke_provider(model)
|
||||
|
||||
prompt, chat_history = self.convert_messages_to_prompt(
|
||||
model=hf_model_name or model,
|
||||
messages=messages,
|
||||
provider=provider,
|
||||
custom_prompt_dict=custom_prompt_dict,
|
||||
)
|
||||
inference_params = copy.deepcopy(optional_params)
|
||||
inference_params = {
|
||||
k: v
|
||||
for k, v in inference_params.items()
|
||||
if k not in self.aws_authentication_params
|
||||
}
|
||||
request_data: dict = {}
|
||||
if provider == "cohere":
|
||||
if model.startswith("cohere.command-r"):
|
||||
## LOAD CONFIG
|
||||
config = litellm.AmazonCohereChatConfig().get_config()
|
||||
for k, v in config.items():
|
||||
if (
|
||||
k not in inference_params
|
||||
): # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in
|
||||
inference_params[k] = v
|
||||
_data = {"message": prompt, **inference_params}
|
||||
if chat_history is not None:
|
||||
_data["chat_history"] = chat_history
|
||||
request_data = _data
|
||||
else:
|
||||
## LOAD CONFIG
|
||||
config = litellm.AmazonCohereConfig.get_config()
|
||||
for k, v in config.items():
|
||||
if (
|
||||
k not in inference_params
|
||||
): # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in
|
||||
inference_params[k] = v
|
||||
if stream is True:
|
||||
inference_params[
|
||||
"stream"
|
||||
] = True # cohere requires stream = True in inference params
|
||||
request_data = {"prompt": prompt, **inference_params}
|
||||
elif provider == "anthropic":
|
||||
return litellm.AmazonAnthropicClaude3Config().transform_request(
|
||||
model=model,
|
||||
messages=messages,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
headers=headers,
|
||||
)
|
||||
elif provider == "nova":
|
||||
return litellm.AmazonInvokeNovaConfig().transform_request(
|
||||
model=model,
|
||||
messages=messages,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
headers=headers,
|
||||
)
|
||||
elif provider == "ai21":
|
||||
## LOAD CONFIG
|
||||
config = litellm.AmazonAI21Config.get_config()
|
||||
for k, v in config.items():
|
||||
if (
|
||||
k not in inference_params
|
||||
): # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in
|
||||
inference_params[k] = v
|
||||
|
||||
request_data = {"prompt": prompt, **inference_params}
|
||||
elif provider == "mistral":
|
||||
## LOAD CONFIG
|
||||
config = litellm.AmazonMistralConfig.get_config()
|
||||
for k, v in config.items():
|
||||
if (
|
||||
k not in inference_params
|
||||
): # completion(top_k=3) > amazon_config(top_k=3) <- allows for dynamic variables to be passed in
|
||||
inference_params[k] = v
|
||||
|
||||
request_data = {"prompt": prompt, **inference_params}
|
||||
elif provider == "amazon": # amazon titan
|
||||
## LOAD CONFIG
|
||||
config = litellm.AmazonTitanConfig.get_config()
|
||||
for k, v in config.items():
|
||||
if (
|
||||
k not in inference_params
|
||||
): # completion(top_k=3) > amazon_config(top_k=3) <- allows for dynamic variables to be passed in
|
||||
inference_params[k] = v
|
||||
|
||||
request_data = {
|
||||
"inputText": prompt,
|
||||
"textGenerationConfig": inference_params,
|
||||
}
|
||||
elif provider == "meta" or provider == "llama" or provider == "deepseek_r1":
|
||||
## LOAD CONFIG
|
||||
config = litellm.AmazonLlamaConfig.get_config()
|
||||
for k, v in config.items():
|
||||
if (
|
||||
k not in inference_params
|
||||
): # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in
|
||||
inference_params[k] = v
|
||||
request_data = {"prompt": prompt, **inference_params}
|
||||
else:
|
||||
raise BedrockError(
|
||||
status_code=404,
|
||||
message="Bedrock Invoke HTTPX: Unknown provider={}, model={}. Try calling via converse route - `bedrock/converse/<model>`.".format(
|
||||
provider, model
|
||||
),
|
||||
)
|
||||
|
||||
return request_data
|
||||
|
||||
def transform_response( # noqa: PLR0915
|
||||
self,
|
||||
model: str,
|
||||
raw_response: httpx.Response,
|
||||
model_response: ModelResponse,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
request_data: dict,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
encoding: Any,
|
||||
api_key: Optional[str] = None,
|
||||
json_mode: Optional[bool] = None,
|
||||
) -> ModelResponse:
|
||||
try:
|
||||
completion_response = raw_response.json()
|
||||
except Exception:
|
||||
raise BedrockError(
|
||||
message=raw_response.text, status_code=raw_response.status_code
|
||||
)
|
||||
verbose_logger.debug(
|
||||
"bedrock invoke response % s",
|
||||
json.dumps(completion_response, indent=4, default=str),
|
||||
)
|
||||
provider = self.get_bedrock_invoke_provider(model)
|
||||
outputText: Optional[str] = None
|
||||
try:
|
||||
if provider == "cohere":
|
||||
if "text" in completion_response:
|
||||
outputText = completion_response["text"] # type: ignore
|
||||
elif "generations" in completion_response:
|
||||
outputText = completion_response["generations"][0]["text"]
|
||||
model_response.choices[0].finish_reason = map_finish_reason(
|
||||
completion_response["generations"][0]["finish_reason"]
|
||||
)
|
||||
elif provider == "anthropic":
|
||||
return litellm.AmazonAnthropicClaude3Config().transform_response(
|
||||
model=model,
|
||||
raw_response=raw_response,
|
||||
model_response=model_response,
|
||||
logging_obj=logging_obj,
|
||||
request_data=request_data,
|
||||
messages=messages,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
encoding=encoding,
|
||||
api_key=api_key,
|
||||
json_mode=json_mode,
|
||||
)
|
||||
elif provider == "nova":
|
||||
return litellm.AmazonInvokeNovaConfig().transform_response(
|
||||
model=model,
|
||||
raw_response=raw_response,
|
||||
model_response=model_response,
|
||||
logging_obj=logging_obj,
|
||||
request_data=request_data,
|
||||
messages=messages,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
encoding=encoding,
|
||||
)
|
||||
elif provider == "ai21":
|
||||
outputText = (
|
||||
completion_response.get("completions")[0].get("data").get("text")
|
||||
)
|
||||
elif provider == "meta" or provider == "llama" or provider == "deepseek_r1":
|
||||
outputText = completion_response["generation"]
|
||||
elif provider == "mistral":
|
||||
outputText = completion_response["outputs"][0]["text"]
|
||||
model_response.choices[0].finish_reason = completion_response[
|
||||
"outputs"
|
||||
][0]["stop_reason"]
|
||||
else: # amazon titan
|
||||
outputText = completion_response.get("results")[0].get("outputText")
|
||||
except Exception as e:
|
||||
raise BedrockError(
|
||||
message="Error processing={}, Received error={}".format(
|
||||
raw_response.text, str(e)
|
||||
),
|
||||
status_code=422,
|
||||
)
|
||||
|
||||
try:
|
||||
if (
|
||||
outputText is not None
|
||||
and len(outputText) > 0
|
||||
and hasattr(model_response.choices[0], "message")
|
||||
and getattr(model_response.choices[0].message, "tool_calls", None) # type: ignore
|
||||
is None
|
||||
):
|
||||
model_response.choices[0].message.content = outputText # type: ignore
|
||||
elif (
|
||||
hasattr(model_response.choices[0], "message")
|
||||
and getattr(model_response.choices[0].message, "tool_calls", None) # type: ignore
|
||||
is not None
|
||||
):
|
||||
pass
|
||||
else:
|
||||
raise Exception()
|
||||
except Exception as e:
|
||||
raise BedrockError(
|
||||
message="Error parsing received text={}.\nError-{}".format(
|
||||
outputText, str(e)
|
||||
),
|
||||
status_code=raw_response.status_code,
|
||||
)
|
||||
|
||||
## CALCULATING USAGE - bedrock returns usage in the headers
|
||||
bedrock_input_tokens = raw_response.headers.get(
|
||||
"x-amzn-bedrock-input-token-count", None
|
||||
)
|
||||
bedrock_output_tokens = raw_response.headers.get(
|
||||
"x-amzn-bedrock-output-token-count", None
|
||||
)
|
||||
|
||||
prompt_tokens = int(
|
||||
bedrock_input_tokens or litellm.token_counter(messages=messages)
|
||||
)
|
||||
|
||||
completion_tokens = int(
|
||||
bedrock_output_tokens
|
||||
or litellm.token_counter(
|
||||
text=model_response.choices[0].message.content, # type: ignore
|
||||
count_response_tokens=True,
|
||||
)
|
||||
)
|
||||
|
||||
model_response.created = int(time.time())
|
||||
model_response.model = model
|
||||
usage = Usage(
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=prompt_tokens + completion_tokens,
|
||||
)
|
||||
setattr(model_response, "usage", usage)
|
||||
|
||||
return model_response
|
||||
|
||||
def validate_environment(
|
||||
self,
|
||||
headers: dict,
|
||||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
) -> dict:
|
||||
return headers
|
||||
|
||||
def get_error_class(
|
||||
self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers]
|
||||
) -> BaseLLMException:
|
||||
return BedrockError(status_code=status_code, message=error_message)
|
||||
|
||||
@track_llm_api_timing()
|
||||
def get_async_custom_stream_wrapper(
|
||||
self,
|
||||
model: str,
|
||||
custom_llm_provider: str,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
api_base: str,
|
||||
headers: dict,
|
||||
data: dict,
|
||||
messages: list,
|
||||
client: Optional[AsyncHTTPHandler] = None,
|
||||
json_mode: Optional[bool] = None,
|
||||
) -> CustomStreamWrapper:
|
||||
streaming_response = CustomStreamWrapper(
|
||||
completion_stream=None,
|
||||
make_call=partial(
|
||||
make_call,
|
||||
client=client,
|
||||
api_base=api_base,
|
||||
headers=headers,
|
||||
data=json.dumps(data),
|
||||
model=model,
|
||||
messages=messages,
|
||||
logging_obj=logging_obj,
|
||||
fake_stream=True if "ai21" in api_base else False,
|
||||
bedrock_invoke_provider=self.get_bedrock_invoke_provider(model),
|
||||
json_mode=json_mode,
|
||||
),
|
||||
model=model,
|
||||
custom_llm_provider="bedrock",
|
||||
logging_obj=logging_obj,
|
||||
)
|
||||
return streaming_response
|
||||
|
||||
@track_llm_api_timing()
|
||||
def get_sync_custom_stream_wrapper(
|
||||
self,
|
||||
model: str,
|
||||
custom_llm_provider: str,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
api_base: str,
|
||||
headers: dict,
|
||||
data: dict,
|
||||
messages: list,
|
||||
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
|
||||
json_mode: Optional[bool] = None,
|
||||
) -> CustomStreamWrapper:
|
||||
if client is None or isinstance(client, AsyncHTTPHandler):
|
||||
client = _get_httpx_client(params={})
|
||||
streaming_response = CustomStreamWrapper(
|
||||
completion_stream=None,
|
||||
make_call=partial(
|
||||
make_sync_call,
|
||||
client=client,
|
||||
api_base=api_base,
|
||||
headers=headers,
|
||||
data=json.dumps(data),
|
||||
model=model,
|
||||
messages=messages,
|
||||
logging_obj=logging_obj,
|
||||
fake_stream=True if "ai21" in api_base else False,
|
||||
bedrock_invoke_provider=self.get_bedrock_invoke_provider(model),
|
||||
json_mode=json_mode,
|
||||
),
|
||||
model=model,
|
||||
custom_llm_provider="bedrock",
|
||||
logging_obj=logging_obj,
|
||||
)
|
||||
return streaming_response
|
||||
|
||||
@property
|
||||
def has_custom_stream_wrapper(self) -> bool:
|
||||
return True
|
||||
|
||||
@property
|
||||
def supports_stream_param_in_request_body(self) -> bool:
|
||||
"""
|
||||
Bedrock invoke does not allow passing `stream` in the request body.
|
||||
"""
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def get_bedrock_invoke_provider(
|
||||
model: str,
|
||||
) -> Optional[litellm.BEDROCK_INVOKE_PROVIDERS_LITERAL]:
|
||||
"""
|
||||
Helper function to get the bedrock provider from the model
|
||||
|
||||
handles 4 scenarios:
|
||||
1. model=invoke/anthropic.claude-3-5-sonnet-20240620-v1:0 -> Returns `anthropic`
|
||||
2. model=anthropic.claude-3-5-sonnet-20240620-v1:0 -> Returns `anthropic`
|
||||
3. model=llama/arn:aws:bedrock:us-east-1:086734376398:imported-model/r4c4kewx2s0n -> Returns `llama`
|
||||
4. model=us.amazon.nova-pro-v1:0 -> Returns `nova`
|
||||
"""
|
||||
if model.startswith("invoke/"):
|
||||
model = model.replace("invoke/", "", 1)
|
||||
|
||||
_split_model = model.split(".")[0]
|
||||
if _split_model in get_args(litellm.BEDROCK_INVOKE_PROVIDERS_LITERAL):
|
||||
return cast(litellm.BEDROCK_INVOKE_PROVIDERS_LITERAL, _split_model)
|
||||
|
||||
# If not a known provider, check for pattern with two slashes
|
||||
provider = AmazonInvokeConfig._get_provider_from_model_path(model)
|
||||
if provider is not None:
|
||||
return provider
|
||||
|
||||
# check if provider == "nova"
|
||||
if "nova" in model:
|
||||
return "nova"
|
||||
|
||||
for provider in get_args(litellm.BEDROCK_INVOKE_PROVIDERS_LITERAL):
|
||||
if provider in model:
|
||||
return provider
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _get_provider_from_model_path(
|
||||
model_path: str,
|
||||
) -> Optional[litellm.BEDROCK_INVOKE_PROVIDERS_LITERAL]:
|
||||
"""
|
||||
Helper function to get the provider from a model path with format: provider/model-name
|
||||
|
||||
Args:
|
||||
model_path (str): The model path (e.g., 'llama/arn:aws:bedrock:us-east-1:086734376398:imported-model/r4c4kewx2s0n' or 'anthropic/model-name')
|
||||
|
||||
Returns:
|
||||
Optional[str]: The provider name, or None if no valid provider found
|
||||
"""
|
||||
parts = model_path.split("/")
|
||||
if len(parts) >= 1:
|
||||
provider = parts[0]
|
||||
if provider in get_args(litellm.BEDROCK_INVOKE_PROVIDERS_LITERAL):
|
||||
return cast(litellm.BEDROCK_INVOKE_PROVIDERS_LITERAL, provider)
|
||||
return None
|
||||
|
||||
def get_bedrock_model_id(
|
||||
self,
|
||||
optional_params: dict,
|
||||
provider: Optional[litellm.BEDROCK_INVOKE_PROVIDERS_LITERAL],
|
||||
model: str,
|
||||
) -> str:
|
||||
modelId = optional_params.pop("model_id", None)
|
||||
if modelId is not None:
|
||||
modelId = self.encode_model_id(model_id=modelId)
|
||||
else:
|
||||
modelId = model
|
||||
|
||||
modelId = modelId.replace("invoke/", "", 1)
|
||||
if provider == "llama" and "llama/" in modelId:
|
||||
modelId = self._get_model_id_from_model_with_spec(modelId, spec="llama")
|
||||
elif provider == "deepseek_r1" and "deepseek_r1/" in modelId:
|
||||
modelId = self._get_model_id_from_model_with_spec(
|
||||
modelId, spec="deepseek_r1"
|
||||
)
|
||||
return modelId
|
||||
|
||||
def _get_model_id_from_model_with_spec(
|
||||
self,
|
||||
model: str,
|
||||
spec: str,
|
||||
) -> str:
|
||||
"""
|
||||
Remove `llama` from modelID since `llama` is simply a spec to follow for custom bedrock models
|
||||
"""
|
||||
model_id = model.replace(spec + "/", "")
|
||||
return self.encode_model_id(model_id=model_id)
|
||||
|
||||
def encode_model_id(self, model_id: str) -> str:
|
||||
"""
|
||||
Double encode the model ID to ensure it matches the expected double-encoded format.
|
||||
Args:
|
||||
model_id (str): The model ID to encode.
|
||||
Returns:
|
||||
str: The double-encoded model ID.
|
||||
"""
|
||||
return urllib.parse.quote(model_id, safe="")
|
||||
|
||||
def convert_messages_to_prompt(
|
||||
self, model, messages, provider, custom_prompt_dict
|
||||
) -> Tuple[str, Optional[list]]:
|
||||
# handle anthropic prompts and amazon titan prompts
|
||||
prompt = ""
|
||||
chat_history: Optional[list] = None
|
||||
## CUSTOM PROMPT
|
||||
if model in custom_prompt_dict:
|
||||
# check if the model has a registered custom prompt
|
||||
model_prompt_details = custom_prompt_dict[model]
|
||||
prompt = custom_prompt(
|
||||
role_dict=model_prompt_details["roles"],
|
||||
initial_prompt_value=model_prompt_details.get(
|
||||
"initial_prompt_value", ""
|
||||
),
|
||||
final_prompt_value=model_prompt_details.get("final_prompt_value", ""),
|
||||
messages=messages,
|
||||
)
|
||||
return prompt, None
|
||||
## ELSE
|
||||
if provider == "anthropic" or provider == "amazon":
|
||||
prompt = prompt_factory(
|
||||
model=model, messages=messages, custom_llm_provider="bedrock"
|
||||
)
|
||||
elif provider == "mistral":
|
||||
prompt = prompt_factory(
|
||||
model=model, messages=messages, custom_llm_provider="bedrock"
|
||||
)
|
||||
elif provider == "meta" or provider == "llama":
|
||||
prompt = prompt_factory(
|
||||
model=model, messages=messages, custom_llm_provider="bedrock"
|
||||
)
|
||||
elif provider == "cohere":
|
||||
prompt, chat_history = cohere_message_pt(messages=messages)
|
||||
elif provider == "deepseek_r1":
|
||||
prompt = deepseek_r1_pt(messages=messages)
|
||||
else:
|
||||
prompt = ""
|
||||
for message in messages:
|
||||
if "role" in message:
|
||||
if message["role"] == "user":
|
||||
prompt += f"{message['content']}"
|
||||
else:
|
||||
prompt += f"{message['content']}"
|
||||
else:
|
||||
prompt += f"{message['content']}"
|
||||
return prompt, chat_history # type: ignore
|
||||
@@ -0,0 +1,422 @@
|
||||
"""
|
||||
Common utilities used across bedrock chat/embedding/image generation
|
||||
"""
|
||||
|
||||
import os
|
||||
from typing import List, Literal, Optional, Union
|
||||
|
||||
import httpx
|
||||
|
||||
import litellm
|
||||
from litellm.llms.base_llm.base_utils import BaseLLMModelInfo
|
||||
from litellm.llms.base_llm.chat.transformation import BaseLLMException
|
||||
from litellm.secret_managers.main import get_secret
|
||||
|
||||
|
||||
class BedrockError(BaseLLMException):
|
||||
pass
|
||||
|
||||
|
||||
class AmazonBedrockGlobalConfig:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def get_mapped_special_auth_params(self) -> dict:
|
||||
"""
|
||||
Mapping of common auth params across bedrock/vertex/azure/watsonx
|
||||
"""
|
||||
return {"region_name": "aws_region_name"}
|
||||
|
||||
def map_special_auth_params(self, non_default_params: dict, optional_params: dict):
|
||||
mapped_params = self.get_mapped_special_auth_params()
|
||||
for param, value in non_default_params.items():
|
||||
if param in mapped_params:
|
||||
optional_params[mapped_params[param]] = value
|
||||
return optional_params
|
||||
|
||||
def get_all_regions(self) -> List[str]:
|
||||
return (
|
||||
self.get_us_regions()
|
||||
+ self.get_eu_regions()
|
||||
+ self.get_ap_regions()
|
||||
+ self.get_ca_regions()
|
||||
+ self.get_sa_regions()
|
||||
)
|
||||
|
||||
def get_ap_regions(self) -> List[str]:
|
||||
"""
|
||||
Source: https://www.aws-services.info/bedrock.html
|
||||
"""
|
||||
return [
|
||||
"ap-northeast-1", # Asia Pacific (Tokyo)
|
||||
"ap-northeast-2", # Asia Pacific (Seoul)
|
||||
"ap-northeast-3", # Asia Pacific (Osaka)
|
||||
"ap-south-1", # Asia Pacific (Mumbai)
|
||||
"ap-south-2", # Asia Pacific (Hyderabad)
|
||||
"ap-southeast-1", # Asia Pacific (Singapore)
|
||||
"ap-southeast-2", # Asia Pacific (Sydney)
|
||||
]
|
||||
|
||||
def get_sa_regions(self) -> List[str]:
|
||||
return ["sa-east-1"]
|
||||
|
||||
def get_eu_regions(self) -> List[str]:
|
||||
"""
|
||||
Source: https://www.aws-services.info/bedrock.html
|
||||
"""
|
||||
return [
|
||||
"eu-west-1", # Europe (Ireland)
|
||||
"eu-west-2", # Europe (London)
|
||||
"eu-west-3", # Europe (Paris)
|
||||
"eu-central-1", # Europe (Frankfurt)
|
||||
"eu-central-2", # Europe (Zurich)
|
||||
"eu-south-1", # Europe (Milan)
|
||||
"eu-south-2", # Europe (Spain)
|
||||
"eu-north-1", # Europe (Stockholm)
|
||||
]
|
||||
|
||||
def get_ca_regions(self) -> List[str]:
|
||||
return ["ca-central-1"]
|
||||
|
||||
def get_us_regions(self) -> List[str]:
|
||||
"""
|
||||
Source: https://www.aws-services.info/bedrock.html
|
||||
"""
|
||||
return [
|
||||
"us-east-1", # US East (N. Virginia)
|
||||
"us-east-2", # US East (Ohio)
|
||||
"us-west-1", # US West (N. California)
|
||||
"us-west-2", # US West (Oregon)
|
||||
"us-gov-east-1", # AWS GovCloud (US-East)
|
||||
"us-gov-west-1", # AWS GovCloud (US-West)
|
||||
]
|
||||
|
||||
|
||||
def add_custom_header(headers):
|
||||
"""Closure to capture the headers and add them."""
|
||||
|
||||
def callback(request, **kwargs):
|
||||
"""Actual callback function that Boto3 will call."""
|
||||
for header_name, header_value in headers.items():
|
||||
request.headers.add_header(header_name, header_value)
|
||||
|
||||
return callback
|
||||
|
||||
|
||||
def init_bedrock_client(
|
||||
region_name=None,
|
||||
aws_access_key_id: Optional[str] = None,
|
||||
aws_secret_access_key: Optional[str] = None,
|
||||
aws_region_name: Optional[str] = None,
|
||||
aws_bedrock_runtime_endpoint: Optional[str] = None,
|
||||
aws_session_name: Optional[str] = None,
|
||||
aws_profile_name: Optional[str] = None,
|
||||
aws_role_name: Optional[str] = None,
|
||||
aws_web_identity_token: Optional[str] = None,
|
||||
extra_headers: Optional[dict] = None,
|
||||
timeout: Optional[Union[float, httpx.Timeout]] = None,
|
||||
):
|
||||
# check for custom AWS_REGION_NAME and use it if not passed to init_bedrock_client
|
||||
litellm_aws_region_name = get_secret("AWS_REGION_NAME", None)
|
||||
standard_aws_region_name = get_secret("AWS_REGION", None)
|
||||
## CHECK IS 'os.environ/' passed in
|
||||
# Define the list of parameters to check
|
||||
params_to_check = [
|
||||
aws_access_key_id,
|
||||
aws_secret_access_key,
|
||||
aws_region_name,
|
||||
aws_bedrock_runtime_endpoint,
|
||||
aws_session_name,
|
||||
aws_profile_name,
|
||||
aws_role_name,
|
||||
aws_web_identity_token,
|
||||
]
|
||||
|
||||
# Iterate over parameters and update if needed
|
||||
for i, param in enumerate(params_to_check):
|
||||
if param and param.startswith("os.environ/"):
|
||||
params_to_check[i] = get_secret(param) # type: ignore
|
||||
# Assign updated values back to parameters
|
||||
(
|
||||
aws_access_key_id,
|
||||
aws_secret_access_key,
|
||||
aws_region_name,
|
||||
aws_bedrock_runtime_endpoint,
|
||||
aws_session_name,
|
||||
aws_profile_name,
|
||||
aws_role_name,
|
||||
aws_web_identity_token,
|
||||
) = params_to_check
|
||||
|
||||
# SSL certificates (a.k.a CA bundle) used to verify the identity of requested hosts.
|
||||
ssl_verify = os.getenv("SSL_VERIFY", litellm.ssl_verify)
|
||||
|
||||
### SET REGION NAME
|
||||
if region_name:
|
||||
pass
|
||||
elif aws_region_name:
|
||||
region_name = aws_region_name
|
||||
elif litellm_aws_region_name:
|
||||
region_name = litellm_aws_region_name
|
||||
elif standard_aws_region_name:
|
||||
region_name = standard_aws_region_name
|
||||
else:
|
||||
raise BedrockError(
|
||||
message="AWS region not set: set AWS_REGION_NAME or AWS_REGION env variable or in .env file",
|
||||
status_code=401,
|
||||
)
|
||||
|
||||
# check for custom AWS_BEDROCK_RUNTIME_ENDPOINT and use it if not passed to init_bedrock_client
|
||||
env_aws_bedrock_runtime_endpoint = get_secret("AWS_BEDROCK_RUNTIME_ENDPOINT")
|
||||
if aws_bedrock_runtime_endpoint:
|
||||
endpoint_url = aws_bedrock_runtime_endpoint
|
||||
elif env_aws_bedrock_runtime_endpoint:
|
||||
endpoint_url = env_aws_bedrock_runtime_endpoint
|
||||
else:
|
||||
endpoint_url = f"https://bedrock-runtime.{region_name}.amazonaws.com"
|
||||
|
||||
import boto3
|
||||
|
||||
if isinstance(timeout, float):
|
||||
config = boto3.session.Config(connect_timeout=timeout, read_timeout=timeout) # type: ignore
|
||||
elif isinstance(timeout, httpx.Timeout):
|
||||
config = boto3.session.Config( # type: ignore
|
||||
connect_timeout=timeout.connect, read_timeout=timeout.read
|
||||
)
|
||||
else:
|
||||
config = boto3.session.Config() # type: ignore
|
||||
|
||||
### CHECK STS ###
|
||||
if (
|
||||
aws_web_identity_token is not None
|
||||
and aws_role_name is not None
|
||||
and aws_session_name is not None
|
||||
):
|
||||
oidc_token = get_secret(aws_web_identity_token)
|
||||
|
||||
if oidc_token is None:
|
||||
raise BedrockError(
|
||||
message="OIDC token could not be retrieved from secret manager.",
|
||||
status_code=401,
|
||||
)
|
||||
|
||||
sts_client = boto3.client("sts")
|
||||
|
||||
# https://docs.aws.amazon.com/STS/latest/APIReference/API_AssumeRoleWithWebIdentity.html
|
||||
# https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sts/client/assume_role_with_web_identity.html
|
||||
sts_response = sts_client.assume_role_with_web_identity(
|
||||
RoleArn=aws_role_name,
|
||||
RoleSessionName=aws_session_name,
|
||||
WebIdentityToken=oidc_token,
|
||||
DurationSeconds=3600,
|
||||
)
|
||||
|
||||
client = boto3.client(
|
||||
service_name="bedrock-runtime",
|
||||
aws_access_key_id=sts_response["Credentials"]["AccessKeyId"],
|
||||
aws_secret_access_key=sts_response["Credentials"]["SecretAccessKey"],
|
||||
aws_session_token=sts_response["Credentials"]["SessionToken"],
|
||||
region_name=region_name,
|
||||
endpoint_url=endpoint_url,
|
||||
config=config,
|
||||
verify=ssl_verify,
|
||||
)
|
||||
elif aws_role_name is not None and aws_session_name is not None:
|
||||
# use sts if role name passed in
|
||||
sts_client = boto3.client(
|
||||
"sts",
|
||||
aws_access_key_id=aws_access_key_id,
|
||||
aws_secret_access_key=aws_secret_access_key,
|
||||
)
|
||||
|
||||
sts_response = sts_client.assume_role(
|
||||
RoleArn=aws_role_name, RoleSessionName=aws_session_name
|
||||
)
|
||||
|
||||
client = boto3.client(
|
||||
service_name="bedrock-runtime",
|
||||
aws_access_key_id=sts_response["Credentials"]["AccessKeyId"],
|
||||
aws_secret_access_key=sts_response["Credentials"]["SecretAccessKey"],
|
||||
aws_session_token=sts_response["Credentials"]["SessionToken"],
|
||||
region_name=region_name,
|
||||
endpoint_url=endpoint_url,
|
||||
config=config,
|
||||
verify=ssl_verify,
|
||||
)
|
||||
elif aws_access_key_id is not None:
|
||||
# uses auth params passed to completion
|
||||
# aws_access_key_id is not None, assume user is trying to auth using litellm.completion
|
||||
|
||||
client = boto3.client(
|
||||
service_name="bedrock-runtime",
|
||||
aws_access_key_id=aws_access_key_id,
|
||||
aws_secret_access_key=aws_secret_access_key,
|
||||
region_name=region_name,
|
||||
endpoint_url=endpoint_url,
|
||||
config=config,
|
||||
verify=ssl_verify,
|
||||
)
|
||||
elif aws_profile_name is not None:
|
||||
# uses auth values from AWS profile usually stored in ~/.aws/credentials
|
||||
|
||||
client = boto3.Session(profile_name=aws_profile_name).client(
|
||||
service_name="bedrock-runtime",
|
||||
region_name=region_name,
|
||||
endpoint_url=endpoint_url,
|
||||
config=config,
|
||||
verify=ssl_verify,
|
||||
)
|
||||
else:
|
||||
# aws_access_key_id is None, assume user is trying to auth using env variables
|
||||
# boto3 automatically reads env variables
|
||||
|
||||
client = boto3.client(
|
||||
service_name="bedrock-runtime",
|
||||
region_name=region_name,
|
||||
endpoint_url=endpoint_url,
|
||||
config=config,
|
||||
verify=ssl_verify,
|
||||
)
|
||||
if extra_headers:
|
||||
client.meta.events.register(
|
||||
"before-sign.bedrock-runtime.*", add_custom_header(extra_headers)
|
||||
)
|
||||
|
||||
return client
|
||||
|
||||
|
||||
class ModelResponseIterator:
|
||||
def __init__(self, model_response):
|
||||
self.model_response = model_response
|
||||
self.is_done = False
|
||||
|
||||
# Sync iterator
|
||||
def __iter__(self):
|
||||
return self
|
||||
|
||||
def __next__(self):
|
||||
if self.is_done:
|
||||
raise StopIteration
|
||||
self.is_done = True
|
||||
return self.model_response
|
||||
|
||||
# Async iterator
|
||||
def __aiter__(self):
|
||||
return self
|
||||
|
||||
async def __anext__(self):
|
||||
if self.is_done:
|
||||
raise StopAsyncIteration
|
||||
self.is_done = True
|
||||
return self.model_response
|
||||
|
||||
|
||||
def get_bedrock_tool_name(response_tool_name: str) -> str:
|
||||
"""
|
||||
If litellm formatted the input tool name, we need to convert it back to the original name.
|
||||
|
||||
Args:
|
||||
response_tool_name (str): The name of the tool as received from the response.
|
||||
|
||||
Returns:
|
||||
str: The original name of the tool.
|
||||
"""
|
||||
|
||||
if response_tool_name in litellm.bedrock_tool_name_mappings.cache_dict:
|
||||
response_tool_name = litellm.bedrock_tool_name_mappings.cache_dict[
|
||||
response_tool_name
|
||||
]
|
||||
return response_tool_name
|
||||
|
||||
|
||||
class BedrockModelInfo(BaseLLMModelInfo):
|
||||
global_config = AmazonBedrockGlobalConfig()
|
||||
all_global_regions = global_config.get_all_regions()
|
||||
|
||||
@staticmethod
|
||||
def extract_model_name_from_arn(model: str) -> str:
|
||||
"""
|
||||
Extract the model name from an AWS Bedrock ARN.
|
||||
Returns the string after the last '/' if 'arn' is in the input string.
|
||||
|
||||
Args:
|
||||
arn (str): The ARN string to parse
|
||||
|
||||
Returns:
|
||||
str: The extracted model name if 'arn' is in the string,
|
||||
otherwise returns the original string
|
||||
"""
|
||||
if "arn" in model.lower():
|
||||
return model.split("/")[-1]
|
||||
return model
|
||||
|
||||
@staticmethod
|
||||
def get_non_litellm_routing_model_name(model: str) -> str:
|
||||
if model.startswith("bedrock/"):
|
||||
model = model.split("/", 1)[1]
|
||||
|
||||
if model.startswith("converse/"):
|
||||
model = model.split("/", 1)[1]
|
||||
|
||||
if model.startswith("invoke/"):
|
||||
model = model.split("/", 1)[1]
|
||||
|
||||
return model
|
||||
|
||||
@staticmethod
|
||||
def get_base_model(model: str) -> str:
|
||||
"""
|
||||
Get the base model from the given model name.
|
||||
|
||||
Handle model names like - "us.meta.llama3-2-11b-instruct-v1:0" -> "meta.llama3-2-11b-instruct-v1"
|
||||
AND "meta.llama3-2-11b-instruct-v1:0" -> "meta.llama3-2-11b-instruct-v1"
|
||||
"""
|
||||
|
||||
model = BedrockModelInfo.get_non_litellm_routing_model_name(model=model)
|
||||
model = BedrockModelInfo.extract_model_name_from_arn(model)
|
||||
|
||||
potential_region = model.split(".", 1)[0]
|
||||
|
||||
alt_potential_region = model.split("/", 1)[
|
||||
0
|
||||
] # in model cost map we store regional information like `/us-west-2/bedrock-model`
|
||||
|
||||
if (
|
||||
potential_region
|
||||
in BedrockModelInfo._supported_cross_region_inference_region()
|
||||
):
|
||||
return model.split(".", 1)[1]
|
||||
elif (
|
||||
alt_potential_region in BedrockModelInfo.all_global_regions
|
||||
and len(model.split("/", 1)) > 1
|
||||
):
|
||||
return model.split("/", 1)[1]
|
||||
|
||||
return model
|
||||
|
||||
@staticmethod
|
||||
def _supported_cross_region_inference_region() -> List[str]:
|
||||
"""
|
||||
Abbreviations of regions AWS Bedrock supports for cross region inference
|
||||
"""
|
||||
return ["us", "eu", "apac"]
|
||||
|
||||
@staticmethod
|
||||
def get_bedrock_route(model: str) -> Literal["converse", "invoke", "converse_like"]:
|
||||
"""
|
||||
Get the bedrock route for the given model.
|
||||
"""
|
||||
base_model = BedrockModelInfo.get_base_model(model)
|
||||
alt_model = BedrockModelInfo.get_non_litellm_routing_model_name(model=model)
|
||||
if "invoke/" in model:
|
||||
return "invoke"
|
||||
elif "converse_like" in model:
|
||||
return "converse_like"
|
||||
elif "converse/" in model:
|
||||
return "converse"
|
||||
elif (
|
||||
base_model in litellm.bedrock_converse_models
|
||||
or alt_model in litellm.bedrock_converse_models
|
||||
):
|
||||
return "converse"
|
||||
return "invoke"
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,88 @@
|
||||
"""
|
||||
Transformation logic from OpenAI /v1/embeddings format to Bedrock Amazon Titan G1 /invoke format.
|
||||
|
||||
Why separate file? Make it easy to see how transformation works
|
||||
|
||||
Convers
|
||||
- G1 request format
|
||||
|
||||
Docs - https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-titan-embed-text.html
|
||||
"""
|
||||
|
||||
import types
|
||||
from typing import List
|
||||
|
||||
from litellm.types.llms.bedrock import (
|
||||
AmazonTitanG1EmbeddingRequest,
|
||||
AmazonTitanG1EmbeddingResponse,
|
||||
)
|
||||
from litellm.types.utils import Embedding, EmbeddingResponse, Usage
|
||||
|
||||
|
||||
class AmazonTitanG1Config:
|
||||
"""
|
||||
Reference: https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-titan-embed-text.html
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
) -> None:
|
||||
locals_ = locals().copy()
|
||||
for key, value in locals_.items():
|
||||
if key != "self" and value is not None:
|
||||
setattr(self.__class__, key, value)
|
||||
|
||||
@classmethod
|
||||
def get_config(cls):
|
||||
return {
|
||||
k: v
|
||||
for k, v in cls.__dict__.items()
|
||||
if not k.startswith("__")
|
||||
and not isinstance(
|
||||
v,
|
||||
(
|
||||
types.FunctionType,
|
||||
types.BuiltinFunctionType,
|
||||
classmethod,
|
||||
staticmethod,
|
||||
),
|
||||
)
|
||||
and v is not None
|
||||
}
|
||||
|
||||
def get_supported_openai_params(self) -> List[str]:
|
||||
return []
|
||||
|
||||
def map_openai_params(
|
||||
self, non_default_params: dict, optional_params: dict
|
||||
) -> dict:
|
||||
return optional_params
|
||||
|
||||
def _transform_request(
|
||||
self, input: str, inference_params: dict
|
||||
) -> AmazonTitanG1EmbeddingRequest:
|
||||
return AmazonTitanG1EmbeddingRequest(inputText=input)
|
||||
|
||||
def _transform_response(
|
||||
self, response_list: List[dict], model: str
|
||||
) -> EmbeddingResponse:
|
||||
total_prompt_tokens = 0
|
||||
|
||||
transformed_responses: List[Embedding] = []
|
||||
for index, response in enumerate(response_list):
|
||||
_parsed_response = AmazonTitanG1EmbeddingResponse(**response) # type: ignore
|
||||
transformed_responses.append(
|
||||
Embedding(
|
||||
embedding=_parsed_response["embedding"],
|
||||
index=index,
|
||||
object="embedding",
|
||||
)
|
||||
)
|
||||
total_prompt_tokens += _parsed_response["inputTextTokenCount"]
|
||||
|
||||
usage = Usage(
|
||||
prompt_tokens=total_prompt_tokens,
|
||||
completion_tokens=0,
|
||||
total_tokens=total_prompt_tokens,
|
||||
)
|
||||
return EmbeddingResponse(model=model, usage=usage, data=transformed_responses)
|
||||
@@ -0,0 +1,79 @@
|
||||
"""
|
||||
Transformation logic from OpenAI /v1/embeddings format to Bedrock Amazon Titan multimodal /invoke format.
|
||||
|
||||
Why separate file? Make it easy to see how transformation works
|
||||
|
||||
Docs - https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-titan-embed-mm.html
|
||||
"""
|
||||
|
||||
from typing import List
|
||||
|
||||
from litellm.types.llms.bedrock import (
|
||||
AmazonTitanMultimodalEmbeddingConfig,
|
||||
AmazonTitanMultimodalEmbeddingRequest,
|
||||
AmazonTitanMultimodalEmbeddingResponse,
|
||||
)
|
||||
from litellm.types.utils import Embedding, EmbeddingResponse, Usage
|
||||
from litellm.utils import get_base64_str, is_base64_encoded
|
||||
|
||||
|
||||
class AmazonTitanMultimodalEmbeddingG1Config:
|
||||
"""
|
||||
Reference - https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-titan-embed-mm.html
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
pass
|
||||
|
||||
def get_supported_openai_params(self) -> List[str]:
|
||||
return ["dimensions"]
|
||||
|
||||
def map_openai_params(
|
||||
self, non_default_params: dict, optional_params: dict
|
||||
) -> dict:
|
||||
for k, v in non_default_params.items():
|
||||
if k == "dimensions":
|
||||
optional_params[
|
||||
"embeddingConfig"
|
||||
] = AmazonTitanMultimodalEmbeddingConfig(outputEmbeddingLength=v)
|
||||
return optional_params
|
||||
|
||||
def _transform_request(
|
||||
self, input: str, inference_params: dict
|
||||
) -> AmazonTitanMultimodalEmbeddingRequest:
|
||||
## check if b64 encoded str or not ##
|
||||
is_encoded = is_base64_encoded(input)
|
||||
if is_encoded: # check if string is b64 encoded image or not
|
||||
b64_str = get_base64_str(input)
|
||||
transformed_request = AmazonTitanMultimodalEmbeddingRequest(
|
||||
inputImage=b64_str
|
||||
)
|
||||
else:
|
||||
transformed_request = AmazonTitanMultimodalEmbeddingRequest(inputText=input)
|
||||
|
||||
for k, v in inference_params.items():
|
||||
transformed_request[k] = v # type: ignore
|
||||
return transformed_request
|
||||
|
||||
def _transform_response(
|
||||
self, response_list: List[dict], model: str
|
||||
) -> EmbeddingResponse:
|
||||
total_prompt_tokens = 0
|
||||
transformed_responses: List[Embedding] = []
|
||||
for index, response in enumerate(response_list):
|
||||
_parsed_response = AmazonTitanMultimodalEmbeddingResponse(**response) # type: ignore
|
||||
transformed_responses.append(
|
||||
Embedding(
|
||||
embedding=_parsed_response["embedding"],
|
||||
index=index,
|
||||
object="embedding",
|
||||
)
|
||||
)
|
||||
total_prompt_tokens += _parsed_response["inputTextTokenCount"]
|
||||
|
||||
usage = Usage(
|
||||
prompt_tokens=total_prompt_tokens,
|
||||
completion_tokens=0,
|
||||
total_tokens=total_prompt_tokens,
|
||||
)
|
||||
return EmbeddingResponse(model=model, usage=usage, data=transformed_responses)
|
||||
@@ -0,0 +1,97 @@
|
||||
"""
|
||||
Transformation logic from OpenAI /v1/embeddings format to Bedrock Amazon Titan V2 /invoke format.
|
||||
|
||||
Why separate file? Make it easy to see how transformation works
|
||||
|
||||
Convers
|
||||
- v2 request format
|
||||
|
||||
Docs - https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-titan-embed-text.html
|
||||
"""
|
||||
|
||||
import types
|
||||
from typing import List, Optional
|
||||
|
||||
from litellm.types.llms.bedrock import (
|
||||
AmazonTitanV2EmbeddingRequest,
|
||||
AmazonTitanV2EmbeddingResponse,
|
||||
)
|
||||
from litellm.types.utils import Embedding, EmbeddingResponse, Usage
|
||||
|
||||
|
||||
class AmazonTitanV2Config:
|
||||
"""
|
||||
Reference: https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-titan-embed-text.html
|
||||
|
||||
normalize: boolean - flag indicating whether or not to normalize the output embeddings. Defaults to true
|
||||
dimensions: int - The number of dimensions the output embeddings should have. The following values are accepted: 1024 (default), 512, 256.
|
||||
"""
|
||||
|
||||
normalize: Optional[bool] = None
|
||||
dimensions: Optional[int] = None
|
||||
|
||||
def __init__(
|
||||
self, normalize: Optional[bool] = None, dimensions: Optional[int] = None
|
||||
) -> None:
|
||||
locals_ = locals().copy()
|
||||
for key, value in locals_.items():
|
||||
if key != "self" and value is not None:
|
||||
setattr(self.__class__, key, value)
|
||||
|
||||
@classmethod
|
||||
def get_config(cls):
|
||||
return {
|
||||
k: v
|
||||
for k, v in cls.__dict__.items()
|
||||
if not k.startswith("__")
|
||||
and not isinstance(
|
||||
v,
|
||||
(
|
||||
types.FunctionType,
|
||||
types.BuiltinFunctionType,
|
||||
classmethod,
|
||||
staticmethod,
|
||||
),
|
||||
)
|
||||
and v is not None
|
||||
}
|
||||
|
||||
def get_supported_openai_params(self) -> List[str]:
|
||||
return ["dimensions"]
|
||||
|
||||
def map_openai_params(
|
||||
self, non_default_params: dict, optional_params: dict
|
||||
) -> dict:
|
||||
for k, v in non_default_params.items():
|
||||
if k == "dimensions":
|
||||
optional_params["dimensions"] = v
|
||||
return optional_params
|
||||
|
||||
def _transform_request(
|
||||
self, input: str, inference_params: dict
|
||||
) -> AmazonTitanV2EmbeddingRequest:
|
||||
return AmazonTitanV2EmbeddingRequest(inputText=input, **inference_params) # type: ignore
|
||||
|
||||
def _transform_response(
|
||||
self, response_list: List[dict], model: str
|
||||
) -> EmbeddingResponse:
|
||||
total_prompt_tokens = 0
|
||||
|
||||
transformed_responses: List[Embedding] = []
|
||||
for index, response in enumerate(response_list):
|
||||
_parsed_response = AmazonTitanV2EmbeddingResponse(**response) # type: ignore
|
||||
transformed_responses.append(
|
||||
Embedding(
|
||||
embedding=_parsed_response["embedding"],
|
||||
index=index,
|
||||
object="embedding",
|
||||
)
|
||||
)
|
||||
total_prompt_tokens += _parsed_response["inputTextTokenCount"]
|
||||
|
||||
usage = Usage(
|
||||
prompt_tokens=total_prompt_tokens,
|
||||
completion_tokens=0,
|
||||
total_tokens=total_prompt_tokens,
|
||||
)
|
||||
return EmbeddingResponse(model=model, usage=usage, data=transformed_responses)
|
||||
@@ -0,0 +1,45 @@
|
||||
"""
|
||||
Transformation logic from OpenAI /v1/embeddings format to Bedrock Cohere /invoke format.
|
||||
|
||||
Why separate file? Make it easy to see how transformation works
|
||||
"""
|
||||
|
||||
from typing import List
|
||||
|
||||
from litellm.llms.cohere.embed.transformation import CohereEmbeddingConfig
|
||||
from litellm.types.llms.bedrock import CohereEmbeddingRequest
|
||||
|
||||
|
||||
class BedrockCohereEmbeddingConfig:
|
||||
def __init__(self) -> None:
|
||||
pass
|
||||
|
||||
def get_supported_openai_params(self) -> List[str]:
|
||||
return ["encoding_format"]
|
||||
|
||||
def map_openai_params(
|
||||
self, non_default_params: dict, optional_params: dict
|
||||
) -> dict:
|
||||
for k, v in non_default_params.items():
|
||||
if k == "encoding_format":
|
||||
optional_params["embedding_types"] = v
|
||||
return optional_params
|
||||
|
||||
def _is_v3_model(self, model: str) -> bool:
|
||||
return "3" in model
|
||||
|
||||
def _transform_request(
|
||||
self, model: str, input: List[str], inference_params: dict
|
||||
) -> CohereEmbeddingRequest:
|
||||
transformed_request = CohereEmbeddingConfig()._transform_request(
|
||||
model, input, inference_params
|
||||
)
|
||||
|
||||
new_transformed_request = CohereEmbeddingRequest(
|
||||
input_type=transformed_request["input_type"],
|
||||
)
|
||||
for k in CohereEmbeddingRequest.__annotations__.keys():
|
||||
if k in transformed_request:
|
||||
new_transformed_request[k] = transformed_request[k] # type: ignore
|
||||
|
||||
return new_transformed_request
|
||||
@@ -0,0 +1,480 @@
|
||||
"""
|
||||
Handles embedding calls to Bedrock's `/invoke` endpoint
|
||||
"""
|
||||
|
||||
import copy
|
||||
import json
|
||||
from typing import Any, Callable, List, Optional, Tuple, Union
|
||||
|
||||
import httpx
|
||||
|
||||
import litellm
|
||||
from litellm.llms.cohere.embed.handler import embedding as cohere_embedding
|
||||
from litellm.llms.custom_httpx.http_handler import (
|
||||
AsyncHTTPHandler,
|
||||
HTTPHandler,
|
||||
_get_httpx_client,
|
||||
get_async_httpx_client,
|
||||
)
|
||||
from litellm.secret_managers.main import get_secret
|
||||
from litellm.types.llms.bedrock import AmazonEmbeddingRequest, CohereEmbeddingRequest
|
||||
from litellm.types.utils import EmbeddingResponse
|
||||
|
||||
from ..base_aws_llm import BaseAWSLLM
|
||||
from ..common_utils import BedrockError
|
||||
from .amazon_titan_g1_transformation import AmazonTitanG1Config
|
||||
from .amazon_titan_multimodal_transformation import (
|
||||
AmazonTitanMultimodalEmbeddingG1Config,
|
||||
)
|
||||
from .amazon_titan_v2_transformation import AmazonTitanV2Config
|
||||
from .cohere_transformation import BedrockCohereEmbeddingConfig
|
||||
|
||||
|
||||
class BedrockEmbedding(BaseAWSLLM):
|
||||
def _load_credentials(
|
||||
self,
|
||||
optional_params: dict,
|
||||
) -> Tuple[Any, str]:
|
||||
try:
|
||||
from botocore.credentials import Credentials
|
||||
except ImportError:
|
||||
raise ImportError("Missing boto3 to call bedrock. Run 'pip install boto3'.")
|
||||
## CREDENTIALS ##
|
||||
# pop aws_secret_access_key, aws_access_key_id, aws_session_token, aws_region_name from kwargs, since completion calls fail with them
|
||||
aws_secret_access_key = optional_params.pop("aws_secret_access_key", None)
|
||||
aws_access_key_id = optional_params.pop("aws_access_key_id", None)
|
||||
aws_session_token = optional_params.pop("aws_session_token", None)
|
||||
aws_region_name = optional_params.pop("aws_region_name", None)
|
||||
aws_role_name = optional_params.pop("aws_role_name", None)
|
||||
aws_session_name = optional_params.pop("aws_session_name", None)
|
||||
aws_profile_name = optional_params.pop("aws_profile_name", None)
|
||||
aws_web_identity_token = optional_params.pop("aws_web_identity_token", None)
|
||||
aws_sts_endpoint = optional_params.pop("aws_sts_endpoint", None)
|
||||
|
||||
### SET REGION NAME ###
|
||||
if aws_region_name is None:
|
||||
# check env #
|
||||
litellm_aws_region_name = get_secret("AWS_REGION_NAME", None)
|
||||
|
||||
if litellm_aws_region_name is not None and isinstance(
|
||||
litellm_aws_region_name, str
|
||||
):
|
||||
aws_region_name = litellm_aws_region_name
|
||||
|
||||
standard_aws_region_name = get_secret("AWS_REGION", None)
|
||||
if standard_aws_region_name is not None and isinstance(
|
||||
standard_aws_region_name, str
|
||||
):
|
||||
aws_region_name = standard_aws_region_name
|
||||
|
||||
if aws_region_name is None:
|
||||
aws_region_name = "us-west-2"
|
||||
|
||||
credentials: Credentials = self.get_credentials(
|
||||
aws_access_key_id=aws_access_key_id,
|
||||
aws_secret_access_key=aws_secret_access_key,
|
||||
aws_session_token=aws_session_token,
|
||||
aws_region_name=aws_region_name,
|
||||
aws_session_name=aws_session_name,
|
||||
aws_profile_name=aws_profile_name,
|
||||
aws_role_name=aws_role_name,
|
||||
aws_web_identity_token=aws_web_identity_token,
|
||||
aws_sts_endpoint=aws_sts_endpoint,
|
||||
)
|
||||
return credentials, aws_region_name
|
||||
|
||||
async def async_embeddings(self):
|
||||
pass
|
||||
|
||||
def _make_sync_call(
|
||||
self,
|
||||
client: Optional[HTTPHandler],
|
||||
timeout: Optional[Union[float, httpx.Timeout]],
|
||||
api_base: str,
|
||||
headers: dict,
|
||||
data: dict,
|
||||
) -> dict:
|
||||
if client is None or not isinstance(client, HTTPHandler):
|
||||
_params = {}
|
||||
if timeout is not None:
|
||||
if isinstance(timeout, float) or isinstance(timeout, int):
|
||||
timeout = httpx.Timeout(timeout)
|
||||
_params["timeout"] = timeout
|
||||
client = _get_httpx_client(_params) # type: ignore
|
||||
else:
|
||||
client = client
|
||||
try:
|
||||
response = client.post(url=api_base, headers=headers, data=json.dumps(data)) # type: ignore
|
||||
response.raise_for_status()
|
||||
except httpx.HTTPStatusError as err:
|
||||
error_code = err.response.status_code
|
||||
raise BedrockError(status_code=error_code, message=err.response.text)
|
||||
except httpx.TimeoutException:
|
||||
raise BedrockError(status_code=408, message="Timeout error occurred.")
|
||||
|
||||
return response.json()
|
||||
|
||||
async def _make_async_call(
|
||||
self,
|
||||
client: Optional[AsyncHTTPHandler],
|
||||
timeout: Optional[Union[float, httpx.Timeout]],
|
||||
api_base: str,
|
||||
headers: dict,
|
||||
data: dict,
|
||||
) -> dict:
|
||||
if client is None or not isinstance(client, AsyncHTTPHandler):
|
||||
_params = {}
|
||||
if timeout is not None:
|
||||
if isinstance(timeout, float) or isinstance(timeout, int):
|
||||
timeout = httpx.Timeout(timeout)
|
||||
_params["timeout"] = timeout
|
||||
client = get_async_httpx_client(
|
||||
params=_params, llm_provider=litellm.LlmProviders.BEDROCK
|
||||
)
|
||||
else:
|
||||
client = client
|
||||
|
||||
try:
|
||||
response = await client.post(url=api_base, headers=headers, data=json.dumps(data)) # type: ignore
|
||||
response.raise_for_status()
|
||||
except httpx.HTTPStatusError as err:
|
||||
error_code = err.response.status_code
|
||||
raise BedrockError(status_code=error_code, message=err.response.text)
|
||||
except httpx.TimeoutException:
|
||||
raise BedrockError(status_code=408, message="Timeout error occurred.")
|
||||
|
||||
return response.json()
|
||||
|
||||
def _single_func_embeddings(
|
||||
self,
|
||||
client: Optional[HTTPHandler],
|
||||
timeout: Optional[Union[float, httpx.Timeout]],
|
||||
batch_data: List[dict],
|
||||
credentials: Any,
|
||||
extra_headers: Optional[dict],
|
||||
endpoint_url: str,
|
||||
aws_region_name: str,
|
||||
model: str,
|
||||
logging_obj: Any,
|
||||
):
|
||||
try:
|
||||
from botocore.auth import SigV4Auth
|
||||
from botocore.awsrequest import AWSRequest
|
||||
except ImportError:
|
||||
raise ImportError("Missing boto3 to call bedrock. Run 'pip install boto3'.")
|
||||
|
||||
responses: List[dict] = []
|
||||
for data in batch_data:
|
||||
sigv4 = SigV4Auth(credentials, "bedrock", aws_region_name)
|
||||
headers = {"Content-Type": "application/json"}
|
||||
if extra_headers is not None:
|
||||
headers = {"Content-Type": "application/json", **extra_headers}
|
||||
request = AWSRequest(
|
||||
method="POST", url=endpoint_url, data=json.dumps(data), headers=headers
|
||||
)
|
||||
sigv4.add_auth(request)
|
||||
if (
|
||||
extra_headers is not None and "Authorization" in extra_headers
|
||||
): # prevent sigv4 from overwriting the auth header
|
||||
request.headers["Authorization"] = extra_headers["Authorization"]
|
||||
prepped = request.prepare()
|
||||
|
||||
## LOGGING
|
||||
logging_obj.pre_call(
|
||||
input=data,
|
||||
api_key="",
|
||||
additional_args={
|
||||
"complete_input_dict": data,
|
||||
"api_base": prepped.url,
|
||||
"headers": prepped.headers,
|
||||
},
|
||||
)
|
||||
response = self._make_sync_call(
|
||||
client=client,
|
||||
timeout=timeout,
|
||||
api_base=prepped.url,
|
||||
headers=prepped.headers, # type: ignore
|
||||
data=data,
|
||||
)
|
||||
|
||||
## LOGGING
|
||||
logging_obj.post_call(
|
||||
input=data,
|
||||
api_key="",
|
||||
original_response=response,
|
||||
additional_args={"complete_input_dict": data},
|
||||
)
|
||||
|
||||
responses.append(response)
|
||||
|
||||
returned_response: Optional[EmbeddingResponse] = None
|
||||
|
||||
## TRANSFORM RESPONSE ##
|
||||
if model == "amazon.titan-embed-image-v1":
|
||||
returned_response = (
|
||||
AmazonTitanMultimodalEmbeddingG1Config()._transform_response(
|
||||
response_list=responses, model=model
|
||||
)
|
||||
)
|
||||
elif model == "amazon.titan-embed-text-v1":
|
||||
returned_response = AmazonTitanG1Config()._transform_response(
|
||||
response_list=responses, model=model
|
||||
)
|
||||
elif model == "amazon.titan-embed-text-v2:0":
|
||||
returned_response = AmazonTitanV2Config()._transform_response(
|
||||
response_list=responses, model=model
|
||||
)
|
||||
|
||||
if returned_response is None:
|
||||
raise Exception(
|
||||
"Unable to map model response to known provider format. model={}".format(
|
||||
model
|
||||
)
|
||||
)
|
||||
|
||||
return returned_response
|
||||
|
||||
async def _async_single_func_embeddings(
|
||||
self,
|
||||
client: Optional[AsyncHTTPHandler],
|
||||
timeout: Optional[Union[float, httpx.Timeout]],
|
||||
batch_data: List[dict],
|
||||
credentials: Any,
|
||||
extra_headers: Optional[dict],
|
||||
endpoint_url: str,
|
||||
aws_region_name: str,
|
||||
model: str,
|
||||
logging_obj: Any,
|
||||
):
|
||||
try:
|
||||
from botocore.auth import SigV4Auth
|
||||
from botocore.awsrequest import AWSRequest
|
||||
except ImportError:
|
||||
raise ImportError("Missing boto3 to call bedrock. Run 'pip install boto3'.")
|
||||
|
||||
responses: List[dict] = []
|
||||
for data in batch_data:
|
||||
sigv4 = SigV4Auth(credentials, "bedrock", aws_region_name)
|
||||
headers = {"Content-Type": "application/json"}
|
||||
if extra_headers is not None:
|
||||
headers = {"Content-Type": "application/json", **extra_headers}
|
||||
request = AWSRequest(
|
||||
method="POST", url=endpoint_url, data=json.dumps(data), headers=headers
|
||||
)
|
||||
sigv4.add_auth(request)
|
||||
if (
|
||||
extra_headers is not None and "Authorization" in extra_headers
|
||||
): # prevent sigv4 from overwriting the auth header
|
||||
request.headers["Authorization"] = extra_headers["Authorization"]
|
||||
prepped = request.prepare()
|
||||
|
||||
## LOGGING
|
||||
logging_obj.pre_call(
|
||||
input=data,
|
||||
api_key="",
|
||||
additional_args={
|
||||
"complete_input_dict": data,
|
||||
"api_base": prepped.url,
|
||||
"headers": prepped.headers,
|
||||
},
|
||||
)
|
||||
response = await self._make_async_call(
|
||||
client=client,
|
||||
timeout=timeout,
|
||||
api_base=prepped.url,
|
||||
headers=prepped.headers, # type: ignore
|
||||
data=data,
|
||||
)
|
||||
|
||||
## LOGGING
|
||||
logging_obj.post_call(
|
||||
input=data,
|
||||
api_key="",
|
||||
original_response=response,
|
||||
additional_args={"complete_input_dict": data},
|
||||
)
|
||||
|
||||
responses.append(response)
|
||||
|
||||
returned_response: Optional[EmbeddingResponse] = None
|
||||
|
||||
## TRANSFORM RESPONSE ##
|
||||
if model == "amazon.titan-embed-image-v1":
|
||||
returned_response = (
|
||||
AmazonTitanMultimodalEmbeddingG1Config()._transform_response(
|
||||
response_list=responses, model=model
|
||||
)
|
||||
)
|
||||
elif model == "amazon.titan-embed-text-v1":
|
||||
returned_response = AmazonTitanG1Config()._transform_response(
|
||||
response_list=responses, model=model
|
||||
)
|
||||
elif model == "amazon.titan-embed-text-v2:0":
|
||||
returned_response = AmazonTitanV2Config()._transform_response(
|
||||
response_list=responses, model=model
|
||||
)
|
||||
|
||||
if returned_response is None:
|
||||
raise Exception(
|
||||
"Unable to map model response to known provider format. model={}".format(
|
||||
model
|
||||
)
|
||||
)
|
||||
|
||||
return returned_response
|
||||
|
||||
def embeddings(
|
||||
self,
|
||||
model: str,
|
||||
input: List[str],
|
||||
api_base: Optional[str],
|
||||
model_response: EmbeddingResponse,
|
||||
print_verbose: Callable,
|
||||
encoding,
|
||||
logging_obj,
|
||||
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]],
|
||||
timeout: Optional[Union[float, httpx.Timeout]],
|
||||
aembedding: Optional[bool],
|
||||
extra_headers: Optional[dict],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
) -> EmbeddingResponse:
|
||||
try:
|
||||
from botocore.auth import SigV4Auth
|
||||
from botocore.awsrequest import AWSRequest
|
||||
except ImportError:
|
||||
raise ImportError("Missing boto3 to call bedrock. Run 'pip install boto3'.")
|
||||
|
||||
credentials, aws_region_name = self._load_credentials(optional_params)
|
||||
|
||||
### TRANSFORMATION ###
|
||||
provider = model.split(".")[0]
|
||||
inference_params = copy.deepcopy(optional_params)
|
||||
inference_params = {
|
||||
k: v
|
||||
for k, v in inference_params.items()
|
||||
if k.lower() not in self.aws_authentication_params
|
||||
}
|
||||
inference_params.pop(
|
||||
"user", None
|
||||
) # make sure user is not passed in for bedrock call
|
||||
modelId = (
|
||||
optional_params.pop("model_id", None) or model
|
||||
) # default to model if not passed
|
||||
|
||||
data: Optional[CohereEmbeddingRequest] = None
|
||||
batch_data: Optional[List] = None
|
||||
if provider == "cohere":
|
||||
data = BedrockCohereEmbeddingConfig()._transform_request(
|
||||
model=model, input=input, inference_params=inference_params
|
||||
)
|
||||
elif provider == "amazon" and model in [
|
||||
"amazon.titan-embed-image-v1",
|
||||
"amazon.titan-embed-text-v1",
|
||||
"amazon.titan-embed-text-v2:0",
|
||||
]:
|
||||
batch_data = []
|
||||
for i in input:
|
||||
if model == "amazon.titan-embed-image-v1":
|
||||
transformed_request: (
|
||||
AmazonEmbeddingRequest
|
||||
) = AmazonTitanMultimodalEmbeddingG1Config()._transform_request(
|
||||
input=i, inference_params=inference_params
|
||||
)
|
||||
elif model == "amazon.titan-embed-text-v1":
|
||||
transformed_request = AmazonTitanG1Config()._transform_request(
|
||||
input=i, inference_params=inference_params
|
||||
)
|
||||
elif model == "amazon.titan-embed-text-v2:0":
|
||||
transformed_request = AmazonTitanV2Config()._transform_request(
|
||||
input=i, inference_params=inference_params
|
||||
)
|
||||
else:
|
||||
raise Exception(
|
||||
"Unmapped model. Received={}. Expected={}".format(
|
||||
model,
|
||||
[
|
||||
"amazon.titan-embed-image-v1",
|
||||
"amazon.titan-embed-text-v1",
|
||||
"amazon.titan-embed-text-v2:0",
|
||||
],
|
||||
)
|
||||
)
|
||||
batch_data.append(transformed_request)
|
||||
|
||||
### SET RUNTIME ENDPOINT ###
|
||||
endpoint_url, proxy_endpoint_url = self.get_runtime_endpoint(
|
||||
api_base=api_base,
|
||||
aws_bedrock_runtime_endpoint=optional_params.pop(
|
||||
"aws_bedrock_runtime_endpoint", None
|
||||
),
|
||||
aws_region_name=aws_region_name,
|
||||
)
|
||||
endpoint_url = f"{endpoint_url}/model/{modelId}/invoke"
|
||||
|
||||
if batch_data is not None:
|
||||
if aembedding:
|
||||
return self._async_single_func_embeddings( # type: ignore
|
||||
client=(
|
||||
client
|
||||
if client is not None and isinstance(client, AsyncHTTPHandler)
|
||||
else None
|
||||
),
|
||||
timeout=timeout,
|
||||
batch_data=batch_data,
|
||||
credentials=credentials,
|
||||
extra_headers=extra_headers,
|
||||
endpoint_url=endpoint_url,
|
||||
aws_region_name=aws_region_name,
|
||||
model=model,
|
||||
logging_obj=logging_obj,
|
||||
)
|
||||
return self._single_func_embeddings(
|
||||
client=(
|
||||
client
|
||||
if client is not None and isinstance(client, HTTPHandler)
|
||||
else None
|
||||
),
|
||||
timeout=timeout,
|
||||
batch_data=batch_data,
|
||||
credentials=credentials,
|
||||
extra_headers=extra_headers,
|
||||
endpoint_url=endpoint_url,
|
||||
aws_region_name=aws_region_name,
|
||||
model=model,
|
||||
logging_obj=logging_obj,
|
||||
)
|
||||
elif data is None:
|
||||
raise Exception("Unable to map Bedrock request to provider")
|
||||
|
||||
sigv4 = SigV4Auth(credentials, "bedrock", aws_region_name)
|
||||
headers = {"Content-Type": "application/json"}
|
||||
if extra_headers is not None:
|
||||
headers = {"Content-Type": "application/json", **extra_headers}
|
||||
|
||||
request = AWSRequest(
|
||||
method="POST", url=endpoint_url, data=json.dumps(data), headers=headers
|
||||
)
|
||||
sigv4.add_auth(request)
|
||||
if (
|
||||
extra_headers is not None and "Authorization" in extra_headers
|
||||
): # prevent sigv4 from overwriting the auth header
|
||||
request.headers["Authorization"] = extra_headers["Authorization"]
|
||||
prepped = request.prepare()
|
||||
|
||||
## ROUTING ##
|
||||
return cohere_embedding(
|
||||
model=model,
|
||||
input=input,
|
||||
model_response=model_response,
|
||||
logging_obj=logging_obj,
|
||||
optional_params=optional_params,
|
||||
encoding=encoding,
|
||||
data=data, # type: ignore
|
||||
complete_api_base=prepped.url,
|
||||
api_key=None,
|
||||
aembedding=aembedding,
|
||||
timeout=timeout,
|
||||
client=client,
|
||||
headers=prepped.headers, # type: ignore
|
||||
)
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,165 @@
|
||||
import types
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from openai.types.image import Image
|
||||
|
||||
from litellm.types.llms.bedrock import (
|
||||
AmazonNovaCanvasColorGuidedGenerationParams,
|
||||
AmazonNovaCanvasColorGuidedRequest,
|
||||
AmazonNovaCanvasImageGenerationConfig,
|
||||
AmazonNovaCanvasRequestBase,
|
||||
AmazonNovaCanvasTextToImageParams,
|
||||
AmazonNovaCanvasTextToImageRequest,
|
||||
AmazonNovaCanvasTextToImageResponse,
|
||||
)
|
||||
from litellm.types.utils import ImageResponse
|
||||
|
||||
|
||||
class AmazonNovaCanvasConfig:
|
||||
"""
|
||||
Reference: https://us-east-1.console.aws.amazon.com/bedrock/home?region=us-east-1#/model-catalog/serverless/amazon.nova-canvas-v1:0
|
||||
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def get_config(cls):
|
||||
return {
|
||||
k: v
|
||||
for k, v in cls.__dict__.items()
|
||||
if not k.startswith("__")
|
||||
and not isinstance(
|
||||
v,
|
||||
(
|
||||
types.FunctionType,
|
||||
types.BuiltinFunctionType,
|
||||
classmethod,
|
||||
staticmethod,
|
||||
),
|
||||
)
|
||||
and v is not None
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def get_supported_openai_params(cls, model: Optional[str] = None) -> List:
|
||||
""" """
|
||||
return ["n", "size", "quality"]
|
||||
|
||||
@classmethod
|
||||
def _is_nova_model(cls, model: Optional[str] = None) -> bool:
|
||||
"""
|
||||
Returns True if the model is a Nova Canvas model
|
||||
|
||||
Nova models follow this pattern:
|
||||
|
||||
"""
|
||||
if model:
|
||||
if "amazon.nova-canvas" in model:
|
||||
return True
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def transform_request_body(
|
||||
cls, text: str, optional_params: dict
|
||||
) -> AmazonNovaCanvasRequestBase:
|
||||
"""
|
||||
Transform the request body for Amazon Nova Canvas model
|
||||
"""
|
||||
task_type = optional_params.pop("taskType", "TEXT_IMAGE")
|
||||
image_generation_config = optional_params.pop("imageGenerationConfig", {})
|
||||
image_generation_config = {**image_generation_config, **optional_params}
|
||||
if task_type == "TEXT_IMAGE":
|
||||
text_to_image_params: Dict[str, Any] = image_generation_config.pop(
|
||||
"textToImageParams", {}
|
||||
)
|
||||
text_to_image_params = {"text": text, **text_to_image_params}
|
||||
try:
|
||||
text_to_image_params_typed = AmazonNovaCanvasTextToImageParams(
|
||||
**text_to_image_params # type: ignore
|
||||
)
|
||||
except Exception as e:
|
||||
raise ValueError(
|
||||
f"Error transforming text to image params: {e}. Got params: {text_to_image_params}, Expected params: {AmazonNovaCanvasTextToImageParams.__annotations__}"
|
||||
)
|
||||
|
||||
try:
|
||||
image_generation_config_typed = AmazonNovaCanvasImageGenerationConfig(
|
||||
**image_generation_config
|
||||
)
|
||||
except Exception as e:
|
||||
raise ValueError(
|
||||
f"Error transforming image generation config: {e}. Got params: {image_generation_config}, Expected params: {AmazonNovaCanvasImageGenerationConfig.__annotations__}"
|
||||
)
|
||||
|
||||
return AmazonNovaCanvasTextToImageRequest(
|
||||
textToImageParams=text_to_image_params_typed,
|
||||
taskType=task_type,
|
||||
imageGenerationConfig=image_generation_config_typed,
|
||||
)
|
||||
if task_type == "COLOR_GUIDED_GENERATION":
|
||||
color_guided_generation_params: Dict[
|
||||
str, Any
|
||||
] = image_generation_config.pop("colorGuidedGenerationParams", {})
|
||||
color_guided_generation_params = {
|
||||
"text": text,
|
||||
**color_guided_generation_params,
|
||||
}
|
||||
try:
|
||||
color_guided_generation_params_typed = AmazonNovaCanvasColorGuidedGenerationParams(
|
||||
**color_guided_generation_params # type: ignore
|
||||
)
|
||||
except Exception as e:
|
||||
raise ValueError(
|
||||
f"Error transforming color guided generation params: {e}. Got params: {color_guided_generation_params}, Expected params: {AmazonNovaCanvasColorGuidedGenerationParams.__annotations__}"
|
||||
)
|
||||
|
||||
try:
|
||||
image_generation_config_typed = AmazonNovaCanvasImageGenerationConfig(
|
||||
**image_generation_config
|
||||
)
|
||||
except Exception as e:
|
||||
raise ValueError(
|
||||
f"Error transforming image generation config: {e}. Got params: {image_generation_config}, Expected params: {AmazonNovaCanvasImageGenerationConfig.__annotations__}"
|
||||
)
|
||||
|
||||
return AmazonNovaCanvasColorGuidedRequest(
|
||||
taskType=task_type,
|
||||
colorGuidedGenerationParams=color_guided_generation_params_typed,
|
||||
imageGenerationConfig=image_generation_config_typed,
|
||||
)
|
||||
raise NotImplementedError(f"Task type {task_type} is not supported")
|
||||
|
||||
@classmethod
|
||||
def map_openai_params(cls, non_default_params: dict, optional_params: dict) -> dict:
|
||||
"""
|
||||
Map the OpenAI params to the Bedrock params
|
||||
"""
|
||||
_size = non_default_params.get("size")
|
||||
if _size is not None:
|
||||
width, height = _size.split("x")
|
||||
optional_params["width"], optional_params["height"] = int(width), int(
|
||||
height
|
||||
)
|
||||
if non_default_params.get("n") is not None:
|
||||
optional_params["numberOfImages"] = non_default_params.get("n")
|
||||
if non_default_params.get("quality") is not None:
|
||||
if non_default_params.get("quality") in ("hd", "premium"):
|
||||
optional_params["quality"] = "premium"
|
||||
if non_default_params.get("quality") == "standard":
|
||||
optional_params["quality"] = "standard"
|
||||
return optional_params
|
||||
|
||||
@classmethod
|
||||
def transform_response_dict_to_openai_response(
|
||||
cls, model_response: ImageResponse, response_dict: dict
|
||||
) -> ImageResponse:
|
||||
"""
|
||||
Transform the response dict to the OpenAI response
|
||||
"""
|
||||
|
||||
nova_response = AmazonNovaCanvasTextToImageResponse(**response_dict)
|
||||
openai_images: List[Image] = []
|
||||
for _img in nova_response.get("images", []):
|
||||
openai_images.append(Image(b64_json=_img))
|
||||
|
||||
model_response.data = openai_images
|
||||
return model_response
|
||||
@@ -0,0 +1,104 @@
|
||||
import types
|
||||
from typing import List, Optional
|
||||
|
||||
from openai.types.image import Image
|
||||
|
||||
from litellm.types.utils import ImageResponse
|
||||
|
||||
|
||||
class AmazonStabilityConfig:
|
||||
"""
|
||||
Reference: https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=stability.stable-diffusion-xl-v0
|
||||
|
||||
Supported Params for the Amazon / Stable Diffusion models:
|
||||
|
||||
- `cfg_scale` (integer): Default `7`. Between [ 0 .. 35 ]. How strictly the diffusion process adheres to the prompt text (higher values keep your image closer to your prompt)
|
||||
|
||||
- `seed` (float): Default: `0`. Between [ 0 .. 4294967295 ]. Random noise seed (omit this option or use 0 for a random seed)
|
||||
|
||||
- `steps` (array of strings): Default `30`. Between [ 10 .. 50 ]. Number of diffusion steps to run.
|
||||
|
||||
- `width` (integer): Default: `512`. multiple of 64 >= 128. Width of the image to generate, in pixels, in an increment divible by 64.
|
||||
Engine-specific dimension validation:
|
||||
|
||||
- SDXL Beta: must be between 128x128 and 512x896 (or 896x512); only one dimension can be greater than 512.
|
||||
- SDXL v0.9: must be one of 1024x1024, 1152x896, 1216x832, 1344x768, 1536x640, 640x1536, 768x1344, 832x1216, or 896x1152
|
||||
- SDXL v1.0: same as SDXL v0.9
|
||||
- SD v1.6: must be between 320x320 and 1536x1536
|
||||
|
||||
- `height` (integer): Default: `512`. multiple of 64 >= 128. Height of the image to generate, in pixels, in an increment divible by 64.
|
||||
Engine-specific dimension validation:
|
||||
|
||||
- SDXL Beta: must be between 128x128 and 512x896 (or 896x512); only one dimension can be greater than 512.
|
||||
- SDXL v0.9: must be one of 1024x1024, 1152x896, 1216x832, 1344x768, 1536x640, 640x1536, 768x1344, 832x1216, or 896x1152
|
||||
- SDXL v1.0: same as SDXL v0.9
|
||||
- SD v1.6: must be between 320x320 and 1536x1536
|
||||
"""
|
||||
|
||||
cfg_scale: Optional[int] = None
|
||||
seed: Optional[float] = None
|
||||
steps: Optional[List[str]] = None
|
||||
width: Optional[int] = None
|
||||
height: Optional[int] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
cfg_scale: Optional[int] = None,
|
||||
seed: Optional[float] = None,
|
||||
steps: Optional[List[str]] = None,
|
||||
width: Optional[int] = None,
|
||||
height: Optional[int] = None,
|
||||
) -> None:
|
||||
locals_ = locals().copy()
|
||||
for key, value in locals_.items():
|
||||
if key != "self" and value is not None:
|
||||
setattr(self.__class__, key, value)
|
||||
|
||||
@classmethod
|
||||
def get_config(cls):
|
||||
return {
|
||||
k: v
|
||||
for k, v in cls.__dict__.items()
|
||||
if not k.startswith("__")
|
||||
and not isinstance(
|
||||
v,
|
||||
(
|
||||
types.FunctionType,
|
||||
types.BuiltinFunctionType,
|
||||
classmethod,
|
||||
staticmethod,
|
||||
),
|
||||
)
|
||||
and v is not None
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def get_supported_openai_params(cls, model: Optional[str] = None) -> List:
|
||||
return ["size"]
|
||||
|
||||
@classmethod
|
||||
def map_openai_params(
|
||||
cls,
|
||||
non_default_params: dict,
|
||||
optional_params: dict,
|
||||
):
|
||||
_size = non_default_params.get("size")
|
||||
if _size is not None:
|
||||
width, height = _size.split("x")
|
||||
optional_params["width"] = int(width)
|
||||
optional_params["height"] = int(height)
|
||||
|
||||
return optional_params
|
||||
|
||||
@classmethod
|
||||
def transform_response_dict_to_openai_response(
|
||||
cls, model_response: ImageResponse, response_dict: dict
|
||||
) -> ImageResponse:
|
||||
image_list: List[Image] = []
|
||||
for artifact in response_dict["artifacts"]:
|
||||
_image = Image(b64_json=artifact["base64"])
|
||||
image_list.append(_image)
|
||||
|
||||
model_response.data = image_list
|
||||
|
||||
return model_response
|
||||
@@ -0,0 +1,100 @@
|
||||
import types
|
||||
from typing import List, Optional
|
||||
|
||||
from openai.types.image import Image
|
||||
|
||||
from litellm.types.llms.bedrock import (
|
||||
AmazonStability3TextToImageRequest,
|
||||
AmazonStability3TextToImageResponse,
|
||||
)
|
||||
from litellm.types.utils import ImageResponse
|
||||
|
||||
|
||||
class AmazonStability3Config:
|
||||
"""
|
||||
Reference: https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=stability.stable-diffusion-xl-v0
|
||||
|
||||
Stability API Ref: https://platform.stability.ai/docs/api-reference#tag/Generate/paths/~1v2beta~1stable-image~1generate~1sd3/post
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def get_config(cls):
|
||||
return {
|
||||
k: v
|
||||
for k, v in cls.__dict__.items()
|
||||
if not k.startswith("__")
|
||||
and not isinstance(
|
||||
v,
|
||||
(
|
||||
types.FunctionType,
|
||||
types.BuiltinFunctionType,
|
||||
classmethod,
|
||||
staticmethod,
|
||||
),
|
||||
)
|
||||
and v is not None
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def get_supported_openai_params(cls, model: Optional[str] = None) -> List:
|
||||
"""
|
||||
No additional OpenAI params are mapped for stability 3
|
||||
"""
|
||||
return []
|
||||
|
||||
@classmethod
|
||||
def _is_stability_3_model(cls, model: Optional[str] = None) -> bool:
|
||||
"""
|
||||
Returns True if the model is a Stability 3 model
|
||||
|
||||
Stability 3 models follow this pattern:
|
||||
sd3-large
|
||||
sd3-large-turbo
|
||||
sd3-medium
|
||||
sd3.5-large
|
||||
sd3.5-large-turbo
|
||||
|
||||
Stability ultra models
|
||||
stable-image-ultra-v1
|
||||
"""
|
||||
if model:
|
||||
if "sd3" in model or "sd3.5" in model:
|
||||
return True
|
||||
if "stable-image-ultra-v1" in model:
|
||||
return True
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def transform_request_body(
|
||||
cls, prompt: str, optional_params: dict
|
||||
) -> AmazonStability3TextToImageRequest:
|
||||
"""
|
||||
Transform the request body for the Stability 3 models
|
||||
"""
|
||||
data = AmazonStability3TextToImageRequest(prompt=prompt, **optional_params)
|
||||
return data
|
||||
|
||||
@classmethod
|
||||
def map_openai_params(cls, non_default_params: dict, optional_params: dict) -> dict:
|
||||
"""
|
||||
Map the OpenAI params to the Bedrock params
|
||||
|
||||
No OpenAI params are mapped for Stability 3, so directly return the optional_params
|
||||
"""
|
||||
return optional_params
|
||||
|
||||
@classmethod
|
||||
def transform_response_dict_to_openai_response(
|
||||
cls, model_response: ImageResponse, response_dict: dict
|
||||
) -> ImageResponse:
|
||||
"""
|
||||
Transform the response dict to the OpenAI response
|
||||
"""
|
||||
|
||||
stability_3_response = AmazonStability3TextToImageResponse(**response_dict)
|
||||
openai_images: List[Image] = []
|
||||
for _img in stability_3_response.get("images", []):
|
||||
openai_images.append(Image(b64_json=_img))
|
||||
|
||||
model_response.data = openai_images
|
||||
return model_response
|
||||
@@ -0,0 +1,41 @@
|
||||
from typing import Optional
|
||||
|
||||
import litellm
|
||||
from litellm.types.utils import ImageResponse
|
||||
|
||||
|
||||
def cost_calculator(
|
||||
model: str,
|
||||
image_response: ImageResponse,
|
||||
size: Optional[str] = None,
|
||||
optional_params: Optional[dict] = None,
|
||||
) -> float:
|
||||
"""
|
||||
Bedrock image generation cost calculator
|
||||
|
||||
Handles both Stability 1 and Stability 3 models
|
||||
"""
|
||||
if litellm.AmazonStability3Config()._is_stability_3_model(model=model):
|
||||
pass
|
||||
else:
|
||||
# Stability 1 models
|
||||
optional_params = optional_params or {}
|
||||
|
||||
# see model_prices_and_context_window.json for details on how steps is used
|
||||
# Reference pricing by steps for stability 1: https://aws.amazon.com/bedrock/pricing/
|
||||
_steps = optional_params.get("steps", 50)
|
||||
steps = "max-steps" if _steps > 50 else "50-steps"
|
||||
|
||||
# size is stored in model_prices_and_context_window.json as 1024-x-1024
|
||||
# current size has 1024x1024
|
||||
size = size or "1024-x-1024"
|
||||
model = f"{size}/{steps}/{model}"
|
||||
|
||||
_model_info = litellm.get_model_info(
|
||||
model=model,
|
||||
custom_llm_provider="bedrock",
|
||||
)
|
||||
|
||||
output_cost_per_image: float = _model_info.get("output_cost_per_image") or 0.0
|
||||
num_images: int = len(image_response.data)
|
||||
return output_cost_per_image * num_images
|
||||
@@ -0,0 +1,321 @@
|
||||
import copy
|
||||
import json
|
||||
import os
|
||||
from typing import TYPE_CHECKING, Any, Optional, Union
|
||||
|
||||
import httpx
|
||||
from pydantic import BaseModel
|
||||
|
||||
import litellm
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as LitellmLogging
|
||||
from litellm.llms.custom_httpx.http_handler import (
|
||||
AsyncHTTPHandler,
|
||||
HTTPHandler,
|
||||
_get_httpx_client,
|
||||
get_async_httpx_client,
|
||||
)
|
||||
from litellm.types.utils import ImageResponse
|
||||
|
||||
from ..base_aws_llm import BaseAWSLLM
|
||||
from ..common_utils import BedrockError
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from botocore.awsrequest import AWSPreparedRequest
|
||||
else:
|
||||
AWSPreparedRequest = Any
|
||||
|
||||
|
||||
class BedrockImagePreparedRequest(BaseModel):
|
||||
"""
|
||||
Internal/Helper class for preparing the request for bedrock image generation
|
||||
"""
|
||||
|
||||
endpoint_url: str
|
||||
prepped: AWSPreparedRequest
|
||||
body: bytes
|
||||
data: dict
|
||||
|
||||
|
||||
class BedrockImageGeneration(BaseAWSLLM):
|
||||
"""
|
||||
Bedrock Image Generation handler
|
||||
"""
|
||||
|
||||
def image_generation(
|
||||
self,
|
||||
model: str,
|
||||
prompt: str,
|
||||
model_response: ImageResponse,
|
||||
optional_params: dict,
|
||||
logging_obj: LitellmLogging,
|
||||
timeout: Optional[Union[float, httpx.Timeout]],
|
||||
aimg_generation: bool = False,
|
||||
api_base: Optional[str] = None,
|
||||
extra_headers: Optional[dict] = None,
|
||||
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
|
||||
):
|
||||
prepared_request = self._prepare_request(
|
||||
model=model,
|
||||
optional_params=optional_params,
|
||||
api_base=api_base,
|
||||
extra_headers=extra_headers,
|
||||
logging_obj=logging_obj,
|
||||
prompt=prompt,
|
||||
)
|
||||
|
||||
if aimg_generation is True:
|
||||
return self.async_image_generation(
|
||||
prepared_request=prepared_request,
|
||||
timeout=timeout,
|
||||
model=model,
|
||||
logging_obj=logging_obj,
|
||||
prompt=prompt,
|
||||
model_response=model_response,
|
||||
client=(
|
||||
client
|
||||
if client is not None and isinstance(client, AsyncHTTPHandler)
|
||||
else None
|
||||
),
|
||||
)
|
||||
|
||||
if client is None or not isinstance(client, HTTPHandler):
|
||||
client = _get_httpx_client()
|
||||
try:
|
||||
response = client.post(url=prepared_request.endpoint_url, headers=prepared_request.prepped.headers, data=prepared_request.body) # type: ignore
|
||||
response.raise_for_status()
|
||||
except httpx.HTTPStatusError as err:
|
||||
error_code = err.response.status_code
|
||||
raise BedrockError(status_code=error_code, message=err.response.text)
|
||||
except httpx.TimeoutException:
|
||||
raise BedrockError(status_code=408, message="Timeout error occurred.")
|
||||
### FORMAT RESPONSE TO OPENAI FORMAT ###
|
||||
model_response = self._transform_response_dict_to_openai_response(
|
||||
model_response=model_response,
|
||||
model=model,
|
||||
logging_obj=logging_obj,
|
||||
prompt=prompt,
|
||||
response=response,
|
||||
data=prepared_request.data,
|
||||
)
|
||||
return model_response
|
||||
|
||||
async def async_image_generation(
|
||||
self,
|
||||
prepared_request: BedrockImagePreparedRequest,
|
||||
timeout: Optional[Union[float, httpx.Timeout]],
|
||||
model: str,
|
||||
logging_obj: LitellmLogging,
|
||||
prompt: str,
|
||||
model_response: ImageResponse,
|
||||
client: Optional[AsyncHTTPHandler] = None,
|
||||
) -> ImageResponse:
|
||||
"""
|
||||
Asynchronous handler for bedrock image generation
|
||||
|
||||
Awaits the response from the bedrock image generation endpoint
|
||||
"""
|
||||
async_client = client or get_async_httpx_client(
|
||||
llm_provider=litellm.LlmProviders.BEDROCK,
|
||||
params={"timeout": timeout},
|
||||
)
|
||||
|
||||
try:
|
||||
response = await async_client.post(url=prepared_request.endpoint_url, headers=prepared_request.prepped.headers, data=prepared_request.body) # type: ignore
|
||||
response.raise_for_status()
|
||||
except httpx.HTTPStatusError as err:
|
||||
error_code = err.response.status_code
|
||||
raise BedrockError(status_code=error_code, message=err.response.text)
|
||||
except httpx.TimeoutException:
|
||||
raise BedrockError(status_code=408, message="Timeout error occurred.")
|
||||
|
||||
### FORMAT RESPONSE TO OPENAI FORMAT ###
|
||||
model_response = self._transform_response_dict_to_openai_response(
|
||||
model=model,
|
||||
logging_obj=logging_obj,
|
||||
prompt=prompt,
|
||||
response=response,
|
||||
data=prepared_request.data,
|
||||
model_response=model_response,
|
||||
)
|
||||
return model_response
|
||||
|
||||
def _prepare_request(
|
||||
self,
|
||||
model: str,
|
||||
optional_params: dict,
|
||||
api_base: Optional[str],
|
||||
extra_headers: Optional[dict],
|
||||
logging_obj: LitellmLogging,
|
||||
prompt: str,
|
||||
) -> BedrockImagePreparedRequest:
|
||||
"""
|
||||
Prepare the request body, headers, and endpoint URL for the Bedrock Image Generation API
|
||||
|
||||
Args:
|
||||
model (str): The model to use for the image generation
|
||||
optional_params (dict): The optional parameters for the image generation
|
||||
api_base (Optional[str]): The base URL for the Bedrock API
|
||||
extra_headers (Optional[dict]): The extra headers to include in the request
|
||||
logging_obj (LitellmLogging): The logging object to use for logging
|
||||
prompt (str): The prompt to use for the image generation
|
||||
Returns:
|
||||
BedrockImagePreparedRequest: The prepared request object
|
||||
|
||||
The BedrockImagePreparedRequest contains:
|
||||
endpoint_url (str): The endpoint URL for the Bedrock Image Generation API
|
||||
prepped (httpx.Request): The prepared request object
|
||||
body (bytes): The request body
|
||||
"""
|
||||
try:
|
||||
from botocore.auth import SigV4Auth
|
||||
from botocore.awsrequest import AWSRequest
|
||||
except ImportError:
|
||||
raise ImportError("Missing boto3 to call bedrock. Run 'pip install boto3'.")
|
||||
boto3_credentials_info = self._get_boto_credentials_from_optional_params(
|
||||
optional_params, model
|
||||
)
|
||||
|
||||
### SET RUNTIME ENDPOINT ###
|
||||
modelId = model
|
||||
_, proxy_endpoint_url = self.get_runtime_endpoint(
|
||||
api_base=api_base,
|
||||
aws_bedrock_runtime_endpoint=boto3_credentials_info.aws_bedrock_runtime_endpoint,
|
||||
aws_region_name=boto3_credentials_info.aws_region_name,
|
||||
)
|
||||
proxy_endpoint_url = f"{proxy_endpoint_url}/model/{modelId}/invoke"
|
||||
sigv4 = SigV4Auth(
|
||||
boto3_credentials_info.credentials,
|
||||
"bedrock",
|
||||
boto3_credentials_info.aws_region_name,
|
||||
)
|
||||
|
||||
data = self._get_request_body(
|
||||
model=model, prompt=prompt, optional_params=optional_params
|
||||
)
|
||||
|
||||
# Make POST Request
|
||||
body = json.dumps(data).encode("utf-8")
|
||||
|
||||
headers = {"Content-Type": "application/json"}
|
||||
if extra_headers is not None:
|
||||
headers = {"Content-Type": "application/json", **extra_headers}
|
||||
request = AWSRequest(
|
||||
method="POST", url=proxy_endpoint_url, data=body, headers=headers
|
||||
)
|
||||
sigv4.add_auth(request)
|
||||
if (
|
||||
extra_headers is not None and "Authorization" in extra_headers
|
||||
): # prevent sigv4 from overwriting the auth header
|
||||
request.headers["Authorization"] = extra_headers["Authorization"]
|
||||
prepped = request.prepare()
|
||||
|
||||
## LOGGING
|
||||
logging_obj.pre_call(
|
||||
input=prompt,
|
||||
api_key="",
|
||||
additional_args={
|
||||
"complete_input_dict": data,
|
||||
"api_base": proxy_endpoint_url,
|
||||
"headers": prepped.headers,
|
||||
},
|
||||
)
|
||||
return BedrockImagePreparedRequest(
|
||||
endpoint_url=proxy_endpoint_url,
|
||||
prepped=prepped,
|
||||
body=body,
|
||||
data=data,
|
||||
)
|
||||
|
||||
def _get_request_body(
|
||||
self,
|
||||
model: str,
|
||||
prompt: str,
|
||||
optional_params: dict,
|
||||
) -> dict:
|
||||
"""
|
||||
Get the request body for the Bedrock Image Generation API
|
||||
|
||||
Checks the model/provider and transforms the request body accordingly
|
||||
|
||||
Returns:
|
||||
dict: The request body to use for the Bedrock Image Generation API
|
||||
"""
|
||||
provider = model.split(".")[0]
|
||||
inference_params = copy.deepcopy(optional_params)
|
||||
inference_params.pop(
|
||||
"user", None
|
||||
) # make sure user is not passed in for bedrock call
|
||||
data = {}
|
||||
if provider == "stability":
|
||||
if litellm.AmazonStability3Config._is_stability_3_model(model):
|
||||
request_body = litellm.AmazonStability3Config.transform_request_body(
|
||||
prompt=prompt, optional_params=optional_params
|
||||
)
|
||||
return dict(request_body)
|
||||
else:
|
||||
prompt = prompt.replace(os.linesep, " ")
|
||||
## LOAD CONFIG
|
||||
config = litellm.AmazonStabilityConfig.get_config()
|
||||
for k, v in config.items():
|
||||
if (
|
||||
k not in inference_params
|
||||
): # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in
|
||||
inference_params[k] = v
|
||||
data = {
|
||||
"text_prompts": [{"text": prompt, "weight": 1}],
|
||||
**inference_params,
|
||||
}
|
||||
elif provider == "amazon":
|
||||
return dict(
|
||||
litellm.AmazonNovaCanvasConfig.transform_request_body(
|
||||
text=prompt, optional_params=optional_params
|
||||
)
|
||||
)
|
||||
else:
|
||||
raise BedrockError(
|
||||
status_code=422, message=f"Unsupported model={model}, passed in"
|
||||
)
|
||||
return data
|
||||
|
||||
def _transform_response_dict_to_openai_response(
|
||||
self,
|
||||
model_response: ImageResponse,
|
||||
model: str,
|
||||
logging_obj: LitellmLogging,
|
||||
prompt: str,
|
||||
response: httpx.Response,
|
||||
data: dict,
|
||||
) -> ImageResponse:
|
||||
"""
|
||||
Transforms the Image Generation response from Bedrock to OpenAI format
|
||||
"""
|
||||
|
||||
## LOGGING
|
||||
if logging_obj is not None:
|
||||
logging_obj.post_call(
|
||||
input=prompt,
|
||||
api_key="",
|
||||
original_response=response.text,
|
||||
additional_args={"complete_input_dict": data},
|
||||
)
|
||||
verbose_logger.debug("raw model_response: %s", response.text)
|
||||
response_dict = response.json()
|
||||
if response_dict is None:
|
||||
raise ValueError("Error in response object format, got None")
|
||||
|
||||
config_class = (
|
||||
litellm.AmazonStability3Config
|
||||
if litellm.AmazonStability3Config._is_stability_3_model(model=model)
|
||||
else (
|
||||
litellm.AmazonNovaCanvasConfig
|
||||
if litellm.AmazonNovaCanvasConfig._is_nova_model(model=model)
|
||||
else litellm.AmazonStabilityConfig
|
||||
)
|
||||
)
|
||||
config_class.transform_response_dict_to_openai_response(
|
||||
model_response=model_response,
|
||||
response_dict=response_dict,
|
||||
)
|
||||
|
||||
return model_response
|
||||
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,167 @@
|
||||
import json
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union, cast
|
||||
|
||||
import httpx
|
||||
|
||||
import litellm
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as LitellmLogging
|
||||
from litellm.llms.custom_httpx.http_handler import (
|
||||
AsyncHTTPHandler,
|
||||
HTTPHandler,
|
||||
_get_httpx_client,
|
||||
get_async_httpx_client,
|
||||
)
|
||||
from litellm.types.llms.bedrock import BedrockPreparedRequest
|
||||
from litellm.types.rerank import RerankRequest
|
||||
from litellm.types.utils import RerankResponse
|
||||
|
||||
from ..base_aws_llm import BaseAWSLLM
|
||||
from ..common_utils import BedrockError
|
||||
from .transformation import BedrockRerankConfig
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from botocore.awsrequest import AWSPreparedRequest
|
||||
else:
|
||||
AWSPreparedRequest = Any
|
||||
|
||||
|
||||
class BedrockRerankHandler(BaseAWSLLM):
|
||||
async def arerank(
|
||||
self,
|
||||
prepared_request: BedrockPreparedRequest,
|
||||
client: Optional[AsyncHTTPHandler] = None,
|
||||
):
|
||||
if client is None:
|
||||
client = get_async_httpx_client(llm_provider=litellm.LlmProviders.BEDROCK)
|
||||
try:
|
||||
response = await client.post(url=prepared_request["endpoint_url"], headers=prepared_request["prepped"].headers, data=prepared_request["body"]) # type: ignore
|
||||
response.raise_for_status()
|
||||
except httpx.HTTPStatusError as err:
|
||||
error_code = err.response.status_code
|
||||
raise BedrockError(status_code=error_code, message=err.response.text)
|
||||
except httpx.TimeoutException:
|
||||
raise BedrockError(status_code=408, message="Timeout error occurred.")
|
||||
|
||||
return BedrockRerankConfig()._transform_response(response.json())
|
||||
|
||||
def rerank(
|
||||
self,
|
||||
model: str,
|
||||
query: str,
|
||||
documents: List[Union[str, Dict[str, Any]]],
|
||||
optional_params: dict,
|
||||
logging_obj: LitellmLogging,
|
||||
top_n: Optional[int] = None,
|
||||
rank_fields: Optional[List[str]] = None,
|
||||
return_documents: Optional[bool] = True,
|
||||
max_chunks_per_doc: Optional[int] = None,
|
||||
_is_async: Optional[bool] = False,
|
||||
api_base: Optional[str] = None,
|
||||
extra_headers: Optional[dict] = None,
|
||||
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
|
||||
) -> RerankResponse:
|
||||
request_data = RerankRequest(
|
||||
model=model,
|
||||
query=query,
|
||||
documents=documents,
|
||||
top_n=top_n,
|
||||
rank_fields=rank_fields,
|
||||
return_documents=return_documents,
|
||||
)
|
||||
data = BedrockRerankConfig()._transform_request(request_data)
|
||||
|
||||
prepared_request = self._prepare_request(
|
||||
model=model,
|
||||
optional_params=optional_params,
|
||||
api_base=api_base,
|
||||
extra_headers=extra_headers,
|
||||
data=cast(dict, data),
|
||||
)
|
||||
|
||||
logging_obj.pre_call(
|
||||
input=data,
|
||||
api_key="",
|
||||
additional_args={
|
||||
"complete_input_dict": data,
|
||||
"api_base": prepared_request["endpoint_url"],
|
||||
"headers": prepared_request["prepped"].headers,
|
||||
},
|
||||
)
|
||||
|
||||
if _is_async:
|
||||
return self.arerank(prepared_request, client=client if client is not None and isinstance(client, AsyncHTTPHandler) else None) # type: ignore
|
||||
|
||||
if client is None or not isinstance(client, HTTPHandler):
|
||||
client = _get_httpx_client()
|
||||
try:
|
||||
response = client.post(url=prepared_request["endpoint_url"], headers=prepared_request["prepped"].headers, data=prepared_request["body"]) # type: ignore
|
||||
response.raise_for_status()
|
||||
except httpx.HTTPStatusError as err:
|
||||
error_code = err.response.status_code
|
||||
raise BedrockError(status_code=error_code, message=err.response.text)
|
||||
except httpx.TimeoutException:
|
||||
raise BedrockError(status_code=408, message="Timeout error occurred.")
|
||||
|
||||
logging_obj.post_call(
|
||||
original_response=response.text,
|
||||
api_key="",
|
||||
)
|
||||
|
||||
response_json = response.json()
|
||||
|
||||
return BedrockRerankConfig()._transform_response(response_json)
|
||||
|
||||
def _prepare_request(
|
||||
self,
|
||||
model: str,
|
||||
api_base: Optional[str],
|
||||
extra_headers: Optional[dict],
|
||||
data: dict,
|
||||
optional_params: dict,
|
||||
) -> BedrockPreparedRequest:
|
||||
try:
|
||||
from botocore.auth import SigV4Auth
|
||||
from botocore.awsrequest import AWSRequest
|
||||
except ImportError:
|
||||
raise ImportError("Missing boto3 to call bedrock. Run 'pip install boto3'.")
|
||||
boto3_credentials_info = self._get_boto_credentials_from_optional_params(
|
||||
optional_params, model
|
||||
)
|
||||
|
||||
### SET RUNTIME ENDPOINT ###
|
||||
_, proxy_endpoint_url = self.get_runtime_endpoint(
|
||||
api_base=api_base,
|
||||
aws_bedrock_runtime_endpoint=boto3_credentials_info.aws_bedrock_runtime_endpoint,
|
||||
aws_region_name=boto3_credentials_info.aws_region_name,
|
||||
)
|
||||
proxy_endpoint_url = proxy_endpoint_url.replace(
|
||||
"bedrock-runtime", "bedrock-agent-runtime"
|
||||
)
|
||||
proxy_endpoint_url = f"{proxy_endpoint_url}/rerank"
|
||||
sigv4 = SigV4Auth(
|
||||
boto3_credentials_info.credentials,
|
||||
"bedrock",
|
||||
boto3_credentials_info.aws_region_name,
|
||||
)
|
||||
# Make POST Request
|
||||
body = json.dumps(data).encode("utf-8")
|
||||
|
||||
headers = {"Content-Type": "application/json"}
|
||||
if extra_headers is not None:
|
||||
headers = {"Content-Type": "application/json", **extra_headers}
|
||||
request = AWSRequest(
|
||||
method="POST", url=proxy_endpoint_url, data=body, headers=headers
|
||||
)
|
||||
sigv4.add_auth(request)
|
||||
if (
|
||||
extra_headers is not None and "Authorization" in extra_headers
|
||||
): # prevent sigv4 from overwriting the auth header
|
||||
request.headers["Authorization"] = extra_headers["Authorization"]
|
||||
prepped = request.prepare()
|
||||
|
||||
return BedrockPreparedRequest(
|
||||
endpoint_url=proxy_endpoint_url,
|
||||
prepped=prepped,
|
||||
body=body,
|
||||
data=data,
|
||||
)
|
||||
@@ -0,0 +1,118 @@
|
||||
"""
|
||||
Translates from Cohere's `/v1/rerank` input format to Bedrock's `/rerank` input format.
|
||||
|
||||
Why separate file? Make it easy to see how transformation works
|
||||
"""
|
||||
|
||||
import uuid
|
||||
from typing import List, Optional, Union
|
||||
|
||||
from litellm.types.llms.bedrock import (
|
||||
BedrockRerankBedrockRerankingConfiguration,
|
||||
BedrockRerankConfiguration,
|
||||
BedrockRerankInlineDocumentSource,
|
||||
BedrockRerankModelConfiguration,
|
||||
BedrockRerankQuery,
|
||||
BedrockRerankRequest,
|
||||
BedrockRerankSource,
|
||||
BedrockRerankTextDocument,
|
||||
BedrockRerankTextQuery,
|
||||
)
|
||||
from litellm.types.rerank import (
|
||||
RerankBilledUnits,
|
||||
RerankRequest,
|
||||
RerankResponse,
|
||||
RerankResponseMeta,
|
||||
RerankResponseResult,
|
||||
RerankTokens,
|
||||
)
|
||||
|
||||
|
||||
class BedrockRerankConfig:
|
||||
def _transform_sources(
|
||||
self, documents: List[Union[str, dict]]
|
||||
) -> List[BedrockRerankSource]:
|
||||
"""
|
||||
Transform the sources from RerankRequest format to Bedrock format.
|
||||
"""
|
||||
_sources = []
|
||||
for document in documents:
|
||||
if isinstance(document, str):
|
||||
_sources.append(
|
||||
BedrockRerankSource(
|
||||
inlineDocumentSource=BedrockRerankInlineDocumentSource(
|
||||
textDocument=BedrockRerankTextDocument(text=document),
|
||||
type="TEXT",
|
||||
),
|
||||
type="INLINE",
|
||||
)
|
||||
)
|
||||
else:
|
||||
_sources.append(
|
||||
BedrockRerankSource(
|
||||
inlineDocumentSource=BedrockRerankInlineDocumentSource(
|
||||
jsonDocument=document, type="JSON"
|
||||
),
|
||||
type="INLINE",
|
||||
)
|
||||
)
|
||||
return _sources
|
||||
|
||||
def _transform_request(self, request_data: RerankRequest) -> BedrockRerankRequest:
|
||||
"""
|
||||
Transform the request from RerankRequest format to Bedrock format.
|
||||
"""
|
||||
_sources = self._transform_sources(request_data.documents)
|
||||
|
||||
return BedrockRerankRequest(
|
||||
queries=[
|
||||
BedrockRerankQuery(
|
||||
textQuery=BedrockRerankTextQuery(text=request_data.query),
|
||||
type="TEXT",
|
||||
)
|
||||
],
|
||||
rerankingConfiguration=BedrockRerankConfiguration(
|
||||
bedrockRerankingConfiguration=BedrockRerankBedrockRerankingConfiguration(
|
||||
modelConfiguration=BedrockRerankModelConfiguration(
|
||||
modelArn=request_data.model
|
||||
),
|
||||
numberOfResults=request_data.top_n or len(request_data.documents),
|
||||
),
|
||||
type="BEDROCK_RERANKING_MODEL",
|
||||
),
|
||||
sources=_sources,
|
||||
)
|
||||
|
||||
def _transform_response(self, response: dict) -> RerankResponse:
|
||||
"""
|
||||
Transform the response from Bedrock into the RerankResponse format.
|
||||
|
||||
example input:
|
||||
{"results":[{"index":0,"relevanceScore":0.6847912669181824},{"index":1,"relevanceScore":0.5980774760246277}]}
|
||||
"""
|
||||
_billed_units = RerankBilledUnits(
|
||||
**response.get("usage", {"search_units": 1})
|
||||
) # by default 1 search unit
|
||||
_tokens = RerankTokens(**response.get("usage", {}))
|
||||
rerank_meta = RerankResponseMeta(billed_units=_billed_units, tokens=_tokens)
|
||||
|
||||
_results: Optional[List[RerankResponseResult]] = None
|
||||
|
||||
bedrock_results = response.get("results")
|
||||
if bedrock_results:
|
||||
_results = [
|
||||
RerankResponseResult(
|
||||
index=result.get("index"),
|
||||
relevance_score=result.get("relevanceScore"),
|
||||
)
|
||||
for result in bedrock_results
|
||||
]
|
||||
|
||||
if _results is None:
|
||||
raise ValueError(f"No results found in the response={response}")
|
||||
|
||||
return RerankResponse(
|
||||
id=response.get("id") or str(uuid.uuid4()),
|
||||
results=_results,
|
||||
meta=rerank_meta,
|
||||
) # Return response
|
||||
Reference in New Issue
Block a user