structure saas with tools

This commit is contained in:
Davidson Gomes
2025-04-25 15:30:54 -03:00
commit 1aef473937
16434 changed files with 6584257 additions and 0 deletions

View File

@@ -0,0 +1,373 @@
import json
import time
from typing import TYPE_CHECKING, Any, AsyncIterator, Iterator, List, Optional, Union
import httpx
import litellm
from litellm.litellm_core_utils.prompt_templates.factory import cohere_messages_pt_v2
from litellm.llms.base_llm.chat.transformation import BaseConfig, BaseLLMException
from litellm.types.llms.openai import AllMessageValues
from litellm.types.utils import ModelResponse, Usage
from ..common_utils import ModelResponseIterator as CohereModelResponseIterator
from ..common_utils import validate_environment as cohere_validate_environment
if TYPE_CHECKING:
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
LiteLLMLoggingObj = _LiteLLMLoggingObj
else:
LiteLLMLoggingObj = Any
class CohereError(BaseLLMException):
def __init__(
self,
status_code: int,
message: str,
headers: Optional[httpx.Headers] = None,
):
self.status_code = status_code
self.message = message
self.request = httpx.Request(method="POST", url="https://api.cohere.ai/v1/chat")
self.response = httpx.Response(status_code=status_code, request=self.request)
super().__init__(
status_code=status_code,
message=message,
headers=headers,
)
class CohereChatConfig(BaseConfig):
"""
Configuration class for Cohere's API interface.
Args:
preamble (str, optional): When specified, the default Cohere preamble will be replaced with the provided one.
chat_history (List[Dict[str, str]], optional): A list of previous messages between the user and the model.
generation_id (str, optional): Unique identifier for the generated reply.
response_id (str, optional): Unique identifier for the response.
conversation_id (str, optional): An alternative to chat_history, creates or resumes a persisted conversation.
prompt_truncation (str, optional): Dictates how the prompt will be constructed. Options: 'AUTO', 'AUTO_PRESERVE_ORDER', 'OFF'.
connectors (List[Dict[str, str]], optional): List of connectors (e.g., web-search) to enrich the model's reply.
search_queries_only (bool, optional): When true, the response will only contain a list of generated search queries.
documents (List[Dict[str, str]], optional): A list of relevant documents that the model can cite.
temperature (float, optional): A non-negative float that tunes the degree of randomness in generation.
max_tokens [DEPRECATED - use max_completion_tokens] (int, optional): The maximum number of tokens the model will generate as part of the response.
max_completion_tokens (int, optional): The maximum number of tokens the model will generate as part of the response.
k (int, optional): Ensures only the top k most likely tokens are considered for generation at each step.
p (float, optional): Ensures that only the most likely tokens, with total probability mass of p, are considered for generation.
frequency_penalty (float, optional): Used to reduce repetitiveness of generated tokens.
presence_penalty (float, optional): Used to reduce repetitiveness of generated tokens.
tools (List[Dict[str, str]], optional): A list of available tools (functions) that the model may suggest invoking.
tool_results (List[Dict[str, Any]], optional): A list of results from invoking tools.
seed (int, optional): A seed to assist reproducibility of the model's response.
"""
preamble: Optional[str] = None
chat_history: Optional[list] = None
generation_id: Optional[str] = None
response_id: Optional[str] = None
conversation_id: Optional[str] = None
prompt_truncation: Optional[str] = None
connectors: Optional[list] = None
search_queries_only: Optional[bool] = None
documents: Optional[list] = None
temperature: Optional[int] = None
max_tokens: Optional[int] = None
max_completion_tokens: Optional[int] = None
k: Optional[int] = None
p: Optional[int] = None
frequency_penalty: Optional[int] = None
presence_penalty: Optional[int] = None
tools: Optional[list] = None
tool_results: Optional[list] = None
seed: Optional[int] = None
def __init__(
self,
preamble: Optional[str] = None,
chat_history: Optional[list] = None,
generation_id: Optional[str] = None,
response_id: Optional[str] = None,
conversation_id: Optional[str] = None,
prompt_truncation: Optional[str] = None,
connectors: Optional[list] = None,
search_queries_only: Optional[bool] = None,
documents: Optional[list] = None,
temperature: Optional[int] = None,
max_tokens: Optional[int] = None,
max_completion_tokens: Optional[int] = None,
k: Optional[int] = None,
p: Optional[int] = None,
frequency_penalty: Optional[int] = None,
presence_penalty: Optional[int] = None,
tools: Optional[list] = None,
tool_results: Optional[list] = None,
seed: Optional[int] = None,
) -> None:
locals_ = locals().copy()
for key, value in locals_.items():
if key != "self" and value is not None:
setattr(self.__class__, key, value)
def validate_environment(
self,
headers: dict,
model: str,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
) -> dict:
return cohere_validate_environment(
headers=headers,
model=model,
messages=messages,
optional_params=optional_params,
api_key=api_key,
)
def get_supported_openai_params(self, model: str) -> List[str]:
return [
"stream",
"temperature",
"max_tokens",
"max_completion_tokens",
"top_p",
"frequency_penalty",
"presence_penalty",
"stop",
"n",
"tools",
"tool_choice",
"seed",
"extra_headers",
]
def map_openai_params(
self,
non_default_params: dict,
optional_params: dict,
model: str,
drop_params: bool,
) -> dict:
for param, value in non_default_params.items():
if param == "stream":
optional_params["stream"] = value
if param == "temperature":
optional_params["temperature"] = value
if param == "max_tokens":
optional_params["max_tokens"] = value
if param == "max_completion_tokens":
optional_params["max_tokens"] = value
if param == "n":
optional_params["num_generations"] = value
if param == "top_p":
optional_params["p"] = value
if param == "frequency_penalty":
optional_params["frequency_penalty"] = value
if param == "presence_penalty":
optional_params["presence_penalty"] = value
if param == "stop":
optional_params["stop_sequences"] = value
if param == "tools":
optional_params["tools"] = value
if param == "seed":
optional_params["seed"] = value
return optional_params
def transform_request(
self,
model: str,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
headers: dict,
) -> dict:
## Load Config
for k, v in litellm.CohereChatConfig.get_config().items():
if (
k not in optional_params
): # completion(top_k=3) > cohere_config(top_k=3) <- allows for dynamic variables to be passed in
optional_params[k] = v
most_recent_message, chat_history = cohere_messages_pt_v2(
messages=messages, model=model, llm_provider="cohere_chat"
)
## Handle Tool Calling
if "tools" in optional_params:
_is_function_call = True
cohere_tools = self._construct_cohere_tool(tools=optional_params["tools"])
optional_params["tools"] = cohere_tools
if isinstance(most_recent_message, dict):
optional_params["tool_results"] = [most_recent_message]
elif isinstance(most_recent_message, str):
optional_params["message"] = most_recent_message
## check if chat history message is 'user' and 'tool_results' is given -> force_single_step=True, else cohere api fails
if len(chat_history) > 0 and chat_history[-1]["role"] == "USER":
optional_params["force_single_step"] = True
return optional_params
def transform_response(
self,
model: str,
raw_response: httpx.Response,
model_response: ModelResponse,
logging_obj: LiteLLMLoggingObj,
request_data: dict,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
encoding: Any,
api_key: Optional[str] = None,
json_mode: Optional[bool] = None,
) -> ModelResponse:
try:
raw_response_json = raw_response.json()
model_response.choices[0].message.content = raw_response_json["text"] # type: ignore
except Exception:
raise CohereError(
message=raw_response.text, status_code=raw_response.status_code
)
## ADD CITATIONS
if "citations" in raw_response_json:
setattr(model_response, "citations", raw_response_json["citations"])
## Tool calling response
cohere_tools_response = raw_response_json.get("tool_calls", None)
if cohere_tools_response is not None and cohere_tools_response != []:
# convert cohere_tools_response to OpenAI response format
tool_calls = []
for tool in cohere_tools_response:
function_name = tool.get("name", "")
generation_id = tool.get("generation_id", "")
parameters = tool.get("parameters", {})
tool_call = {
"id": f"call_{generation_id}",
"type": "function",
"function": {
"name": function_name,
"arguments": json.dumps(parameters),
},
}
tool_calls.append(tool_call)
_message = litellm.Message(
tool_calls=tool_calls,
content=None,
)
model_response.choices[0].message = _message # type: ignore
## CALCULATING USAGE - use cohere `billed_units` for returning usage
billed_units = raw_response_json.get("meta", {}).get("billed_units", {})
prompt_tokens = billed_units.get("input_tokens", 0)
completion_tokens = billed_units.get("output_tokens", 0)
model_response.created = int(time.time())
model_response.model = model
usage = Usage(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
)
setattr(model_response, "usage", usage)
return model_response
def _construct_cohere_tool(
self,
tools: Optional[list] = None,
):
if tools is None:
tools = []
cohere_tools = []
for tool in tools:
cohere_tool = self._translate_openai_tool_to_cohere(tool)
cohere_tools.append(cohere_tool)
return cohere_tools
def _translate_openai_tool_to_cohere(
self,
openai_tool: dict,
):
# cohere tools look like this
"""
{
"name": "query_daily_sales_report",
"description": "Connects to a database to retrieve overall sales volumes and sales information for a given day.",
"parameter_definitions": {
"day": {
"description": "Retrieves sales data for this day, formatted as YYYY-MM-DD.",
"type": "str",
"required": True
}
}
}
"""
# OpenAI tools look like this
"""
{
"type": "function",
"function": {
"name": "get_current_weather",
"description": "Get the current weather in a given location",
"parameters": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state, e.g. San Francisco, CA",
},
"unit": {"type": "string", "enum": ["celsius", "fahrenheit"]},
},
"required": ["location"],
},
},
}
"""
cohere_tool = {
"name": openai_tool["function"]["name"],
"description": openai_tool["function"]["description"],
"parameter_definitions": {},
}
for param_name, param_def in openai_tool["function"]["parameters"][
"properties"
].items():
required_params = (
openai_tool.get("function", {})
.get("parameters", {})
.get("required", [])
)
cohere_param_def = {
"description": param_def.get("description", ""),
"type": param_def.get("type", ""),
"required": param_name in required_params,
}
cohere_tool["parameter_definitions"][param_name] = cohere_param_def
return cohere_tool
def get_model_response_iterator(
self,
streaming_response: Union[Iterator[str], AsyncIterator[str], ModelResponse],
sync_stream: bool,
json_mode: Optional[bool] = False,
):
return CohereModelResponseIterator(
streaming_response=streaming_response,
sync_stream=sync_stream,
json_mode=json_mode,
)
def get_error_class(
self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers]
) -> BaseLLMException:
return CohereError(status_code=status_code, message=error_message)

View File

@@ -0,0 +1,356 @@
import time
from typing import TYPE_CHECKING, Any, AsyncIterator, Iterator, List, Optional, Union
import httpx
import litellm
from litellm.litellm_core_utils.prompt_templates.factory import cohere_messages_pt_v2
from litellm.llms.base_llm.chat.transformation import BaseConfig, BaseLLMException
from litellm.types.llms.cohere import CohereV2ChatResponse
from litellm.types.llms.openai import AllMessageValues, ChatCompletionToolCallChunk
from litellm.types.utils import ModelResponse, Usage
from ..common_utils import CohereError
from ..common_utils import ModelResponseIterator as CohereModelResponseIterator
from ..common_utils import validate_environment as cohere_validate_environment
if TYPE_CHECKING:
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
LiteLLMLoggingObj = _LiteLLMLoggingObj
else:
LiteLLMLoggingObj = Any
class CohereV2ChatConfig(BaseConfig):
"""
Configuration class for Cohere's API interface.
Args:
preamble (str, optional): When specified, the default Cohere preamble will be replaced with the provided one.
chat_history (List[Dict[str, str]], optional): A list of previous messages between the user and the model.
generation_id (str, optional): Unique identifier for the generated reply.
response_id (str, optional): Unique identifier for the response.
conversation_id (str, optional): An alternative to chat_history, creates or resumes a persisted conversation.
prompt_truncation (str, optional): Dictates how the prompt will be constructed. Options: 'AUTO', 'AUTO_PRESERVE_ORDER', 'OFF'.
connectors (List[Dict[str, str]], optional): List of connectors (e.g., web-search) to enrich the model's reply.
search_queries_only (bool, optional): When true, the response will only contain a list of generated search queries.
documents (List[Dict[str, str]], optional): A list of relevant documents that the model can cite.
temperature (float, optional): A non-negative float that tunes the degree of randomness in generation.
max_tokens (int, optional): The maximum number of tokens the model will generate as part of the response.
k (int, optional): Ensures only the top k most likely tokens are considered for generation at each step.
p (float, optional): Ensures that only the most likely tokens, with total probability mass of p, are considered for generation.
frequency_penalty (float, optional): Used to reduce repetitiveness of generated tokens.
presence_penalty (float, optional): Used to reduce repetitiveness of generated tokens.
tools (List[Dict[str, str]], optional): A list of available tools (functions) that the model may suggest invoking.
tool_results (List[Dict[str, Any]], optional): A list of results from invoking tools.
seed (int, optional): A seed to assist reproducibility of the model's response.
"""
preamble: Optional[str] = None
chat_history: Optional[list] = None
generation_id: Optional[str] = None
response_id: Optional[str] = None
conversation_id: Optional[str] = None
prompt_truncation: Optional[str] = None
connectors: Optional[list] = None
search_queries_only: Optional[bool] = None
documents: Optional[list] = None
temperature: Optional[int] = None
max_tokens: Optional[int] = None
k: Optional[int] = None
p: Optional[int] = None
frequency_penalty: Optional[int] = None
presence_penalty: Optional[int] = None
tools: Optional[list] = None
tool_results: Optional[list] = None
seed: Optional[int] = None
def __init__(
self,
preamble: Optional[str] = None,
chat_history: Optional[list] = None,
generation_id: Optional[str] = None,
response_id: Optional[str] = None,
conversation_id: Optional[str] = None,
prompt_truncation: Optional[str] = None,
connectors: Optional[list] = None,
search_queries_only: Optional[bool] = None,
documents: Optional[list] = None,
temperature: Optional[int] = None,
max_tokens: Optional[int] = None,
k: Optional[int] = None,
p: Optional[int] = None,
frequency_penalty: Optional[int] = None,
presence_penalty: Optional[int] = None,
tools: Optional[list] = None,
tool_results: Optional[list] = None,
seed: Optional[int] = None,
) -> None:
locals_ = locals()
for key, value in locals_.items():
if key != "self" and value is not None:
setattr(self.__class__, key, value)
def validate_environment(
self,
headers: dict,
model: str,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
) -> dict:
return cohere_validate_environment(
headers=headers,
model=model,
messages=messages,
optional_params=optional_params,
api_key=api_key,
)
def get_supported_openai_params(self, model: str) -> List[str]:
return [
"stream",
"temperature",
"max_tokens",
"top_p",
"frequency_penalty",
"presence_penalty",
"stop",
"n",
"tools",
"tool_choice",
"seed",
"extra_headers",
]
def map_openai_params(
self,
non_default_params: dict,
optional_params: dict,
model: str,
drop_params: bool,
) -> dict:
for param, value in non_default_params.items():
if param == "stream":
optional_params["stream"] = value
if param == "temperature":
optional_params["temperature"] = value
if param == "max_tokens":
optional_params["max_tokens"] = value
if param == "n":
optional_params["num_generations"] = value
if param == "top_p":
optional_params["p"] = value
if param == "frequency_penalty":
optional_params["frequency_penalty"] = value
if param == "presence_penalty":
optional_params["presence_penalty"] = value
if param == "stop":
optional_params["stop_sequences"] = value
if param == "tools":
optional_params["tools"] = value
if param == "seed":
optional_params["seed"] = value
return optional_params
def transform_request(
self,
model: str,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
headers: dict,
) -> dict:
## Load Config
for k, v in litellm.CohereChatConfig.get_config().items():
if (
k not in optional_params
): # completion(top_k=3) > cohere_config(top_k=3) <- allows for dynamic variables to be passed in
optional_params[k] = v
most_recent_message, chat_history = cohere_messages_pt_v2(
messages=messages, model=model, llm_provider="cohere_chat"
)
## Handle Tool Calling
if "tools" in optional_params:
_is_function_call = True
cohere_tools = self._construct_cohere_tool(tools=optional_params["tools"])
optional_params["tools"] = cohere_tools
if isinstance(most_recent_message, dict):
optional_params["tool_results"] = [most_recent_message]
elif isinstance(most_recent_message, str):
optional_params["message"] = most_recent_message
## check if chat history message is 'user' and 'tool_results' is given -> force_single_step=True, else cohere api fails
if len(chat_history) > 0 and chat_history[-1]["role"] == "USER":
optional_params["force_single_step"] = True
return optional_params
def transform_response(
self,
model: str,
raw_response: httpx.Response,
model_response: ModelResponse,
logging_obj: LiteLLMLoggingObj,
request_data: dict,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
encoding: Any,
api_key: Optional[str] = None,
json_mode: Optional[bool] = None,
) -> ModelResponse:
try:
raw_response_json = raw_response.json()
except Exception:
raise CohereError(
message=raw_response.text, status_code=raw_response.status_code
)
try:
cohere_v2_chat_response = CohereV2ChatResponse(**raw_response_json) # type: ignore
except Exception:
raise CohereError(message=raw_response.text, status_code=422)
cohere_content = cohere_v2_chat_response["message"].get("content", None)
if cohere_content is not None:
model_response.choices[0].message.content = "".join( # type: ignore
[
content.get("text", "")
for content in cohere_content
if content is not None
]
)
## ADD CITATIONS
if "citations" in cohere_v2_chat_response:
setattr(model_response, "citations", cohere_v2_chat_response["citations"])
## Tool calling response
cohere_tools_response = cohere_v2_chat_response["message"].get("tool_calls", [])
if cohere_tools_response is not None and cohere_tools_response != []:
# convert cohere_tools_response to OpenAI response format
tool_calls: List[ChatCompletionToolCallChunk] = []
for index, tool in enumerate(cohere_tools_response):
tool_call: ChatCompletionToolCallChunk = {
**tool, # type: ignore
"index": index,
}
tool_calls.append(tool_call)
_message = litellm.Message(
tool_calls=tool_calls,
content=None,
)
model_response.choices[0].message = _message # type: ignore
## CALCULATING USAGE - use cohere `billed_units` for returning usage
token_usage = cohere_v2_chat_response["usage"].get("tokens", {})
prompt_tokens = token_usage.get("input_tokens", 0)
completion_tokens = token_usage.get("output_tokens", 0)
model_response.created = int(time.time())
model_response.model = model
usage = Usage(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
)
setattr(model_response, "usage", usage)
return model_response
def _construct_cohere_tool(
self,
tools: Optional[list] = None,
):
if tools is None:
tools = []
cohere_tools = []
for tool in tools:
cohere_tool = self._translate_openai_tool_to_cohere(tool)
cohere_tools.append(cohere_tool)
return cohere_tools
def _translate_openai_tool_to_cohere(
self,
openai_tool: dict,
):
# cohere tools look like this
"""
{
"name": "query_daily_sales_report",
"description": "Connects to a database to retrieve overall sales volumes and sales information for a given day.",
"parameter_definitions": {
"day": {
"description": "Retrieves sales data for this day, formatted as YYYY-MM-DD.",
"type": "str",
"required": True
}
}
}
"""
# OpenAI tools look like this
"""
{
"type": "function",
"function": {
"name": "get_current_weather",
"description": "Get the current weather in a given location",
"parameters": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state, e.g. San Francisco, CA",
},
"unit": {"type": "string", "enum": ["celsius", "fahrenheit"]},
},
"required": ["location"],
},
},
}
"""
cohere_tool = {
"name": openai_tool["function"]["name"],
"description": openai_tool["function"]["description"],
"parameter_definitions": {},
}
for param_name, param_def in openai_tool["function"]["parameters"][
"properties"
].items():
required_params = (
openai_tool.get("function", {})
.get("parameters", {})
.get("required", [])
)
cohere_param_def = {
"description": param_def.get("description", ""),
"type": param_def.get("type", ""),
"required": param_name in required_params,
}
cohere_tool["parameter_definitions"][param_name] = cohere_param_def
return cohere_tool
def get_model_response_iterator(
self,
streaming_response: Union[Iterator[str], AsyncIterator[str], ModelResponse],
sync_stream: bool,
json_mode: Optional[bool] = False,
):
return CohereModelResponseIterator(
streaming_response=streaming_response,
sync_stream=sync_stream,
json_mode=json_mode,
)
def get_error_class(
self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers]
) -> BaseLLMException:
return CohereError(status_code=status_code, message=error_message)

View File

@@ -0,0 +1,147 @@
import json
from typing import List, Optional
from litellm.llms.base_llm.chat.transformation import BaseLLMException
from litellm.types.llms.openai import AllMessageValues
from litellm.types.utils import (
ChatCompletionToolCallChunk,
ChatCompletionUsageBlock,
GenericStreamingChunk,
)
class CohereError(BaseLLMException):
def __init__(self, status_code, message):
super().__init__(status_code=status_code, message=message)
def validate_environment(
headers: dict,
model: str,
messages: List[AllMessageValues],
optional_params: dict,
api_key: Optional[str] = None,
) -> dict:
"""
Return headers to use for cohere chat completion request
Cohere API Ref: https://docs.cohere.com/reference/chat
Expected headers:
{
"Request-Source": "unspecified:litellm",
"accept": "application/json",
"content-type": "application/json",
"Authorization": "bearer $CO_API_KEY"
}
"""
headers.update(
{
"Request-Source": "unspecified:litellm",
"accept": "application/json",
"content-type": "application/json",
}
)
if api_key:
headers["Authorization"] = f"bearer {api_key}"
return headers
class ModelResponseIterator:
def __init__(
self, streaming_response, sync_stream: bool, json_mode: Optional[bool] = False
):
self.streaming_response = streaming_response
self.response_iterator = self.streaming_response
self.content_blocks: List = []
self.tool_index = -1
self.json_mode = json_mode
def chunk_parser(self, chunk: dict) -> GenericStreamingChunk:
try:
text = ""
tool_use: Optional[ChatCompletionToolCallChunk] = None
is_finished = False
finish_reason = ""
usage: Optional[ChatCompletionUsageBlock] = None
provider_specific_fields = None
index = int(chunk.get("index", 0))
if "text" in chunk:
text = chunk["text"]
elif "is_finished" in chunk and chunk["is_finished"] is True:
is_finished = chunk["is_finished"]
finish_reason = chunk["finish_reason"]
if "citations" in chunk:
provider_specific_fields = {"citations": chunk["citations"]}
returned_chunk = GenericStreamingChunk(
text=text,
tool_use=tool_use,
is_finished=is_finished,
finish_reason=finish_reason,
usage=usage,
index=index,
provider_specific_fields=provider_specific_fields,
)
return returned_chunk
except json.JSONDecodeError:
raise ValueError(f"Failed to decode JSON from chunk: {chunk}")
# Sync iterator
def __iter__(self):
return self
def __next__(self):
try:
chunk = self.response_iterator.__next__()
except StopIteration:
raise StopIteration
except ValueError as e:
raise RuntimeError(f"Error receiving chunk from stream: {e}")
try:
return self.convert_str_chunk_to_generic_chunk(chunk=chunk)
except StopIteration:
raise StopIteration
except ValueError as e:
raise RuntimeError(f"Error parsing chunk: {e},\nReceived chunk: {chunk}")
def convert_str_chunk_to_generic_chunk(self, chunk: str) -> GenericStreamingChunk:
"""
Convert a string chunk to a GenericStreamingChunk
Note: This is used for Cohere pass through streaming logging
"""
str_line = chunk
if isinstance(chunk, bytes): # Handle binary data
str_line = chunk.decode("utf-8") # Convert bytes to string
index = str_line.find("data:")
if index != -1:
str_line = str_line[index:]
data_json = json.loads(str_line)
return self.chunk_parser(chunk=data_json)
# Async iterator
def __aiter__(self):
self.async_response_iterator = self.streaming_response.__aiter__()
return self
async def __anext__(self):
try:
chunk = await self.async_response_iterator.__anext__()
except StopAsyncIteration:
raise StopAsyncIteration
except ValueError as e:
raise RuntimeError(f"Error receiving chunk from stream: {e}")
try:
return self.convert_str_chunk_to_generic_chunk(chunk=chunk)
except StopAsyncIteration:
raise StopAsyncIteration
except ValueError as e:
raise RuntimeError(f"Error parsing chunk: {e},\nReceived chunk: {chunk}")

View File

@@ -0,0 +1,5 @@
"""
Cohere /generate API - uses `llm_http_handler.py` to make httpx requests
Request/Response transformation is handled in `transformation.py`
"""

View File

@@ -0,0 +1,265 @@
import time
from typing import TYPE_CHECKING, Any, AsyncIterator, Iterator, List, Optional, Union
import httpx
import litellm
from litellm.litellm_core_utils.prompt_templates.common_utils import (
convert_content_list_to_str,
)
from litellm.llms.base_llm.chat.transformation import BaseConfig, BaseLLMException
from litellm.types.llms.openai import AllMessageValues
from litellm.types.utils import Choices, Message, ModelResponse, Usage
from ..common_utils import CohereError
from ..common_utils import ModelResponseIterator as CohereModelResponseIterator
from ..common_utils import validate_environment as cohere_validate_environment
if TYPE_CHECKING:
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
LiteLLMLoggingObj = _LiteLLMLoggingObj
else:
LiteLLMLoggingObj = Any
class CohereTextConfig(BaseConfig):
"""
Reference: https://docs.cohere.com/reference/generate
The class `CohereConfig` provides configuration for the Cohere's API interface. Below are the parameters:
- `num_generations` (integer): Maximum number of generations returned. Default is 1, with a minimum value of 1 and a maximum value of 5.
- `max_tokens` (integer): Maximum number of tokens the model will generate as part of the response. Default value is 20.
- `truncate` (string): Specifies how the API handles inputs longer than maximum token length. Options include NONE, START, END. Default is END.
- `temperature` (number): A non-negative float controlling the randomness in generation. Lower temperatures result in less random generations. Default is 0.75.
- `preset` (string): Identifier of a custom preset, a combination of parameters such as prompt, temperature etc.
- `end_sequences` (array of strings): The generated text gets cut at the beginning of the earliest occurrence of an end sequence, which will be excluded from the text.
- `stop_sequences` (array of strings): The generated text gets cut at the end of the earliest occurrence of a stop sequence, which will be included in the text.
- `k` (integer): Limits generation at each step to top `k` most likely tokens. Default is 0.
- `p` (number): Limits generation at each step to most likely tokens with total probability mass of `p`. Default is 0.
- `frequency_penalty` (number): Reduces repetitiveness of generated tokens. Higher values apply stronger penalties to previously occurred tokens.
- `presence_penalty` (number): Reduces repetitiveness of generated tokens. Similar to frequency_penalty, but this penalty applies equally to all tokens that have already appeared.
- `return_likelihoods` (string): Specifies how and if token likelihoods are returned with the response. Options include GENERATION, ALL and NONE.
- `logit_bias` (object): Used to prevent the model from generating unwanted tokens or to incentivize it to include desired tokens. e.g. {"hello_world": 1233}
"""
num_generations: Optional[int] = None
max_tokens: Optional[int] = None
truncate: Optional[str] = None
temperature: Optional[int] = None
preset: Optional[str] = None
end_sequences: Optional[list] = None
stop_sequences: Optional[list] = None
k: Optional[int] = None
p: Optional[int] = None
frequency_penalty: Optional[int] = None
presence_penalty: Optional[int] = None
return_likelihoods: Optional[str] = None
logit_bias: Optional[dict] = None
def __init__(
self,
num_generations: Optional[int] = None,
max_tokens: Optional[int] = None,
truncate: Optional[str] = None,
temperature: Optional[int] = None,
preset: Optional[str] = None,
end_sequences: Optional[list] = None,
stop_sequences: Optional[list] = None,
k: Optional[int] = None,
p: Optional[int] = None,
frequency_penalty: Optional[int] = None,
presence_penalty: Optional[int] = None,
return_likelihoods: Optional[str] = None,
logit_bias: Optional[dict] = None,
) -> None:
locals_ = locals().copy()
for key, value in locals_.items():
if key != "self" and value is not None:
setattr(self.__class__, key, value)
@classmethod
def get_config(cls):
return super().get_config()
def validate_environment(
self,
headers: dict,
model: str,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
) -> dict:
return cohere_validate_environment(
headers=headers,
model=model,
messages=messages,
optional_params=optional_params,
api_key=api_key,
)
def get_error_class(
self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers]
) -> BaseLLMException:
return CohereError(status_code=status_code, message=error_message)
def get_supported_openai_params(self, model: str) -> List:
return [
"stream",
"temperature",
"max_tokens",
"logit_bias",
"top_p",
"frequency_penalty",
"presence_penalty",
"stop",
"n",
"extra_headers",
]
def map_openai_params(
self,
non_default_params: dict,
optional_params: dict,
model: str,
drop_params: bool,
) -> dict:
for param, value in non_default_params.items():
if param == "stream":
optional_params["stream"] = value
elif param == "temperature":
optional_params["temperature"] = value
elif param == "max_tokens":
optional_params["max_tokens"] = value
elif param == "n":
optional_params["num_generations"] = value
elif param == "logit_bias":
optional_params["logit_bias"] = value
elif param == "top_p":
optional_params["p"] = value
elif param == "frequency_penalty":
optional_params["frequency_penalty"] = value
elif param == "presence_penalty":
optional_params["presence_penalty"] = value
elif param == "stop":
optional_params["stop_sequences"] = value
return optional_params
def transform_request(
self,
model: str,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
headers: dict,
) -> dict:
prompt = " ".join(
convert_content_list_to_str(message=message) for message in messages
)
## Load Config
config = litellm.CohereConfig.get_config()
for k, v in config.items():
if (
k not in optional_params
): # completion(top_k=3) > cohere_config(top_k=3) <- allows for dynamic variables to be passed in
optional_params[k] = v
## Handle Tool Calling
if "tools" in optional_params:
_is_function_call = True
tool_calling_system_prompt = self._construct_cohere_tool_for_completion_api(
tools=optional_params["tools"]
)
optional_params["tools"] = tool_calling_system_prompt
data = {
"model": model,
"prompt": prompt,
**optional_params,
}
return data
def transform_response(
self,
model: str,
raw_response: httpx.Response,
model_response: ModelResponse,
logging_obj: LiteLLMLoggingObj,
request_data: dict,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
encoding: Any,
api_key: Optional[str] = None,
json_mode: Optional[bool] = None,
) -> ModelResponse:
prompt = " ".join(
convert_content_list_to_str(message=message) for message in messages
)
completion_response = raw_response.json()
choices_list = []
for idx, item in enumerate(completion_response["generations"]):
if len(item["text"]) > 0:
message_obj = Message(content=item["text"])
else:
message_obj = Message(content=None)
choice_obj = Choices(
finish_reason=item["finish_reason"],
index=idx + 1,
message=message_obj,
)
choices_list.append(choice_obj)
model_response.choices = choices_list # type: ignore
## CALCULATING USAGE
prompt_tokens = len(encoding.encode(prompt))
completion_tokens = len(
encoding.encode(model_response["choices"][0]["message"].get("content", ""))
)
model_response.created = int(time.time())
model_response.model = model
usage = Usage(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
)
setattr(model_response, "usage", usage)
return model_response
def _construct_cohere_tool_for_completion_api(
self,
tools: Optional[List] = None,
) -> dict:
if tools is None:
tools = []
return {"tools": tools}
def get_model_response_iterator(
self,
streaming_response: Union[Iterator[str], AsyncIterator[str], ModelResponse],
sync_stream: bool,
json_mode: Optional[bool] = False,
):
return CohereModelResponseIterator(
streaming_response=streaming_response,
sync_stream=sync_stream,
json_mode=json_mode,
)

View File

@@ -0,0 +1,177 @@
import json
from typing import Any, Callable, Optional, Union
import httpx
import litellm
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
from litellm.llms.custom_httpx.http_handler import (
AsyncHTTPHandler,
HTTPHandler,
get_async_httpx_client,
)
from litellm.types.llms.bedrock import CohereEmbeddingRequest
from litellm.types.utils import EmbeddingResponse
from .transformation import CohereEmbeddingConfig
def validate_environment(api_key, headers: dict):
headers.update(
{
"Request-Source": "unspecified:litellm",
"accept": "application/json",
"content-type": "application/json",
}
)
if api_key:
headers["Authorization"] = f"Bearer {api_key}"
return headers
class CohereError(Exception):
def __init__(self, status_code, message):
self.status_code = status_code
self.message = message
self.request = httpx.Request(
method="POST", url="https://api.cohere.ai/v1/generate"
)
self.response = httpx.Response(status_code=status_code, request=self.request)
super().__init__(
self.message
) # Call the base class constructor with the parameters it needs
async def async_embedding(
model: str,
data: Union[dict, CohereEmbeddingRequest],
input: list,
model_response: litellm.utils.EmbeddingResponse,
timeout: Optional[Union[float, httpx.Timeout]],
logging_obj: LiteLLMLoggingObj,
optional_params: dict,
api_base: str,
api_key: Optional[str],
headers: dict,
encoding: Callable,
client: Optional[AsyncHTTPHandler] = None,
):
## LOGGING
logging_obj.pre_call(
input=input,
api_key=api_key,
additional_args={
"complete_input_dict": data,
"headers": headers,
"api_base": api_base,
},
)
## COMPLETION CALL
if client is None:
client = get_async_httpx_client(
llm_provider=litellm.LlmProviders.COHERE,
params={"timeout": timeout},
)
try:
response = await client.post(api_base, headers=headers, data=json.dumps(data))
except httpx.HTTPStatusError as e:
## LOGGING
logging_obj.post_call(
input=input,
api_key=api_key,
additional_args={"complete_input_dict": data},
original_response=e.response.text,
)
raise e
except Exception as e:
## LOGGING
logging_obj.post_call(
input=input,
api_key=api_key,
additional_args={"complete_input_dict": data},
original_response=str(e),
)
raise e
## PROCESS RESPONSE ##
return CohereEmbeddingConfig()._transform_response(
response=response,
api_key=api_key,
logging_obj=logging_obj,
data=data,
model_response=model_response,
model=model,
encoding=encoding,
input=input,
)
def embedding(
model: str,
input: list,
model_response: EmbeddingResponse,
logging_obj: LiteLLMLoggingObj,
optional_params: dict,
headers: dict,
encoding: Any,
data: Optional[Union[dict, CohereEmbeddingRequest]] = None,
complete_api_base: Optional[str] = None,
api_key: Optional[str] = None,
aembedding: Optional[bool] = None,
timeout: Optional[Union[float, httpx.Timeout]] = httpx.Timeout(None),
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
):
headers = validate_environment(api_key, headers=headers)
embed_url = complete_api_base or "https://api.cohere.ai/v1/embed"
model = model
data = data or CohereEmbeddingConfig()._transform_request(
model=model, input=input, inference_params=optional_params
)
## ROUTING
if aembedding is True:
return async_embedding(
model=model,
data=data,
input=input,
model_response=model_response,
timeout=timeout,
logging_obj=logging_obj,
optional_params=optional_params,
api_base=embed_url,
api_key=api_key,
headers=headers,
encoding=encoding,
client=(
client
if client is not None and isinstance(client, AsyncHTTPHandler)
else None
),
)
## LOGGING
logging_obj.pre_call(
input=input,
api_key=api_key,
additional_args={"complete_input_dict": data},
)
## COMPLETION CALL
if client is None or not isinstance(client, HTTPHandler):
client = HTTPHandler(concurrent_limit=1)
response = client.post(embed_url, headers=headers, data=json.dumps(data))
return CohereEmbeddingConfig()._transform_response(
response=response,
api_key=api_key,
logging_obj=logging_obj,
data=data,
model_response=model_response,
model=model,
encoding=encoding,
input=input,
)

View File

@@ -0,0 +1,151 @@
"""
Transformation logic from OpenAI /v1/embeddings format to Cohere's /v1/embed format.
Why separate file? Make it easy to see how transformation works
Convers
- v3 embedding models
- v2 embedding models
Docs - https://docs.cohere.com/v2/reference/embed
"""
from typing import Any, List, Optional, Union
import httpx
from litellm import COHERE_DEFAULT_EMBEDDING_INPUT_TYPE
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
from litellm.types.llms.bedrock import (
CohereEmbeddingRequest,
CohereEmbeddingRequestWithModel,
)
from litellm.types.utils import EmbeddingResponse, PromptTokensDetailsWrapper, Usage
from litellm.utils import is_base64_encoded
class CohereEmbeddingConfig:
"""
Reference: https://docs.cohere.com/v2/reference/embed
"""
def __init__(self) -> None:
pass
def get_supported_openai_params(self) -> List[str]:
return ["encoding_format"]
def map_openai_params(
self, non_default_params: dict, optional_params: dict
) -> dict:
for k, v in non_default_params.items():
if k == "encoding_format":
optional_params["embedding_types"] = v
return optional_params
def _is_v3_model(self, model: str) -> bool:
return "3" in model
def _transform_request(
self, model: str, input: List[str], inference_params: dict
) -> CohereEmbeddingRequestWithModel:
is_encoded = False
for input_str in input:
is_encoded = is_base64_encoded(input_str)
if is_encoded: # check if string is b64 encoded image or not
transformed_request = CohereEmbeddingRequestWithModel(
model=model,
images=input,
input_type="image",
)
else:
transformed_request = CohereEmbeddingRequestWithModel(
model=model,
texts=input,
input_type=COHERE_DEFAULT_EMBEDDING_INPUT_TYPE,
)
for k, v in inference_params.items():
transformed_request[k] = v # type: ignore
return transformed_request
def _calculate_usage(self, input: List[str], encoding: Any, meta: dict) -> Usage:
input_tokens = 0
text_tokens: Optional[int] = meta.get("billed_units", {}).get("input_tokens")
image_tokens: Optional[int] = meta.get("billed_units", {}).get("images")
prompt_tokens_details: Optional[PromptTokensDetailsWrapper] = None
if image_tokens is None and text_tokens is None:
for text in input:
input_tokens += len(encoding.encode(text))
else:
prompt_tokens_details = PromptTokensDetailsWrapper(
image_tokens=image_tokens,
text_tokens=text_tokens,
)
if image_tokens:
input_tokens += image_tokens
if text_tokens:
input_tokens += text_tokens
return Usage(
prompt_tokens=input_tokens,
completion_tokens=0,
total_tokens=input_tokens,
prompt_tokens_details=prompt_tokens_details,
)
def _transform_response(
self,
response: httpx.Response,
api_key: Optional[str],
logging_obj: LiteLLMLoggingObj,
data: Union[dict, CohereEmbeddingRequest],
model_response: EmbeddingResponse,
model: str,
encoding: Any,
input: list,
) -> EmbeddingResponse:
response_json = response.json()
## LOGGING
logging_obj.post_call(
input=input,
api_key=api_key,
additional_args={"complete_input_dict": data},
original_response=response_json,
)
"""
response
{
'object': "list",
'data': [
]
'model',
'usage'
}
"""
embeddings = response_json["embeddings"]
output_data = []
for idx, embedding in enumerate(embeddings):
output_data.append(
{"object": "embedding", "index": idx, "embedding": embedding}
)
model_response.object = "list"
model_response.data = output_data
model_response.model = model
input_tokens = 0
for text in input:
input_tokens += len(encoding.encode(text))
setattr(
model_response,
"usage",
self._calculate_usage(input, encoding, response_json.get("meta", {})),
)
return model_response

View File

@@ -0,0 +1,5 @@
"""
Cohere Rerank - uses `llm_http_handler.py` to make httpx requests
Request/Response transformation is handled in `transformation.py`
"""

View File

@@ -0,0 +1,151 @@
from typing import Any, Dict, List, Optional, Union
import httpx
import litellm
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
from litellm.llms.base_llm.chat.transformation import BaseLLMException
from litellm.llms.base_llm.rerank.transformation import BaseRerankConfig
from litellm.secret_managers.main import get_secret_str
from litellm.types.rerank import OptionalRerankParams, RerankRequest
from litellm.types.utils import RerankResponse
from ..common_utils import CohereError
class CohereRerankConfig(BaseRerankConfig):
"""
Reference: https://docs.cohere.com/v2/reference/rerank
"""
def __init__(self) -> None:
pass
def get_complete_url(self, api_base: Optional[str], model: str) -> str:
if api_base:
# Remove trailing slashes and ensure clean base URL
api_base = api_base.rstrip("/")
if not api_base.endswith("/v1/rerank"):
api_base = f"{api_base}/v1/rerank"
return api_base
return "https://api.cohere.ai/v1/rerank"
def get_supported_cohere_rerank_params(self, model: str) -> list:
return [
"query",
"documents",
"top_n",
"max_chunks_per_doc",
"rank_fields",
"return_documents",
]
def map_cohere_rerank_params(
self,
non_default_params: Optional[dict],
model: str,
drop_params: bool,
query: str,
documents: List[Union[str, Dict[str, Any]]],
custom_llm_provider: Optional[str] = None,
top_n: Optional[int] = None,
rank_fields: Optional[List[str]] = None,
return_documents: Optional[bool] = True,
max_chunks_per_doc: Optional[int] = None,
max_tokens_per_doc: Optional[int] = None,
) -> OptionalRerankParams:
"""
Map Cohere rerank params
No mapping required - returns all supported params
"""
return OptionalRerankParams(
query=query,
documents=documents,
top_n=top_n,
rank_fields=rank_fields,
return_documents=return_documents,
max_chunks_per_doc=max_chunks_per_doc,
)
def validate_environment(
self,
headers: dict,
model: str,
api_key: Optional[str] = None,
) -> dict:
if api_key is None:
api_key = (
get_secret_str("COHERE_API_KEY")
or get_secret_str("CO_API_KEY")
or litellm.cohere_key
)
if api_key is None:
raise ValueError(
"Cohere API key is required. Please set 'COHERE_API_KEY' or 'CO_API_KEY' or 'litellm.cohere_key'"
)
default_headers = {
"Authorization": f"bearer {api_key}",
"accept": "application/json",
"content-type": "application/json",
}
# If 'Authorization' is provided in headers, it overrides the default.
if "Authorization" in headers:
default_headers["Authorization"] = headers["Authorization"]
# Merge other headers, overriding any default ones except Authorization
return {**default_headers, **headers}
def transform_rerank_request(
self,
model: str,
optional_rerank_params: OptionalRerankParams,
headers: dict,
) -> dict:
if "query" not in optional_rerank_params:
raise ValueError("query is required for Cohere rerank")
if "documents" not in optional_rerank_params:
raise ValueError("documents is required for Cohere rerank")
rerank_request = RerankRequest(
model=model,
query=optional_rerank_params["query"],
documents=optional_rerank_params["documents"],
top_n=optional_rerank_params.get("top_n", None),
rank_fields=optional_rerank_params.get("rank_fields", None),
return_documents=optional_rerank_params.get("return_documents", None),
max_chunks_per_doc=optional_rerank_params.get("max_chunks_per_doc", None),
)
return rerank_request.model_dump(exclude_none=True)
def transform_rerank_response(
self,
model: str,
raw_response: httpx.Response,
model_response: RerankResponse,
logging_obj: LiteLLMLoggingObj,
api_key: Optional[str] = None,
request_data: dict = {},
optional_params: dict = {},
litellm_params: dict = {},
) -> RerankResponse:
"""
Transform Cohere rerank response
No transformation required, litellm follows cohere API response format
"""
try:
raw_response_json = raw_response.json()
except Exception:
raise CohereError(
message=raw_response.text, status_code=raw_response.status_code
)
return RerankResponse(**raw_response_json)
def get_error_class(
self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers]
) -> BaseLLMException:
return CohereError(message=error_message, status_code=status_code)

View File

@@ -0,0 +1,81 @@
from typing import Any, Dict, List, Optional, Union
from litellm.llms.cohere.rerank.transformation import CohereRerankConfig
from litellm.types.rerank import OptionalRerankParams, RerankRequest
class CohereRerankV2Config(CohereRerankConfig):
"""
Reference: https://docs.cohere.com/v2/reference/rerank
"""
def __init__(self) -> None:
pass
def get_complete_url(self, api_base: Optional[str], model: str) -> str:
if api_base:
# Remove trailing slashes and ensure clean base URL
api_base = api_base.rstrip("/")
if not api_base.endswith("/v2/rerank"):
api_base = f"{api_base}/v2/rerank"
return api_base
return "https://api.cohere.ai/v2/rerank"
def get_supported_cohere_rerank_params(self, model: str) -> list:
return [
"query",
"documents",
"top_n",
"max_tokens_per_doc",
"rank_fields",
"return_documents",
]
def map_cohere_rerank_params(
self,
non_default_params: Optional[dict],
model: str,
drop_params: bool,
query: str,
documents: List[Union[str, Dict[str, Any]]],
custom_llm_provider: Optional[str] = None,
top_n: Optional[int] = None,
rank_fields: Optional[List[str]] = None,
return_documents: Optional[bool] = True,
max_chunks_per_doc: Optional[int] = None,
max_tokens_per_doc: Optional[int] = None,
) -> OptionalRerankParams:
"""
Map Cohere rerank params
No mapping required - returns all supported params
"""
return OptionalRerankParams(
query=query,
documents=documents,
top_n=top_n,
rank_fields=rank_fields,
return_documents=return_documents,
max_tokens_per_doc=max_tokens_per_doc,
)
def transform_rerank_request(
self,
model: str,
optional_rerank_params: OptionalRerankParams,
headers: dict,
) -> dict:
if "query" not in optional_rerank_params:
raise ValueError("query is required for Cohere rerank")
if "documents" not in optional_rerank_params:
raise ValueError("documents is required for Cohere rerank")
rerank_request = RerankRequest(
model=model,
query=optional_rerank_params["query"],
documents=optional_rerank_params["documents"],
top_n=optional_rerank_params.get("top_n", None),
rank_fields=optional_rerank_params.get("rank_fields", None),
return_documents=optional_rerank_params.get("return_documents", None),
max_tokens_per_doc=optional_rerank_params.get("max_tokens_per_doc", None),
)
return rerank_request.model_dump(exclude_none=True)