structure saas with tools
This commit is contained in:
@@ -0,0 +1,2 @@
|
||||
from .converse_handler import BedrockConverseLLM
|
||||
from .invoke_handler import BedrockLLM
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,466 @@
|
||||
import json
|
||||
import urllib
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
import httpx
|
||||
|
||||
import litellm
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObject
|
||||
from litellm.llms.custom_httpx.http_handler import (
|
||||
AsyncHTTPHandler,
|
||||
HTTPHandler,
|
||||
_get_httpx_client,
|
||||
get_async_httpx_client,
|
||||
)
|
||||
from litellm.types.utils import ModelResponse
|
||||
from litellm.utils import CustomStreamWrapper
|
||||
|
||||
from ..base_aws_llm import BaseAWSLLM, Credentials
|
||||
from ..common_utils import BedrockError
|
||||
from .invoke_handler import AWSEventStreamDecoder, MockResponseIterator, make_call
|
||||
|
||||
|
||||
def make_sync_call(
|
||||
client: Optional[HTTPHandler],
|
||||
api_base: str,
|
||||
headers: dict,
|
||||
data: str,
|
||||
model: str,
|
||||
messages: list,
|
||||
logging_obj: LiteLLMLoggingObject,
|
||||
json_mode: Optional[bool] = False,
|
||||
fake_stream: bool = False,
|
||||
):
|
||||
if client is None:
|
||||
client = _get_httpx_client() # Create a new client if none provided
|
||||
|
||||
response = client.post(
|
||||
api_base,
|
||||
headers=headers,
|
||||
data=data,
|
||||
stream=not fake_stream,
|
||||
logging_obj=logging_obj,
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
raise BedrockError(
|
||||
status_code=response.status_code, message=str(response.read())
|
||||
)
|
||||
|
||||
if fake_stream:
|
||||
model_response: (
|
||||
ModelResponse
|
||||
) = litellm.AmazonConverseConfig()._transform_response(
|
||||
model=model,
|
||||
response=response,
|
||||
model_response=litellm.ModelResponse(),
|
||||
stream=True,
|
||||
logging_obj=logging_obj,
|
||||
optional_params={},
|
||||
api_key="",
|
||||
data=data,
|
||||
messages=messages,
|
||||
encoding=litellm.encoding,
|
||||
) # type: ignore
|
||||
completion_stream: Any = MockResponseIterator(
|
||||
model_response=model_response, json_mode=json_mode
|
||||
)
|
||||
else:
|
||||
decoder = AWSEventStreamDecoder(model=model)
|
||||
completion_stream = decoder.iter_bytes(response.iter_bytes(chunk_size=1024))
|
||||
|
||||
# LOGGING
|
||||
logging_obj.post_call(
|
||||
input=messages,
|
||||
api_key="",
|
||||
original_response="first stream response received",
|
||||
additional_args={"complete_input_dict": data},
|
||||
)
|
||||
|
||||
return completion_stream
|
||||
|
||||
|
||||
class BedrockConverseLLM(BaseAWSLLM):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def encode_model_id(self, model_id: str) -> str:
|
||||
"""
|
||||
Double encode the model ID to ensure it matches the expected double-encoded format.
|
||||
Args:
|
||||
model_id (str): The model ID to encode.
|
||||
Returns:
|
||||
str: The double-encoded model ID.
|
||||
"""
|
||||
return urllib.parse.quote(model_id, safe="") # type: ignore
|
||||
|
||||
async def async_streaming(
|
||||
self,
|
||||
model: str,
|
||||
messages: list,
|
||||
api_base: str,
|
||||
model_response: ModelResponse,
|
||||
timeout: Optional[Union[float, httpx.Timeout]],
|
||||
encoding,
|
||||
logging_obj,
|
||||
stream,
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
credentials: Credentials,
|
||||
logger_fn=None,
|
||||
headers={},
|
||||
client: Optional[AsyncHTTPHandler] = None,
|
||||
fake_stream: bool = False,
|
||||
json_mode: Optional[bool] = False,
|
||||
) -> CustomStreamWrapper:
|
||||
request_data = await litellm.AmazonConverseConfig()._async_transform_request(
|
||||
model=model,
|
||||
messages=messages,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
)
|
||||
data = json.dumps(request_data)
|
||||
|
||||
prepped = self.get_request_headers(
|
||||
credentials=credentials,
|
||||
aws_region_name=litellm_params.get("aws_region_name") or "us-west-2",
|
||||
extra_headers=headers,
|
||||
endpoint_url=api_base,
|
||||
data=data,
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
## LOGGING
|
||||
logging_obj.pre_call(
|
||||
input=messages,
|
||||
api_key="",
|
||||
additional_args={
|
||||
"complete_input_dict": data,
|
||||
"api_base": api_base,
|
||||
"headers": dict(prepped.headers),
|
||||
},
|
||||
)
|
||||
|
||||
completion_stream = await make_call(
|
||||
client=client,
|
||||
api_base=api_base,
|
||||
headers=dict(prepped.headers),
|
||||
data=data,
|
||||
model=model,
|
||||
messages=messages,
|
||||
logging_obj=logging_obj,
|
||||
fake_stream=fake_stream,
|
||||
json_mode=json_mode,
|
||||
)
|
||||
streaming_response = CustomStreamWrapper(
|
||||
completion_stream=completion_stream,
|
||||
model=model,
|
||||
custom_llm_provider="bedrock",
|
||||
logging_obj=logging_obj,
|
||||
)
|
||||
return streaming_response
|
||||
|
||||
async def async_completion(
|
||||
self,
|
||||
model: str,
|
||||
messages: list,
|
||||
api_base: str,
|
||||
model_response: ModelResponse,
|
||||
timeout: Optional[Union[float, httpx.Timeout]],
|
||||
encoding,
|
||||
logging_obj: LiteLLMLoggingObject,
|
||||
stream,
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
credentials: Credentials,
|
||||
logger_fn=None,
|
||||
headers: dict = {},
|
||||
client: Optional[AsyncHTTPHandler] = None,
|
||||
) -> Union[ModelResponse, CustomStreamWrapper]:
|
||||
request_data = await litellm.AmazonConverseConfig()._async_transform_request(
|
||||
model=model,
|
||||
messages=messages,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
)
|
||||
data = json.dumps(request_data)
|
||||
|
||||
prepped = self.get_request_headers(
|
||||
credentials=credentials,
|
||||
aws_region_name=litellm_params.get("aws_region_name") or "us-west-2",
|
||||
extra_headers=headers,
|
||||
endpoint_url=api_base,
|
||||
data=data,
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
## LOGGING
|
||||
logging_obj.pre_call(
|
||||
input=messages,
|
||||
api_key="",
|
||||
additional_args={
|
||||
"complete_input_dict": data,
|
||||
"api_base": api_base,
|
||||
"headers": prepped.headers,
|
||||
},
|
||||
)
|
||||
|
||||
headers = dict(prepped.headers)
|
||||
if client is None or not isinstance(client, AsyncHTTPHandler):
|
||||
_params = {}
|
||||
if timeout is not None:
|
||||
if isinstance(timeout, float) or isinstance(timeout, int):
|
||||
timeout = httpx.Timeout(timeout)
|
||||
_params["timeout"] = timeout
|
||||
client = get_async_httpx_client(
|
||||
params=_params, llm_provider=litellm.LlmProviders.BEDROCK
|
||||
)
|
||||
else:
|
||||
client = client # type: ignore
|
||||
|
||||
try:
|
||||
response = await client.post(
|
||||
url=api_base,
|
||||
headers=headers,
|
||||
data=data,
|
||||
logging_obj=logging_obj,
|
||||
) # type: ignore
|
||||
response.raise_for_status()
|
||||
except httpx.HTTPStatusError as err:
|
||||
error_code = err.response.status_code
|
||||
raise BedrockError(status_code=error_code, message=err.response.text)
|
||||
except httpx.TimeoutException:
|
||||
raise BedrockError(status_code=408, message="Timeout error occurred.")
|
||||
|
||||
return litellm.AmazonConverseConfig()._transform_response(
|
||||
model=model,
|
||||
response=response,
|
||||
model_response=model_response,
|
||||
stream=stream if isinstance(stream, bool) else False,
|
||||
logging_obj=logging_obj,
|
||||
api_key="",
|
||||
data=data,
|
||||
messages=messages,
|
||||
optional_params=optional_params,
|
||||
encoding=encoding,
|
||||
)
|
||||
|
||||
def completion( # noqa: PLR0915
|
||||
self,
|
||||
model: str,
|
||||
messages: list,
|
||||
api_base: Optional[str],
|
||||
custom_prompt_dict: dict,
|
||||
model_response: ModelResponse,
|
||||
encoding,
|
||||
logging_obj: LiteLLMLoggingObject,
|
||||
optional_params: dict,
|
||||
acompletion: bool,
|
||||
timeout: Optional[Union[float, httpx.Timeout]],
|
||||
litellm_params: dict,
|
||||
logger_fn=None,
|
||||
extra_headers: Optional[dict] = None,
|
||||
client: Optional[Union[AsyncHTTPHandler, HTTPHandler]] = None,
|
||||
):
|
||||
## SETUP ##
|
||||
stream = optional_params.pop("stream", None)
|
||||
unencoded_model_id = optional_params.pop("model_id", None)
|
||||
fake_stream = optional_params.pop("fake_stream", False)
|
||||
json_mode = optional_params.get("json_mode", False)
|
||||
if unencoded_model_id is not None:
|
||||
modelId = self.encode_model_id(model_id=unencoded_model_id)
|
||||
else:
|
||||
modelId = self.encode_model_id(model_id=model)
|
||||
|
||||
if stream is True and "ai21" in modelId:
|
||||
fake_stream = True
|
||||
|
||||
### SET REGION NAME ###
|
||||
aws_region_name = self._get_aws_region_name(
|
||||
optional_params=optional_params,
|
||||
model=model,
|
||||
model_id=unencoded_model_id,
|
||||
)
|
||||
|
||||
## CREDENTIALS ##
|
||||
# pop aws_secret_access_key, aws_access_key_id, aws_region_name from kwargs, since completion calls fail with them
|
||||
aws_secret_access_key = optional_params.pop("aws_secret_access_key", None)
|
||||
aws_access_key_id = optional_params.pop("aws_access_key_id", None)
|
||||
aws_session_token = optional_params.pop("aws_session_token", None)
|
||||
aws_role_name = optional_params.pop("aws_role_name", None)
|
||||
aws_session_name = optional_params.pop("aws_session_name", None)
|
||||
aws_profile_name = optional_params.pop("aws_profile_name", None)
|
||||
aws_bedrock_runtime_endpoint = optional_params.pop(
|
||||
"aws_bedrock_runtime_endpoint", None
|
||||
) # https://bedrock-runtime.{region_name}.amazonaws.com
|
||||
aws_web_identity_token = optional_params.pop("aws_web_identity_token", None)
|
||||
aws_sts_endpoint = optional_params.pop("aws_sts_endpoint", None)
|
||||
optional_params.pop("aws_region_name", None)
|
||||
|
||||
litellm_params[
|
||||
"aws_region_name"
|
||||
] = aws_region_name # [DO NOT DELETE] important for async calls
|
||||
|
||||
credentials: Credentials = self.get_credentials(
|
||||
aws_access_key_id=aws_access_key_id,
|
||||
aws_secret_access_key=aws_secret_access_key,
|
||||
aws_session_token=aws_session_token,
|
||||
aws_region_name=aws_region_name,
|
||||
aws_session_name=aws_session_name,
|
||||
aws_profile_name=aws_profile_name,
|
||||
aws_role_name=aws_role_name,
|
||||
aws_web_identity_token=aws_web_identity_token,
|
||||
aws_sts_endpoint=aws_sts_endpoint,
|
||||
)
|
||||
|
||||
### SET RUNTIME ENDPOINT ###
|
||||
endpoint_url, proxy_endpoint_url = self.get_runtime_endpoint(
|
||||
api_base=api_base,
|
||||
aws_bedrock_runtime_endpoint=aws_bedrock_runtime_endpoint,
|
||||
aws_region_name=aws_region_name,
|
||||
)
|
||||
if (stream is not None and stream is True) and not fake_stream:
|
||||
endpoint_url = f"{endpoint_url}/model/{modelId}/converse-stream"
|
||||
proxy_endpoint_url = f"{proxy_endpoint_url}/model/{modelId}/converse-stream"
|
||||
else:
|
||||
endpoint_url = f"{endpoint_url}/model/{modelId}/converse"
|
||||
proxy_endpoint_url = f"{proxy_endpoint_url}/model/{modelId}/converse"
|
||||
|
||||
## COMPLETION CALL
|
||||
headers = {"Content-Type": "application/json"}
|
||||
if extra_headers is not None:
|
||||
headers = {"Content-Type": "application/json", **extra_headers}
|
||||
|
||||
### ROUTING (ASYNC, STREAMING, SYNC)
|
||||
if acompletion:
|
||||
if isinstance(client, HTTPHandler):
|
||||
client = None
|
||||
if stream is True:
|
||||
return self.async_streaming(
|
||||
model=model,
|
||||
messages=messages,
|
||||
api_base=proxy_endpoint_url,
|
||||
model_response=model_response,
|
||||
encoding=encoding,
|
||||
logging_obj=logging_obj,
|
||||
optional_params=optional_params,
|
||||
stream=True,
|
||||
litellm_params=litellm_params,
|
||||
logger_fn=logger_fn,
|
||||
headers=headers,
|
||||
timeout=timeout,
|
||||
client=client,
|
||||
json_mode=json_mode,
|
||||
fake_stream=fake_stream,
|
||||
credentials=credentials,
|
||||
) # type: ignore
|
||||
### ASYNC COMPLETION
|
||||
return self.async_completion(
|
||||
model=model,
|
||||
messages=messages,
|
||||
api_base=proxy_endpoint_url,
|
||||
model_response=model_response,
|
||||
encoding=encoding,
|
||||
logging_obj=logging_obj,
|
||||
optional_params=optional_params,
|
||||
stream=stream, # type: ignore
|
||||
litellm_params=litellm_params,
|
||||
logger_fn=logger_fn,
|
||||
headers=headers,
|
||||
timeout=timeout,
|
||||
client=client,
|
||||
credentials=credentials,
|
||||
) # type: ignore
|
||||
|
||||
## TRANSFORMATION ##
|
||||
|
||||
_data = litellm.AmazonConverseConfig()._transform_request(
|
||||
model=model,
|
||||
messages=messages,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
)
|
||||
data = json.dumps(_data)
|
||||
|
||||
prepped = self.get_request_headers(
|
||||
credentials=credentials,
|
||||
aws_region_name=aws_region_name,
|
||||
extra_headers=extra_headers,
|
||||
endpoint_url=proxy_endpoint_url,
|
||||
data=data,
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
## LOGGING
|
||||
logging_obj.pre_call(
|
||||
input=messages,
|
||||
api_key="",
|
||||
additional_args={
|
||||
"complete_input_dict": data,
|
||||
"api_base": proxy_endpoint_url,
|
||||
"headers": prepped.headers,
|
||||
},
|
||||
)
|
||||
if client is None or isinstance(client, AsyncHTTPHandler):
|
||||
_params = {}
|
||||
if timeout is not None:
|
||||
if isinstance(timeout, float) or isinstance(timeout, int):
|
||||
timeout = httpx.Timeout(timeout)
|
||||
_params["timeout"] = timeout
|
||||
client = _get_httpx_client(_params) # type: ignore
|
||||
else:
|
||||
client = client
|
||||
|
||||
if stream is not None and stream is True:
|
||||
completion_stream = make_sync_call(
|
||||
client=(
|
||||
client
|
||||
if client is not None and isinstance(client, HTTPHandler)
|
||||
else None
|
||||
),
|
||||
api_base=proxy_endpoint_url,
|
||||
headers=prepped.headers, # type: ignore
|
||||
data=data,
|
||||
model=model,
|
||||
messages=messages,
|
||||
logging_obj=logging_obj,
|
||||
json_mode=json_mode,
|
||||
fake_stream=fake_stream,
|
||||
)
|
||||
streaming_response = CustomStreamWrapper(
|
||||
completion_stream=completion_stream,
|
||||
model=model,
|
||||
custom_llm_provider="bedrock",
|
||||
logging_obj=logging_obj,
|
||||
)
|
||||
|
||||
return streaming_response
|
||||
|
||||
### COMPLETION
|
||||
|
||||
try:
|
||||
response = client.post(
|
||||
url=proxy_endpoint_url,
|
||||
headers=prepped.headers,
|
||||
data=data,
|
||||
logging_obj=logging_obj,
|
||||
) # type: ignore
|
||||
response.raise_for_status()
|
||||
except httpx.HTTPStatusError as err:
|
||||
error_code = err.response.status_code
|
||||
raise BedrockError(status_code=error_code, message=err.response.text)
|
||||
except httpx.TimeoutException:
|
||||
raise BedrockError(status_code=408, message="Timeout error occurred.")
|
||||
|
||||
return litellm.AmazonConverseConfig()._transform_response(
|
||||
model=model,
|
||||
response=response,
|
||||
model_response=model_response,
|
||||
stream=stream if isinstance(stream, bool) else False,
|
||||
logging_obj=logging_obj,
|
||||
api_key="",
|
||||
data=data,
|
||||
messages=messages,
|
||||
optional_params=optional_params,
|
||||
encoding=encoding,
|
||||
)
|
||||
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,5 @@
|
||||
"""
|
||||
Uses base_llm_http_handler to call the 'converse like' endpoint.
|
||||
|
||||
Relevant issue: https://github.com/BerriAI/litellm/issues/8085
|
||||
"""
|
||||
@@ -0,0 +1,3 @@
|
||||
"""
|
||||
Uses `converse_transformation.py` to transform the messages to the format required by Bedrock Converse.
|
||||
"""
|
||||
@@ -0,0 +1,852 @@
|
||||
"""
|
||||
Translating between OpenAI's `/chat/completion` format and Amazon's `/converse` format
|
||||
"""
|
||||
|
||||
import copy
|
||||
import time
|
||||
import types
|
||||
from typing import List, Literal, Optional, Tuple, Union, cast, overload
|
||||
|
||||
import httpx
|
||||
|
||||
import litellm
|
||||
from litellm.litellm_core_utils.core_helpers import map_finish_reason
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging
|
||||
from litellm.litellm_core_utils.prompt_templates.factory import (
|
||||
BedrockConverseMessagesProcessor,
|
||||
_bedrock_converse_messages_pt,
|
||||
_bedrock_tools_pt,
|
||||
)
|
||||
from litellm.llms.anthropic.chat.transformation import AnthropicConfig
|
||||
from litellm.llms.base_llm.chat.transformation import BaseConfig, BaseLLMException
|
||||
from litellm.types.llms.bedrock import *
|
||||
from litellm.types.llms.openai import (
|
||||
AllMessageValues,
|
||||
ChatCompletionRedactedThinkingBlock,
|
||||
ChatCompletionResponseMessage,
|
||||
ChatCompletionSystemMessage,
|
||||
ChatCompletionThinkingBlock,
|
||||
ChatCompletionToolCallChunk,
|
||||
ChatCompletionToolCallFunctionChunk,
|
||||
ChatCompletionToolParam,
|
||||
ChatCompletionToolParamFunctionChunk,
|
||||
ChatCompletionUserMessage,
|
||||
OpenAIChatCompletionToolParam,
|
||||
OpenAIMessageContentListBlock,
|
||||
)
|
||||
from litellm.types.utils import ModelResponse, PromptTokensDetailsWrapper, Usage
|
||||
from litellm.utils import add_dummy_tool, has_tool_call_blocks
|
||||
|
||||
from ..common_utils import BedrockError, BedrockModelInfo, get_bedrock_tool_name
|
||||
|
||||
|
||||
class AmazonConverseConfig(BaseConfig):
|
||||
"""
|
||||
Reference - https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_Converse.html
|
||||
#2 - https://docs.aws.amazon.com/bedrock/latest/userguide/conversation-inference.html#conversation-inference-supported-models-features
|
||||
"""
|
||||
|
||||
maxTokens: Optional[int]
|
||||
stopSequences: Optional[List[str]]
|
||||
temperature: Optional[int]
|
||||
topP: Optional[int]
|
||||
topK: Optional[int]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
maxTokens: Optional[int] = None,
|
||||
stopSequences: Optional[List[str]] = None,
|
||||
temperature: Optional[int] = None,
|
||||
topP: Optional[int] = None,
|
||||
topK: Optional[int] = None,
|
||||
) -> None:
|
||||
locals_ = locals().copy()
|
||||
for key, value in locals_.items():
|
||||
if key != "self" and value is not None:
|
||||
setattr(self.__class__, key, value)
|
||||
|
||||
@property
|
||||
def custom_llm_provider(self) -> Optional[str]:
|
||||
return "bedrock_converse"
|
||||
|
||||
@classmethod
|
||||
def get_config_blocks(cls) -> dict:
|
||||
return {
|
||||
"guardrailConfig": GuardrailConfigBlock,
|
||||
"performanceConfig": PerformanceConfigBlock,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def get_config(cls):
|
||||
return {
|
||||
k: v
|
||||
for k, v in cls.__dict__.items()
|
||||
if not k.startswith("__")
|
||||
and not isinstance(
|
||||
v,
|
||||
(
|
||||
types.FunctionType,
|
||||
types.BuiltinFunctionType,
|
||||
classmethod,
|
||||
staticmethod,
|
||||
),
|
||||
)
|
||||
and v is not None
|
||||
}
|
||||
|
||||
def get_supported_openai_params(self, model: str) -> List[str]:
|
||||
supported_params = [
|
||||
"max_tokens",
|
||||
"max_completion_tokens",
|
||||
"stream",
|
||||
"stream_options",
|
||||
"stop",
|
||||
"temperature",
|
||||
"top_p",
|
||||
"extra_headers",
|
||||
"response_format",
|
||||
]
|
||||
|
||||
## Filter out 'cross-region' from model name
|
||||
base_model = BedrockModelInfo.get_base_model(model)
|
||||
|
||||
if (
|
||||
base_model.startswith("anthropic")
|
||||
or base_model.startswith("mistral")
|
||||
or base_model.startswith("cohere")
|
||||
or base_model.startswith("meta.llama3-1")
|
||||
or base_model.startswith("meta.llama3-2")
|
||||
or base_model.startswith("meta.llama3-3")
|
||||
or base_model.startswith("amazon.nova")
|
||||
):
|
||||
supported_params.append("tools")
|
||||
|
||||
if litellm.utils.supports_tool_choice(
|
||||
model=model, custom_llm_provider=self.custom_llm_provider
|
||||
):
|
||||
# only anthropic and mistral support tool choice config. otherwise (E.g. cohere) will fail the call - https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_ToolChoice.html
|
||||
supported_params.append("tool_choice")
|
||||
|
||||
if (
|
||||
"claude-3-7" in model
|
||||
): # [TODO]: move to a 'supports_reasoning_content' param from model cost map
|
||||
supported_params.append("thinking")
|
||||
supported_params.append("reasoning_effort")
|
||||
return supported_params
|
||||
|
||||
def map_tool_choice_values(
|
||||
self, model: str, tool_choice: Union[str, dict], drop_params: bool
|
||||
) -> Optional[ToolChoiceValuesBlock]:
|
||||
if tool_choice == "none":
|
||||
if litellm.drop_params is True or drop_params is True:
|
||||
return None
|
||||
else:
|
||||
raise litellm.utils.UnsupportedParamsError(
|
||||
message="Bedrock doesn't support tool_choice={}. To drop it from the call, set `litellm.drop_params = True.".format(
|
||||
tool_choice
|
||||
),
|
||||
status_code=400,
|
||||
)
|
||||
elif tool_choice == "required":
|
||||
return ToolChoiceValuesBlock(any={})
|
||||
elif tool_choice == "auto":
|
||||
return ToolChoiceValuesBlock(auto={})
|
||||
elif isinstance(tool_choice, dict):
|
||||
# only supported for anthropic + mistral models - https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_ToolChoice.html
|
||||
specific_tool = SpecificToolChoiceBlock(
|
||||
name=tool_choice.get("function", {}).get("name", "")
|
||||
)
|
||||
return ToolChoiceValuesBlock(tool=specific_tool)
|
||||
else:
|
||||
raise litellm.utils.UnsupportedParamsError(
|
||||
message="Bedrock doesn't support tool_choice={}. Supported tool_choice values=['auto', 'required', json object]. To drop it from the call, set `litellm.drop_params = True.".format(
|
||||
tool_choice
|
||||
),
|
||||
status_code=400,
|
||||
)
|
||||
|
||||
def get_supported_image_types(self) -> List[str]:
|
||||
return ["png", "jpeg", "gif", "webp"]
|
||||
|
||||
def get_supported_document_types(self) -> List[str]:
|
||||
return ["pdf", "csv", "doc", "docx", "xls", "xlsx", "html", "txt", "md"]
|
||||
|
||||
def get_all_supported_content_types(self) -> List[str]:
|
||||
return self.get_supported_image_types() + self.get_supported_document_types()
|
||||
|
||||
def _create_json_tool_call_for_response_format(
|
||||
self,
|
||||
json_schema: Optional[dict] = None,
|
||||
schema_name: str = "json_tool_call",
|
||||
description: Optional[str] = None,
|
||||
) -> ChatCompletionToolParam:
|
||||
"""
|
||||
Handles creating a tool call for getting responses in JSON format.
|
||||
|
||||
Args:
|
||||
json_schema (Optional[dict]): The JSON schema the response should be in
|
||||
|
||||
Returns:
|
||||
AnthropicMessagesTool: The tool call to send to Anthropic API to get responses in JSON format
|
||||
"""
|
||||
|
||||
if json_schema is None:
|
||||
# Anthropic raises a 400 BadRequest error if properties is passed as None
|
||||
# see usage with additionalProperties (Example 5) https://github.com/anthropics/anthropic-cookbook/blob/main/tool_use/extracting_structured_json.ipynb
|
||||
_input_schema = {
|
||||
"type": "object",
|
||||
"additionalProperties": True,
|
||||
"properties": {},
|
||||
}
|
||||
else:
|
||||
_input_schema = json_schema
|
||||
|
||||
tool_param_function_chunk = ChatCompletionToolParamFunctionChunk(
|
||||
name=schema_name, parameters=_input_schema
|
||||
)
|
||||
if description:
|
||||
tool_param_function_chunk["description"] = description
|
||||
|
||||
_tool = ChatCompletionToolParam(
|
||||
type="function",
|
||||
function=tool_param_function_chunk,
|
||||
)
|
||||
return _tool
|
||||
|
||||
def _apply_tool_call_transformation(
|
||||
self,
|
||||
tools: List[OpenAIChatCompletionToolParam],
|
||||
model: str,
|
||||
non_default_params: dict,
|
||||
optional_params: dict,
|
||||
):
|
||||
optional_params = self._add_tools_to_optional_params(
|
||||
optional_params=optional_params, tools=tools
|
||||
)
|
||||
|
||||
if (
|
||||
"meta.llama3-3-70b-instruct-v1:0" in model
|
||||
and non_default_params.get("stream", False) is True
|
||||
):
|
||||
optional_params["fake_stream"] = True
|
||||
|
||||
def map_openai_params(
|
||||
self,
|
||||
non_default_params: dict,
|
||||
optional_params: dict,
|
||||
model: str,
|
||||
drop_params: bool,
|
||||
) -> dict:
|
||||
is_thinking_enabled = self.is_thinking_enabled(non_default_params)
|
||||
|
||||
for param, value in non_default_params.items():
|
||||
if param == "response_format" and isinstance(value, dict):
|
||||
ignore_response_format_types = ["text"]
|
||||
if value["type"] in ignore_response_format_types: # value is a no-op
|
||||
continue
|
||||
|
||||
json_schema: Optional[dict] = None
|
||||
schema_name: str = ""
|
||||
description: Optional[str] = None
|
||||
if "response_schema" in value:
|
||||
json_schema = value["response_schema"]
|
||||
schema_name = "json_tool_call"
|
||||
elif "json_schema" in value:
|
||||
json_schema = value["json_schema"]["schema"]
|
||||
schema_name = value["json_schema"]["name"]
|
||||
description = value["json_schema"].get("description")
|
||||
|
||||
if "type" in value and value["type"] == "text":
|
||||
continue
|
||||
|
||||
"""
|
||||
Follow similar approach to anthropic - translate to a single tool call.
|
||||
|
||||
When using tools in this way: - https://docs.anthropic.com/en/docs/build-with-claude/tool-use#json-mode
|
||||
- You usually want to provide a single tool
|
||||
- You should set tool_choice (see Forcing tool use) to instruct the model to explicitly use that tool
|
||||
- Remember that the model will pass the input to the tool, so the name of the tool and description should be from the model’s perspective.
|
||||
"""
|
||||
_tool = self._create_json_tool_call_for_response_format(
|
||||
json_schema=json_schema,
|
||||
schema_name=schema_name if schema_name != "" else "json_tool_call",
|
||||
description=description,
|
||||
)
|
||||
optional_params = self._add_tools_to_optional_params(
|
||||
optional_params=optional_params, tools=[_tool]
|
||||
)
|
||||
if (
|
||||
litellm.utils.supports_tool_choice(
|
||||
model=model, custom_llm_provider=self.custom_llm_provider
|
||||
)
|
||||
and not is_thinking_enabled
|
||||
):
|
||||
optional_params["tool_choice"] = ToolChoiceValuesBlock(
|
||||
tool=SpecificToolChoiceBlock(
|
||||
name=schema_name if schema_name != "" else "json_tool_call"
|
||||
)
|
||||
)
|
||||
optional_params["json_mode"] = True
|
||||
if non_default_params.get("stream", False) is True:
|
||||
optional_params["fake_stream"] = True
|
||||
if param == "max_tokens" or param == "max_completion_tokens":
|
||||
optional_params["maxTokens"] = value
|
||||
if param == "stream":
|
||||
optional_params["stream"] = value
|
||||
if param == "stop":
|
||||
if isinstance(value, str):
|
||||
if len(value) == 0: # converse raises error for empty strings
|
||||
continue
|
||||
value = [value]
|
||||
optional_params["stopSequences"] = value
|
||||
if param == "temperature":
|
||||
optional_params["temperature"] = value
|
||||
if param == "top_p":
|
||||
optional_params["topP"] = value
|
||||
if param == "tools" and isinstance(value, list):
|
||||
self._apply_tool_call_transformation(
|
||||
tools=cast(List[OpenAIChatCompletionToolParam], value),
|
||||
model=model,
|
||||
non_default_params=non_default_params,
|
||||
optional_params=optional_params,
|
||||
)
|
||||
if param == "tool_choice":
|
||||
_tool_choice_value = self.map_tool_choice_values(
|
||||
model=model, tool_choice=value, drop_params=drop_params # type: ignore
|
||||
)
|
||||
if _tool_choice_value is not None:
|
||||
optional_params["tool_choice"] = _tool_choice_value
|
||||
if param == "thinking":
|
||||
optional_params["thinking"] = value
|
||||
elif param == "reasoning_effort" and isinstance(value, str):
|
||||
optional_params["thinking"] = AnthropicConfig._map_reasoning_effort(
|
||||
value
|
||||
)
|
||||
|
||||
self.update_optional_params_with_thinking_tokens(
|
||||
non_default_params=non_default_params, optional_params=optional_params
|
||||
)
|
||||
|
||||
return optional_params
|
||||
|
||||
@overload
|
||||
def _get_cache_point_block(
|
||||
self,
|
||||
message_block: Union[
|
||||
OpenAIMessageContentListBlock,
|
||||
ChatCompletionUserMessage,
|
||||
ChatCompletionSystemMessage,
|
||||
],
|
||||
block_type: Literal["system"],
|
||||
) -> Optional[SystemContentBlock]:
|
||||
pass
|
||||
|
||||
@overload
|
||||
def _get_cache_point_block(
|
||||
self,
|
||||
message_block: Union[
|
||||
OpenAIMessageContentListBlock,
|
||||
ChatCompletionUserMessage,
|
||||
ChatCompletionSystemMessage,
|
||||
],
|
||||
block_type: Literal["content_block"],
|
||||
) -> Optional[ContentBlock]:
|
||||
pass
|
||||
|
||||
def _get_cache_point_block(
|
||||
self,
|
||||
message_block: Union[
|
||||
OpenAIMessageContentListBlock,
|
||||
ChatCompletionUserMessage,
|
||||
ChatCompletionSystemMessage,
|
||||
],
|
||||
block_type: Literal["system", "content_block"],
|
||||
) -> Optional[Union[SystemContentBlock, ContentBlock]]:
|
||||
if message_block.get("cache_control", None) is None:
|
||||
return None
|
||||
if block_type == "system":
|
||||
return SystemContentBlock(cachePoint=CachePointBlock(type="default"))
|
||||
else:
|
||||
return ContentBlock(cachePoint=CachePointBlock(type="default"))
|
||||
|
||||
def _transform_system_message(
|
||||
self, messages: List[AllMessageValues]
|
||||
) -> Tuple[List[AllMessageValues], List[SystemContentBlock]]:
|
||||
system_prompt_indices = []
|
||||
system_content_blocks: List[SystemContentBlock] = []
|
||||
for idx, message in enumerate(messages):
|
||||
if message["role"] == "system":
|
||||
system_prompt_indices.append(idx)
|
||||
if isinstance(message["content"], str) and message["content"]:
|
||||
system_content_blocks.append(
|
||||
SystemContentBlock(text=message["content"])
|
||||
)
|
||||
cache_block = self._get_cache_point_block(
|
||||
message, block_type="system"
|
||||
)
|
||||
if cache_block:
|
||||
system_content_blocks.append(cache_block)
|
||||
elif isinstance(message["content"], list):
|
||||
for m in message["content"]:
|
||||
if m.get("type") == "text" and m.get("text"):
|
||||
system_content_blocks.append(
|
||||
SystemContentBlock(text=m["text"])
|
||||
)
|
||||
cache_block = self._get_cache_point_block(
|
||||
m, block_type="system"
|
||||
)
|
||||
if cache_block:
|
||||
system_content_blocks.append(cache_block)
|
||||
if len(system_prompt_indices) > 0:
|
||||
for idx in reversed(system_prompt_indices):
|
||||
messages.pop(idx)
|
||||
return messages, system_content_blocks
|
||||
|
||||
def _transform_inference_params(self, inference_params: dict) -> InferenceConfig:
|
||||
if "top_k" in inference_params:
|
||||
inference_params["topK"] = inference_params.pop("top_k")
|
||||
return InferenceConfig(**inference_params)
|
||||
|
||||
def _handle_top_k_value(self, model: str, inference_params: dict) -> dict:
|
||||
base_model = BedrockModelInfo.get_base_model(model)
|
||||
|
||||
val_top_k = None
|
||||
if "topK" in inference_params:
|
||||
val_top_k = inference_params.pop("topK")
|
||||
elif "top_k" in inference_params:
|
||||
val_top_k = inference_params.pop("top_k")
|
||||
|
||||
if val_top_k:
|
||||
if base_model.startswith("anthropic"):
|
||||
return {"top_k": val_top_k}
|
||||
if base_model.startswith("amazon.nova"):
|
||||
return {"inferenceConfig": {"topK": val_top_k}}
|
||||
|
||||
return {}
|
||||
|
||||
def _transform_request_helper(
|
||||
self,
|
||||
model: str,
|
||||
system_content_blocks: List[SystemContentBlock],
|
||||
optional_params: dict,
|
||||
messages: Optional[List[AllMessageValues]] = None,
|
||||
) -> CommonRequestObject:
|
||||
## VALIDATE REQUEST
|
||||
"""
|
||||
Bedrock doesn't support tool calling without `tools=` param specified.
|
||||
"""
|
||||
if (
|
||||
"tools" not in optional_params
|
||||
and messages is not None
|
||||
and has_tool_call_blocks(messages)
|
||||
):
|
||||
if litellm.modify_params:
|
||||
optional_params["tools"] = add_dummy_tool(
|
||||
custom_llm_provider="bedrock_converse"
|
||||
)
|
||||
else:
|
||||
raise litellm.UnsupportedParamsError(
|
||||
message="Bedrock doesn't support tool calling without `tools=` param specified. Pass `tools=` param OR set `litellm.modify_params = True` // `litellm_settings::modify_params: True` to add dummy tool to the request.",
|
||||
model="",
|
||||
llm_provider="bedrock",
|
||||
)
|
||||
|
||||
inference_params = copy.deepcopy(optional_params)
|
||||
supported_converse_params = list(
|
||||
AmazonConverseConfig.__annotations__.keys()
|
||||
) + ["top_k"]
|
||||
supported_tool_call_params = ["tools", "tool_choice"]
|
||||
supported_config_params = list(self.get_config_blocks().keys())
|
||||
total_supported_params = (
|
||||
supported_converse_params
|
||||
+ supported_tool_call_params
|
||||
+ supported_config_params
|
||||
)
|
||||
inference_params.pop("json_mode", None) # used for handling json_schema
|
||||
|
||||
# keep supported params in 'inference_params', and set all model-specific params in 'additional_request_params'
|
||||
additional_request_params = {
|
||||
k: v for k, v in inference_params.items() if k not in total_supported_params
|
||||
}
|
||||
inference_params = {
|
||||
k: v for k, v in inference_params.items() if k in total_supported_params
|
||||
}
|
||||
|
||||
# Only set the topK value in for models that support it
|
||||
additional_request_params.update(
|
||||
self._handle_top_k_value(model, inference_params)
|
||||
)
|
||||
|
||||
bedrock_tools: List[ToolBlock] = _bedrock_tools_pt(
|
||||
inference_params.pop("tools", [])
|
||||
)
|
||||
bedrock_tool_config: Optional[ToolConfigBlock] = None
|
||||
if len(bedrock_tools) > 0:
|
||||
tool_choice_values: ToolChoiceValuesBlock = inference_params.pop(
|
||||
"tool_choice", None
|
||||
)
|
||||
bedrock_tool_config = ToolConfigBlock(
|
||||
tools=bedrock_tools,
|
||||
)
|
||||
if tool_choice_values is not None:
|
||||
bedrock_tool_config["toolChoice"] = tool_choice_values
|
||||
|
||||
data: CommonRequestObject = {
|
||||
"additionalModelRequestFields": additional_request_params,
|
||||
"system": system_content_blocks,
|
||||
"inferenceConfig": self._transform_inference_params(
|
||||
inference_params=inference_params
|
||||
),
|
||||
}
|
||||
|
||||
# Handle all config blocks
|
||||
for config_name, config_class in self.get_config_blocks().items():
|
||||
config_value = inference_params.pop(config_name, None)
|
||||
if config_value is not None:
|
||||
data[config_name] = config_class(**config_value) # type: ignore
|
||||
|
||||
# Tool Config
|
||||
if bedrock_tool_config is not None:
|
||||
data["toolConfig"] = bedrock_tool_config
|
||||
|
||||
return data
|
||||
|
||||
async def _async_transform_request(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
) -> RequestObject:
|
||||
messages, system_content_blocks = self._transform_system_message(messages)
|
||||
## TRANSFORMATION ##
|
||||
|
||||
_data: CommonRequestObject = self._transform_request_helper(
|
||||
model=model,
|
||||
system_content_blocks=system_content_blocks,
|
||||
optional_params=optional_params,
|
||||
messages=messages,
|
||||
)
|
||||
|
||||
bedrock_messages = (
|
||||
await BedrockConverseMessagesProcessor._bedrock_converse_messages_pt_async(
|
||||
messages=messages,
|
||||
model=model,
|
||||
llm_provider="bedrock_converse",
|
||||
user_continue_message=litellm_params.pop("user_continue_message", None),
|
||||
)
|
||||
)
|
||||
|
||||
data: RequestObject = {"messages": bedrock_messages, **_data}
|
||||
|
||||
return data
|
||||
|
||||
def transform_request(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
headers: dict,
|
||||
) -> dict:
|
||||
return cast(
|
||||
dict,
|
||||
self._transform_request(
|
||||
model=model,
|
||||
messages=messages,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
),
|
||||
)
|
||||
|
||||
def _transform_request(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
) -> RequestObject:
|
||||
messages, system_content_blocks = self._transform_system_message(messages)
|
||||
|
||||
_data: CommonRequestObject = self._transform_request_helper(
|
||||
model=model,
|
||||
system_content_blocks=system_content_blocks,
|
||||
optional_params=optional_params,
|
||||
messages=messages,
|
||||
)
|
||||
|
||||
## TRANSFORMATION ##
|
||||
bedrock_messages: List[MessageBlock] = _bedrock_converse_messages_pt(
|
||||
messages=messages,
|
||||
model=model,
|
||||
llm_provider="bedrock_converse",
|
||||
user_continue_message=litellm_params.pop("user_continue_message", None),
|
||||
)
|
||||
|
||||
data: RequestObject = {"messages": bedrock_messages, **_data}
|
||||
|
||||
return data
|
||||
|
||||
def transform_response(
|
||||
self,
|
||||
model: str,
|
||||
raw_response: httpx.Response,
|
||||
model_response: ModelResponse,
|
||||
logging_obj: Logging,
|
||||
request_data: dict,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
encoding: Any,
|
||||
api_key: Optional[str] = None,
|
||||
json_mode: Optional[bool] = None,
|
||||
) -> ModelResponse:
|
||||
return self._transform_response(
|
||||
model=model,
|
||||
response=raw_response,
|
||||
model_response=model_response,
|
||||
stream=optional_params.get("stream", False),
|
||||
logging_obj=logging_obj,
|
||||
optional_params=optional_params,
|
||||
api_key=api_key,
|
||||
data=request_data,
|
||||
messages=messages,
|
||||
encoding=encoding,
|
||||
)
|
||||
|
||||
def _transform_reasoning_content(
|
||||
self, reasoning_content_blocks: List[BedrockConverseReasoningContentBlock]
|
||||
) -> str:
|
||||
"""
|
||||
Extract the reasoning text from the reasoning content blocks
|
||||
|
||||
Ensures deepseek reasoning content compatible output.
|
||||
"""
|
||||
reasoning_content_str = ""
|
||||
for block in reasoning_content_blocks:
|
||||
if "reasoningText" in block:
|
||||
reasoning_content_str += block["reasoningText"]["text"]
|
||||
return reasoning_content_str
|
||||
|
||||
def _transform_thinking_blocks(
|
||||
self, thinking_blocks: List[BedrockConverseReasoningContentBlock]
|
||||
) -> List[Union[ChatCompletionThinkingBlock, ChatCompletionRedactedThinkingBlock]]:
|
||||
"""Return a consistent format for thinking blocks between Anthropic and Bedrock."""
|
||||
thinking_blocks_list: List[
|
||||
Union[ChatCompletionThinkingBlock, ChatCompletionRedactedThinkingBlock]
|
||||
] = []
|
||||
for block in thinking_blocks:
|
||||
if "reasoningText" in block:
|
||||
_thinking_block = ChatCompletionThinkingBlock(type="thinking")
|
||||
_text = block["reasoningText"].get("text")
|
||||
_signature = block["reasoningText"].get("signature")
|
||||
if _text is not None:
|
||||
_thinking_block["thinking"] = _text
|
||||
if _signature is not None:
|
||||
_thinking_block["signature"] = _signature
|
||||
thinking_blocks_list.append(_thinking_block)
|
||||
elif "redactedContent" in block:
|
||||
_redacted_block = ChatCompletionRedactedThinkingBlock(
|
||||
type="redacted_thinking", data=block["redactedContent"]
|
||||
)
|
||||
thinking_blocks_list.append(_redacted_block)
|
||||
return thinking_blocks_list
|
||||
|
||||
def _transform_usage(self, usage: ConverseTokenUsageBlock) -> Usage:
|
||||
input_tokens = usage["inputTokens"]
|
||||
output_tokens = usage["outputTokens"]
|
||||
total_tokens = usage["totalTokens"]
|
||||
cache_creation_input_tokens: int = 0
|
||||
cache_read_input_tokens: int = 0
|
||||
|
||||
if "cacheReadInputTokens" in usage:
|
||||
cache_read_input_tokens = usage["cacheReadInputTokens"]
|
||||
input_tokens += cache_read_input_tokens
|
||||
if "cacheWriteInputTokens" in usage:
|
||||
"""
|
||||
Do not increment prompt_tokens with cacheWriteInputTokens
|
||||
"""
|
||||
cache_creation_input_tokens = usage["cacheWriteInputTokens"]
|
||||
|
||||
prompt_tokens_details = PromptTokensDetailsWrapper(
|
||||
cached_tokens=cache_read_input_tokens
|
||||
)
|
||||
openai_usage = Usage(
|
||||
prompt_tokens=input_tokens,
|
||||
completion_tokens=output_tokens,
|
||||
total_tokens=total_tokens,
|
||||
prompt_tokens_details=prompt_tokens_details,
|
||||
cache_creation_input_tokens=cache_creation_input_tokens,
|
||||
cache_read_input_tokens=cache_read_input_tokens,
|
||||
)
|
||||
return openai_usage
|
||||
|
||||
def _transform_response(
|
||||
self,
|
||||
model: str,
|
||||
response: httpx.Response,
|
||||
model_response: ModelResponse,
|
||||
stream: bool,
|
||||
logging_obj: Optional[Logging],
|
||||
optional_params: dict,
|
||||
api_key: Optional[str],
|
||||
data: Union[dict, str],
|
||||
messages: List,
|
||||
encoding,
|
||||
) -> ModelResponse:
|
||||
## LOGGING
|
||||
if logging_obj is not None:
|
||||
logging_obj.post_call(
|
||||
input=messages,
|
||||
api_key=api_key,
|
||||
original_response=response.text,
|
||||
additional_args={"complete_input_dict": data},
|
||||
)
|
||||
|
||||
json_mode: Optional[bool] = optional_params.pop("json_mode", None)
|
||||
## RESPONSE OBJECT
|
||||
try:
|
||||
completion_response = ConverseResponseBlock(**response.json()) # type: ignore
|
||||
except Exception as e:
|
||||
raise BedrockError(
|
||||
message="Received={}, Error converting to valid response block={}. File an issue if litellm error - https://github.com/BerriAI/litellm/issues".format(
|
||||
response.text, str(e)
|
||||
),
|
||||
status_code=422,
|
||||
)
|
||||
|
||||
"""
|
||||
Bedrock Response Object has optional message block
|
||||
|
||||
completion_response["output"].get("message", None)
|
||||
|
||||
A message block looks like this (Example 1):
|
||||
"output": {
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{
|
||||
"text": "Is there anything else you'd like to talk about? Perhaps I can help with some economic questions or provide some information about economic concepts?"
|
||||
}
|
||||
]
|
||||
}
|
||||
},
|
||||
(Example 2):
|
||||
"output": {
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{
|
||||
"toolUse": {
|
||||
"toolUseId": "tooluse_hbTgdi0CSLq_hM4P8csZJA",
|
||||
"name": "top_song",
|
||||
"input": {
|
||||
"sign": "WZPZ"
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
"""
|
||||
message: Optional[MessageBlock] = completion_response["output"]["message"]
|
||||
chat_completion_message: ChatCompletionResponseMessage = {"role": "assistant"}
|
||||
content_str = ""
|
||||
tools: List[ChatCompletionToolCallChunk] = []
|
||||
reasoningContentBlocks: Optional[
|
||||
List[BedrockConverseReasoningContentBlock]
|
||||
] = None
|
||||
|
||||
if message is not None:
|
||||
for idx, content in enumerate(message["content"]):
|
||||
"""
|
||||
- Content is either a tool response or text
|
||||
"""
|
||||
if "text" in content:
|
||||
content_str += content["text"]
|
||||
if "toolUse" in content:
|
||||
## check tool name was formatted by litellm
|
||||
_response_tool_name = content["toolUse"]["name"]
|
||||
response_tool_name = get_bedrock_tool_name(
|
||||
response_tool_name=_response_tool_name
|
||||
)
|
||||
_function_chunk = ChatCompletionToolCallFunctionChunk(
|
||||
name=response_tool_name,
|
||||
arguments=json.dumps(content["toolUse"]["input"]),
|
||||
)
|
||||
|
||||
_tool_response_chunk = ChatCompletionToolCallChunk(
|
||||
id=content["toolUse"]["toolUseId"],
|
||||
type="function",
|
||||
function=_function_chunk,
|
||||
index=idx,
|
||||
)
|
||||
tools.append(_tool_response_chunk)
|
||||
if "reasoningContent" in content:
|
||||
if reasoningContentBlocks is None:
|
||||
reasoningContentBlocks = []
|
||||
reasoningContentBlocks.append(content["reasoningContent"])
|
||||
|
||||
if reasoningContentBlocks is not None:
|
||||
chat_completion_message["provider_specific_fields"] = {
|
||||
"reasoningContentBlocks": reasoningContentBlocks,
|
||||
}
|
||||
chat_completion_message[
|
||||
"reasoning_content"
|
||||
] = self._transform_reasoning_content(reasoningContentBlocks)
|
||||
chat_completion_message[
|
||||
"thinking_blocks"
|
||||
] = self._transform_thinking_blocks(reasoningContentBlocks)
|
||||
chat_completion_message["content"] = content_str
|
||||
if json_mode is True and tools is not None and len(tools) == 1:
|
||||
# to support 'json_schema' logic on bedrock models
|
||||
json_mode_content_str: Optional[str] = tools[0]["function"].get("arguments")
|
||||
if json_mode_content_str is not None:
|
||||
chat_completion_message["content"] = json_mode_content_str
|
||||
else:
|
||||
chat_completion_message["tool_calls"] = tools
|
||||
|
||||
## CALCULATING USAGE - bedrock returns usage in the headers
|
||||
usage = self._transform_usage(completion_response["usage"])
|
||||
|
||||
model_response.choices = [
|
||||
litellm.Choices(
|
||||
finish_reason=map_finish_reason(completion_response["stopReason"]),
|
||||
index=0,
|
||||
message=litellm.Message(**chat_completion_message),
|
||||
)
|
||||
]
|
||||
model_response.created = int(time.time())
|
||||
model_response.model = model
|
||||
|
||||
setattr(model_response, "usage", usage)
|
||||
|
||||
# Add "trace" from Bedrock guardrails - if user has opted in to returning it
|
||||
if "trace" in completion_response:
|
||||
setattr(model_response, "trace", completion_response["trace"])
|
||||
|
||||
return model_response
|
||||
|
||||
def get_error_class(
|
||||
self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers]
|
||||
) -> BaseLLMException:
|
||||
return BedrockError(
|
||||
message=error_message,
|
||||
status_code=status_code,
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
def validate_environment(
|
||||
self,
|
||||
headers: dict,
|
||||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
) -> dict:
|
||||
if api_key:
|
||||
headers["Authorization"] = f"Bearer {api_key}"
|
||||
return headers
|
||||
File diff suppressed because it is too large
Load Diff
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,99 @@
|
||||
import types
|
||||
from typing import List, Optional
|
||||
|
||||
from litellm.llms.base_llm.chat.transformation import BaseConfig
|
||||
from litellm.llms.bedrock.chat.invoke_transformations.base_invoke_transformation import (
|
||||
AmazonInvokeConfig,
|
||||
)
|
||||
|
||||
|
||||
class AmazonAI21Config(AmazonInvokeConfig, BaseConfig):
|
||||
"""
|
||||
Reference: https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=j2-ultra
|
||||
|
||||
Supported Params for the Amazon / AI21 models:
|
||||
|
||||
- `maxTokens` (int32): The maximum number of tokens to generate per result. Optional, default is 16. If no `stopSequences` are given, generation stops after producing `maxTokens`.
|
||||
|
||||
- `temperature` (float): Modifies the distribution from which tokens are sampled. Optional, default is 0.7. A value of 0 essentially disables sampling and results in greedy decoding.
|
||||
|
||||
- `topP` (float): Used for sampling tokens from the corresponding top percentile of probability mass. Optional, default is 1. For instance, a value of 0.9 considers only tokens comprising the top 90% probability mass.
|
||||
|
||||
- `stopSequences` (array of strings): Stops decoding if any of the input strings is generated. Optional.
|
||||
|
||||
- `frequencyPenalty` (object): Placeholder for frequency penalty object.
|
||||
|
||||
- `presencePenalty` (object): Placeholder for presence penalty object.
|
||||
|
||||
- `countPenalty` (object): Placeholder for count penalty object.
|
||||
"""
|
||||
|
||||
maxTokens: Optional[int] = None
|
||||
temperature: Optional[float] = None
|
||||
topP: Optional[float] = None
|
||||
stopSequences: Optional[list] = None
|
||||
frequencePenalty: Optional[dict] = None
|
||||
presencePenalty: Optional[dict] = None
|
||||
countPenalty: Optional[dict] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
maxTokens: Optional[int] = None,
|
||||
temperature: Optional[float] = None,
|
||||
topP: Optional[float] = None,
|
||||
stopSequences: Optional[list] = None,
|
||||
frequencePenalty: Optional[dict] = None,
|
||||
presencePenalty: Optional[dict] = None,
|
||||
countPenalty: Optional[dict] = None,
|
||||
) -> None:
|
||||
locals_ = locals().copy()
|
||||
for key, value in locals_.items():
|
||||
if key != "self" and value is not None:
|
||||
setattr(self.__class__, key, value)
|
||||
|
||||
AmazonInvokeConfig.__init__(self)
|
||||
|
||||
@classmethod
|
||||
def get_config(cls):
|
||||
return {
|
||||
k: v
|
||||
for k, v in cls.__dict__.items()
|
||||
if not k.startswith("__")
|
||||
and not k.startswith("_abc")
|
||||
and not isinstance(
|
||||
v,
|
||||
(
|
||||
types.FunctionType,
|
||||
types.BuiltinFunctionType,
|
||||
classmethod,
|
||||
staticmethod,
|
||||
),
|
||||
)
|
||||
and v is not None
|
||||
}
|
||||
|
||||
def get_supported_openai_params(self, model: str) -> List:
|
||||
return [
|
||||
"max_tokens",
|
||||
"temperature",
|
||||
"top_p",
|
||||
"stream",
|
||||
]
|
||||
|
||||
def map_openai_params(
|
||||
self,
|
||||
non_default_params: dict,
|
||||
optional_params: dict,
|
||||
model: str,
|
||||
drop_params: bool,
|
||||
) -> dict:
|
||||
for k, v in non_default_params.items():
|
||||
if k == "max_tokens":
|
||||
optional_params["maxTokens"] = v
|
||||
if k == "temperature":
|
||||
optional_params["temperature"] = v
|
||||
if k == "top_p":
|
||||
optional_params["topP"] = v
|
||||
if k == "stream":
|
||||
optional_params["stream"] = v
|
||||
return optional_params
|
||||
@@ -0,0 +1,75 @@
|
||||
import types
|
||||
from typing import List, Optional
|
||||
|
||||
from litellm.llms.bedrock.chat.invoke_transformations.base_invoke_transformation import (
|
||||
AmazonInvokeConfig,
|
||||
)
|
||||
from litellm.llms.cohere.chat.transformation import CohereChatConfig
|
||||
|
||||
|
||||
class AmazonCohereConfig(AmazonInvokeConfig, CohereChatConfig):
|
||||
"""
|
||||
Reference: https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=command
|
||||
|
||||
Supported Params for the Amazon / Cohere models:
|
||||
|
||||
- `max_tokens` (integer) max tokens,
|
||||
- `temperature` (float) model temperature,
|
||||
- `return_likelihood` (string) n/a
|
||||
"""
|
||||
|
||||
max_tokens: Optional[int] = None
|
||||
return_likelihood: Optional[str] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_tokens: Optional[int] = None,
|
||||
temperature: Optional[float] = None,
|
||||
return_likelihood: Optional[str] = None,
|
||||
) -> None:
|
||||
locals_ = locals().copy()
|
||||
for key, value in locals_.items():
|
||||
if key != "self" and value is not None:
|
||||
setattr(self.__class__, key, value)
|
||||
|
||||
AmazonInvokeConfig.__init__(self)
|
||||
|
||||
@classmethod
|
||||
def get_config(cls):
|
||||
return {
|
||||
k: v
|
||||
for k, v in cls.__dict__.items()
|
||||
if not k.startswith("__")
|
||||
and not k.startswith("_abc")
|
||||
and not isinstance(
|
||||
v,
|
||||
(
|
||||
types.FunctionType,
|
||||
types.BuiltinFunctionType,
|
||||
classmethod,
|
||||
staticmethod,
|
||||
),
|
||||
)
|
||||
and v is not None
|
||||
}
|
||||
|
||||
def get_supported_openai_params(self, model: str) -> List[str]:
|
||||
supported_params = CohereChatConfig.get_supported_openai_params(
|
||||
self, model=model
|
||||
)
|
||||
return supported_params
|
||||
|
||||
def map_openai_params(
|
||||
self,
|
||||
non_default_params: dict,
|
||||
optional_params: dict,
|
||||
model: str,
|
||||
drop_params: bool,
|
||||
) -> dict:
|
||||
return CohereChatConfig.map_openai_params(
|
||||
self,
|
||||
non_default_params=non_default_params,
|
||||
optional_params=optional_params,
|
||||
model=model,
|
||||
drop_params=drop_params,
|
||||
)
|
||||
@@ -0,0 +1,135 @@
|
||||
from typing import Any, List, Optional, cast
|
||||
|
||||
from httpx import Response
|
||||
|
||||
from litellm import verbose_logger
|
||||
from litellm.litellm_core_utils.llm_response_utils.convert_dict_to_response import (
|
||||
_parse_content_for_reasoning,
|
||||
)
|
||||
from litellm.llms.base_llm.base_model_iterator import BaseModelResponseIterator
|
||||
from litellm.llms.bedrock.chat.invoke_transformations.base_invoke_transformation import (
|
||||
LiteLLMLoggingObj,
|
||||
)
|
||||
from litellm.types.llms.bedrock import AmazonDeepSeekR1StreamingResponse
|
||||
from litellm.types.llms.openai import AllMessageValues
|
||||
from litellm.types.utils import (
|
||||
ChatCompletionUsageBlock,
|
||||
Choices,
|
||||
Delta,
|
||||
Message,
|
||||
ModelResponse,
|
||||
ModelResponseStream,
|
||||
StreamingChoices,
|
||||
)
|
||||
|
||||
from .amazon_llama_transformation import AmazonLlamaConfig
|
||||
|
||||
|
||||
class AmazonDeepSeekR1Config(AmazonLlamaConfig):
|
||||
def transform_response(
|
||||
self,
|
||||
model: str,
|
||||
raw_response: Response,
|
||||
model_response: ModelResponse,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
request_data: dict,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
encoding: Any,
|
||||
api_key: Optional[str] = None,
|
||||
json_mode: Optional[bool] = None,
|
||||
) -> ModelResponse:
|
||||
"""
|
||||
Extract the reasoning content, and return it as a separate field in the response.
|
||||
"""
|
||||
response = super().transform_response(
|
||||
model,
|
||||
raw_response,
|
||||
model_response,
|
||||
logging_obj,
|
||||
request_data,
|
||||
messages,
|
||||
optional_params,
|
||||
litellm_params,
|
||||
encoding,
|
||||
api_key,
|
||||
json_mode,
|
||||
)
|
||||
prompt = cast(Optional[str], request_data.get("prompt"))
|
||||
message_content = cast(
|
||||
Optional[str], cast(Choices, response.choices[0]).message.get("content")
|
||||
)
|
||||
if prompt and prompt.strip().endswith("<think>") and message_content:
|
||||
message_content_with_reasoning_token = "<think>" + message_content
|
||||
reasoning, content = _parse_content_for_reasoning(
|
||||
message_content_with_reasoning_token
|
||||
)
|
||||
provider_specific_fields = (
|
||||
cast(Choices, response.choices[0]).message.provider_specific_fields
|
||||
or {}
|
||||
)
|
||||
if reasoning:
|
||||
provider_specific_fields["reasoning_content"] = reasoning
|
||||
|
||||
message = Message(
|
||||
**{
|
||||
**cast(Choices, response.choices[0]).message.model_dump(),
|
||||
"content": content,
|
||||
"provider_specific_fields": provider_specific_fields,
|
||||
}
|
||||
)
|
||||
cast(Choices, response.choices[0]).message = message
|
||||
return response
|
||||
|
||||
|
||||
class AmazonDeepseekR1ResponseIterator(BaseModelResponseIterator):
|
||||
def __init__(self, streaming_response: Any, sync_stream: bool) -> None:
|
||||
super().__init__(streaming_response=streaming_response, sync_stream=sync_stream)
|
||||
self.has_finished_thinking = False
|
||||
|
||||
def chunk_parser(self, chunk: dict) -> ModelResponseStream:
|
||||
"""
|
||||
Deepseek r1 starts by thinking, then it generates the response.
|
||||
"""
|
||||
try:
|
||||
typed_chunk = AmazonDeepSeekR1StreamingResponse(**chunk) # type: ignore
|
||||
generated_content = typed_chunk["generation"]
|
||||
if generated_content == "</think>" and not self.has_finished_thinking:
|
||||
verbose_logger.debug(
|
||||
"Deepseek r1: </think> received, setting has_finished_thinking to True"
|
||||
)
|
||||
generated_content = ""
|
||||
self.has_finished_thinking = True
|
||||
|
||||
prompt_token_count = typed_chunk.get("prompt_token_count") or 0
|
||||
generation_token_count = typed_chunk.get("generation_token_count") or 0
|
||||
usage = ChatCompletionUsageBlock(
|
||||
prompt_tokens=prompt_token_count,
|
||||
completion_tokens=generation_token_count,
|
||||
total_tokens=prompt_token_count + generation_token_count,
|
||||
)
|
||||
|
||||
return ModelResponseStream(
|
||||
choices=[
|
||||
StreamingChoices(
|
||||
finish_reason=typed_chunk["stop_reason"],
|
||||
delta=Delta(
|
||||
content=(
|
||||
generated_content
|
||||
if self.has_finished_thinking
|
||||
else None
|
||||
),
|
||||
reasoning_content=(
|
||||
generated_content
|
||||
if not self.has_finished_thinking
|
||||
else None
|
||||
),
|
||||
),
|
||||
)
|
||||
],
|
||||
usage=usage,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
raise e
|
||||
@@ -0,0 +1,80 @@
|
||||
import types
|
||||
from typing import List, Optional
|
||||
|
||||
from litellm.llms.base_llm.chat.transformation import BaseConfig
|
||||
from litellm.llms.bedrock.chat.invoke_transformations.base_invoke_transformation import (
|
||||
AmazonInvokeConfig,
|
||||
)
|
||||
|
||||
|
||||
class AmazonLlamaConfig(AmazonInvokeConfig, BaseConfig):
|
||||
"""
|
||||
Reference: https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=meta.llama2-13b-chat-v1
|
||||
|
||||
Supported Params for the Amazon / Meta Llama models:
|
||||
|
||||
- `max_gen_len` (integer) max tokens,
|
||||
- `temperature` (float) temperature for model,
|
||||
- `top_p` (float) top p for model
|
||||
"""
|
||||
|
||||
max_gen_len: Optional[int] = None
|
||||
temperature: Optional[float] = None
|
||||
topP: Optional[float] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
maxTokenCount: Optional[int] = None,
|
||||
temperature: Optional[float] = None,
|
||||
topP: Optional[int] = None,
|
||||
) -> None:
|
||||
locals_ = locals().copy()
|
||||
for key, value in locals_.items():
|
||||
if key != "self" and value is not None:
|
||||
setattr(self.__class__, key, value)
|
||||
AmazonInvokeConfig.__init__(self)
|
||||
|
||||
@classmethod
|
||||
def get_config(cls):
|
||||
return {
|
||||
k: v
|
||||
for k, v in cls.__dict__.items()
|
||||
if not k.startswith("__")
|
||||
and not k.startswith("_abc")
|
||||
and not isinstance(
|
||||
v,
|
||||
(
|
||||
types.FunctionType,
|
||||
types.BuiltinFunctionType,
|
||||
classmethod,
|
||||
staticmethod,
|
||||
),
|
||||
)
|
||||
and v is not None
|
||||
}
|
||||
|
||||
def get_supported_openai_params(self, model: str) -> List:
|
||||
return [
|
||||
"max_tokens",
|
||||
"temperature",
|
||||
"top_p",
|
||||
"stream",
|
||||
]
|
||||
|
||||
def map_openai_params(
|
||||
self,
|
||||
non_default_params: dict,
|
||||
optional_params: dict,
|
||||
model: str,
|
||||
drop_params: bool,
|
||||
) -> dict:
|
||||
for k, v in non_default_params.items():
|
||||
if k == "max_tokens":
|
||||
optional_params["max_gen_len"] = v
|
||||
if k == "temperature":
|
||||
optional_params["temperature"] = v
|
||||
if k == "top_p":
|
||||
optional_params["top_p"] = v
|
||||
if k == "stream":
|
||||
optional_params["stream"] = v
|
||||
return optional_params
|
||||
@@ -0,0 +1,83 @@
|
||||
import types
|
||||
from typing import List, Optional
|
||||
|
||||
from litellm.llms.base_llm.chat.transformation import BaseConfig
|
||||
from litellm.llms.bedrock.chat.invoke_transformations.base_invoke_transformation import (
|
||||
AmazonInvokeConfig,
|
||||
)
|
||||
|
||||
|
||||
class AmazonMistralConfig(AmazonInvokeConfig, BaseConfig):
|
||||
"""
|
||||
Reference: https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-mistral.html
|
||||
Supported Params for the Amazon / Mistral models:
|
||||
|
||||
- `max_tokens` (integer) max tokens,
|
||||
- `temperature` (float) temperature for model,
|
||||
- `top_p` (float) top p for model
|
||||
- `stop` [string] A list of stop sequences that if generated by the model, stops the model from generating further output.
|
||||
- `top_k` (float) top k for model
|
||||
"""
|
||||
|
||||
max_tokens: Optional[int] = None
|
||||
temperature: Optional[float] = None
|
||||
top_p: Optional[float] = None
|
||||
top_k: Optional[float] = None
|
||||
stop: Optional[List[str]] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_tokens: Optional[int] = None,
|
||||
temperature: Optional[float] = None,
|
||||
top_p: Optional[int] = None,
|
||||
top_k: Optional[float] = None,
|
||||
stop: Optional[List[str]] = None,
|
||||
) -> None:
|
||||
locals_ = locals().copy()
|
||||
for key, value in locals_.items():
|
||||
if key != "self" and value is not None:
|
||||
setattr(self.__class__, key, value)
|
||||
|
||||
AmazonInvokeConfig.__init__(self)
|
||||
|
||||
@classmethod
|
||||
def get_config(cls):
|
||||
return {
|
||||
k: v
|
||||
for k, v in cls.__dict__.items()
|
||||
if not k.startswith("__")
|
||||
and not k.startswith("_abc")
|
||||
and not isinstance(
|
||||
v,
|
||||
(
|
||||
types.FunctionType,
|
||||
types.BuiltinFunctionType,
|
||||
classmethod,
|
||||
staticmethod,
|
||||
),
|
||||
)
|
||||
and v is not None
|
||||
}
|
||||
|
||||
def get_supported_openai_params(self, model: str) -> List[str]:
|
||||
return ["max_tokens", "temperature", "top_p", "stop", "stream"]
|
||||
|
||||
def map_openai_params(
|
||||
self,
|
||||
non_default_params: dict,
|
||||
optional_params: dict,
|
||||
model: str,
|
||||
drop_params: bool,
|
||||
) -> dict:
|
||||
for k, v in non_default_params.items():
|
||||
if k == "max_tokens":
|
||||
optional_params["max_tokens"] = v
|
||||
if k == "temperature":
|
||||
optional_params["temperature"] = v
|
||||
if k == "top_p":
|
||||
optional_params["top_p"] = v
|
||||
if k == "stop":
|
||||
optional_params["stop"] = v
|
||||
if k == "stream":
|
||||
optional_params["stream"] = v
|
||||
return optional_params
|
||||
@@ -0,0 +1,121 @@
|
||||
"""
|
||||
Handles transforming requests for `bedrock/invoke/{nova} models`
|
||||
|
||||
Inherits from `AmazonConverseConfig`
|
||||
|
||||
Nova + Invoke API Tutorial: https://docs.aws.amazon.com/nova/latest/userguide/using-invoke-api.html
|
||||
"""
|
||||
|
||||
from typing import Any, List, Optional
|
||||
|
||||
import httpx
|
||||
|
||||
import litellm
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging
|
||||
from litellm.types.llms.bedrock import BedrockInvokeNovaRequest
|
||||
from litellm.types.llms.openai import AllMessageValues
|
||||
from litellm.types.utils import ModelResponse
|
||||
|
||||
from ..converse_transformation import AmazonConverseConfig
|
||||
from .base_invoke_transformation import AmazonInvokeConfig
|
||||
|
||||
|
||||
class AmazonInvokeNovaConfig(AmazonInvokeConfig, AmazonConverseConfig):
|
||||
"""
|
||||
Config for sending `nova` requests to `/bedrock/invoke/`
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def get_supported_openai_params(self, model: str) -> list:
|
||||
return AmazonConverseConfig.get_supported_openai_params(self, model)
|
||||
|
||||
def map_openai_params(
|
||||
self,
|
||||
non_default_params: dict,
|
||||
optional_params: dict,
|
||||
model: str,
|
||||
drop_params: bool,
|
||||
) -> dict:
|
||||
return AmazonConverseConfig.map_openai_params(
|
||||
self, non_default_params, optional_params, model, drop_params
|
||||
)
|
||||
|
||||
def transform_request(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
headers: dict,
|
||||
) -> dict:
|
||||
_transformed_nova_request = AmazonConverseConfig.transform_request(
|
||||
self,
|
||||
model=model,
|
||||
messages=messages,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
headers=headers,
|
||||
)
|
||||
_bedrock_invoke_nova_request = BedrockInvokeNovaRequest(
|
||||
**_transformed_nova_request
|
||||
)
|
||||
self._remove_empty_system_messages(_bedrock_invoke_nova_request)
|
||||
bedrock_invoke_nova_request = self._filter_allowed_fields(
|
||||
_bedrock_invoke_nova_request
|
||||
)
|
||||
return bedrock_invoke_nova_request
|
||||
|
||||
def transform_response(
|
||||
self,
|
||||
model: str,
|
||||
raw_response: httpx.Response,
|
||||
model_response: ModelResponse,
|
||||
logging_obj: Logging,
|
||||
request_data: dict,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
encoding: Any,
|
||||
api_key: Optional[str] = None,
|
||||
json_mode: Optional[bool] = None,
|
||||
) -> litellm.ModelResponse:
|
||||
return AmazonConverseConfig.transform_response(
|
||||
self,
|
||||
model,
|
||||
raw_response,
|
||||
model_response,
|
||||
logging_obj,
|
||||
request_data,
|
||||
messages,
|
||||
optional_params,
|
||||
litellm_params,
|
||||
encoding,
|
||||
api_key,
|
||||
json_mode,
|
||||
)
|
||||
|
||||
def _filter_allowed_fields(
|
||||
self, bedrock_invoke_nova_request: BedrockInvokeNovaRequest
|
||||
) -> dict:
|
||||
"""
|
||||
Filter out fields that are not allowed in the `BedrockInvokeNovaRequest` dataclass.
|
||||
"""
|
||||
allowed_fields = set(BedrockInvokeNovaRequest.__annotations__.keys())
|
||||
return {
|
||||
k: v for k, v in bedrock_invoke_nova_request.items() if k in allowed_fields
|
||||
}
|
||||
|
||||
def _remove_empty_system_messages(
|
||||
self, bedrock_invoke_nova_request: BedrockInvokeNovaRequest
|
||||
) -> None:
|
||||
"""
|
||||
In-place remove empty `system` messages from the request.
|
||||
|
||||
/bedrock/invoke/ does not allow empty `system` messages.
|
||||
"""
|
||||
_system_message = bedrock_invoke_nova_request.get("system", None)
|
||||
if isinstance(_system_message, list) and len(_system_message) == 0:
|
||||
bedrock_invoke_nova_request.pop("system", None)
|
||||
return
|
||||
@@ -0,0 +1,116 @@
|
||||
import re
|
||||
import types
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import litellm
|
||||
from litellm.llms.base_llm.chat.transformation import BaseConfig
|
||||
from litellm.llms.bedrock.chat.invoke_transformations.base_invoke_transformation import (
|
||||
AmazonInvokeConfig,
|
||||
)
|
||||
|
||||
|
||||
class AmazonTitanConfig(AmazonInvokeConfig, BaseConfig):
|
||||
"""
|
||||
Reference: https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=titan-text-express-v1
|
||||
|
||||
Supported Params for the Amazon Titan models:
|
||||
|
||||
- `maxTokenCount` (integer) max tokens,
|
||||
- `stopSequences` (string[]) list of stop sequence strings
|
||||
- `temperature` (float) temperature for model,
|
||||
- `topP` (int) top p for model
|
||||
"""
|
||||
|
||||
maxTokenCount: Optional[int] = None
|
||||
stopSequences: Optional[list] = None
|
||||
temperature: Optional[float] = None
|
||||
topP: Optional[int] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
maxTokenCount: Optional[int] = None,
|
||||
stopSequences: Optional[list] = None,
|
||||
temperature: Optional[float] = None,
|
||||
topP: Optional[int] = None,
|
||||
) -> None:
|
||||
locals_ = locals().copy()
|
||||
for key, value in locals_.items():
|
||||
if key != "self" and value is not None:
|
||||
setattr(self.__class__, key, value)
|
||||
|
||||
AmazonInvokeConfig.__init__(self)
|
||||
|
||||
@classmethod
|
||||
def get_config(cls):
|
||||
return {
|
||||
k: v
|
||||
for k, v in cls.__dict__.items()
|
||||
if not k.startswith("__")
|
||||
and not k.startswith("_abc")
|
||||
and not isinstance(
|
||||
v,
|
||||
(
|
||||
types.FunctionType,
|
||||
types.BuiltinFunctionType,
|
||||
classmethod,
|
||||
staticmethod,
|
||||
),
|
||||
)
|
||||
and v is not None
|
||||
}
|
||||
|
||||
def _map_and_modify_arg(
|
||||
self,
|
||||
supported_params: dict,
|
||||
provider: str,
|
||||
model: str,
|
||||
stop: Union[List[str], str],
|
||||
):
|
||||
"""
|
||||
filter params to fit the required provider format, drop those that don't fit if user sets `litellm.drop_params = True`.
|
||||
"""
|
||||
filtered_stop = None
|
||||
if "stop" in supported_params and litellm.drop_params:
|
||||
if provider == "bedrock" and "amazon" in model:
|
||||
filtered_stop = []
|
||||
if isinstance(stop, list):
|
||||
for s in stop:
|
||||
if re.match(r"^(\|+|User:)$", s):
|
||||
filtered_stop.append(s)
|
||||
if filtered_stop is not None:
|
||||
supported_params["stop"] = filtered_stop
|
||||
|
||||
return supported_params
|
||||
|
||||
def get_supported_openai_params(self, model: str) -> List[str]:
|
||||
return [
|
||||
"max_tokens",
|
||||
"max_completion_tokens",
|
||||
"stop",
|
||||
"temperature",
|
||||
"top_p",
|
||||
"stream",
|
||||
]
|
||||
|
||||
def map_openai_params(
|
||||
self,
|
||||
non_default_params: dict,
|
||||
optional_params: dict,
|
||||
model: str,
|
||||
drop_params: bool,
|
||||
) -> dict:
|
||||
for k, v in non_default_params.items():
|
||||
if k == "max_tokens" or k == "max_completion_tokens":
|
||||
optional_params["maxTokenCount"] = v
|
||||
if k == "temperature":
|
||||
optional_params["temperature"] = v
|
||||
if k == "stop":
|
||||
filtered_stop = self._map_and_modify_arg(
|
||||
{"stop": v}, provider="bedrock", model=model, stop=v
|
||||
)
|
||||
optional_params["stopSequences"] = filtered_stop["stop"]
|
||||
if k == "top_p":
|
||||
optional_params["topP"] = v
|
||||
if k == "stream":
|
||||
optional_params["stream"] = v
|
||||
return optional_params
|
||||
@@ -0,0 +1,90 @@
|
||||
import types
|
||||
from typing import Optional
|
||||
|
||||
import litellm
|
||||
|
||||
from .base_invoke_transformation import AmazonInvokeConfig
|
||||
|
||||
|
||||
class AmazonAnthropicConfig(AmazonInvokeConfig):
|
||||
"""
|
||||
Reference: https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=claude
|
||||
|
||||
Supported Params for the Amazon / Anthropic models:
|
||||
|
||||
- `max_tokens_to_sample` (integer) max tokens,
|
||||
- `temperature` (float) model temperature,
|
||||
- `top_k` (integer) top k,
|
||||
- `top_p` (integer) top p,
|
||||
- `stop_sequences` (string[]) list of stop sequences - e.g. ["\\n\\nHuman:"],
|
||||
- `anthropic_version` (string) version of anthropic for bedrock - e.g. "bedrock-2023-05-31"
|
||||
"""
|
||||
|
||||
max_tokens_to_sample: Optional[int] = litellm.max_tokens
|
||||
stop_sequences: Optional[list] = None
|
||||
temperature: Optional[float] = None
|
||||
top_k: Optional[int] = None
|
||||
top_p: Optional[int] = None
|
||||
anthropic_version: Optional[str] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_tokens_to_sample: Optional[int] = None,
|
||||
stop_sequences: Optional[list] = None,
|
||||
temperature: Optional[float] = None,
|
||||
top_k: Optional[int] = None,
|
||||
top_p: Optional[int] = None,
|
||||
anthropic_version: Optional[str] = None,
|
||||
) -> None:
|
||||
locals_ = locals().copy()
|
||||
for key, value in locals_.items():
|
||||
if key != "self" and value is not None:
|
||||
setattr(self.__class__, key, value)
|
||||
|
||||
@classmethod
|
||||
def get_config(cls):
|
||||
return {
|
||||
k: v
|
||||
for k, v in cls.__dict__.items()
|
||||
if not k.startswith("__")
|
||||
and not isinstance(
|
||||
v,
|
||||
(
|
||||
types.FunctionType,
|
||||
types.BuiltinFunctionType,
|
||||
classmethod,
|
||||
staticmethod,
|
||||
),
|
||||
)
|
||||
and v is not None
|
||||
}
|
||||
|
||||
def get_supported_openai_params(self, model: str):
|
||||
return [
|
||||
"max_tokens",
|
||||
"max_completion_tokens",
|
||||
"temperature",
|
||||
"stop",
|
||||
"top_p",
|
||||
"stream",
|
||||
]
|
||||
|
||||
def map_openai_params(
|
||||
self,
|
||||
non_default_params: dict,
|
||||
optional_params: dict,
|
||||
model: str,
|
||||
drop_params: bool,
|
||||
):
|
||||
for param, value in non_default_params.items():
|
||||
if param == "max_tokens" or param == "max_completion_tokens":
|
||||
optional_params["max_tokens_to_sample"] = value
|
||||
if param == "temperature":
|
||||
optional_params["temperature"] = value
|
||||
if param == "top_p":
|
||||
optional_params["top_p"] = value
|
||||
if param == "stop":
|
||||
optional_params["stop_sequences"] = value
|
||||
if param == "stream" and value is True:
|
||||
optional_params["stream"] = value
|
||||
return optional_params
|
||||
@@ -0,0 +1,100 @@
|
||||
from typing import TYPE_CHECKING, Any, List, Optional
|
||||
|
||||
import httpx
|
||||
|
||||
from litellm.llms.anthropic.chat.transformation import AnthropicConfig
|
||||
from litellm.llms.bedrock.chat.invoke_transformations.base_invoke_transformation import (
|
||||
AmazonInvokeConfig,
|
||||
)
|
||||
from litellm.types.llms.openai import AllMessageValues
|
||||
from litellm.types.utils import ModelResponse
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
|
||||
|
||||
LiteLLMLoggingObj = _LiteLLMLoggingObj
|
||||
else:
|
||||
LiteLLMLoggingObj = Any
|
||||
|
||||
|
||||
class AmazonAnthropicClaude3Config(AmazonInvokeConfig, AnthropicConfig):
|
||||
"""
|
||||
Reference:
|
||||
https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=claude
|
||||
https://docs.anthropic.com/claude/docs/models-overview#model-comparison
|
||||
|
||||
Supported Params for the Amazon / Anthropic Claude 3 models:
|
||||
"""
|
||||
|
||||
anthropic_version: str = "bedrock-2023-05-31"
|
||||
|
||||
def get_supported_openai_params(self, model: str) -> List[str]:
|
||||
return AnthropicConfig.get_supported_openai_params(self, model)
|
||||
|
||||
def map_openai_params(
|
||||
self,
|
||||
non_default_params: dict,
|
||||
optional_params: dict,
|
||||
model: str,
|
||||
drop_params: bool,
|
||||
) -> dict:
|
||||
return AnthropicConfig.map_openai_params(
|
||||
self,
|
||||
non_default_params,
|
||||
optional_params,
|
||||
model,
|
||||
drop_params,
|
||||
)
|
||||
|
||||
def transform_request(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
headers: dict,
|
||||
) -> dict:
|
||||
_anthropic_request = AnthropicConfig.transform_request(
|
||||
self,
|
||||
model=model,
|
||||
messages=messages,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
_anthropic_request.pop("model", None)
|
||||
_anthropic_request.pop("stream", None)
|
||||
if "anthropic_version" not in _anthropic_request:
|
||||
_anthropic_request["anthropic_version"] = self.anthropic_version
|
||||
|
||||
return _anthropic_request
|
||||
|
||||
def transform_response(
|
||||
self,
|
||||
model: str,
|
||||
raw_response: httpx.Response,
|
||||
model_response: ModelResponse,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
request_data: dict,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
encoding: Any,
|
||||
api_key: Optional[str] = None,
|
||||
json_mode: Optional[bool] = None,
|
||||
) -> ModelResponse:
|
||||
return AnthropicConfig.transform_response(
|
||||
self,
|
||||
model=model,
|
||||
raw_response=raw_response,
|
||||
model_response=model_response,
|
||||
logging_obj=logging_obj,
|
||||
request_data=request_data,
|
||||
messages=messages,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
encoding=encoding,
|
||||
api_key=api_key,
|
||||
json_mode=json_mode,
|
||||
)
|
||||
@@ -0,0 +1,679 @@
|
||||
import copy
|
||||
import json
|
||||
import time
|
||||
import urllib.parse
|
||||
from functools import partial
|
||||
from typing import TYPE_CHECKING, Any, List, Optional, Tuple, Union, cast, get_args
|
||||
|
||||
import httpx
|
||||
|
||||
import litellm
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.litellm_core_utils.core_helpers import map_finish_reason
|
||||
from litellm.litellm_core_utils.logging_utils import track_llm_api_timing
|
||||
from litellm.litellm_core_utils.prompt_templates.factory import (
|
||||
cohere_message_pt,
|
||||
custom_prompt,
|
||||
deepseek_r1_pt,
|
||||
prompt_factory,
|
||||
)
|
||||
from litellm.llms.base_llm.chat.transformation import BaseConfig, BaseLLMException
|
||||
from litellm.llms.bedrock.chat.invoke_handler import make_call, make_sync_call
|
||||
from litellm.llms.bedrock.common_utils import BedrockError
|
||||
from litellm.llms.custom_httpx.http_handler import (
|
||||
AsyncHTTPHandler,
|
||||
HTTPHandler,
|
||||
_get_httpx_client,
|
||||
)
|
||||
from litellm.types.llms.openai import AllMessageValues
|
||||
from litellm.types.utils import ModelResponse, Usage
|
||||
from litellm.utils import CustomStreamWrapper
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
|
||||
|
||||
LiteLLMLoggingObj = _LiteLLMLoggingObj
|
||||
else:
|
||||
LiteLLMLoggingObj = Any
|
||||
|
||||
from litellm.llms.bedrock.base_aws_llm import BaseAWSLLM
|
||||
|
||||
|
||||
class AmazonInvokeConfig(BaseConfig, BaseAWSLLM):
|
||||
def __init__(self, **kwargs):
|
||||
BaseConfig.__init__(self, **kwargs)
|
||||
BaseAWSLLM.__init__(self, **kwargs)
|
||||
|
||||
def get_supported_openai_params(self, model: str) -> List[str]:
|
||||
"""
|
||||
This is a base invoke model mapping. For Invoke - define a bedrock provider specific config that extends this class.
|
||||
"""
|
||||
return [
|
||||
"max_tokens",
|
||||
"max_completion_tokens",
|
||||
"stream",
|
||||
]
|
||||
|
||||
def map_openai_params(
|
||||
self,
|
||||
non_default_params: dict,
|
||||
optional_params: dict,
|
||||
model: str,
|
||||
drop_params: bool,
|
||||
) -> dict:
|
||||
"""
|
||||
This is a base invoke model mapping. For Invoke - define a bedrock provider specific config that extends this class.
|
||||
"""
|
||||
for param, value in non_default_params.items():
|
||||
if param == "max_tokens" or param == "max_completion_tokens":
|
||||
optional_params["max_tokens"] = value
|
||||
if param == "stream":
|
||||
optional_params["stream"] = value
|
||||
return optional_params
|
||||
|
||||
def get_complete_url(
|
||||
self,
|
||||
api_base: Optional[str],
|
||||
api_key: Optional[str],
|
||||
model: str,
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
stream: Optional[bool] = None,
|
||||
) -> str:
|
||||
"""
|
||||
Get the complete url for the request
|
||||
"""
|
||||
provider = self.get_bedrock_invoke_provider(model)
|
||||
modelId = self.get_bedrock_model_id(
|
||||
model=model,
|
||||
provider=provider,
|
||||
optional_params=optional_params,
|
||||
)
|
||||
### SET RUNTIME ENDPOINT ###
|
||||
aws_bedrock_runtime_endpoint = optional_params.get(
|
||||
"aws_bedrock_runtime_endpoint", None
|
||||
) # https://bedrock-runtime.{region_name}.amazonaws.com
|
||||
endpoint_url, proxy_endpoint_url = self.get_runtime_endpoint(
|
||||
api_base=api_base,
|
||||
aws_bedrock_runtime_endpoint=aws_bedrock_runtime_endpoint,
|
||||
aws_region_name=self._get_aws_region_name(
|
||||
optional_params=optional_params, model=model
|
||||
),
|
||||
)
|
||||
|
||||
if (stream is not None and stream is True) and provider != "ai21":
|
||||
endpoint_url = f"{endpoint_url}/model/{modelId}/invoke-with-response-stream"
|
||||
proxy_endpoint_url = (
|
||||
f"{proxy_endpoint_url}/model/{modelId}/invoke-with-response-stream"
|
||||
)
|
||||
else:
|
||||
endpoint_url = f"{endpoint_url}/model/{modelId}/invoke"
|
||||
proxy_endpoint_url = f"{proxy_endpoint_url}/model/{modelId}/invoke"
|
||||
|
||||
return endpoint_url
|
||||
|
||||
def sign_request(
|
||||
self,
|
||||
headers: dict,
|
||||
optional_params: dict,
|
||||
request_data: dict,
|
||||
api_base: str,
|
||||
model: Optional[str] = None,
|
||||
stream: Optional[bool] = None,
|
||||
fake_stream: Optional[bool] = None,
|
||||
) -> dict:
|
||||
try:
|
||||
from botocore.auth import SigV4Auth
|
||||
from botocore.awsrequest import AWSRequest
|
||||
from botocore.credentials import Credentials
|
||||
except ImportError:
|
||||
raise ImportError("Missing boto3 to call bedrock. Run 'pip install boto3'.")
|
||||
|
||||
## CREDENTIALS ##
|
||||
# pop aws_secret_access_key, aws_access_key_id, aws_session_token, aws_region_name from kwargs, since completion calls fail with them
|
||||
aws_secret_access_key = optional_params.get("aws_secret_access_key", None)
|
||||
aws_access_key_id = optional_params.get("aws_access_key_id", None)
|
||||
aws_session_token = optional_params.get("aws_session_token", None)
|
||||
aws_role_name = optional_params.get("aws_role_name", None)
|
||||
aws_session_name = optional_params.get("aws_session_name", None)
|
||||
aws_profile_name = optional_params.get("aws_profile_name", None)
|
||||
aws_web_identity_token = optional_params.get("aws_web_identity_token", None)
|
||||
aws_sts_endpoint = optional_params.get("aws_sts_endpoint", None)
|
||||
aws_region_name = self._get_aws_region_name(
|
||||
optional_params=optional_params, model=model
|
||||
)
|
||||
|
||||
credentials: Credentials = self.get_credentials(
|
||||
aws_access_key_id=aws_access_key_id,
|
||||
aws_secret_access_key=aws_secret_access_key,
|
||||
aws_session_token=aws_session_token,
|
||||
aws_region_name=aws_region_name,
|
||||
aws_session_name=aws_session_name,
|
||||
aws_profile_name=aws_profile_name,
|
||||
aws_role_name=aws_role_name,
|
||||
aws_web_identity_token=aws_web_identity_token,
|
||||
aws_sts_endpoint=aws_sts_endpoint,
|
||||
)
|
||||
|
||||
sigv4 = SigV4Auth(credentials, "bedrock", aws_region_name)
|
||||
if headers is not None:
|
||||
headers = {"Content-Type": "application/json", **headers}
|
||||
else:
|
||||
headers = {"Content-Type": "application/json"}
|
||||
|
||||
request = AWSRequest(
|
||||
method="POST",
|
||||
url=api_base,
|
||||
data=json.dumps(request_data),
|
||||
headers=headers,
|
||||
)
|
||||
sigv4.add_auth(request)
|
||||
|
||||
request_headers_dict = dict(request.headers)
|
||||
if (
|
||||
headers is not None and "Authorization" in headers
|
||||
): # prevent sigv4 from overwriting the auth header
|
||||
request_headers_dict["Authorization"] = headers["Authorization"]
|
||||
return request_headers_dict
|
||||
|
||||
def transform_request(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
headers: dict,
|
||||
) -> dict:
|
||||
## SETUP ##
|
||||
stream = optional_params.pop("stream", None)
|
||||
custom_prompt_dict: dict = litellm_params.pop("custom_prompt_dict", None) or {}
|
||||
hf_model_name = litellm_params.get("hf_model_name", None)
|
||||
|
||||
provider = self.get_bedrock_invoke_provider(model)
|
||||
|
||||
prompt, chat_history = self.convert_messages_to_prompt(
|
||||
model=hf_model_name or model,
|
||||
messages=messages,
|
||||
provider=provider,
|
||||
custom_prompt_dict=custom_prompt_dict,
|
||||
)
|
||||
inference_params = copy.deepcopy(optional_params)
|
||||
inference_params = {
|
||||
k: v
|
||||
for k, v in inference_params.items()
|
||||
if k not in self.aws_authentication_params
|
||||
}
|
||||
request_data: dict = {}
|
||||
if provider == "cohere":
|
||||
if model.startswith("cohere.command-r"):
|
||||
## LOAD CONFIG
|
||||
config = litellm.AmazonCohereChatConfig().get_config()
|
||||
for k, v in config.items():
|
||||
if (
|
||||
k not in inference_params
|
||||
): # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in
|
||||
inference_params[k] = v
|
||||
_data = {"message": prompt, **inference_params}
|
||||
if chat_history is not None:
|
||||
_data["chat_history"] = chat_history
|
||||
request_data = _data
|
||||
else:
|
||||
## LOAD CONFIG
|
||||
config = litellm.AmazonCohereConfig.get_config()
|
||||
for k, v in config.items():
|
||||
if (
|
||||
k not in inference_params
|
||||
): # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in
|
||||
inference_params[k] = v
|
||||
if stream is True:
|
||||
inference_params[
|
||||
"stream"
|
||||
] = True # cohere requires stream = True in inference params
|
||||
request_data = {"prompt": prompt, **inference_params}
|
||||
elif provider == "anthropic":
|
||||
return litellm.AmazonAnthropicClaude3Config().transform_request(
|
||||
model=model,
|
||||
messages=messages,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
headers=headers,
|
||||
)
|
||||
elif provider == "nova":
|
||||
return litellm.AmazonInvokeNovaConfig().transform_request(
|
||||
model=model,
|
||||
messages=messages,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
headers=headers,
|
||||
)
|
||||
elif provider == "ai21":
|
||||
## LOAD CONFIG
|
||||
config = litellm.AmazonAI21Config.get_config()
|
||||
for k, v in config.items():
|
||||
if (
|
||||
k not in inference_params
|
||||
): # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in
|
||||
inference_params[k] = v
|
||||
|
||||
request_data = {"prompt": prompt, **inference_params}
|
||||
elif provider == "mistral":
|
||||
## LOAD CONFIG
|
||||
config = litellm.AmazonMistralConfig.get_config()
|
||||
for k, v in config.items():
|
||||
if (
|
||||
k not in inference_params
|
||||
): # completion(top_k=3) > amazon_config(top_k=3) <- allows for dynamic variables to be passed in
|
||||
inference_params[k] = v
|
||||
|
||||
request_data = {"prompt": prompt, **inference_params}
|
||||
elif provider == "amazon": # amazon titan
|
||||
## LOAD CONFIG
|
||||
config = litellm.AmazonTitanConfig.get_config()
|
||||
for k, v in config.items():
|
||||
if (
|
||||
k not in inference_params
|
||||
): # completion(top_k=3) > amazon_config(top_k=3) <- allows for dynamic variables to be passed in
|
||||
inference_params[k] = v
|
||||
|
||||
request_data = {
|
||||
"inputText": prompt,
|
||||
"textGenerationConfig": inference_params,
|
||||
}
|
||||
elif provider == "meta" or provider == "llama" or provider == "deepseek_r1":
|
||||
## LOAD CONFIG
|
||||
config = litellm.AmazonLlamaConfig.get_config()
|
||||
for k, v in config.items():
|
||||
if (
|
||||
k not in inference_params
|
||||
): # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in
|
||||
inference_params[k] = v
|
||||
request_data = {"prompt": prompt, **inference_params}
|
||||
else:
|
||||
raise BedrockError(
|
||||
status_code=404,
|
||||
message="Bedrock Invoke HTTPX: Unknown provider={}, model={}. Try calling via converse route - `bedrock/converse/<model>`.".format(
|
||||
provider, model
|
||||
),
|
||||
)
|
||||
|
||||
return request_data
|
||||
|
||||
def transform_response( # noqa: PLR0915
|
||||
self,
|
||||
model: str,
|
||||
raw_response: httpx.Response,
|
||||
model_response: ModelResponse,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
request_data: dict,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
encoding: Any,
|
||||
api_key: Optional[str] = None,
|
||||
json_mode: Optional[bool] = None,
|
||||
) -> ModelResponse:
|
||||
try:
|
||||
completion_response = raw_response.json()
|
||||
except Exception:
|
||||
raise BedrockError(
|
||||
message=raw_response.text, status_code=raw_response.status_code
|
||||
)
|
||||
verbose_logger.debug(
|
||||
"bedrock invoke response % s",
|
||||
json.dumps(completion_response, indent=4, default=str),
|
||||
)
|
||||
provider = self.get_bedrock_invoke_provider(model)
|
||||
outputText: Optional[str] = None
|
||||
try:
|
||||
if provider == "cohere":
|
||||
if "text" in completion_response:
|
||||
outputText = completion_response["text"] # type: ignore
|
||||
elif "generations" in completion_response:
|
||||
outputText = completion_response["generations"][0]["text"]
|
||||
model_response.choices[0].finish_reason = map_finish_reason(
|
||||
completion_response["generations"][0]["finish_reason"]
|
||||
)
|
||||
elif provider == "anthropic":
|
||||
return litellm.AmazonAnthropicClaude3Config().transform_response(
|
||||
model=model,
|
||||
raw_response=raw_response,
|
||||
model_response=model_response,
|
||||
logging_obj=logging_obj,
|
||||
request_data=request_data,
|
||||
messages=messages,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
encoding=encoding,
|
||||
api_key=api_key,
|
||||
json_mode=json_mode,
|
||||
)
|
||||
elif provider == "nova":
|
||||
return litellm.AmazonInvokeNovaConfig().transform_response(
|
||||
model=model,
|
||||
raw_response=raw_response,
|
||||
model_response=model_response,
|
||||
logging_obj=logging_obj,
|
||||
request_data=request_data,
|
||||
messages=messages,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
encoding=encoding,
|
||||
)
|
||||
elif provider == "ai21":
|
||||
outputText = (
|
||||
completion_response.get("completions")[0].get("data").get("text")
|
||||
)
|
||||
elif provider == "meta" or provider == "llama" or provider == "deepseek_r1":
|
||||
outputText = completion_response["generation"]
|
||||
elif provider == "mistral":
|
||||
outputText = completion_response["outputs"][0]["text"]
|
||||
model_response.choices[0].finish_reason = completion_response[
|
||||
"outputs"
|
||||
][0]["stop_reason"]
|
||||
else: # amazon titan
|
||||
outputText = completion_response.get("results")[0].get("outputText")
|
||||
except Exception as e:
|
||||
raise BedrockError(
|
||||
message="Error processing={}, Received error={}".format(
|
||||
raw_response.text, str(e)
|
||||
),
|
||||
status_code=422,
|
||||
)
|
||||
|
||||
try:
|
||||
if (
|
||||
outputText is not None
|
||||
and len(outputText) > 0
|
||||
and hasattr(model_response.choices[0], "message")
|
||||
and getattr(model_response.choices[0].message, "tool_calls", None) # type: ignore
|
||||
is None
|
||||
):
|
||||
model_response.choices[0].message.content = outputText # type: ignore
|
||||
elif (
|
||||
hasattr(model_response.choices[0], "message")
|
||||
and getattr(model_response.choices[0].message, "tool_calls", None) # type: ignore
|
||||
is not None
|
||||
):
|
||||
pass
|
||||
else:
|
||||
raise Exception()
|
||||
except Exception as e:
|
||||
raise BedrockError(
|
||||
message="Error parsing received text={}.\nError-{}".format(
|
||||
outputText, str(e)
|
||||
),
|
||||
status_code=raw_response.status_code,
|
||||
)
|
||||
|
||||
## CALCULATING USAGE - bedrock returns usage in the headers
|
||||
bedrock_input_tokens = raw_response.headers.get(
|
||||
"x-amzn-bedrock-input-token-count", None
|
||||
)
|
||||
bedrock_output_tokens = raw_response.headers.get(
|
||||
"x-amzn-bedrock-output-token-count", None
|
||||
)
|
||||
|
||||
prompt_tokens = int(
|
||||
bedrock_input_tokens or litellm.token_counter(messages=messages)
|
||||
)
|
||||
|
||||
completion_tokens = int(
|
||||
bedrock_output_tokens
|
||||
or litellm.token_counter(
|
||||
text=model_response.choices[0].message.content, # type: ignore
|
||||
count_response_tokens=True,
|
||||
)
|
||||
)
|
||||
|
||||
model_response.created = int(time.time())
|
||||
model_response.model = model
|
||||
usage = Usage(
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=prompt_tokens + completion_tokens,
|
||||
)
|
||||
setattr(model_response, "usage", usage)
|
||||
|
||||
return model_response
|
||||
|
||||
def validate_environment(
|
||||
self,
|
||||
headers: dict,
|
||||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
) -> dict:
|
||||
return headers
|
||||
|
||||
def get_error_class(
|
||||
self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers]
|
||||
) -> BaseLLMException:
|
||||
return BedrockError(status_code=status_code, message=error_message)
|
||||
|
||||
@track_llm_api_timing()
|
||||
def get_async_custom_stream_wrapper(
|
||||
self,
|
||||
model: str,
|
||||
custom_llm_provider: str,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
api_base: str,
|
||||
headers: dict,
|
||||
data: dict,
|
||||
messages: list,
|
||||
client: Optional[AsyncHTTPHandler] = None,
|
||||
json_mode: Optional[bool] = None,
|
||||
) -> CustomStreamWrapper:
|
||||
streaming_response = CustomStreamWrapper(
|
||||
completion_stream=None,
|
||||
make_call=partial(
|
||||
make_call,
|
||||
client=client,
|
||||
api_base=api_base,
|
||||
headers=headers,
|
||||
data=json.dumps(data),
|
||||
model=model,
|
||||
messages=messages,
|
||||
logging_obj=logging_obj,
|
||||
fake_stream=True if "ai21" in api_base else False,
|
||||
bedrock_invoke_provider=self.get_bedrock_invoke_provider(model),
|
||||
json_mode=json_mode,
|
||||
),
|
||||
model=model,
|
||||
custom_llm_provider="bedrock",
|
||||
logging_obj=logging_obj,
|
||||
)
|
||||
return streaming_response
|
||||
|
||||
@track_llm_api_timing()
|
||||
def get_sync_custom_stream_wrapper(
|
||||
self,
|
||||
model: str,
|
||||
custom_llm_provider: str,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
api_base: str,
|
||||
headers: dict,
|
||||
data: dict,
|
||||
messages: list,
|
||||
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
|
||||
json_mode: Optional[bool] = None,
|
||||
) -> CustomStreamWrapper:
|
||||
if client is None or isinstance(client, AsyncHTTPHandler):
|
||||
client = _get_httpx_client(params={})
|
||||
streaming_response = CustomStreamWrapper(
|
||||
completion_stream=None,
|
||||
make_call=partial(
|
||||
make_sync_call,
|
||||
client=client,
|
||||
api_base=api_base,
|
||||
headers=headers,
|
||||
data=json.dumps(data),
|
||||
model=model,
|
||||
messages=messages,
|
||||
logging_obj=logging_obj,
|
||||
fake_stream=True if "ai21" in api_base else False,
|
||||
bedrock_invoke_provider=self.get_bedrock_invoke_provider(model),
|
||||
json_mode=json_mode,
|
||||
),
|
||||
model=model,
|
||||
custom_llm_provider="bedrock",
|
||||
logging_obj=logging_obj,
|
||||
)
|
||||
return streaming_response
|
||||
|
||||
@property
|
||||
def has_custom_stream_wrapper(self) -> bool:
|
||||
return True
|
||||
|
||||
@property
|
||||
def supports_stream_param_in_request_body(self) -> bool:
|
||||
"""
|
||||
Bedrock invoke does not allow passing `stream` in the request body.
|
||||
"""
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def get_bedrock_invoke_provider(
|
||||
model: str,
|
||||
) -> Optional[litellm.BEDROCK_INVOKE_PROVIDERS_LITERAL]:
|
||||
"""
|
||||
Helper function to get the bedrock provider from the model
|
||||
|
||||
handles 4 scenarios:
|
||||
1. model=invoke/anthropic.claude-3-5-sonnet-20240620-v1:0 -> Returns `anthropic`
|
||||
2. model=anthropic.claude-3-5-sonnet-20240620-v1:0 -> Returns `anthropic`
|
||||
3. model=llama/arn:aws:bedrock:us-east-1:086734376398:imported-model/r4c4kewx2s0n -> Returns `llama`
|
||||
4. model=us.amazon.nova-pro-v1:0 -> Returns `nova`
|
||||
"""
|
||||
if model.startswith("invoke/"):
|
||||
model = model.replace("invoke/", "", 1)
|
||||
|
||||
_split_model = model.split(".")[0]
|
||||
if _split_model in get_args(litellm.BEDROCK_INVOKE_PROVIDERS_LITERAL):
|
||||
return cast(litellm.BEDROCK_INVOKE_PROVIDERS_LITERAL, _split_model)
|
||||
|
||||
# If not a known provider, check for pattern with two slashes
|
||||
provider = AmazonInvokeConfig._get_provider_from_model_path(model)
|
||||
if provider is not None:
|
||||
return provider
|
||||
|
||||
# check if provider == "nova"
|
||||
if "nova" in model:
|
||||
return "nova"
|
||||
|
||||
for provider in get_args(litellm.BEDROCK_INVOKE_PROVIDERS_LITERAL):
|
||||
if provider in model:
|
||||
return provider
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _get_provider_from_model_path(
|
||||
model_path: str,
|
||||
) -> Optional[litellm.BEDROCK_INVOKE_PROVIDERS_LITERAL]:
|
||||
"""
|
||||
Helper function to get the provider from a model path with format: provider/model-name
|
||||
|
||||
Args:
|
||||
model_path (str): The model path (e.g., 'llama/arn:aws:bedrock:us-east-1:086734376398:imported-model/r4c4kewx2s0n' or 'anthropic/model-name')
|
||||
|
||||
Returns:
|
||||
Optional[str]: The provider name, or None if no valid provider found
|
||||
"""
|
||||
parts = model_path.split("/")
|
||||
if len(parts) >= 1:
|
||||
provider = parts[0]
|
||||
if provider in get_args(litellm.BEDROCK_INVOKE_PROVIDERS_LITERAL):
|
||||
return cast(litellm.BEDROCK_INVOKE_PROVIDERS_LITERAL, provider)
|
||||
return None
|
||||
|
||||
def get_bedrock_model_id(
|
||||
self,
|
||||
optional_params: dict,
|
||||
provider: Optional[litellm.BEDROCK_INVOKE_PROVIDERS_LITERAL],
|
||||
model: str,
|
||||
) -> str:
|
||||
modelId = optional_params.pop("model_id", None)
|
||||
if modelId is not None:
|
||||
modelId = self.encode_model_id(model_id=modelId)
|
||||
else:
|
||||
modelId = model
|
||||
|
||||
modelId = modelId.replace("invoke/", "", 1)
|
||||
if provider == "llama" and "llama/" in modelId:
|
||||
modelId = self._get_model_id_from_model_with_spec(modelId, spec="llama")
|
||||
elif provider == "deepseek_r1" and "deepseek_r1/" in modelId:
|
||||
modelId = self._get_model_id_from_model_with_spec(
|
||||
modelId, spec="deepseek_r1"
|
||||
)
|
||||
return modelId
|
||||
|
||||
def _get_model_id_from_model_with_spec(
|
||||
self,
|
||||
model: str,
|
||||
spec: str,
|
||||
) -> str:
|
||||
"""
|
||||
Remove `llama` from modelID since `llama` is simply a spec to follow for custom bedrock models
|
||||
"""
|
||||
model_id = model.replace(spec + "/", "")
|
||||
return self.encode_model_id(model_id=model_id)
|
||||
|
||||
def encode_model_id(self, model_id: str) -> str:
|
||||
"""
|
||||
Double encode the model ID to ensure it matches the expected double-encoded format.
|
||||
Args:
|
||||
model_id (str): The model ID to encode.
|
||||
Returns:
|
||||
str: The double-encoded model ID.
|
||||
"""
|
||||
return urllib.parse.quote(model_id, safe="")
|
||||
|
||||
def convert_messages_to_prompt(
|
||||
self, model, messages, provider, custom_prompt_dict
|
||||
) -> Tuple[str, Optional[list]]:
|
||||
# handle anthropic prompts and amazon titan prompts
|
||||
prompt = ""
|
||||
chat_history: Optional[list] = None
|
||||
## CUSTOM PROMPT
|
||||
if model in custom_prompt_dict:
|
||||
# check if the model has a registered custom prompt
|
||||
model_prompt_details = custom_prompt_dict[model]
|
||||
prompt = custom_prompt(
|
||||
role_dict=model_prompt_details["roles"],
|
||||
initial_prompt_value=model_prompt_details.get(
|
||||
"initial_prompt_value", ""
|
||||
),
|
||||
final_prompt_value=model_prompt_details.get("final_prompt_value", ""),
|
||||
messages=messages,
|
||||
)
|
||||
return prompt, None
|
||||
## ELSE
|
||||
if provider == "anthropic" or provider == "amazon":
|
||||
prompt = prompt_factory(
|
||||
model=model, messages=messages, custom_llm_provider="bedrock"
|
||||
)
|
||||
elif provider == "mistral":
|
||||
prompt = prompt_factory(
|
||||
model=model, messages=messages, custom_llm_provider="bedrock"
|
||||
)
|
||||
elif provider == "meta" or provider == "llama":
|
||||
prompt = prompt_factory(
|
||||
model=model, messages=messages, custom_llm_provider="bedrock"
|
||||
)
|
||||
elif provider == "cohere":
|
||||
prompt, chat_history = cohere_message_pt(messages=messages)
|
||||
elif provider == "deepseek_r1":
|
||||
prompt = deepseek_r1_pt(messages=messages)
|
||||
else:
|
||||
prompt = ""
|
||||
for message in messages:
|
||||
if "role" in message:
|
||||
if message["role"] == "user":
|
||||
prompt += f"{message['content']}"
|
||||
else:
|
||||
prompt += f"{message['content']}"
|
||||
else:
|
||||
prompt += f"{message['content']}"
|
||||
return prompt, chat_history # type: ignore
|
||||
Reference in New Issue
Block a user