structure saas with tools
This commit is contained in:
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,373 @@
|
||||
import json
|
||||
import time
|
||||
from typing import TYPE_CHECKING, Any, AsyncIterator, Iterator, List, Optional, Union
|
||||
|
||||
import httpx
|
||||
|
||||
import litellm
|
||||
from litellm.litellm_core_utils.prompt_templates.factory import cohere_messages_pt_v2
|
||||
from litellm.llms.base_llm.chat.transformation import BaseConfig, BaseLLMException
|
||||
from litellm.types.llms.openai import AllMessageValues
|
||||
from litellm.types.utils import ModelResponse, Usage
|
||||
|
||||
from ..common_utils import ModelResponseIterator as CohereModelResponseIterator
|
||||
from ..common_utils import validate_environment as cohere_validate_environment
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
|
||||
|
||||
LiteLLMLoggingObj = _LiteLLMLoggingObj
|
||||
else:
|
||||
LiteLLMLoggingObj = Any
|
||||
|
||||
|
||||
class CohereError(BaseLLMException):
|
||||
def __init__(
|
||||
self,
|
||||
status_code: int,
|
||||
message: str,
|
||||
headers: Optional[httpx.Headers] = None,
|
||||
):
|
||||
self.status_code = status_code
|
||||
self.message = message
|
||||
self.request = httpx.Request(method="POST", url="https://api.cohere.ai/v1/chat")
|
||||
self.response = httpx.Response(status_code=status_code, request=self.request)
|
||||
super().__init__(
|
||||
status_code=status_code,
|
||||
message=message,
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
|
||||
class CohereChatConfig(BaseConfig):
|
||||
"""
|
||||
Configuration class for Cohere's API interface.
|
||||
|
||||
Args:
|
||||
preamble (str, optional): When specified, the default Cohere preamble will be replaced with the provided one.
|
||||
chat_history (List[Dict[str, str]], optional): A list of previous messages between the user and the model.
|
||||
generation_id (str, optional): Unique identifier for the generated reply.
|
||||
response_id (str, optional): Unique identifier for the response.
|
||||
conversation_id (str, optional): An alternative to chat_history, creates or resumes a persisted conversation.
|
||||
prompt_truncation (str, optional): Dictates how the prompt will be constructed. Options: 'AUTO', 'AUTO_PRESERVE_ORDER', 'OFF'.
|
||||
connectors (List[Dict[str, str]], optional): List of connectors (e.g., web-search) to enrich the model's reply.
|
||||
search_queries_only (bool, optional): When true, the response will only contain a list of generated search queries.
|
||||
documents (List[Dict[str, str]], optional): A list of relevant documents that the model can cite.
|
||||
temperature (float, optional): A non-negative float that tunes the degree of randomness in generation.
|
||||
max_tokens [DEPRECATED - use max_completion_tokens] (int, optional): The maximum number of tokens the model will generate as part of the response.
|
||||
max_completion_tokens (int, optional): The maximum number of tokens the model will generate as part of the response.
|
||||
k (int, optional): Ensures only the top k most likely tokens are considered for generation at each step.
|
||||
p (float, optional): Ensures that only the most likely tokens, with total probability mass of p, are considered for generation.
|
||||
frequency_penalty (float, optional): Used to reduce repetitiveness of generated tokens.
|
||||
presence_penalty (float, optional): Used to reduce repetitiveness of generated tokens.
|
||||
tools (List[Dict[str, str]], optional): A list of available tools (functions) that the model may suggest invoking.
|
||||
tool_results (List[Dict[str, Any]], optional): A list of results from invoking tools.
|
||||
seed (int, optional): A seed to assist reproducibility of the model's response.
|
||||
"""
|
||||
|
||||
preamble: Optional[str] = None
|
||||
chat_history: Optional[list] = None
|
||||
generation_id: Optional[str] = None
|
||||
response_id: Optional[str] = None
|
||||
conversation_id: Optional[str] = None
|
||||
prompt_truncation: Optional[str] = None
|
||||
connectors: Optional[list] = None
|
||||
search_queries_only: Optional[bool] = None
|
||||
documents: Optional[list] = None
|
||||
temperature: Optional[int] = None
|
||||
max_tokens: Optional[int] = None
|
||||
max_completion_tokens: Optional[int] = None
|
||||
k: Optional[int] = None
|
||||
p: Optional[int] = None
|
||||
frequency_penalty: Optional[int] = None
|
||||
presence_penalty: Optional[int] = None
|
||||
tools: Optional[list] = None
|
||||
tool_results: Optional[list] = None
|
||||
seed: Optional[int] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
preamble: Optional[str] = None,
|
||||
chat_history: Optional[list] = None,
|
||||
generation_id: Optional[str] = None,
|
||||
response_id: Optional[str] = None,
|
||||
conversation_id: Optional[str] = None,
|
||||
prompt_truncation: Optional[str] = None,
|
||||
connectors: Optional[list] = None,
|
||||
search_queries_only: Optional[bool] = None,
|
||||
documents: Optional[list] = None,
|
||||
temperature: Optional[int] = None,
|
||||
max_tokens: Optional[int] = None,
|
||||
max_completion_tokens: Optional[int] = None,
|
||||
k: Optional[int] = None,
|
||||
p: Optional[int] = None,
|
||||
frequency_penalty: Optional[int] = None,
|
||||
presence_penalty: Optional[int] = None,
|
||||
tools: Optional[list] = None,
|
||||
tool_results: Optional[list] = None,
|
||||
seed: Optional[int] = None,
|
||||
) -> None:
|
||||
locals_ = locals().copy()
|
||||
for key, value in locals_.items():
|
||||
if key != "self" and value is not None:
|
||||
setattr(self.__class__, key, value)
|
||||
|
||||
def validate_environment(
|
||||
self,
|
||||
headers: dict,
|
||||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
) -> dict:
|
||||
return cohere_validate_environment(
|
||||
headers=headers,
|
||||
model=model,
|
||||
messages=messages,
|
||||
optional_params=optional_params,
|
||||
api_key=api_key,
|
||||
)
|
||||
|
||||
def get_supported_openai_params(self, model: str) -> List[str]:
|
||||
return [
|
||||
"stream",
|
||||
"temperature",
|
||||
"max_tokens",
|
||||
"max_completion_tokens",
|
||||
"top_p",
|
||||
"frequency_penalty",
|
||||
"presence_penalty",
|
||||
"stop",
|
||||
"n",
|
||||
"tools",
|
||||
"tool_choice",
|
||||
"seed",
|
||||
"extra_headers",
|
||||
]
|
||||
|
||||
def map_openai_params(
|
||||
self,
|
||||
non_default_params: dict,
|
||||
optional_params: dict,
|
||||
model: str,
|
||||
drop_params: bool,
|
||||
) -> dict:
|
||||
for param, value in non_default_params.items():
|
||||
if param == "stream":
|
||||
optional_params["stream"] = value
|
||||
if param == "temperature":
|
||||
optional_params["temperature"] = value
|
||||
if param == "max_tokens":
|
||||
optional_params["max_tokens"] = value
|
||||
if param == "max_completion_tokens":
|
||||
optional_params["max_tokens"] = value
|
||||
if param == "n":
|
||||
optional_params["num_generations"] = value
|
||||
if param == "top_p":
|
||||
optional_params["p"] = value
|
||||
if param == "frequency_penalty":
|
||||
optional_params["frequency_penalty"] = value
|
||||
if param == "presence_penalty":
|
||||
optional_params["presence_penalty"] = value
|
||||
if param == "stop":
|
||||
optional_params["stop_sequences"] = value
|
||||
if param == "tools":
|
||||
optional_params["tools"] = value
|
||||
if param == "seed":
|
||||
optional_params["seed"] = value
|
||||
return optional_params
|
||||
|
||||
def transform_request(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
headers: dict,
|
||||
) -> dict:
|
||||
## Load Config
|
||||
for k, v in litellm.CohereChatConfig.get_config().items():
|
||||
if (
|
||||
k not in optional_params
|
||||
): # completion(top_k=3) > cohere_config(top_k=3) <- allows for dynamic variables to be passed in
|
||||
optional_params[k] = v
|
||||
|
||||
most_recent_message, chat_history = cohere_messages_pt_v2(
|
||||
messages=messages, model=model, llm_provider="cohere_chat"
|
||||
)
|
||||
|
||||
## Handle Tool Calling
|
||||
if "tools" in optional_params:
|
||||
_is_function_call = True
|
||||
cohere_tools = self._construct_cohere_tool(tools=optional_params["tools"])
|
||||
optional_params["tools"] = cohere_tools
|
||||
if isinstance(most_recent_message, dict):
|
||||
optional_params["tool_results"] = [most_recent_message]
|
||||
elif isinstance(most_recent_message, str):
|
||||
optional_params["message"] = most_recent_message
|
||||
|
||||
## check if chat history message is 'user' and 'tool_results' is given -> force_single_step=True, else cohere api fails
|
||||
if len(chat_history) > 0 and chat_history[-1]["role"] == "USER":
|
||||
optional_params["force_single_step"] = True
|
||||
|
||||
return optional_params
|
||||
|
||||
def transform_response(
|
||||
self,
|
||||
model: str,
|
||||
raw_response: httpx.Response,
|
||||
model_response: ModelResponse,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
request_data: dict,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
encoding: Any,
|
||||
api_key: Optional[str] = None,
|
||||
json_mode: Optional[bool] = None,
|
||||
) -> ModelResponse:
|
||||
try:
|
||||
raw_response_json = raw_response.json()
|
||||
model_response.choices[0].message.content = raw_response_json["text"] # type: ignore
|
||||
except Exception:
|
||||
raise CohereError(
|
||||
message=raw_response.text, status_code=raw_response.status_code
|
||||
)
|
||||
|
||||
## ADD CITATIONS
|
||||
if "citations" in raw_response_json:
|
||||
setattr(model_response, "citations", raw_response_json["citations"])
|
||||
|
||||
## Tool calling response
|
||||
cohere_tools_response = raw_response_json.get("tool_calls", None)
|
||||
if cohere_tools_response is not None and cohere_tools_response != []:
|
||||
# convert cohere_tools_response to OpenAI response format
|
||||
tool_calls = []
|
||||
for tool in cohere_tools_response:
|
||||
function_name = tool.get("name", "")
|
||||
generation_id = tool.get("generation_id", "")
|
||||
parameters = tool.get("parameters", {})
|
||||
tool_call = {
|
||||
"id": f"call_{generation_id}",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": function_name,
|
||||
"arguments": json.dumps(parameters),
|
||||
},
|
||||
}
|
||||
tool_calls.append(tool_call)
|
||||
_message = litellm.Message(
|
||||
tool_calls=tool_calls,
|
||||
content=None,
|
||||
)
|
||||
model_response.choices[0].message = _message # type: ignore
|
||||
|
||||
## CALCULATING USAGE - use cohere `billed_units` for returning usage
|
||||
billed_units = raw_response_json.get("meta", {}).get("billed_units", {})
|
||||
|
||||
prompt_tokens = billed_units.get("input_tokens", 0)
|
||||
completion_tokens = billed_units.get("output_tokens", 0)
|
||||
|
||||
model_response.created = int(time.time())
|
||||
model_response.model = model
|
||||
usage = Usage(
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=prompt_tokens + completion_tokens,
|
||||
)
|
||||
setattr(model_response, "usage", usage)
|
||||
return model_response
|
||||
|
||||
def _construct_cohere_tool(
|
||||
self,
|
||||
tools: Optional[list] = None,
|
||||
):
|
||||
if tools is None:
|
||||
tools = []
|
||||
cohere_tools = []
|
||||
for tool in tools:
|
||||
cohere_tool = self._translate_openai_tool_to_cohere(tool)
|
||||
cohere_tools.append(cohere_tool)
|
||||
return cohere_tools
|
||||
|
||||
def _translate_openai_tool_to_cohere(
|
||||
self,
|
||||
openai_tool: dict,
|
||||
):
|
||||
# cohere tools look like this
|
||||
"""
|
||||
{
|
||||
"name": "query_daily_sales_report",
|
||||
"description": "Connects to a database to retrieve overall sales volumes and sales information for a given day.",
|
||||
"parameter_definitions": {
|
||||
"day": {
|
||||
"description": "Retrieves sales data for this day, formatted as YYYY-MM-DD.",
|
||||
"type": "str",
|
||||
"required": True
|
||||
}
|
||||
}
|
||||
}
|
||||
"""
|
||||
|
||||
# OpenAI tools look like this
|
||||
"""
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_current_weather",
|
||||
"description": "Get the current weather in a given location",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"location": {
|
||||
"type": "string",
|
||||
"description": "The city and state, e.g. San Francisco, CA",
|
||||
},
|
||||
"unit": {"type": "string", "enum": ["celsius", "fahrenheit"]},
|
||||
},
|
||||
"required": ["location"],
|
||||
},
|
||||
},
|
||||
}
|
||||
"""
|
||||
cohere_tool = {
|
||||
"name": openai_tool["function"]["name"],
|
||||
"description": openai_tool["function"]["description"],
|
||||
"parameter_definitions": {},
|
||||
}
|
||||
|
||||
for param_name, param_def in openai_tool["function"]["parameters"][
|
||||
"properties"
|
||||
].items():
|
||||
required_params = (
|
||||
openai_tool.get("function", {})
|
||||
.get("parameters", {})
|
||||
.get("required", [])
|
||||
)
|
||||
cohere_param_def = {
|
||||
"description": param_def.get("description", ""),
|
||||
"type": param_def.get("type", ""),
|
||||
"required": param_name in required_params,
|
||||
}
|
||||
cohere_tool["parameter_definitions"][param_name] = cohere_param_def
|
||||
|
||||
return cohere_tool
|
||||
|
||||
def get_model_response_iterator(
|
||||
self,
|
||||
streaming_response: Union[Iterator[str], AsyncIterator[str], ModelResponse],
|
||||
sync_stream: bool,
|
||||
json_mode: Optional[bool] = False,
|
||||
):
|
||||
return CohereModelResponseIterator(
|
||||
streaming_response=streaming_response,
|
||||
sync_stream=sync_stream,
|
||||
json_mode=json_mode,
|
||||
)
|
||||
|
||||
def get_error_class(
|
||||
self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers]
|
||||
) -> BaseLLMException:
|
||||
return CohereError(status_code=status_code, message=error_message)
|
||||
@@ -0,0 +1,356 @@
|
||||
import time
|
||||
from typing import TYPE_CHECKING, Any, AsyncIterator, Iterator, List, Optional, Union
|
||||
|
||||
import httpx
|
||||
|
||||
import litellm
|
||||
from litellm.litellm_core_utils.prompt_templates.factory import cohere_messages_pt_v2
|
||||
from litellm.llms.base_llm.chat.transformation import BaseConfig, BaseLLMException
|
||||
from litellm.types.llms.cohere import CohereV2ChatResponse
|
||||
from litellm.types.llms.openai import AllMessageValues, ChatCompletionToolCallChunk
|
||||
from litellm.types.utils import ModelResponse, Usage
|
||||
|
||||
from ..common_utils import CohereError
|
||||
from ..common_utils import ModelResponseIterator as CohereModelResponseIterator
|
||||
from ..common_utils import validate_environment as cohere_validate_environment
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
|
||||
|
||||
LiteLLMLoggingObj = _LiteLLMLoggingObj
|
||||
else:
|
||||
LiteLLMLoggingObj = Any
|
||||
|
||||
|
||||
class CohereV2ChatConfig(BaseConfig):
|
||||
"""
|
||||
Configuration class for Cohere's API interface.
|
||||
|
||||
Args:
|
||||
preamble (str, optional): When specified, the default Cohere preamble will be replaced with the provided one.
|
||||
chat_history (List[Dict[str, str]], optional): A list of previous messages between the user and the model.
|
||||
generation_id (str, optional): Unique identifier for the generated reply.
|
||||
response_id (str, optional): Unique identifier for the response.
|
||||
conversation_id (str, optional): An alternative to chat_history, creates or resumes a persisted conversation.
|
||||
prompt_truncation (str, optional): Dictates how the prompt will be constructed. Options: 'AUTO', 'AUTO_PRESERVE_ORDER', 'OFF'.
|
||||
connectors (List[Dict[str, str]], optional): List of connectors (e.g., web-search) to enrich the model's reply.
|
||||
search_queries_only (bool, optional): When true, the response will only contain a list of generated search queries.
|
||||
documents (List[Dict[str, str]], optional): A list of relevant documents that the model can cite.
|
||||
temperature (float, optional): A non-negative float that tunes the degree of randomness in generation.
|
||||
max_tokens (int, optional): The maximum number of tokens the model will generate as part of the response.
|
||||
k (int, optional): Ensures only the top k most likely tokens are considered for generation at each step.
|
||||
p (float, optional): Ensures that only the most likely tokens, with total probability mass of p, are considered for generation.
|
||||
frequency_penalty (float, optional): Used to reduce repetitiveness of generated tokens.
|
||||
presence_penalty (float, optional): Used to reduce repetitiveness of generated tokens.
|
||||
tools (List[Dict[str, str]], optional): A list of available tools (functions) that the model may suggest invoking.
|
||||
tool_results (List[Dict[str, Any]], optional): A list of results from invoking tools.
|
||||
seed (int, optional): A seed to assist reproducibility of the model's response.
|
||||
"""
|
||||
|
||||
preamble: Optional[str] = None
|
||||
chat_history: Optional[list] = None
|
||||
generation_id: Optional[str] = None
|
||||
response_id: Optional[str] = None
|
||||
conversation_id: Optional[str] = None
|
||||
prompt_truncation: Optional[str] = None
|
||||
connectors: Optional[list] = None
|
||||
search_queries_only: Optional[bool] = None
|
||||
documents: Optional[list] = None
|
||||
temperature: Optional[int] = None
|
||||
max_tokens: Optional[int] = None
|
||||
k: Optional[int] = None
|
||||
p: Optional[int] = None
|
||||
frequency_penalty: Optional[int] = None
|
||||
presence_penalty: Optional[int] = None
|
||||
tools: Optional[list] = None
|
||||
tool_results: Optional[list] = None
|
||||
seed: Optional[int] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
preamble: Optional[str] = None,
|
||||
chat_history: Optional[list] = None,
|
||||
generation_id: Optional[str] = None,
|
||||
response_id: Optional[str] = None,
|
||||
conversation_id: Optional[str] = None,
|
||||
prompt_truncation: Optional[str] = None,
|
||||
connectors: Optional[list] = None,
|
||||
search_queries_only: Optional[bool] = None,
|
||||
documents: Optional[list] = None,
|
||||
temperature: Optional[int] = None,
|
||||
max_tokens: Optional[int] = None,
|
||||
k: Optional[int] = None,
|
||||
p: Optional[int] = None,
|
||||
frequency_penalty: Optional[int] = None,
|
||||
presence_penalty: Optional[int] = None,
|
||||
tools: Optional[list] = None,
|
||||
tool_results: Optional[list] = None,
|
||||
seed: Optional[int] = None,
|
||||
) -> None:
|
||||
locals_ = locals()
|
||||
for key, value in locals_.items():
|
||||
if key != "self" and value is not None:
|
||||
setattr(self.__class__, key, value)
|
||||
|
||||
def validate_environment(
|
||||
self,
|
||||
headers: dict,
|
||||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
) -> dict:
|
||||
return cohere_validate_environment(
|
||||
headers=headers,
|
||||
model=model,
|
||||
messages=messages,
|
||||
optional_params=optional_params,
|
||||
api_key=api_key,
|
||||
)
|
||||
|
||||
def get_supported_openai_params(self, model: str) -> List[str]:
|
||||
return [
|
||||
"stream",
|
||||
"temperature",
|
||||
"max_tokens",
|
||||
"top_p",
|
||||
"frequency_penalty",
|
||||
"presence_penalty",
|
||||
"stop",
|
||||
"n",
|
||||
"tools",
|
||||
"tool_choice",
|
||||
"seed",
|
||||
"extra_headers",
|
||||
]
|
||||
|
||||
def map_openai_params(
|
||||
self,
|
||||
non_default_params: dict,
|
||||
optional_params: dict,
|
||||
model: str,
|
||||
drop_params: bool,
|
||||
) -> dict:
|
||||
for param, value in non_default_params.items():
|
||||
if param == "stream":
|
||||
optional_params["stream"] = value
|
||||
if param == "temperature":
|
||||
optional_params["temperature"] = value
|
||||
if param == "max_tokens":
|
||||
optional_params["max_tokens"] = value
|
||||
if param == "n":
|
||||
optional_params["num_generations"] = value
|
||||
if param == "top_p":
|
||||
optional_params["p"] = value
|
||||
if param == "frequency_penalty":
|
||||
optional_params["frequency_penalty"] = value
|
||||
if param == "presence_penalty":
|
||||
optional_params["presence_penalty"] = value
|
||||
if param == "stop":
|
||||
optional_params["stop_sequences"] = value
|
||||
if param == "tools":
|
||||
optional_params["tools"] = value
|
||||
if param == "seed":
|
||||
optional_params["seed"] = value
|
||||
return optional_params
|
||||
|
||||
def transform_request(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
headers: dict,
|
||||
) -> dict:
|
||||
## Load Config
|
||||
for k, v in litellm.CohereChatConfig.get_config().items():
|
||||
if (
|
||||
k not in optional_params
|
||||
): # completion(top_k=3) > cohere_config(top_k=3) <- allows for dynamic variables to be passed in
|
||||
optional_params[k] = v
|
||||
|
||||
most_recent_message, chat_history = cohere_messages_pt_v2(
|
||||
messages=messages, model=model, llm_provider="cohere_chat"
|
||||
)
|
||||
|
||||
## Handle Tool Calling
|
||||
if "tools" in optional_params:
|
||||
_is_function_call = True
|
||||
cohere_tools = self._construct_cohere_tool(tools=optional_params["tools"])
|
||||
optional_params["tools"] = cohere_tools
|
||||
if isinstance(most_recent_message, dict):
|
||||
optional_params["tool_results"] = [most_recent_message]
|
||||
elif isinstance(most_recent_message, str):
|
||||
optional_params["message"] = most_recent_message
|
||||
|
||||
## check if chat history message is 'user' and 'tool_results' is given -> force_single_step=True, else cohere api fails
|
||||
if len(chat_history) > 0 and chat_history[-1]["role"] == "USER":
|
||||
optional_params["force_single_step"] = True
|
||||
|
||||
return optional_params
|
||||
|
||||
def transform_response(
|
||||
self,
|
||||
model: str,
|
||||
raw_response: httpx.Response,
|
||||
model_response: ModelResponse,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
request_data: dict,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
encoding: Any,
|
||||
api_key: Optional[str] = None,
|
||||
json_mode: Optional[bool] = None,
|
||||
) -> ModelResponse:
|
||||
try:
|
||||
raw_response_json = raw_response.json()
|
||||
except Exception:
|
||||
raise CohereError(
|
||||
message=raw_response.text, status_code=raw_response.status_code
|
||||
)
|
||||
|
||||
try:
|
||||
cohere_v2_chat_response = CohereV2ChatResponse(**raw_response_json) # type: ignore
|
||||
except Exception:
|
||||
raise CohereError(message=raw_response.text, status_code=422)
|
||||
|
||||
cohere_content = cohere_v2_chat_response["message"].get("content", None)
|
||||
if cohere_content is not None:
|
||||
model_response.choices[0].message.content = "".join( # type: ignore
|
||||
[
|
||||
content.get("text", "")
|
||||
for content in cohere_content
|
||||
if content is not None
|
||||
]
|
||||
)
|
||||
|
||||
## ADD CITATIONS
|
||||
if "citations" in cohere_v2_chat_response:
|
||||
setattr(model_response, "citations", cohere_v2_chat_response["citations"])
|
||||
|
||||
## Tool calling response
|
||||
cohere_tools_response = cohere_v2_chat_response["message"].get("tool_calls", [])
|
||||
if cohere_tools_response is not None and cohere_tools_response != []:
|
||||
# convert cohere_tools_response to OpenAI response format
|
||||
tool_calls: List[ChatCompletionToolCallChunk] = []
|
||||
for index, tool in enumerate(cohere_tools_response):
|
||||
tool_call: ChatCompletionToolCallChunk = {
|
||||
**tool, # type: ignore
|
||||
"index": index,
|
||||
}
|
||||
tool_calls.append(tool_call)
|
||||
_message = litellm.Message(
|
||||
tool_calls=tool_calls,
|
||||
content=None,
|
||||
)
|
||||
model_response.choices[0].message = _message # type: ignore
|
||||
|
||||
## CALCULATING USAGE - use cohere `billed_units` for returning usage
|
||||
token_usage = cohere_v2_chat_response["usage"].get("tokens", {})
|
||||
prompt_tokens = token_usage.get("input_tokens", 0)
|
||||
completion_tokens = token_usage.get("output_tokens", 0)
|
||||
|
||||
model_response.created = int(time.time())
|
||||
model_response.model = model
|
||||
usage = Usage(
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=prompt_tokens + completion_tokens,
|
||||
)
|
||||
setattr(model_response, "usage", usage)
|
||||
return model_response
|
||||
|
||||
def _construct_cohere_tool(
|
||||
self,
|
||||
tools: Optional[list] = None,
|
||||
):
|
||||
if tools is None:
|
||||
tools = []
|
||||
cohere_tools = []
|
||||
for tool in tools:
|
||||
cohere_tool = self._translate_openai_tool_to_cohere(tool)
|
||||
cohere_tools.append(cohere_tool)
|
||||
return cohere_tools
|
||||
|
||||
def _translate_openai_tool_to_cohere(
|
||||
self,
|
||||
openai_tool: dict,
|
||||
):
|
||||
# cohere tools look like this
|
||||
"""
|
||||
{
|
||||
"name": "query_daily_sales_report",
|
||||
"description": "Connects to a database to retrieve overall sales volumes and sales information for a given day.",
|
||||
"parameter_definitions": {
|
||||
"day": {
|
||||
"description": "Retrieves sales data for this day, formatted as YYYY-MM-DD.",
|
||||
"type": "str",
|
||||
"required": True
|
||||
}
|
||||
}
|
||||
}
|
||||
"""
|
||||
|
||||
# OpenAI tools look like this
|
||||
"""
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_current_weather",
|
||||
"description": "Get the current weather in a given location",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"location": {
|
||||
"type": "string",
|
||||
"description": "The city and state, e.g. San Francisco, CA",
|
||||
},
|
||||
"unit": {"type": "string", "enum": ["celsius", "fahrenheit"]},
|
||||
},
|
||||
"required": ["location"],
|
||||
},
|
||||
},
|
||||
}
|
||||
"""
|
||||
cohere_tool = {
|
||||
"name": openai_tool["function"]["name"],
|
||||
"description": openai_tool["function"]["description"],
|
||||
"parameter_definitions": {},
|
||||
}
|
||||
|
||||
for param_name, param_def in openai_tool["function"]["parameters"][
|
||||
"properties"
|
||||
].items():
|
||||
required_params = (
|
||||
openai_tool.get("function", {})
|
||||
.get("parameters", {})
|
||||
.get("required", [])
|
||||
)
|
||||
cohere_param_def = {
|
||||
"description": param_def.get("description", ""),
|
||||
"type": param_def.get("type", ""),
|
||||
"required": param_name in required_params,
|
||||
}
|
||||
cohere_tool["parameter_definitions"][param_name] = cohere_param_def
|
||||
|
||||
return cohere_tool
|
||||
|
||||
def get_model_response_iterator(
|
||||
self,
|
||||
streaming_response: Union[Iterator[str], AsyncIterator[str], ModelResponse],
|
||||
sync_stream: bool,
|
||||
json_mode: Optional[bool] = False,
|
||||
):
|
||||
return CohereModelResponseIterator(
|
||||
streaming_response=streaming_response,
|
||||
sync_stream=sync_stream,
|
||||
json_mode=json_mode,
|
||||
)
|
||||
|
||||
def get_error_class(
|
||||
self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers]
|
||||
) -> BaseLLMException:
|
||||
return CohereError(status_code=status_code, message=error_message)
|
||||
@@ -0,0 +1,147 @@
|
||||
import json
|
||||
from typing import List, Optional
|
||||
|
||||
from litellm.llms.base_llm.chat.transformation import BaseLLMException
|
||||
from litellm.types.llms.openai import AllMessageValues
|
||||
from litellm.types.utils import (
|
||||
ChatCompletionToolCallChunk,
|
||||
ChatCompletionUsageBlock,
|
||||
GenericStreamingChunk,
|
||||
)
|
||||
|
||||
|
||||
class CohereError(BaseLLMException):
|
||||
def __init__(self, status_code, message):
|
||||
super().__init__(status_code=status_code, message=message)
|
||||
|
||||
|
||||
def validate_environment(
|
||||
headers: dict,
|
||||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
api_key: Optional[str] = None,
|
||||
) -> dict:
|
||||
"""
|
||||
Return headers to use for cohere chat completion request
|
||||
|
||||
Cohere API Ref: https://docs.cohere.com/reference/chat
|
||||
Expected headers:
|
||||
{
|
||||
"Request-Source": "unspecified:litellm",
|
||||
"accept": "application/json",
|
||||
"content-type": "application/json",
|
||||
"Authorization": "bearer $CO_API_KEY"
|
||||
}
|
||||
"""
|
||||
headers.update(
|
||||
{
|
||||
"Request-Source": "unspecified:litellm",
|
||||
"accept": "application/json",
|
||||
"content-type": "application/json",
|
||||
}
|
||||
)
|
||||
if api_key:
|
||||
headers["Authorization"] = f"bearer {api_key}"
|
||||
return headers
|
||||
|
||||
|
||||
class ModelResponseIterator:
|
||||
def __init__(
|
||||
self, streaming_response, sync_stream: bool, json_mode: Optional[bool] = False
|
||||
):
|
||||
self.streaming_response = streaming_response
|
||||
self.response_iterator = self.streaming_response
|
||||
self.content_blocks: List = []
|
||||
self.tool_index = -1
|
||||
self.json_mode = json_mode
|
||||
|
||||
def chunk_parser(self, chunk: dict) -> GenericStreamingChunk:
|
||||
try:
|
||||
text = ""
|
||||
tool_use: Optional[ChatCompletionToolCallChunk] = None
|
||||
is_finished = False
|
||||
finish_reason = ""
|
||||
usage: Optional[ChatCompletionUsageBlock] = None
|
||||
provider_specific_fields = None
|
||||
|
||||
index = int(chunk.get("index", 0))
|
||||
|
||||
if "text" in chunk:
|
||||
text = chunk["text"]
|
||||
elif "is_finished" in chunk and chunk["is_finished"] is True:
|
||||
is_finished = chunk["is_finished"]
|
||||
finish_reason = chunk["finish_reason"]
|
||||
|
||||
if "citations" in chunk:
|
||||
provider_specific_fields = {"citations": chunk["citations"]}
|
||||
|
||||
returned_chunk = GenericStreamingChunk(
|
||||
text=text,
|
||||
tool_use=tool_use,
|
||||
is_finished=is_finished,
|
||||
finish_reason=finish_reason,
|
||||
usage=usage,
|
||||
index=index,
|
||||
provider_specific_fields=provider_specific_fields,
|
||||
)
|
||||
|
||||
return returned_chunk
|
||||
|
||||
except json.JSONDecodeError:
|
||||
raise ValueError(f"Failed to decode JSON from chunk: {chunk}")
|
||||
|
||||
# Sync iterator
|
||||
def __iter__(self):
|
||||
return self
|
||||
|
||||
def __next__(self):
|
||||
try:
|
||||
chunk = self.response_iterator.__next__()
|
||||
except StopIteration:
|
||||
raise StopIteration
|
||||
except ValueError as e:
|
||||
raise RuntimeError(f"Error receiving chunk from stream: {e}")
|
||||
|
||||
try:
|
||||
return self.convert_str_chunk_to_generic_chunk(chunk=chunk)
|
||||
except StopIteration:
|
||||
raise StopIteration
|
||||
except ValueError as e:
|
||||
raise RuntimeError(f"Error parsing chunk: {e},\nReceived chunk: {chunk}")
|
||||
|
||||
def convert_str_chunk_to_generic_chunk(self, chunk: str) -> GenericStreamingChunk:
|
||||
"""
|
||||
Convert a string chunk to a GenericStreamingChunk
|
||||
|
||||
Note: This is used for Cohere pass through streaming logging
|
||||
"""
|
||||
str_line = chunk
|
||||
if isinstance(chunk, bytes): # Handle binary data
|
||||
str_line = chunk.decode("utf-8") # Convert bytes to string
|
||||
index = str_line.find("data:")
|
||||
if index != -1:
|
||||
str_line = str_line[index:]
|
||||
|
||||
data_json = json.loads(str_line)
|
||||
return self.chunk_parser(chunk=data_json)
|
||||
|
||||
# Async iterator
|
||||
def __aiter__(self):
|
||||
self.async_response_iterator = self.streaming_response.__aiter__()
|
||||
return self
|
||||
|
||||
async def __anext__(self):
|
||||
try:
|
||||
chunk = await self.async_response_iterator.__anext__()
|
||||
except StopAsyncIteration:
|
||||
raise StopAsyncIteration
|
||||
except ValueError as e:
|
||||
raise RuntimeError(f"Error receiving chunk from stream: {e}")
|
||||
|
||||
try:
|
||||
return self.convert_str_chunk_to_generic_chunk(chunk=chunk)
|
||||
except StopAsyncIteration:
|
||||
raise StopAsyncIteration
|
||||
except ValueError as e:
|
||||
raise RuntimeError(f"Error parsing chunk: {e},\nReceived chunk: {chunk}")
|
||||
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,5 @@
|
||||
"""
|
||||
Cohere /generate API - uses `llm_http_handler.py` to make httpx requests
|
||||
|
||||
Request/Response transformation is handled in `transformation.py`
|
||||
"""
|
||||
@@ -0,0 +1,265 @@
|
||||
import time
|
||||
from typing import TYPE_CHECKING, Any, AsyncIterator, Iterator, List, Optional, Union
|
||||
|
||||
import httpx
|
||||
|
||||
import litellm
|
||||
from litellm.litellm_core_utils.prompt_templates.common_utils import (
|
||||
convert_content_list_to_str,
|
||||
)
|
||||
from litellm.llms.base_llm.chat.transformation import BaseConfig, BaseLLMException
|
||||
from litellm.types.llms.openai import AllMessageValues
|
||||
from litellm.types.utils import Choices, Message, ModelResponse, Usage
|
||||
|
||||
from ..common_utils import CohereError
|
||||
from ..common_utils import ModelResponseIterator as CohereModelResponseIterator
|
||||
from ..common_utils import validate_environment as cohere_validate_environment
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
|
||||
|
||||
LiteLLMLoggingObj = _LiteLLMLoggingObj
|
||||
else:
|
||||
LiteLLMLoggingObj = Any
|
||||
|
||||
|
||||
class CohereTextConfig(BaseConfig):
|
||||
"""
|
||||
Reference: https://docs.cohere.com/reference/generate
|
||||
|
||||
The class `CohereConfig` provides configuration for the Cohere's API interface. Below are the parameters:
|
||||
|
||||
- `num_generations` (integer): Maximum number of generations returned. Default is 1, with a minimum value of 1 and a maximum value of 5.
|
||||
|
||||
- `max_tokens` (integer): Maximum number of tokens the model will generate as part of the response. Default value is 20.
|
||||
|
||||
- `truncate` (string): Specifies how the API handles inputs longer than maximum token length. Options include NONE, START, END. Default is END.
|
||||
|
||||
- `temperature` (number): A non-negative float controlling the randomness in generation. Lower temperatures result in less random generations. Default is 0.75.
|
||||
|
||||
- `preset` (string): Identifier of a custom preset, a combination of parameters such as prompt, temperature etc.
|
||||
|
||||
- `end_sequences` (array of strings): The generated text gets cut at the beginning of the earliest occurrence of an end sequence, which will be excluded from the text.
|
||||
|
||||
- `stop_sequences` (array of strings): The generated text gets cut at the end of the earliest occurrence of a stop sequence, which will be included in the text.
|
||||
|
||||
- `k` (integer): Limits generation at each step to top `k` most likely tokens. Default is 0.
|
||||
|
||||
- `p` (number): Limits generation at each step to most likely tokens with total probability mass of `p`. Default is 0.
|
||||
|
||||
- `frequency_penalty` (number): Reduces repetitiveness of generated tokens. Higher values apply stronger penalties to previously occurred tokens.
|
||||
|
||||
- `presence_penalty` (number): Reduces repetitiveness of generated tokens. Similar to frequency_penalty, but this penalty applies equally to all tokens that have already appeared.
|
||||
|
||||
- `return_likelihoods` (string): Specifies how and if token likelihoods are returned with the response. Options include GENERATION, ALL and NONE.
|
||||
|
||||
- `logit_bias` (object): Used to prevent the model from generating unwanted tokens or to incentivize it to include desired tokens. e.g. {"hello_world": 1233}
|
||||
"""
|
||||
|
||||
num_generations: Optional[int] = None
|
||||
max_tokens: Optional[int] = None
|
||||
truncate: Optional[str] = None
|
||||
temperature: Optional[int] = None
|
||||
preset: Optional[str] = None
|
||||
end_sequences: Optional[list] = None
|
||||
stop_sequences: Optional[list] = None
|
||||
k: Optional[int] = None
|
||||
p: Optional[int] = None
|
||||
frequency_penalty: Optional[int] = None
|
||||
presence_penalty: Optional[int] = None
|
||||
return_likelihoods: Optional[str] = None
|
||||
logit_bias: Optional[dict] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_generations: Optional[int] = None,
|
||||
max_tokens: Optional[int] = None,
|
||||
truncate: Optional[str] = None,
|
||||
temperature: Optional[int] = None,
|
||||
preset: Optional[str] = None,
|
||||
end_sequences: Optional[list] = None,
|
||||
stop_sequences: Optional[list] = None,
|
||||
k: Optional[int] = None,
|
||||
p: Optional[int] = None,
|
||||
frequency_penalty: Optional[int] = None,
|
||||
presence_penalty: Optional[int] = None,
|
||||
return_likelihoods: Optional[str] = None,
|
||||
logit_bias: Optional[dict] = None,
|
||||
) -> None:
|
||||
locals_ = locals().copy()
|
||||
for key, value in locals_.items():
|
||||
if key != "self" and value is not None:
|
||||
setattr(self.__class__, key, value)
|
||||
|
||||
@classmethod
|
||||
def get_config(cls):
|
||||
return super().get_config()
|
||||
|
||||
def validate_environment(
|
||||
self,
|
||||
headers: dict,
|
||||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
) -> dict:
|
||||
return cohere_validate_environment(
|
||||
headers=headers,
|
||||
model=model,
|
||||
messages=messages,
|
||||
optional_params=optional_params,
|
||||
api_key=api_key,
|
||||
)
|
||||
|
||||
def get_error_class(
|
||||
self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers]
|
||||
) -> BaseLLMException:
|
||||
return CohereError(status_code=status_code, message=error_message)
|
||||
|
||||
def get_supported_openai_params(self, model: str) -> List:
|
||||
return [
|
||||
"stream",
|
||||
"temperature",
|
||||
"max_tokens",
|
||||
"logit_bias",
|
||||
"top_p",
|
||||
"frequency_penalty",
|
||||
"presence_penalty",
|
||||
"stop",
|
||||
"n",
|
||||
"extra_headers",
|
||||
]
|
||||
|
||||
def map_openai_params(
|
||||
self,
|
||||
non_default_params: dict,
|
||||
optional_params: dict,
|
||||
model: str,
|
||||
drop_params: bool,
|
||||
) -> dict:
|
||||
for param, value in non_default_params.items():
|
||||
if param == "stream":
|
||||
optional_params["stream"] = value
|
||||
elif param == "temperature":
|
||||
optional_params["temperature"] = value
|
||||
elif param == "max_tokens":
|
||||
optional_params["max_tokens"] = value
|
||||
elif param == "n":
|
||||
optional_params["num_generations"] = value
|
||||
elif param == "logit_bias":
|
||||
optional_params["logit_bias"] = value
|
||||
elif param == "top_p":
|
||||
optional_params["p"] = value
|
||||
elif param == "frequency_penalty":
|
||||
optional_params["frequency_penalty"] = value
|
||||
elif param == "presence_penalty":
|
||||
optional_params["presence_penalty"] = value
|
||||
elif param == "stop":
|
||||
optional_params["stop_sequences"] = value
|
||||
return optional_params
|
||||
|
||||
def transform_request(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
headers: dict,
|
||||
) -> dict:
|
||||
prompt = " ".join(
|
||||
convert_content_list_to_str(message=message) for message in messages
|
||||
)
|
||||
|
||||
## Load Config
|
||||
config = litellm.CohereConfig.get_config()
|
||||
for k, v in config.items():
|
||||
if (
|
||||
k not in optional_params
|
||||
): # completion(top_k=3) > cohere_config(top_k=3) <- allows for dynamic variables to be passed in
|
||||
optional_params[k] = v
|
||||
|
||||
## Handle Tool Calling
|
||||
if "tools" in optional_params:
|
||||
_is_function_call = True
|
||||
tool_calling_system_prompt = self._construct_cohere_tool_for_completion_api(
|
||||
tools=optional_params["tools"]
|
||||
)
|
||||
optional_params["tools"] = tool_calling_system_prompt
|
||||
|
||||
data = {
|
||||
"model": model,
|
||||
"prompt": prompt,
|
||||
**optional_params,
|
||||
}
|
||||
|
||||
return data
|
||||
|
||||
def transform_response(
|
||||
self,
|
||||
model: str,
|
||||
raw_response: httpx.Response,
|
||||
model_response: ModelResponse,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
request_data: dict,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
encoding: Any,
|
||||
api_key: Optional[str] = None,
|
||||
json_mode: Optional[bool] = None,
|
||||
) -> ModelResponse:
|
||||
prompt = " ".join(
|
||||
convert_content_list_to_str(message=message) for message in messages
|
||||
)
|
||||
completion_response = raw_response.json()
|
||||
choices_list = []
|
||||
for idx, item in enumerate(completion_response["generations"]):
|
||||
if len(item["text"]) > 0:
|
||||
message_obj = Message(content=item["text"])
|
||||
else:
|
||||
message_obj = Message(content=None)
|
||||
choice_obj = Choices(
|
||||
finish_reason=item["finish_reason"],
|
||||
index=idx + 1,
|
||||
message=message_obj,
|
||||
)
|
||||
choices_list.append(choice_obj)
|
||||
model_response.choices = choices_list # type: ignore
|
||||
|
||||
## CALCULATING USAGE
|
||||
prompt_tokens = len(encoding.encode(prompt))
|
||||
completion_tokens = len(
|
||||
encoding.encode(model_response["choices"][0]["message"].get("content", ""))
|
||||
)
|
||||
|
||||
model_response.created = int(time.time())
|
||||
model_response.model = model
|
||||
usage = Usage(
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=prompt_tokens + completion_tokens,
|
||||
)
|
||||
setattr(model_response, "usage", usage)
|
||||
return model_response
|
||||
|
||||
def _construct_cohere_tool_for_completion_api(
|
||||
self,
|
||||
tools: Optional[List] = None,
|
||||
) -> dict:
|
||||
if tools is None:
|
||||
tools = []
|
||||
return {"tools": tools}
|
||||
|
||||
def get_model_response_iterator(
|
||||
self,
|
||||
streaming_response: Union[Iterator[str], AsyncIterator[str], ModelResponse],
|
||||
sync_stream: bool,
|
||||
json_mode: Optional[bool] = False,
|
||||
):
|
||||
return CohereModelResponseIterator(
|
||||
streaming_response=streaming_response,
|
||||
sync_stream=sync_stream,
|
||||
json_mode=json_mode,
|
||||
)
|
||||
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,177 @@
|
||||
import json
|
||||
from typing import Any, Callable, Optional, Union
|
||||
|
||||
import httpx
|
||||
|
||||
import litellm
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
|
||||
from litellm.llms.custom_httpx.http_handler import (
|
||||
AsyncHTTPHandler,
|
||||
HTTPHandler,
|
||||
get_async_httpx_client,
|
||||
)
|
||||
from litellm.types.llms.bedrock import CohereEmbeddingRequest
|
||||
from litellm.types.utils import EmbeddingResponse
|
||||
|
||||
from .transformation import CohereEmbeddingConfig
|
||||
|
||||
|
||||
def validate_environment(api_key, headers: dict):
|
||||
headers.update(
|
||||
{
|
||||
"Request-Source": "unspecified:litellm",
|
||||
"accept": "application/json",
|
||||
"content-type": "application/json",
|
||||
}
|
||||
)
|
||||
if api_key:
|
||||
headers["Authorization"] = f"Bearer {api_key}"
|
||||
return headers
|
||||
|
||||
|
||||
class CohereError(Exception):
|
||||
def __init__(self, status_code, message):
|
||||
self.status_code = status_code
|
||||
self.message = message
|
||||
self.request = httpx.Request(
|
||||
method="POST", url="https://api.cohere.ai/v1/generate"
|
||||
)
|
||||
self.response = httpx.Response(status_code=status_code, request=self.request)
|
||||
super().__init__(
|
||||
self.message
|
||||
) # Call the base class constructor with the parameters it needs
|
||||
|
||||
|
||||
async def async_embedding(
|
||||
model: str,
|
||||
data: Union[dict, CohereEmbeddingRequest],
|
||||
input: list,
|
||||
model_response: litellm.utils.EmbeddingResponse,
|
||||
timeout: Optional[Union[float, httpx.Timeout]],
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
optional_params: dict,
|
||||
api_base: str,
|
||||
api_key: Optional[str],
|
||||
headers: dict,
|
||||
encoding: Callable,
|
||||
client: Optional[AsyncHTTPHandler] = None,
|
||||
):
|
||||
## LOGGING
|
||||
logging_obj.pre_call(
|
||||
input=input,
|
||||
api_key=api_key,
|
||||
additional_args={
|
||||
"complete_input_dict": data,
|
||||
"headers": headers,
|
||||
"api_base": api_base,
|
||||
},
|
||||
)
|
||||
## COMPLETION CALL
|
||||
|
||||
if client is None:
|
||||
client = get_async_httpx_client(
|
||||
llm_provider=litellm.LlmProviders.COHERE,
|
||||
params={"timeout": timeout},
|
||||
)
|
||||
|
||||
try:
|
||||
response = await client.post(api_base, headers=headers, data=json.dumps(data))
|
||||
except httpx.HTTPStatusError as e:
|
||||
## LOGGING
|
||||
logging_obj.post_call(
|
||||
input=input,
|
||||
api_key=api_key,
|
||||
additional_args={"complete_input_dict": data},
|
||||
original_response=e.response.text,
|
||||
)
|
||||
raise e
|
||||
except Exception as e:
|
||||
## LOGGING
|
||||
logging_obj.post_call(
|
||||
input=input,
|
||||
api_key=api_key,
|
||||
additional_args={"complete_input_dict": data},
|
||||
original_response=str(e),
|
||||
)
|
||||
raise e
|
||||
|
||||
## PROCESS RESPONSE ##
|
||||
return CohereEmbeddingConfig()._transform_response(
|
||||
response=response,
|
||||
api_key=api_key,
|
||||
logging_obj=logging_obj,
|
||||
data=data,
|
||||
model_response=model_response,
|
||||
model=model,
|
||||
encoding=encoding,
|
||||
input=input,
|
||||
)
|
||||
|
||||
|
||||
def embedding(
|
||||
model: str,
|
||||
input: list,
|
||||
model_response: EmbeddingResponse,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
optional_params: dict,
|
||||
headers: dict,
|
||||
encoding: Any,
|
||||
data: Optional[Union[dict, CohereEmbeddingRequest]] = None,
|
||||
complete_api_base: Optional[str] = None,
|
||||
api_key: Optional[str] = None,
|
||||
aembedding: Optional[bool] = None,
|
||||
timeout: Optional[Union[float, httpx.Timeout]] = httpx.Timeout(None),
|
||||
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
|
||||
):
|
||||
headers = validate_environment(api_key, headers=headers)
|
||||
embed_url = complete_api_base or "https://api.cohere.ai/v1/embed"
|
||||
model = model
|
||||
|
||||
data = data or CohereEmbeddingConfig()._transform_request(
|
||||
model=model, input=input, inference_params=optional_params
|
||||
)
|
||||
|
||||
## ROUTING
|
||||
if aembedding is True:
|
||||
return async_embedding(
|
||||
model=model,
|
||||
data=data,
|
||||
input=input,
|
||||
model_response=model_response,
|
||||
timeout=timeout,
|
||||
logging_obj=logging_obj,
|
||||
optional_params=optional_params,
|
||||
api_base=embed_url,
|
||||
api_key=api_key,
|
||||
headers=headers,
|
||||
encoding=encoding,
|
||||
client=(
|
||||
client
|
||||
if client is not None and isinstance(client, AsyncHTTPHandler)
|
||||
else None
|
||||
),
|
||||
)
|
||||
|
||||
## LOGGING
|
||||
logging_obj.pre_call(
|
||||
input=input,
|
||||
api_key=api_key,
|
||||
additional_args={"complete_input_dict": data},
|
||||
)
|
||||
|
||||
## COMPLETION CALL
|
||||
if client is None or not isinstance(client, HTTPHandler):
|
||||
client = HTTPHandler(concurrent_limit=1)
|
||||
|
||||
response = client.post(embed_url, headers=headers, data=json.dumps(data))
|
||||
|
||||
return CohereEmbeddingConfig()._transform_response(
|
||||
response=response,
|
||||
api_key=api_key,
|
||||
logging_obj=logging_obj,
|
||||
data=data,
|
||||
model_response=model_response,
|
||||
model=model,
|
||||
encoding=encoding,
|
||||
input=input,
|
||||
)
|
||||
@@ -0,0 +1,151 @@
|
||||
"""
|
||||
Transformation logic from OpenAI /v1/embeddings format to Cohere's /v1/embed format.
|
||||
|
||||
Why separate file? Make it easy to see how transformation works
|
||||
|
||||
Convers
|
||||
- v3 embedding models
|
||||
- v2 embedding models
|
||||
|
||||
Docs - https://docs.cohere.com/v2/reference/embed
|
||||
"""
|
||||
|
||||
from typing import Any, List, Optional, Union
|
||||
|
||||
import httpx
|
||||
|
||||
from litellm import COHERE_DEFAULT_EMBEDDING_INPUT_TYPE
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
|
||||
from litellm.types.llms.bedrock import (
|
||||
CohereEmbeddingRequest,
|
||||
CohereEmbeddingRequestWithModel,
|
||||
)
|
||||
from litellm.types.utils import EmbeddingResponse, PromptTokensDetailsWrapper, Usage
|
||||
from litellm.utils import is_base64_encoded
|
||||
|
||||
|
||||
class CohereEmbeddingConfig:
|
||||
"""
|
||||
Reference: https://docs.cohere.com/v2/reference/embed
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
pass
|
||||
|
||||
def get_supported_openai_params(self) -> List[str]:
|
||||
return ["encoding_format"]
|
||||
|
||||
def map_openai_params(
|
||||
self, non_default_params: dict, optional_params: dict
|
||||
) -> dict:
|
||||
for k, v in non_default_params.items():
|
||||
if k == "encoding_format":
|
||||
optional_params["embedding_types"] = v
|
||||
return optional_params
|
||||
|
||||
def _is_v3_model(self, model: str) -> bool:
|
||||
return "3" in model
|
||||
|
||||
def _transform_request(
|
||||
self, model: str, input: List[str], inference_params: dict
|
||||
) -> CohereEmbeddingRequestWithModel:
|
||||
is_encoded = False
|
||||
for input_str in input:
|
||||
is_encoded = is_base64_encoded(input_str)
|
||||
|
||||
if is_encoded: # check if string is b64 encoded image or not
|
||||
transformed_request = CohereEmbeddingRequestWithModel(
|
||||
model=model,
|
||||
images=input,
|
||||
input_type="image",
|
||||
)
|
||||
else:
|
||||
transformed_request = CohereEmbeddingRequestWithModel(
|
||||
model=model,
|
||||
texts=input,
|
||||
input_type=COHERE_DEFAULT_EMBEDDING_INPUT_TYPE,
|
||||
)
|
||||
|
||||
for k, v in inference_params.items():
|
||||
transformed_request[k] = v # type: ignore
|
||||
|
||||
return transformed_request
|
||||
|
||||
def _calculate_usage(self, input: List[str], encoding: Any, meta: dict) -> Usage:
|
||||
input_tokens = 0
|
||||
|
||||
text_tokens: Optional[int] = meta.get("billed_units", {}).get("input_tokens")
|
||||
|
||||
image_tokens: Optional[int] = meta.get("billed_units", {}).get("images")
|
||||
|
||||
prompt_tokens_details: Optional[PromptTokensDetailsWrapper] = None
|
||||
if image_tokens is None and text_tokens is None:
|
||||
for text in input:
|
||||
input_tokens += len(encoding.encode(text))
|
||||
else:
|
||||
prompt_tokens_details = PromptTokensDetailsWrapper(
|
||||
image_tokens=image_tokens,
|
||||
text_tokens=text_tokens,
|
||||
)
|
||||
if image_tokens:
|
||||
input_tokens += image_tokens
|
||||
if text_tokens:
|
||||
input_tokens += text_tokens
|
||||
|
||||
return Usage(
|
||||
prompt_tokens=input_tokens,
|
||||
completion_tokens=0,
|
||||
total_tokens=input_tokens,
|
||||
prompt_tokens_details=prompt_tokens_details,
|
||||
)
|
||||
|
||||
def _transform_response(
|
||||
self,
|
||||
response: httpx.Response,
|
||||
api_key: Optional[str],
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
data: Union[dict, CohereEmbeddingRequest],
|
||||
model_response: EmbeddingResponse,
|
||||
model: str,
|
||||
encoding: Any,
|
||||
input: list,
|
||||
) -> EmbeddingResponse:
|
||||
response_json = response.json()
|
||||
## LOGGING
|
||||
logging_obj.post_call(
|
||||
input=input,
|
||||
api_key=api_key,
|
||||
additional_args={"complete_input_dict": data},
|
||||
original_response=response_json,
|
||||
)
|
||||
"""
|
||||
response
|
||||
{
|
||||
'object': "list",
|
||||
'data': [
|
||||
|
||||
]
|
||||
'model',
|
||||
'usage'
|
||||
}
|
||||
"""
|
||||
embeddings = response_json["embeddings"]
|
||||
output_data = []
|
||||
for idx, embedding in enumerate(embeddings):
|
||||
output_data.append(
|
||||
{"object": "embedding", "index": idx, "embedding": embedding}
|
||||
)
|
||||
model_response.object = "list"
|
||||
model_response.data = output_data
|
||||
model_response.model = model
|
||||
input_tokens = 0
|
||||
for text in input:
|
||||
input_tokens += len(encoding.encode(text))
|
||||
|
||||
setattr(
|
||||
model_response,
|
||||
"usage",
|
||||
self._calculate_usage(input, encoding, response_json.get("meta", {})),
|
||||
)
|
||||
|
||||
return model_response
|
||||
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,5 @@
|
||||
"""
|
||||
Cohere Rerank - uses `llm_http_handler.py` to make httpx requests
|
||||
|
||||
Request/Response transformation is handled in `transformation.py`
|
||||
"""
|
||||
@@ -0,0 +1,151 @@
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
import httpx
|
||||
|
||||
import litellm
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
|
||||
from litellm.llms.base_llm.chat.transformation import BaseLLMException
|
||||
from litellm.llms.base_llm.rerank.transformation import BaseRerankConfig
|
||||
from litellm.secret_managers.main import get_secret_str
|
||||
from litellm.types.rerank import OptionalRerankParams, RerankRequest
|
||||
from litellm.types.utils import RerankResponse
|
||||
|
||||
from ..common_utils import CohereError
|
||||
|
||||
|
||||
class CohereRerankConfig(BaseRerankConfig):
|
||||
"""
|
||||
Reference: https://docs.cohere.com/v2/reference/rerank
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
pass
|
||||
|
||||
def get_complete_url(self, api_base: Optional[str], model: str) -> str:
|
||||
if api_base:
|
||||
# Remove trailing slashes and ensure clean base URL
|
||||
api_base = api_base.rstrip("/")
|
||||
if not api_base.endswith("/v1/rerank"):
|
||||
api_base = f"{api_base}/v1/rerank"
|
||||
return api_base
|
||||
return "https://api.cohere.ai/v1/rerank"
|
||||
|
||||
def get_supported_cohere_rerank_params(self, model: str) -> list:
|
||||
return [
|
||||
"query",
|
||||
"documents",
|
||||
"top_n",
|
||||
"max_chunks_per_doc",
|
||||
"rank_fields",
|
||||
"return_documents",
|
||||
]
|
||||
|
||||
def map_cohere_rerank_params(
|
||||
self,
|
||||
non_default_params: Optional[dict],
|
||||
model: str,
|
||||
drop_params: bool,
|
||||
query: str,
|
||||
documents: List[Union[str, Dict[str, Any]]],
|
||||
custom_llm_provider: Optional[str] = None,
|
||||
top_n: Optional[int] = None,
|
||||
rank_fields: Optional[List[str]] = None,
|
||||
return_documents: Optional[bool] = True,
|
||||
max_chunks_per_doc: Optional[int] = None,
|
||||
max_tokens_per_doc: Optional[int] = None,
|
||||
) -> OptionalRerankParams:
|
||||
"""
|
||||
Map Cohere rerank params
|
||||
|
||||
No mapping required - returns all supported params
|
||||
"""
|
||||
return OptionalRerankParams(
|
||||
query=query,
|
||||
documents=documents,
|
||||
top_n=top_n,
|
||||
rank_fields=rank_fields,
|
||||
return_documents=return_documents,
|
||||
max_chunks_per_doc=max_chunks_per_doc,
|
||||
)
|
||||
|
||||
def validate_environment(
|
||||
self,
|
||||
headers: dict,
|
||||
model: str,
|
||||
api_key: Optional[str] = None,
|
||||
) -> dict:
|
||||
if api_key is None:
|
||||
api_key = (
|
||||
get_secret_str("COHERE_API_KEY")
|
||||
or get_secret_str("CO_API_KEY")
|
||||
or litellm.cohere_key
|
||||
)
|
||||
|
||||
if api_key is None:
|
||||
raise ValueError(
|
||||
"Cohere API key is required. Please set 'COHERE_API_KEY' or 'CO_API_KEY' or 'litellm.cohere_key'"
|
||||
)
|
||||
|
||||
default_headers = {
|
||||
"Authorization": f"bearer {api_key}",
|
||||
"accept": "application/json",
|
||||
"content-type": "application/json",
|
||||
}
|
||||
|
||||
# If 'Authorization' is provided in headers, it overrides the default.
|
||||
if "Authorization" in headers:
|
||||
default_headers["Authorization"] = headers["Authorization"]
|
||||
|
||||
# Merge other headers, overriding any default ones except Authorization
|
||||
return {**default_headers, **headers}
|
||||
|
||||
def transform_rerank_request(
|
||||
self,
|
||||
model: str,
|
||||
optional_rerank_params: OptionalRerankParams,
|
||||
headers: dict,
|
||||
) -> dict:
|
||||
if "query" not in optional_rerank_params:
|
||||
raise ValueError("query is required for Cohere rerank")
|
||||
if "documents" not in optional_rerank_params:
|
||||
raise ValueError("documents is required for Cohere rerank")
|
||||
rerank_request = RerankRequest(
|
||||
model=model,
|
||||
query=optional_rerank_params["query"],
|
||||
documents=optional_rerank_params["documents"],
|
||||
top_n=optional_rerank_params.get("top_n", None),
|
||||
rank_fields=optional_rerank_params.get("rank_fields", None),
|
||||
return_documents=optional_rerank_params.get("return_documents", None),
|
||||
max_chunks_per_doc=optional_rerank_params.get("max_chunks_per_doc", None),
|
||||
)
|
||||
return rerank_request.model_dump(exclude_none=True)
|
||||
|
||||
def transform_rerank_response(
|
||||
self,
|
||||
model: str,
|
||||
raw_response: httpx.Response,
|
||||
model_response: RerankResponse,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
api_key: Optional[str] = None,
|
||||
request_data: dict = {},
|
||||
optional_params: dict = {},
|
||||
litellm_params: dict = {},
|
||||
) -> RerankResponse:
|
||||
"""
|
||||
Transform Cohere rerank response
|
||||
|
||||
No transformation required, litellm follows cohere API response format
|
||||
"""
|
||||
try:
|
||||
raw_response_json = raw_response.json()
|
||||
except Exception:
|
||||
raise CohereError(
|
||||
message=raw_response.text, status_code=raw_response.status_code
|
||||
)
|
||||
|
||||
return RerankResponse(**raw_response_json)
|
||||
|
||||
def get_error_class(
|
||||
self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers]
|
||||
) -> BaseLLMException:
|
||||
return CohereError(message=error_message, status_code=status_code)
|
||||
Binary file not shown.
@@ -0,0 +1,81 @@
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from litellm.llms.cohere.rerank.transformation import CohereRerankConfig
|
||||
from litellm.types.rerank import OptionalRerankParams, RerankRequest
|
||||
|
||||
|
||||
class CohereRerankV2Config(CohereRerankConfig):
|
||||
"""
|
||||
Reference: https://docs.cohere.com/v2/reference/rerank
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
pass
|
||||
|
||||
def get_complete_url(self, api_base: Optional[str], model: str) -> str:
|
||||
if api_base:
|
||||
# Remove trailing slashes and ensure clean base URL
|
||||
api_base = api_base.rstrip("/")
|
||||
if not api_base.endswith("/v2/rerank"):
|
||||
api_base = f"{api_base}/v2/rerank"
|
||||
return api_base
|
||||
return "https://api.cohere.ai/v2/rerank"
|
||||
|
||||
def get_supported_cohere_rerank_params(self, model: str) -> list:
|
||||
return [
|
||||
"query",
|
||||
"documents",
|
||||
"top_n",
|
||||
"max_tokens_per_doc",
|
||||
"rank_fields",
|
||||
"return_documents",
|
||||
]
|
||||
|
||||
def map_cohere_rerank_params(
|
||||
self,
|
||||
non_default_params: Optional[dict],
|
||||
model: str,
|
||||
drop_params: bool,
|
||||
query: str,
|
||||
documents: List[Union[str, Dict[str, Any]]],
|
||||
custom_llm_provider: Optional[str] = None,
|
||||
top_n: Optional[int] = None,
|
||||
rank_fields: Optional[List[str]] = None,
|
||||
return_documents: Optional[bool] = True,
|
||||
max_chunks_per_doc: Optional[int] = None,
|
||||
max_tokens_per_doc: Optional[int] = None,
|
||||
) -> OptionalRerankParams:
|
||||
"""
|
||||
Map Cohere rerank params
|
||||
|
||||
No mapping required - returns all supported params
|
||||
"""
|
||||
return OptionalRerankParams(
|
||||
query=query,
|
||||
documents=documents,
|
||||
top_n=top_n,
|
||||
rank_fields=rank_fields,
|
||||
return_documents=return_documents,
|
||||
max_tokens_per_doc=max_tokens_per_doc,
|
||||
)
|
||||
|
||||
def transform_rerank_request(
|
||||
self,
|
||||
model: str,
|
||||
optional_rerank_params: OptionalRerankParams,
|
||||
headers: dict,
|
||||
) -> dict:
|
||||
if "query" not in optional_rerank_params:
|
||||
raise ValueError("query is required for Cohere rerank")
|
||||
if "documents" not in optional_rerank_params:
|
||||
raise ValueError("documents is required for Cohere rerank")
|
||||
rerank_request = RerankRequest(
|
||||
model=model,
|
||||
query=optional_rerank_params["query"],
|
||||
documents=optional_rerank_params["documents"],
|
||||
top_n=optional_rerank_params.get("top_n", None),
|
||||
rank_fields=optional_rerank_params.get("rank_fields", None),
|
||||
return_documents=optional_rerank_params.get("return_documents", None),
|
||||
max_tokens_per_doc=optional_rerank_params.get("max_tokens_per_doc", None),
|
||||
)
|
||||
return rerank_request.model_dump(exclude_none=True)
|
||||
Reference in New Issue
Block a user