structure saas with tools

This commit is contained in:
Davidson Gomes
2025-04-25 15:30:54 -03:00
commit 1aef473937
16434 changed files with 6584257 additions and 0 deletions

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,167 @@
import json
import logging
import os
import sys
from datetime import datetime
from logging import Formatter
set_verbose = False
if set_verbose is True:
logging.warning(
"`litellm.set_verbose` is deprecated. Please set `os.environ['LITELLM_LOG'] = 'DEBUG'` for debug logs."
)
json_logs = bool(os.getenv("JSON_LOGS", False))
# Create a handler for the logger (you may need to adapt this based on your needs)
log_level = os.getenv("LITELLM_LOG", "DEBUG")
numeric_level: str = getattr(logging, log_level.upper())
handler = logging.StreamHandler()
handler.setLevel(numeric_level)
class JsonFormatter(Formatter):
def __init__(self):
super(JsonFormatter, self).__init__()
def formatTime(self, record, datefmt=None):
# Use datetime to format the timestamp in ISO 8601 format
dt = datetime.fromtimestamp(record.created)
return dt.isoformat()
def format(self, record):
json_record = {
"message": record.getMessage(),
"level": record.levelname,
"timestamp": self.formatTime(record),
}
if record.exc_info:
json_record["stacktrace"] = self.formatException(record.exc_info)
return json.dumps(json_record)
# Function to set up exception handlers for JSON logging
def _setup_json_exception_handlers(formatter):
# Create a handler with JSON formatting for exceptions
error_handler = logging.StreamHandler()
error_handler.setFormatter(formatter)
# Setup excepthook for uncaught exceptions
def json_excepthook(exc_type, exc_value, exc_traceback):
record = logging.LogRecord(
name="LiteLLM",
level=logging.ERROR,
pathname="",
lineno=0,
msg=str(exc_value),
args=(),
exc_info=(exc_type, exc_value, exc_traceback),
)
error_handler.handle(record)
sys.excepthook = json_excepthook
# Configure asyncio exception handler if possible
try:
import asyncio
def async_json_exception_handler(loop, context):
exception = context.get("exception")
if exception:
record = logging.LogRecord(
name="LiteLLM",
level=logging.ERROR,
pathname="",
lineno=0,
msg=str(exception),
args=(),
exc_info=None,
)
error_handler.handle(record)
else:
loop.default_exception_handler(context)
asyncio.get_event_loop().set_exception_handler(async_json_exception_handler)
except Exception:
pass
# Create a formatter and set it for the handler
if json_logs:
handler.setFormatter(JsonFormatter())
_setup_json_exception_handlers(JsonFormatter())
else:
formatter = logging.Formatter(
"\033[92m%(asctime)s - %(name)s:%(levelname)s\033[0m: %(filename)s:%(lineno)s - %(message)s",
datefmt="%H:%M:%S",
)
handler.setFormatter(formatter)
verbose_proxy_logger = logging.getLogger("LiteLLM Proxy")
verbose_router_logger = logging.getLogger("LiteLLM Router")
verbose_logger = logging.getLogger("LiteLLM")
# Add the handler to the logger
verbose_router_logger.addHandler(handler)
verbose_proxy_logger.addHandler(handler)
verbose_logger.addHandler(handler)
def _turn_on_json():
handler = logging.StreamHandler()
handler.setFormatter(JsonFormatter())
# Define all loggers to update, including root logger
loggers = [logging.getLogger()] + [
verbose_router_logger,
verbose_proxy_logger,
verbose_logger,
]
# Iterate through each logger and update its handlers
for logger in loggers:
# Remove all existing handlers
for h in logger.handlers[:]:
logger.removeHandler(h)
# Add the new handler
logger.addHandler(handler)
# Set up exception handlers
_setup_json_exception_handlers(JsonFormatter())
def _turn_on_debug():
verbose_logger.setLevel(level=logging.DEBUG) # set package log to debug
verbose_router_logger.setLevel(level=logging.DEBUG) # set router logs to debug
verbose_proxy_logger.setLevel(level=logging.DEBUG) # set proxy logs to debug
def _disable_debugging():
verbose_logger.disabled = True
verbose_router_logger.disabled = True
verbose_proxy_logger.disabled = True
def _enable_debugging():
verbose_logger.disabled = False
verbose_router_logger.disabled = False
verbose_proxy_logger.disabled = False
def print_verbose(print_statement):
try:
if set_verbose:
print(print_statement) # noqa
except Exception:
pass
def _is_debugging_on() -> bool:
"""
Returns True if debugging is on
"""
if verbose_logger.isEnabledFor(logging.DEBUG) or set_verbose is True:
return True
return False

View File

@@ -0,0 +1,333 @@
# +-----------------------------------------------+
# | |
# | Give Feedback / Get Help |
# | https://github.com/BerriAI/litellm/issues/new |
# | |
# +-----------------------------------------------+
#
# Thank you users! We ❤️ you! - Krrish & Ishaan
import inspect
import json
# s/o [@Frank Colson](https://www.linkedin.com/in/frank-colson-422b9b183/) for this redis implementation
import os
from typing import List, Optional, Union
import redis # type: ignore
import redis.asyncio as async_redis # type: ignore
from litellm import get_secret, get_secret_str
from litellm.constants import REDIS_CONNECTION_POOL_TIMEOUT, REDIS_SOCKET_TIMEOUT
from ._logging import verbose_logger
def _get_redis_kwargs():
arg_spec = inspect.getfullargspec(redis.Redis)
# Only allow primitive arguments
exclude_args = {
"self",
"connection_pool",
"retry",
}
include_args = ["url"]
available_args = [x for x in arg_spec.args if x not in exclude_args] + include_args
return available_args
def _get_redis_url_kwargs(client=None):
if client is None:
client = redis.Redis.from_url
arg_spec = inspect.getfullargspec(redis.Redis.from_url)
# Only allow primitive arguments
exclude_args = {
"self",
"connection_pool",
"retry",
}
include_args = ["url"]
available_args = [x for x in arg_spec.args if x not in exclude_args] + include_args
return available_args
def _get_redis_cluster_kwargs(client=None):
if client is None:
client = redis.Redis.from_url
arg_spec = inspect.getfullargspec(redis.RedisCluster)
# Only allow primitive arguments
exclude_args = {"self", "connection_pool", "retry", "host", "port", "startup_nodes"}
available_args = [x for x in arg_spec.args if x not in exclude_args]
available_args.append("password")
available_args.append("username")
available_args.append("ssl")
return available_args
def _get_redis_env_kwarg_mapping():
PREFIX = "REDIS_"
return {f"{PREFIX}{x.upper()}": x for x in _get_redis_kwargs()}
def _redis_kwargs_from_environment():
mapping = _get_redis_env_kwarg_mapping()
return_dict = {}
for k, v in mapping.items():
value = get_secret(k, default_value=None) # type: ignore
if value is not None:
return_dict[v] = value
return return_dict
def get_redis_url_from_environment():
if "REDIS_URL" in os.environ:
return os.environ["REDIS_URL"]
if "REDIS_HOST" not in os.environ or "REDIS_PORT" not in os.environ:
raise ValueError(
"Either 'REDIS_URL' or both 'REDIS_HOST' and 'REDIS_PORT' must be specified for Redis."
)
if "REDIS_PASSWORD" in os.environ:
redis_password = f":{os.environ['REDIS_PASSWORD']}@"
else:
redis_password = ""
return (
f"redis://{redis_password}{os.environ['REDIS_HOST']}:{os.environ['REDIS_PORT']}"
)
def _get_redis_client_logic(**env_overrides):
"""
Common functionality across sync + async redis client implementations
"""
### check if "os.environ/<key-name>" passed in
for k, v in env_overrides.items():
if isinstance(v, str) and v.startswith("os.environ/"):
v = v.replace("os.environ/", "")
value = get_secret(v) # type: ignore
env_overrides[k] = value
redis_kwargs = {
**_redis_kwargs_from_environment(),
**env_overrides,
}
_startup_nodes: Optional[Union[str, list]] = redis_kwargs.get("startup_nodes", None) or get_secret( # type: ignore
"REDIS_CLUSTER_NODES"
)
if _startup_nodes is not None and isinstance(_startup_nodes, str):
redis_kwargs["startup_nodes"] = json.loads(_startup_nodes)
_sentinel_nodes: Optional[Union[str, list]] = redis_kwargs.get("sentinel_nodes", None) or get_secret( # type: ignore
"REDIS_SENTINEL_NODES"
)
if _sentinel_nodes is not None and isinstance(_sentinel_nodes, str):
redis_kwargs["sentinel_nodes"] = json.loads(_sentinel_nodes)
_sentinel_password: Optional[str] = redis_kwargs.get(
"sentinel_password", None
) or get_secret_str("REDIS_SENTINEL_PASSWORD")
if _sentinel_password is not None:
redis_kwargs["sentinel_password"] = _sentinel_password
_service_name: Optional[str] = redis_kwargs.get("service_name", None) or get_secret( # type: ignore
"REDIS_SERVICE_NAME"
)
if _service_name is not None:
redis_kwargs["service_name"] = _service_name
if "url" in redis_kwargs and redis_kwargs["url"] is not None:
redis_kwargs.pop("host", None)
redis_kwargs.pop("port", None)
redis_kwargs.pop("db", None)
redis_kwargs.pop("password", None)
elif "startup_nodes" in redis_kwargs and redis_kwargs["startup_nodes"] is not None:
pass
elif (
"sentinel_nodes" in redis_kwargs and redis_kwargs["sentinel_nodes"] is not None
):
pass
elif "host" not in redis_kwargs or redis_kwargs["host"] is None:
raise ValueError("Either 'host' or 'url' must be specified for redis.")
# litellm.print_verbose(f"redis_kwargs: {redis_kwargs}")
return redis_kwargs
def init_redis_cluster(redis_kwargs) -> redis.RedisCluster:
_redis_cluster_nodes_in_env: Optional[str] = get_secret("REDIS_CLUSTER_NODES") # type: ignore
if _redis_cluster_nodes_in_env is not None:
try:
redis_kwargs["startup_nodes"] = json.loads(_redis_cluster_nodes_in_env)
except json.JSONDecodeError:
raise ValueError(
"REDIS_CLUSTER_NODES environment variable is not valid JSON. Please ensure it's properly formatted."
)
verbose_logger.debug("init_redis_cluster: startup nodes are being initialized.")
from redis.cluster import ClusterNode
args = _get_redis_cluster_kwargs()
cluster_kwargs = {}
for arg in redis_kwargs:
if arg in args:
cluster_kwargs[arg] = redis_kwargs[arg]
new_startup_nodes: List[ClusterNode] = []
for item in redis_kwargs["startup_nodes"]:
new_startup_nodes.append(ClusterNode(**item))
redis_kwargs.pop("startup_nodes")
return redis.RedisCluster(startup_nodes=new_startup_nodes, **cluster_kwargs) # type: ignore
def _init_redis_sentinel(redis_kwargs) -> redis.Redis:
sentinel_nodes = redis_kwargs.get("sentinel_nodes")
sentinel_password = redis_kwargs.get("sentinel_password")
service_name = redis_kwargs.get("service_name")
if not sentinel_nodes or not service_name:
raise ValueError(
"Both 'sentinel_nodes' and 'service_name' are required for Redis Sentinel."
)
verbose_logger.debug("init_redis_sentinel: sentinel nodes are being initialized.")
# Set up the Sentinel client
sentinel = redis.Sentinel(
sentinel_nodes,
socket_timeout=REDIS_SOCKET_TIMEOUT,
password=sentinel_password,
)
# Return the master instance for the given service
return sentinel.master_for(service_name)
def _init_async_redis_sentinel(redis_kwargs) -> async_redis.Redis:
sentinel_nodes = redis_kwargs.get("sentinel_nodes")
sentinel_password = redis_kwargs.get("sentinel_password")
service_name = redis_kwargs.get("service_name")
if not sentinel_nodes or not service_name:
raise ValueError(
"Both 'sentinel_nodes' and 'service_name' are required for Redis Sentinel."
)
verbose_logger.debug("init_redis_sentinel: sentinel nodes are being initialized.")
# Set up the Sentinel client
sentinel = async_redis.Sentinel(
sentinel_nodes,
socket_timeout=REDIS_SOCKET_TIMEOUT,
password=sentinel_password,
)
# Return the master instance for the given service
return sentinel.master_for(service_name)
def get_redis_client(**env_overrides):
redis_kwargs = _get_redis_client_logic(**env_overrides)
if "url" in redis_kwargs and redis_kwargs["url"] is not None:
args = _get_redis_url_kwargs()
url_kwargs = {}
for arg in redis_kwargs:
if arg in args:
url_kwargs[arg] = redis_kwargs[arg]
return redis.Redis.from_url(**url_kwargs)
if "startup_nodes" in redis_kwargs or get_secret("REDIS_CLUSTER_NODES") is not None: # type: ignore
return init_redis_cluster(redis_kwargs)
# Check for Redis Sentinel
if "sentinel_nodes" in redis_kwargs and "service_name" in redis_kwargs:
return _init_redis_sentinel(redis_kwargs)
return redis.Redis(**redis_kwargs)
def get_redis_async_client(
**env_overrides,
) -> async_redis.Redis:
redis_kwargs = _get_redis_client_logic(**env_overrides)
if "url" in redis_kwargs and redis_kwargs["url"] is not None:
args = _get_redis_url_kwargs(client=async_redis.Redis.from_url)
url_kwargs = {}
for arg in redis_kwargs:
if arg in args:
url_kwargs[arg] = redis_kwargs[arg]
else:
verbose_logger.debug(
"REDIS: ignoring argument: {}. Not an allowed async_redis.Redis.from_url arg.".format(
arg
)
)
return async_redis.Redis.from_url(**url_kwargs)
if "startup_nodes" in redis_kwargs:
from redis.cluster import ClusterNode
args = _get_redis_cluster_kwargs()
cluster_kwargs = {}
for arg in redis_kwargs:
if arg in args:
cluster_kwargs[arg] = redis_kwargs[arg]
new_startup_nodes: List[ClusterNode] = []
for item in redis_kwargs["startup_nodes"]:
new_startup_nodes.append(ClusterNode(**item))
redis_kwargs.pop("startup_nodes")
return async_redis.RedisCluster(
startup_nodes=new_startup_nodes, **cluster_kwargs # type: ignore
)
# Check for Redis Sentinel
if "sentinel_nodes" in redis_kwargs and "service_name" in redis_kwargs:
return _init_async_redis_sentinel(redis_kwargs)
return async_redis.Redis(
**redis_kwargs,
)
def get_redis_connection_pool(**env_overrides):
redis_kwargs = _get_redis_client_logic(**env_overrides)
verbose_logger.debug("get_redis_connection_pool: redis_kwargs", redis_kwargs)
if "url" in redis_kwargs and redis_kwargs["url"] is not None:
return async_redis.BlockingConnectionPool.from_url(
timeout=REDIS_CONNECTION_POOL_TIMEOUT, url=redis_kwargs["url"]
)
connection_class = async_redis.Connection
if "ssl" in redis_kwargs:
connection_class = async_redis.SSLConnection
redis_kwargs.pop("ssl", None)
redis_kwargs["connection_class"] = connection_class
redis_kwargs.pop("startup_nodes", None)
return async_redis.BlockingConnectionPool(
timeout=REDIS_CONNECTION_POOL_TIMEOUT, **redis_kwargs
)

View File

@@ -0,0 +1,311 @@
import asyncio
from datetime import datetime, timedelta
from typing import TYPE_CHECKING, Any, Optional, Union
import litellm
from litellm._logging import verbose_logger
from litellm.proxy._types import UserAPIKeyAuth
from .integrations.custom_logger import CustomLogger
from .integrations.datadog.datadog import DataDogLogger
from .integrations.opentelemetry import OpenTelemetry
from .integrations.prometheus_services import PrometheusServicesLogger
from .types.services import ServiceLoggerPayload, ServiceTypes
if TYPE_CHECKING:
from opentelemetry.trace import Span as _Span
Span = Union[_Span, Any]
OTELClass = OpenTelemetry
else:
Span = Any
OTELClass = Any
class ServiceLogging(CustomLogger):
"""
Separate class used for monitoring health of litellm-adjacent services (redis/postgres).
"""
def __init__(self, mock_testing: bool = False) -> None:
self.mock_testing = mock_testing
self.mock_testing_sync_success_hook = 0
self.mock_testing_async_success_hook = 0
self.mock_testing_sync_failure_hook = 0
self.mock_testing_async_failure_hook = 0
if "prometheus_system" in litellm.service_callback:
self.prometheusServicesLogger = PrometheusServicesLogger()
def service_success_hook(
self,
service: ServiceTypes,
duration: float,
call_type: str,
parent_otel_span: Optional[Span] = None,
start_time: Optional[Union[datetime, float]] = None,
end_time: Optional[Union[float, datetime]] = None,
):
"""
Handles both sync and async monitoring by checking for existing event loop.
"""
if self.mock_testing:
self.mock_testing_sync_success_hook += 1
try:
# Try to get the current event loop
loop = asyncio.get_event_loop()
# Check if the loop is running
if loop.is_running():
# If we're in a running loop, create a task
loop.create_task(
self.async_service_success_hook(
service=service,
duration=duration,
call_type=call_type,
parent_otel_span=parent_otel_span,
start_time=start_time,
end_time=end_time,
)
)
else:
# Loop exists but not running, we can use run_until_complete
loop.run_until_complete(
self.async_service_success_hook(
service=service,
duration=duration,
call_type=call_type,
parent_otel_span=parent_otel_span,
start_time=start_time,
end_time=end_time,
)
)
except RuntimeError:
# No event loop exists, create a new one and run
asyncio.run(
self.async_service_success_hook(
service=service,
duration=duration,
call_type=call_type,
parent_otel_span=parent_otel_span,
start_time=start_time,
end_time=end_time,
)
)
def service_failure_hook(
self, service: ServiceTypes, duration: float, error: Exception, call_type: str
):
"""
[TODO] Not implemented for sync calls yet. V0 is focused on async monitoring (used by proxy).
"""
if self.mock_testing:
self.mock_testing_sync_failure_hook += 1
async def async_service_success_hook(
self,
service: ServiceTypes,
call_type: str,
duration: float,
parent_otel_span: Optional[Span] = None,
start_time: Optional[Union[datetime, float]] = None,
end_time: Optional[Union[datetime, float]] = None,
event_metadata: Optional[dict] = None,
):
"""
- For counting if the redis, postgres call is successful
"""
if self.mock_testing:
self.mock_testing_async_success_hook += 1
payload = ServiceLoggerPayload(
is_error=False,
error=None,
service=service,
duration=duration,
call_type=call_type,
event_metadata=event_metadata,
)
for callback in litellm.service_callback:
if callback == "prometheus_system":
await self.init_prometheus_services_logger_if_none()
await self.prometheusServicesLogger.async_service_success_hook(
payload=payload
)
elif callback == "datadog" or isinstance(callback, DataDogLogger):
await self.init_datadog_logger_if_none()
await self.dd_logger.async_service_success_hook(
payload=payload,
parent_otel_span=parent_otel_span,
start_time=start_time,
end_time=end_time,
event_metadata=event_metadata,
)
elif callback == "otel" or isinstance(callback, OpenTelemetry):
from litellm.proxy.proxy_server import open_telemetry_logger
await self.init_otel_logger_if_none()
if (
parent_otel_span is not None
and open_telemetry_logger is not None
and isinstance(open_telemetry_logger, OpenTelemetry)
):
await self.otel_logger.async_service_success_hook(
payload=payload,
parent_otel_span=parent_otel_span,
start_time=start_time,
end_time=end_time,
event_metadata=event_metadata,
)
async def init_prometheus_services_logger_if_none(self):
"""
initializes prometheusServicesLogger if it is None or no attribute exists on ServiceLogging Object
"""
if not hasattr(self, "prometheusServicesLogger"):
self.prometheusServicesLogger = PrometheusServicesLogger()
elif self.prometheusServicesLogger is None:
self.prometheusServicesLogger = self.prometheusServicesLogger()
return
async def init_datadog_logger_if_none(self):
"""
initializes dd_logger if it is None or no attribute exists on ServiceLogging Object
"""
from litellm.integrations.datadog.datadog import DataDogLogger
if not hasattr(self, "dd_logger"):
self.dd_logger: DataDogLogger = DataDogLogger()
return
async def init_otel_logger_if_none(self):
"""
initializes otel_logger if it is None or no attribute exists on ServiceLogging Object
"""
from litellm.proxy.proxy_server import open_telemetry_logger
if not hasattr(self, "otel_logger"):
if open_telemetry_logger is not None and isinstance(
open_telemetry_logger, OpenTelemetry
):
self.otel_logger: OpenTelemetry = open_telemetry_logger
else:
verbose_logger.warning(
"ServiceLogger: open_telemetry_logger is None or not an instance of OpenTelemetry"
)
return
async def async_service_failure_hook(
self,
service: ServiceTypes,
duration: float,
error: Union[str, Exception],
call_type: str,
parent_otel_span: Optional[Span] = None,
start_time: Optional[Union[datetime, float]] = None,
end_time: Optional[Union[float, datetime]] = None,
event_metadata: Optional[dict] = None,
):
"""
- For counting if the redis, postgres call is unsuccessful
"""
if self.mock_testing:
self.mock_testing_async_failure_hook += 1
error_message = ""
if isinstance(error, Exception):
error_message = str(error)
elif isinstance(error, str):
error_message = error
payload = ServiceLoggerPayload(
is_error=True,
error=error_message,
service=service,
duration=duration,
call_type=call_type,
event_metadata=event_metadata,
)
for callback in litellm.service_callback:
if callback == "prometheus_system":
await self.init_prometheus_services_logger_if_none()
await self.prometheusServicesLogger.async_service_failure_hook(
payload=payload,
error=error,
)
elif callback == "datadog" or isinstance(callback, DataDogLogger):
await self.init_datadog_logger_if_none()
await self.dd_logger.async_service_failure_hook(
payload=payload,
error=error_message,
parent_otel_span=parent_otel_span,
start_time=start_time,
end_time=end_time,
event_metadata=event_metadata,
)
elif callback == "otel" or isinstance(callback, OpenTelemetry):
from litellm.proxy.proxy_server import open_telemetry_logger
await self.init_otel_logger_if_none()
if not isinstance(error, str):
error = str(error)
if (
parent_otel_span is not None
and open_telemetry_logger is not None
and isinstance(open_telemetry_logger, OpenTelemetry)
):
await self.otel_logger.async_service_success_hook(
payload=payload,
parent_otel_span=parent_otel_span,
start_time=start_time,
end_time=end_time,
event_metadata=event_metadata,
)
async def async_post_call_failure_hook(
self,
request_data: dict,
original_exception: Exception,
user_api_key_dict: UserAPIKeyAuth,
):
"""
Hook to track failed litellm-service calls
"""
return await super().async_post_call_failure_hook(
request_data,
original_exception,
user_api_key_dict,
)
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
"""
Hook to track latency for litellm proxy llm api calls
"""
try:
_duration = end_time - start_time
if isinstance(_duration, timedelta):
_duration = _duration.total_seconds()
elif isinstance(_duration, float):
pass
else:
raise Exception(
"Duration={} is not a float or timedelta object. type={}".format(
_duration, type(_duration)
)
) # invalid _duration value
await self.async_service_success_hook(
service=ServiceTypes.LITELLM,
duration=_duration,
call_type=kwargs["call_type"],
)
except Exception as e:
raise e

View File

@@ -0,0 +1,6 @@
import importlib_metadata
try:
version = importlib_metadata.version("litellm")
except Exception:
version = "unknown"

View File

@@ -0,0 +1,6 @@
"""
Anthropic module for LiteLLM
"""
from .messages import acreate, create
__all__ = ["acreate", "create"]

View File

@@ -0,0 +1,117 @@
"""
Interface for Anthropic's messages API
Use this to call LLMs in Anthropic /messages Request/Response format
This is an __init__.py file to allow the following interface
- litellm.messages.acreate
- litellm.messages.create
"""
from typing import AsyncIterator, Dict, Iterator, List, Optional, Union
from litellm.llms.anthropic.experimental_pass_through.messages.handler import (
anthropic_messages as _async_anthropic_messages,
)
from litellm.types.llms.anthropic_messages.anthropic_response import (
AnthropicMessagesResponse,
)
async def acreate(
max_tokens: int,
messages: List[Dict],
model: str,
metadata: Optional[Dict] = None,
stop_sequences: Optional[List[str]] = None,
stream: Optional[bool] = False,
system: Optional[str] = None,
temperature: Optional[float] = 1.0,
thinking: Optional[Dict] = None,
tool_choice: Optional[Dict] = None,
tools: Optional[List[Dict]] = None,
top_k: Optional[int] = None,
top_p: Optional[float] = None,
**kwargs
) -> Union[AnthropicMessagesResponse, AsyncIterator]:
"""
Async wrapper for Anthropic's messages API
Args:
max_tokens (int): Maximum tokens to generate (required)
messages (List[Dict]): List of message objects with role and content (required)
model (str): Model name to use (required)
metadata (Dict, optional): Request metadata
stop_sequences (List[str], optional): Custom stop sequences
stream (bool, optional): Whether to stream the response
system (str, optional): System prompt
temperature (float, optional): Sampling temperature (0.0 to 1.0)
thinking (Dict, optional): Extended thinking configuration
tool_choice (Dict, optional): Tool choice configuration
tools (List[Dict], optional): List of tool definitions
top_k (int, optional): Top K sampling parameter
top_p (float, optional): Nucleus sampling parameter
**kwargs: Additional arguments
Returns:
Dict: Response from the API
"""
return await _async_anthropic_messages(
max_tokens=max_tokens,
messages=messages,
model=model,
metadata=metadata,
stop_sequences=stop_sequences,
stream=stream,
system=system,
temperature=temperature,
thinking=thinking,
tool_choice=tool_choice,
tools=tools,
top_k=top_k,
top_p=top_p,
**kwargs,
)
async def create(
max_tokens: int,
messages: List[Dict],
model: str,
metadata: Optional[Dict] = None,
stop_sequences: Optional[List[str]] = None,
stream: Optional[bool] = False,
system: Optional[str] = None,
temperature: Optional[float] = 1.0,
thinking: Optional[Dict] = None,
tool_choice: Optional[Dict] = None,
tools: Optional[List[Dict]] = None,
top_k: Optional[int] = None,
top_p: Optional[float] = None,
**kwargs
) -> Union[AnthropicMessagesResponse, Iterator]:
"""
Async wrapper for Anthropic's messages API
Args:
max_tokens (int): Maximum tokens to generate (required)
messages (List[Dict]): List of message objects with role and content (required)
model (str): Model name to use (required)
metadata (Dict, optional): Request metadata
stop_sequences (List[str], optional): Custom stop sequences
stream (bool, optional): Whether to stream the response
system (str, optional): System prompt
temperature (float, optional): Sampling temperature (0.0 to 1.0)
thinking (Dict, optional): Extended thinking configuration
tool_choice (Dict, optional): Tool choice configuration
tools (List[Dict], optional): List of tool definitions
top_k (int, optional): Top K sampling parameter
top_p (float, optional): Nucleus sampling parameter
**kwargs: Additional arguments
Returns:
Dict: Response from the API
"""
raise NotImplementedError("This function is not implemented")

View File

@@ -0,0 +1,116 @@
## Use LLM API endpoints in Anthropic Interface
Note: This is called `anthropic_interface` because `anthropic` is a known python package and was failing mypy type checking.
## Usage
---
### LiteLLM Python SDK
#### Non-streaming example
```python showLineNumbers title="Example using LiteLLM Python SDK"
import litellm
response = await litellm.anthropic.messages.acreate(
messages=[{"role": "user", "content": "Hello, can you tell me a short joke?"}],
api_key=api_key,
model="anthropic/claude-3-haiku-20240307",
max_tokens=100,
)
```
Example response:
```json
{
"content": [
{
"text": "Hi! this is a very short joke",
"type": "text"
}
],
"id": "msg_013Zva2CMHLNnXjNJJKqJ2EF",
"model": "claude-3-7-sonnet-20250219",
"role": "assistant",
"stop_reason": "end_turn",
"stop_sequence": null,
"type": "message",
"usage": {
"input_tokens": 2095,
"output_tokens": 503,
"cache_creation_input_tokens": 2095,
"cache_read_input_tokens": 0
}
}
```
#### Streaming example
```python showLineNumbers title="Example using LiteLLM Python SDK"
import litellm
response = await litellm.anthropic.messages.acreate(
messages=[{"role": "user", "content": "Hello, can you tell me a short joke?"}],
api_key=api_key,
model="anthropic/claude-3-haiku-20240307",
max_tokens=100,
stream=True,
)
async for chunk in response:
print(chunk)
```
### LiteLLM Proxy Server
1. Setup config.yaml
```yaml
model_list:
- model_name: anthropic-claude
litellm_params:
model: claude-3-7-sonnet-latest
```
2. Start proxy
```bash
litellm --config /path/to/config.yaml
```
3. Test it!
<Tabs>
<TabItem label="Anthropic Python SDK" value="python">
```python showLineNumbers title="Example using LiteLLM Proxy Server"
import anthropic
# point anthropic sdk to litellm proxy
client = anthropic.Anthropic(
base_url="http://0.0.0.0:4000",
api_key="sk-1234",
)
response = client.messages.create(
messages=[{"role": "user", "content": "Hello, can you tell me a short joke?"}],
model="anthropic/claude-3-haiku-20240307",
max_tokens=100,
)
```
</TabItem>
<TabItem label="curl" value="curl">
```bash showLineNumbers title="Example using LiteLLM Proxy Server"
curl -L -X POST 'http://0.0.0.0:4000/v1/messages' \
-H 'content-type: application/json' \
-H 'x-api-key: $LITELLM_API_KEY' \
-H 'anthropic-version: 2023-06-01' \
-d '{
"model": "anthropic-claude",
"messages": [
{
"role": "user",
"content": "Hello, can you tell me a short joke?"
}
],
"max_tokens": 100
}'
```

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,161 @@
from typing import Optional, Union
import litellm
from ..exceptions import UnsupportedParamsError
from ..types.llms.openai import *
def get_optional_params_add_message(
role: Optional[str],
content: Optional[
Union[
str,
List[
Union[
MessageContentTextObject,
MessageContentImageFileObject,
MessageContentImageURLObject,
]
],
]
],
attachments: Optional[List[Attachment]],
metadata: Optional[dict],
custom_llm_provider: str,
**kwargs,
):
"""
Azure doesn't support 'attachments' for creating a message
Reference - https://learn.microsoft.com/en-us/azure/ai-services/openai/assistants-reference-messages?tabs=python#create-message
"""
passed_params = locals()
custom_llm_provider = passed_params.pop("custom_llm_provider")
special_params = passed_params.pop("kwargs")
for k, v in special_params.items():
passed_params[k] = v
default_params = {
"role": None,
"content": None,
"attachments": None,
"metadata": None,
}
non_default_params = {
k: v
for k, v in passed_params.items()
if (k in default_params and v != default_params[k])
}
optional_params = {}
## raise exception if non-default value passed for non-openai/azure embedding calls
def _check_valid_arg(supported_params):
if len(non_default_params.keys()) > 0:
keys = list(non_default_params.keys())
for k in keys:
if (
litellm.drop_params is True and k not in supported_params
): # drop the unsupported non-default values
non_default_params.pop(k, None)
elif k not in supported_params:
raise litellm.utils.UnsupportedParamsError(
status_code=500,
message="k={}, not supported by {}. Supported params={}. To drop it from the call, set `litellm.drop_params = True`.".format(
k, custom_llm_provider, supported_params
),
)
return non_default_params
if custom_llm_provider == "openai":
optional_params = non_default_params
elif custom_llm_provider == "azure":
supported_params = (
litellm.AzureOpenAIAssistantsAPIConfig().get_supported_openai_create_message_params()
)
_check_valid_arg(supported_params=supported_params)
optional_params = litellm.AzureOpenAIAssistantsAPIConfig().map_openai_params_create_message_params(
non_default_params=non_default_params, optional_params=optional_params
)
for k in passed_params.keys():
if k not in default_params.keys():
optional_params[k] = passed_params[k]
return optional_params
def get_optional_params_image_gen(
n: Optional[int] = None,
quality: Optional[str] = None,
response_format: Optional[str] = None,
size: Optional[str] = None,
style: Optional[str] = None,
user: Optional[str] = None,
custom_llm_provider: Optional[str] = None,
**kwargs,
):
# retrieve all parameters passed to the function
passed_params = locals()
custom_llm_provider = passed_params.pop("custom_llm_provider")
special_params = passed_params.pop("kwargs")
for k, v in special_params.items():
passed_params[k] = v
default_params = {
"n": None,
"quality": None,
"response_format": None,
"size": None,
"style": None,
"user": None,
}
non_default_params = {
k: v
for k, v in passed_params.items()
if (k in default_params and v != default_params[k])
}
optional_params = {}
## raise exception if non-default value passed for non-openai/azure embedding calls
def _check_valid_arg(supported_params):
if len(non_default_params.keys()) > 0:
keys = list(non_default_params.keys())
for k in keys:
if (
litellm.drop_params is True and k not in supported_params
): # drop the unsupported non-default values
non_default_params.pop(k, None)
elif k not in supported_params:
raise UnsupportedParamsError(
status_code=500,
message=f"Setting user/encoding format is not supported by {custom_llm_provider}. To drop it from the call, set `litellm.drop_params = True`.",
)
return non_default_params
if (
custom_llm_provider == "openai"
or custom_llm_provider == "azure"
or custom_llm_provider in litellm.openai_compatible_providers
):
optional_params = non_default_params
elif custom_llm_provider == "bedrock":
supported_params = ["size"]
_check_valid_arg(supported_params=supported_params)
if size is not None:
width, height = size.split("x")
optional_params["width"] = int(width)
optional_params["height"] = int(height)
elif custom_llm_provider == "vertex_ai":
supported_params = ["n"]
"""
All params here: https://console.cloud.google.com/vertex-ai/publishers/google/model-garden/imagegeneration?project=adroit-crow-413218
"""
_check_valid_arg(supported_params=supported_params)
if n is not None:
optional_params["sampleCount"] = int(n)
for k in passed_params.keys():
if k not in default_params.keys():
optional_params[k] = passed_params[k]
return optional_params

View File

@@ -0,0 +1,11 @@
# Implementation of `litellm.batch_completion`, `litellm.batch_completion_models`, `litellm.batch_completion_models_all_responses`
Doc: https://docs.litellm.ai/docs/completion/batching
LiteLLM Python SDK allows you to:
1. `litellm.batch_completion` Batch litellm.completion function for a given model.
2. `litellm.batch_completion_models` Send a request to multiple language models concurrently and return the response
as soon as one of the models responds.
3. `litellm.batch_completion_models_all_responses` Send a request to multiple language models concurrently and return a list of responses
from all models that respond.

View File

@@ -0,0 +1,253 @@
from concurrent.futures import FIRST_COMPLETED, ThreadPoolExecutor, wait
from typing import List, Optional
import litellm
from litellm._logging import print_verbose
from litellm.utils import get_optional_params
from ..llms.vllm.completion import handler as vllm_handler
def batch_completion(
model: str,
# Optional OpenAI params: see https://platform.openai.com/docs/api-reference/chat/create
messages: List = [],
functions: Optional[List] = None,
function_call: Optional[str] = None,
temperature: Optional[float] = None,
top_p: Optional[float] = None,
n: Optional[int] = None,
stream: Optional[bool] = None,
stop=None,
max_tokens: Optional[int] = None,
presence_penalty: Optional[float] = None,
frequency_penalty: Optional[float] = None,
logit_bias: Optional[dict] = None,
user: Optional[str] = None,
deployment_id=None,
request_timeout: Optional[int] = None,
timeout: Optional[int] = 600,
max_workers: Optional[int] = 100,
# Optional liteLLM function params
**kwargs,
):
"""
Batch litellm.completion function for a given model.
Args:
model (str): The model to use for generating completions.
messages (List, optional): List of messages to use as input for generating completions. Defaults to [].
functions (List, optional): List of functions to use as input for generating completions. Defaults to [].
function_call (str, optional): The function call to use as input for generating completions. Defaults to "".
temperature (float, optional): The temperature parameter for generating completions. Defaults to None.
top_p (float, optional): The top-p parameter for generating completions. Defaults to None.
n (int, optional): The number of completions to generate. Defaults to None.
stream (bool, optional): Whether to stream completions or not. Defaults to None.
stop (optional): The stop parameter for generating completions. Defaults to None.
max_tokens (float, optional): The maximum number of tokens to generate. Defaults to None.
presence_penalty (float, optional): The presence penalty for generating completions. Defaults to None.
frequency_penalty (float, optional): The frequency penalty for generating completions. Defaults to None.
logit_bias (dict, optional): The logit bias for generating completions. Defaults to {}.
user (str, optional): The user string for generating completions. Defaults to "".
deployment_id (optional): The deployment ID for generating completions. Defaults to None.
request_timeout (int, optional): The request timeout for generating completions. Defaults to None.
max_workers (int,optional): The maximum number of threads to use for parallel processing.
Returns:
list: A list of completion results.
"""
args = locals()
batch_messages = messages
completions = []
model = model
custom_llm_provider = None
if model.split("/", 1)[0] in litellm.provider_list:
custom_llm_provider = model.split("/", 1)[0]
model = model.split("/", 1)[1]
if custom_llm_provider == "vllm":
optional_params = get_optional_params(
functions=functions,
function_call=function_call,
temperature=temperature,
top_p=top_p,
n=n,
stream=stream or False,
stop=stop,
max_tokens=max_tokens,
presence_penalty=presence_penalty,
frequency_penalty=frequency_penalty,
logit_bias=logit_bias,
user=user,
# params to identify the model
model=model,
custom_llm_provider=custom_llm_provider,
)
results = vllm_handler.batch_completions(
model=model,
messages=batch_messages,
custom_prompt_dict=litellm.custom_prompt_dict,
optional_params=optional_params,
)
# all non VLLM models for batch completion models
else:
def chunks(lst, n):
"""Yield successive n-sized chunks from lst."""
for i in range(0, len(lst), n):
yield lst[i : i + n]
with ThreadPoolExecutor(max_workers=max_workers) as executor:
for sub_batch in chunks(batch_messages, 100):
for message_list in sub_batch:
kwargs_modified = args.copy()
kwargs_modified.pop("max_workers")
kwargs_modified["messages"] = message_list
original_kwargs = {}
if "kwargs" in kwargs_modified:
original_kwargs = kwargs_modified.pop("kwargs")
future = executor.submit(
litellm.completion, **kwargs_modified, **original_kwargs
)
completions.append(future)
# Retrieve the results from the futures
# results = [future.result() for future in completions]
# return exceptions if any
results = []
for future in completions:
try:
results.append(future.result())
except Exception as exc:
results.append(exc)
return results
# send one request to multiple models
# return as soon as one of the llms responds
def batch_completion_models(*args, **kwargs):
"""
Send a request to multiple language models concurrently and return the response
as soon as one of the models responds.
Args:
*args: Variable-length positional arguments passed to the completion function.
**kwargs: Additional keyword arguments:
- models (str or list of str): The language models to send requests to.
- Other keyword arguments to be passed to the completion function.
Returns:
str or None: The response from one of the language models, or None if no response is received.
Note:
This function utilizes a ThreadPoolExecutor to parallelize requests to multiple models.
It sends requests concurrently and returns the response from the first model that responds.
"""
if "model" in kwargs:
kwargs.pop("model")
if "models" in kwargs:
models = kwargs["models"]
kwargs.pop("models")
futures = {}
with ThreadPoolExecutor(max_workers=len(models)) as executor:
for model in models:
futures[model] = executor.submit(
litellm.completion, *args, model=model, **kwargs
)
for model, future in sorted(
futures.items(), key=lambda x: models.index(x[0])
):
if future.result() is not None:
return future.result()
elif "deployments" in kwargs:
deployments = kwargs["deployments"]
kwargs.pop("deployments")
kwargs.pop("model_list")
nested_kwargs = kwargs.pop("kwargs", {})
futures = {}
with ThreadPoolExecutor(max_workers=len(deployments)) as executor:
for deployment in deployments:
for key in kwargs.keys():
if (
key not in deployment
): # don't override deployment values e.g. model name, api base, etc.
deployment[key] = kwargs[key]
kwargs = {**deployment, **nested_kwargs}
futures[deployment["model"]] = executor.submit(
litellm.completion, **kwargs
)
while futures:
# wait for the first returned future
print_verbose("\n\n waiting for next result\n\n")
done, _ = wait(futures.values(), return_when=FIRST_COMPLETED)
print_verbose(f"done list\n{done}")
for future in done:
try:
result = future.result()
return result
except Exception:
# if model 1 fails, continue with response from model 2, model3
print_verbose(
"\n\ngot an exception, ignoring, removing from futures"
)
print_verbose(futures)
new_futures = {}
for key, value in futures.items():
if future == value:
print_verbose(f"removing key{key}")
continue
else:
new_futures[key] = value
futures = new_futures
print_verbose(f"new futures{futures}")
continue
print_verbose("\n\ndone looping through futures\n\n")
print_verbose(futures)
return None # If no response is received from any model
def batch_completion_models_all_responses(*args, **kwargs):
"""
Send a request to multiple language models concurrently and return a list of responses
from all models that respond.
Args:
*args: Variable-length positional arguments passed to the completion function.
**kwargs: Additional keyword arguments:
- models (str or list of str): The language models to send requests to.
- Other keyword arguments to be passed to the completion function.
Returns:
list: A list of responses from the language models that responded.
Note:
This function utilizes a ThreadPoolExecutor to parallelize requests to multiple models.
It sends requests concurrently and collects responses from all models that respond.
"""
import concurrent.futures
# ANSI escape codes for colored output
if "model" in kwargs:
kwargs.pop("model")
if "models" in kwargs:
models = kwargs["models"]
kwargs.pop("models")
else:
raise Exception("'models' param not in kwargs")
responses = []
with concurrent.futures.ThreadPoolExecutor(max_workers=len(models)) as executor:
for idx, model in enumerate(models):
future = executor.submit(litellm.completion, *args, model=model, **kwargs)
if future.result() is not None:
responses.append(future.result())
return responses

View File

@@ -0,0 +1,182 @@
import json
from typing import Any, List, Literal, Tuple
import litellm
from litellm._logging import verbose_logger
from litellm.types.llms.openai import Batch
from litellm.types.utils import CallTypes, Usage
async def _handle_completed_batch(
batch: Batch,
custom_llm_provider: Literal["openai", "azure", "vertex_ai"],
) -> Tuple[float, Usage, List[str]]:
"""Helper function to process a completed batch and handle logging"""
# Get batch results
file_content_dictionary = await _get_batch_output_file_content_as_dictionary(
batch, custom_llm_provider
)
# Calculate costs and usage
batch_cost = await _batch_cost_calculator(
custom_llm_provider=custom_llm_provider,
file_content_dictionary=file_content_dictionary,
)
batch_usage = _get_batch_job_total_usage_from_file_content(
file_content_dictionary=file_content_dictionary,
custom_llm_provider=custom_llm_provider,
)
batch_models = _get_batch_models_from_file_content(file_content_dictionary)
return batch_cost, batch_usage, batch_models
def _get_batch_models_from_file_content(
file_content_dictionary: List[dict],
) -> List[str]:
"""
Get the models from the file content
"""
batch_models = []
for _item in file_content_dictionary:
if _batch_response_was_successful(_item):
_response_body = _get_response_from_batch_job_output_file(_item)
_model = _response_body.get("model")
if _model:
batch_models.append(_model)
return batch_models
async def _batch_cost_calculator(
file_content_dictionary: List[dict],
custom_llm_provider: Literal["openai", "azure", "vertex_ai"] = "openai",
) -> float:
"""
Calculate the cost of a batch based on the output file id
"""
if custom_llm_provider == "vertex_ai":
raise ValueError("Vertex AI does not support file content retrieval")
total_cost = _get_batch_job_cost_from_file_content(
file_content_dictionary=file_content_dictionary,
custom_llm_provider=custom_llm_provider,
)
verbose_logger.debug("total_cost=%s", total_cost)
return total_cost
async def _get_batch_output_file_content_as_dictionary(
batch: Batch,
custom_llm_provider: Literal["openai", "azure", "vertex_ai"] = "openai",
) -> List[dict]:
"""
Get the batch output file content as a list of dictionaries
"""
from litellm.files.main import afile_content
if custom_llm_provider == "vertex_ai":
raise ValueError("Vertex AI does not support file content retrieval")
if batch.output_file_id is None:
raise ValueError("Output file id is None cannot retrieve file content")
_file_content = await afile_content(
file_id=batch.output_file_id,
custom_llm_provider=custom_llm_provider,
)
return _get_file_content_as_dictionary(_file_content.content)
def _get_file_content_as_dictionary(file_content: bytes) -> List[dict]:
"""
Get the file content as a list of dictionaries from JSON Lines format
"""
try:
_file_content_str = file_content.decode("utf-8")
# Split by newlines and parse each line as a separate JSON object
json_objects = []
for line in _file_content_str.strip().split("\n"):
if line: # Skip empty lines
json_objects.append(json.loads(line))
verbose_logger.debug("json_objects=%s", json.dumps(json_objects, indent=4))
return json_objects
except Exception as e:
raise e
def _get_batch_job_cost_from_file_content(
file_content_dictionary: List[dict],
custom_llm_provider: Literal["openai", "azure", "vertex_ai"] = "openai",
) -> float:
"""
Get the cost of a batch job from the file content
"""
try:
total_cost: float = 0.0
# parse the file content as json
verbose_logger.debug(
"file_content_dictionary=%s", json.dumps(file_content_dictionary, indent=4)
)
for _item in file_content_dictionary:
if _batch_response_was_successful(_item):
_response_body = _get_response_from_batch_job_output_file(_item)
total_cost += litellm.completion_cost(
completion_response=_response_body,
custom_llm_provider=custom_llm_provider,
call_type=CallTypes.aretrieve_batch.value,
)
verbose_logger.debug("total_cost=%s", total_cost)
return total_cost
except Exception as e:
verbose_logger.error("error in _get_batch_job_cost_from_file_content", e)
raise e
def _get_batch_job_total_usage_from_file_content(
file_content_dictionary: List[dict],
custom_llm_provider: Literal["openai", "azure", "vertex_ai"] = "openai",
) -> Usage:
"""
Get the tokens of a batch job from the file content
"""
total_tokens: int = 0
prompt_tokens: int = 0
completion_tokens: int = 0
for _item in file_content_dictionary:
if _batch_response_was_successful(_item):
_response_body = _get_response_from_batch_job_output_file(_item)
usage: Usage = _get_batch_job_usage_from_response_body(_response_body)
total_tokens += usage.total_tokens
prompt_tokens += usage.prompt_tokens
completion_tokens += usage.completion_tokens
return Usage(
total_tokens=total_tokens,
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
)
def _get_batch_job_usage_from_response_body(response_body: dict) -> Usage:
"""
Get the tokens of a batch job from the response body
"""
_usage_dict = response_body.get("usage", None) or {}
usage: Usage = Usage(**_usage_dict)
return usage
def _get_response_from_batch_job_output_file(batch_job_output_file: dict) -> Any:
"""
Get the response from the batch job output file
"""
_response: dict = batch_job_output_file.get("response", None) or {}
_response_body = _response.get("body", None) or {}
return _response_body
def _batch_response_was_successful(batch_job_output_file: dict) -> bool:
"""
Check if the batch job response status == 200
"""
_response: dict = batch_job_output_file.get("response", None) or {}
return _response.get("status_code", None) == 200

View File

@@ -0,0 +1,792 @@
"""
Main File for Batches API implementation
https://platform.openai.com/docs/api-reference/batch
- create_batch()
- retrieve_batch()
- cancel_batch()
- list_batch()
"""
import asyncio
import contextvars
import os
from functools import partial
from typing import Any, Coroutine, Dict, Literal, Optional, Union
import httpx
import litellm
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
from litellm.llms.azure.batches.handler import AzureBatchesAPI
from litellm.llms.openai.openai import OpenAIBatchesAPI
from litellm.llms.vertex_ai.batches.handler import VertexAIBatchPrediction
from litellm.secret_managers.main import get_secret_str
from litellm.types.llms.openai import (
Batch,
CancelBatchRequest,
CreateBatchRequest,
RetrieveBatchRequest,
)
from litellm.types.router import GenericLiteLLMParams
from litellm.types.utils import LiteLLMBatch
from litellm.utils import client, get_litellm_params, supports_httpx_timeout
####### ENVIRONMENT VARIABLES ###################
openai_batches_instance = OpenAIBatchesAPI()
azure_batches_instance = AzureBatchesAPI()
vertex_ai_batches_instance = VertexAIBatchPrediction(gcs_bucket_name="")
#################################################
@client
async def acreate_batch(
completion_window: Literal["24h"],
endpoint: Literal["/v1/chat/completions", "/v1/embeddings", "/v1/completions"],
input_file_id: str,
custom_llm_provider: Literal["openai", "azure", "vertex_ai"] = "openai",
metadata: Optional[Dict[str, str]] = None,
extra_headers: Optional[Dict[str, str]] = None,
extra_body: Optional[Dict[str, str]] = None,
**kwargs,
) -> Batch:
"""
Async: Creates and executes a batch from an uploaded file of request
LiteLLM Equivalent of POST: https://api.openai.com/v1/batches
"""
try:
loop = asyncio.get_event_loop()
kwargs["acreate_batch"] = True
# Use a partial function to pass your keyword arguments
func = partial(
create_batch,
completion_window,
endpoint,
input_file_id,
custom_llm_provider,
metadata,
extra_headers,
extra_body,
**kwargs,
)
# Add the context to the function
ctx = contextvars.copy_context()
func_with_context = partial(ctx.run, func)
init_response = await loop.run_in_executor(None, func_with_context)
if asyncio.iscoroutine(init_response):
response = await init_response
else:
response = init_response
return response
except Exception as e:
raise e
@client
def create_batch(
completion_window: Literal["24h"],
endpoint: Literal["/v1/chat/completions", "/v1/embeddings", "/v1/completions"],
input_file_id: str,
custom_llm_provider: Literal["openai", "azure", "vertex_ai"] = "openai",
metadata: Optional[Dict[str, str]] = None,
extra_headers: Optional[Dict[str, str]] = None,
extra_body: Optional[Dict[str, str]] = None,
**kwargs,
) -> Union[LiteLLMBatch, Coroutine[Any, Any, LiteLLMBatch]]:
"""
Creates and executes a batch from an uploaded file of request
LiteLLM Equivalent of POST: https://api.openai.com/v1/batches
"""
try:
optional_params = GenericLiteLLMParams(**kwargs)
litellm_call_id = kwargs.get("litellm_call_id", None)
proxy_server_request = kwargs.get("proxy_server_request", None)
model_info = kwargs.get("model_info", None)
_is_async = kwargs.pop("acreate_batch", False) is True
litellm_params = get_litellm_params(**kwargs)
litellm_logging_obj: LiteLLMLoggingObj = kwargs.get("litellm_logging_obj", None)
### TIMEOUT LOGIC ###
timeout = optional_params.timeout or kwargs.get("request_timeout", 600) or 600
litellm_logging_obj.update_environment_variables(
model=None,
user=None,
optional_params=optional_params.model_dump(),
litellm_params={
"litellm_call_id": litellm_call_id,
"proxy_server_request": proxy_server_request,
"model_info": model_info,
"metadata": metadata,
"preset_cache_key": None,
"stream_response": {},
**optional_params.model_dump(exclude_unset=True),
},
custom_llm_provider=custom_llm_provider,
)
if (
timeout is not None
and isinstance(timeout, httpx.Timeout)
and supports_httpx_timeout(custom_llm_provider) is False
):
read_timeout = timeout.read or 600
timeout = read_timeout # default 10 min timeout
elif timeout is not None and not isinstance(timeout, httpx.Timeout):
timeout = float(timeout) # type: ignore
elif timeout is None:
timeout = 600.0
_create_batch_request = CreateBatchRequest(
completion_window=completion_window,
endpoint=endpoint,
input_file_id=input_file_id,
metadata=metadata,
extra_headers=extra_headers,
extra_body=extra_body,
)
api_base: Optional[str] = None
if custom_llm_provider == "openai":
# for deepinfra/perplexity/anyscale/groq we check in get_llm_provider and pass in the api base from there
api_base = (
optional_params.api_base
or litellm.api_base
or os.getenv("OPENAI_API_BASE")
or "https://api.openai.com/v1"
)
organization = (
optional_params.organization
or litellm.organization
or os.getenv("OPENAI_ORGANIZATION", None)
or None # default - https://github.com/openai/openai-python/blob/284c1799070c723c6a553337134148a7ab088dd8/openai/util.py#L105
)
# set API KEY
api_key = (
optional_params.api_key
or litellm.api_key # for deepinfra/perplexity/anyscale we check in get_llm_provider and pass in the api key from there
or litellm.openai_key
or os.getenv("OPENAI_API_KEY")
)
response = openai_batches_instance.create_batch(
api_base=api_base,
api_key=api_key,
organization=organization,
create_batch_data=_create_batch_request,
timeout=timeout,
max_retries=optional_params.max_retries,
_is_async=_is_async,
)
elif custom_llm_provider == "azure":
api_base = (
optional_params.api_base
or litellm.api_base
or get_secret_str("AZURE_API_BASE")
)
api_version = (
optional_params.api_version
or litellm.api_version
or get_secret_str("AZURE_API_VERSION")
)
api_key = (
optional_params.api_key
or litellm.api_key
or litellm.azure_key
or get_secret_str("AZURE_OPENAI_API_KEY")
or get_secret_str("AZURE_API_KEY")
)
extra_body = optional_params.get("extra_body", {})
if extra_body is not None:
extra_body.pop("azure_ad_token", None)
else:
get_secret_str("AZURE_AD_TOKEN") # type: ignore
response = azure_batches_instance.create_batch(
_is_async=_is_async,
api_base=api_base,
api_key=api_key,
api_version=api_version,
timeout=timeout,
max_retries=optional_params.max_retries,
create_batch_data=_create_batch_request,
litellm_params=litellm_params,
)
elif custom_llm_provider == "vertex_ai":
api_base = optional_params.api_base or ""
vertex_ai_project = (
optional_params.vertex_project
or litellm.vertex_project
or get_secret_str("VERTEXAI_PROJECT")
)
vertex_ai_location = (
optional_params.vertex_location
or litellm.vertex_location
or get_secret_str("VERTEXAI_LOCATION")
)
vertex_credentials = optional_params.vertex_credentials or get_secret_str(
"VERTEXAI_CREDENTIALS"
)
response = vertex_ai_batches_instance.create_batch(
_is_async=_is_async,
api_base=api_base,
vertex_project=vertex_ai_project,
vertex_location=vertex_ai_location,
vertex_credentials=vertex_credentials,
timeout=timeout,
max_retries=optional_params.max_retries,
create_batch_data=_create_batch_request,
)
else:
raise litellm.exceptions.BadRequestError(
message="LiteLLM doesn't support custom_llm_provider={} for 'create_batch'".format(
custom_llm_provider
),
model="n/a",
llm_provider=custom_llm_provider,
response=httpx.Response(
status_code=400,
content="Unsupported provider",
request=httpx.Request(method="create_batch", url="https://github.com/BerriAI/litellm"), # type: ignore
),
)
return response
except Exception as e:
raise e
@client
async def aretrieve_batch(
batch_id: str,
custom_llm_provider: Literal["openai", "azure", "vertex_ai"] = "openai",
metadata: Optional[Dict[str, str]] = None,
extra_headers: Optional[Dict[str, str]] = None,
extra_body: Optional[Dict[str, str]] = None,
**kwargs,
) -> LiteLLMBatch:
"""
Async: Retrieves a batch.
LiteLLM Equivalent of GET https://api.openai.com/v1/batches/{batch_id}
"""
try:
loop = asyncio.get_event_loop()
kwargs["aretrieve_batch"] = True
# Use a partial function to pass your keyword arguments
func = partial(
retrieve_batch,
batch_id,
custom_llm_provider,
metadata,
extra_headers,
extra_body,
**kwargs,
)
# Add the context to the function
ctx = contextvars.copy_context()
func_with_context = partial(ctx.run, func)
init_response = await loop.run_in_executor(None, func_with_context)
if asyncio.iscoroutine(init_response):
response = await init_response
else:
response = init_response # type: ignore
return response
except Exception as e:
raise e
@client
def retrieve_batch(
batch_id: str,
custom_llm_provider: Literal["openai", "azure", "vertex_ai"] = "openai",
metadata: Optional[Dict[str, str]] = None,
extra_headers: Optional[Dict[str, str]] = None,
extra_body: Optional[Dict[str, str]] = None,
**kwargs,
) -> Union[LiteLLMBatch, Coroutine[Any, Any, LiteLLMBatch]]:
"""
Retrieves a batch.
LiteLLM Equivalent of GET https://api.openai.com/v1/batches/{batch_id}
"""
try:
optional_params = GenericLiteLLMParams(**kwargs)
litellm_logging_obj: LiteLLMLoggingObj = kwargs.get("litellm_logging_obj", None)
### TIMEOUT LOGIC ###
timeout = optional_params.timeout or kwargs.get("request_timeout", 600) or 600
litellm_params = get_litellm_params(
custom_llm_provider=custom_llm_provider,
**kwargs,
)
litellm_logging_obj.update_environment_variables(
model=None,
user=None,
optional_params=optional_params.model_dump(),
litellm_params=litellm_params,
custom_llm_provider=custom_llm_provider,
)
if (
timeout is not None
and isinstance(timeout, httpx.Timeout)
and supports_httpx_timeout(custom_llm_provider) is False
):
read_timeout = timeout.read or 600
timeout = read_timeout # default 10 min timeout
elif timeout is not None and not isinstance(timeout, httpx.Timeout):
timeout = float(timeout) # type: ignore
elif timeout is None:
timeout = 600.0
_retrieve_batch_request = RetrieveBatchRequest(
batch_id=batch_id,
extra_headers=extra_headers,
extra_body=extra_body,
)
_is_async = kwargs.pop("aretrieve_batch", False) is True
api_base: Optional[str] = None
if custom_llm_provider == "openai":
# for deepinfra/perplexity/anyscale/groq we check in get_llm_provider and pass in the api base from there
api_base = (
optional_params.api_base
or litellm.api_base
or os.getenv("OPENAI_API_BASE")
or "https://api.openai.com/v1"
)
organization = (
optional_params.organization
or litellm.organization
or os.getenv("OPENAI_ORGANIZATION", None)
or None # default - https://github.com/openai/openai-python/blob/284c1799070c723c6a553337134148a7ab088dd8/openai/util.py#L105
)
# set API KEY
api_key = (
optional_params.api_key
or litellm.api_key # for deepinfra/perplexity/anyscale we check in get_llm_provider and pass in the api key from there
or litellm.openai_key
or os.getenv("OPENAI_API_KEY")
)
response = openai_batches_instance.retrieve_batch(
_is_async=_is_async,
retrieve_batch_data=_retrieve_batch_request,
api_base=api_base,
api_key=api_key,
organization=organization,
timeout=timeout,
max_retries=optional_params.max_retries,
)
elif custom_llm_provider == "azure":
api_base = (
optional_params.api_base
or litellm.api_base
or get_secret_str("AZURE_API_BASE")
)
api_version = (
optional_params.api_version
or litellm.api_version
or get_secret_str("AZURE_API_VERSION")
)
api_key = (
optional_params.api_key
or litellm.api_key
or litellm.azure_key
or get_secret_str("AZURE_OPENAI_API_KEY")
or get_secret_str("AZURE_API_KEY")
)
extra_body = optional_params.get("extra_body", {})
if extra_body is not None:
extra_body.pop("azure_ad_token", None)
else:
get_secret_str("AZURE_AD_TOKEN") # type: ignore
response = azure_batches_instance.retrieve_batch(
_is_async=_is_async,
api_base=api_base,
api_key=api_key,
api_version=api_version,
timeout=timeout,
max_retries=optional_params.max_retries,
retrieve_batch_data=_retrieve_batch_request,
litellm_params=litellm_params,
)
elif custom_llm_provider == "vertex_ai":
api_base = optional_params.api_base or ""
vertex_ai_project = (
optional_params.vertex_project
or litellm.vertex_project
or get_secret_str("VERTEXAI_PROJECT")
)
vertex_ai_location = (
optional_params.vertex_location
or litellm.vertex_location
or get_secret_str("VERTEXAI_LOCATION")
)
vertex_credentials = optional_params.vertex_credentials or get_secret_str(
"VERTEXAI_CREDENTIALS"
)
response = vertex_ai_batches_instance.retrieve_batch(
_is_async=_is_async,
batch_id=batch_id,
api_base=api_base,
vertex_project=vertex_ai_project,
vertex_location=vertex_ai_location,
vertex_credentials=vertex_credentials,
timeout=timeout,
max_retries=optional_params.max_retries,
)
else:
raise litellm.exceptions.BadRequestError(
message="LiteLLM doesn't support {} for 'create_batch'. Only 'openai' is supported.".format(
custom_llm_provider
),
model="n/a",
llm_provider=custom_llm_provider,
response=httpx.Response(
status_code=400,
content="Unsupported provider",
request=httpx.Request(method="create_thread", url="https://github.com/BerriAI/litellm"), # type: ignore
),
)
return response
except Exception as e:
raise e
async def alist_batches(
after: Optional[str] = None,
limit: Optional[int] = None,
custom_llm_provider: Literal["openai", "azure"] = "openai",
metadata: Optional[Dict[str, str]] = None,
extra_headers: Optional[Dict[str, str]] = None,
extra_body: Optional[Dict[str, str]] = None,
**kwargs,
):
"""
Async: List your organization's batches.
"""
try:
loop = asyncio.get_event_loop()
kwargs["alist_batches"] = True
# Use a partial function to pass your keyword arguments
func = partial(
list_batches,
after,
limit,
custom_llm_provider,
extra_headers,
extra_body,
**kwargs,
)
# Add the context to the function
ctx = contextvars.copy_context()
func_with_context = partial(ctx.run, func)
init_response = await loop.run_in_executor(None, func_with_context)
if asyncio.iscoroutine(init_response):
response = await init_response
else:
response = init_response # type: ignore
return response
except Exception as e:
raise e
def list_batches(
after: Optional[str] = None,
limit: Optional[int] = None,
custom_llm_provider: Literal["openai", "azure"] = "openai",
extra_headers: Optional[Dict[str, str]] = None,
extra_body: Optional[Dict[str, str]] = None,
**kwargs,
):
"""
Lists batches
List your organization's batches.
"""
try:
# set API KEY
optional_params = GenericLiteLLMParams(**kwargs)
litellm_params = get_litellm_params(
custom_llm_provider=custom_llm_provider,
**kwargs,
)
api_key = (
optional_params.api_key
or litellm.api_key # for deepinfra/perplexity/anyscale we check in get_llm_provider and pass in the api key from there
or litellm.openai_key
or os.getenv("OPENAI_API_KEY")
)
### TIMEOUT LOGIC ###
timeout = optional_params.timeout or kwargs.get("request_timeout", 600) or 600
# set timeout for 10 minutes by default
if (
timeout is not None
and isinstance(timeout, httpx.Timeout)
and supports_httpx_timeout(custom_llm_provider) is False
):
read_timeout = timeout.read or 600
timeout = read_timeout # default 10 min timeout
elif timeout is not None and not isinstance(timeout, httpx.Timeout):
timeout = float(timeout) # type: ignore
elif timeout is None:
timeout = 600.0
_is_async = kwargs.pop("alist_batches", False) is True
if custom_llm_provider == "openai":
# for deepinfra/perplexity/anyscale/groq we check in get_llm_provider and pass in the api base from there
api_base = (
optional_params.api_base
or litellm.api_base
or os.getenv("OPENAI_API_BASE")
or "https://api.openai.com/v1"
)
organization = (
optional_params.organization
or litellm.organization
or os.getenv("OPENAI_ORGANIZATION", None)
or None # default - https://github.com/openai/openai-python/blob/284c1799070c723c6a553337134148a7ab088dd8/openai/util.py#L105
)
response = openai_batches_instance.list_batches(
_is_async=_is_async,
after=after,
limit=limit,
api_base=api_base,
api_key=api_key,
organization=organization,
timeout=timeout,
max_retries=optional_params.max_retries,
)
elif custom_llm_provider == "azure":
api_base = optional_params.api_base or litellm.api_base or get_secret_str("AZURE_API_BASE") # type: ignore
api_version = (
optional_params.api_version
or litellm.api_version
or get_secret_str("AZURE_API_VERSION")
)
api_key = (
optional_params.api_key
or litellm.api_key
or litellm.azure_key
or get_secret_str("AZURE_OPENAI_API_KEY")
or get_secret_str("AZURE_API_KEY")
)
extra_body = optional_params.get("extra_body", {})
if extra_body is not None:
extra_body.pop("azure_ad_token", None)
else:
get_secret_str("AZURE_AD_TOKEN") # type: ignore
response = azure_batches_instance.list_batches(
_is_async=_is_async,
api_base=api_base,
api_key=api_key,
api_version=api_version,
timeout=timeout,
max_retries=optional_params.max_retries,
litellm_params=litellm_params,
)
else:
raise litellm.exceptions.BadRequestError(
message="LiteLLM doesn't support {} for 'list_batch'. Only 'openai' is supported.".format(
custom_llm_provider
),
model="n/a",
llm_provider=custom_llm_provider,
response=httpx.Response(
status_code=400,
content="Unsupported provider",
request=httpx.Request(method="create_thread", url="https://github.com/BerriAI/litellm"), # type: ignore
),
)
return response
except Exception as e:
raise e
async def acancel_batch(
batch_id: str,
custom_llm_provider: Literal["openai", "azure"] = "openai",
metadata: Optional[Dict[str, str]] = None,
extra_headers: Optional[Dict[str, str]] = None,
extra_body: Optional[Dict[str, str]] = None,
**kwargs,
) -> Batch:
"""
Async: Cancels a batch.
LiteLLM Equivalent of POST https://api.openai.com/v1/batches/{batch_id}/cancel
"""
try:
loop = asyncio.get_event_loop()
kwargs["acancel_batch"] = True
# Use a partial function to pass your keyword arguments
func = partial(
cancel_batch,
batch_id,
custom_llm_provider,
metadata,
extra_headers,
extra_body,
**kwargs,
)
# Add the context to the function
ctx = contextvars.copy_context()
func_with_context = partial(ctx.run, func)
init_response = await loop.run_in_executor(None, func_with_context)
if asyncio.iscoroutine(init_response):
response = await init_response
else:
response = init_response
return response
except Exception as e:
raise e
def cancel_batch(
batch_id: str,
custom_llm_provider: Literal["openai", "azure"] = "openai",
metadata: Optional[Dict[str, str]] = None,
extra_headers: Optional[Dict[str, str]] = None,
extra_body: Optional[Dict[str, str]] = None,
**kwargs,
) -> Union[Batch, Coroutine[Any, Any, Batch]]:
"""
Cancels a batch.
LiteLLM Equivalent of POST https://api.openai.com/v1/batches/{batch_id}/cancel
"""
try:
optional_params = GenericLiteLLMParams(**kwargs)
litellm_params = get_litellm_params(
custom_llm_provider=custom_llm_provider,
**kwargs,
)
### TIMEOUT LOGIC ###
timeout = optional_params.timeout or kwargs.get("request_timeout", 600) or 600
# set timeout for 10 minutes by default
if (
timeout is not None
and isinstance(timeout, httpx.Timeout)
and supports_httpx_timeout(custom_llm_provider) is False
):
read_timeout = timeout.read or 600
timeout = read_timeout # default 10 min timeout
elif timeout is not None and not isinstance(timeout, httpx.Timeout):
timeout = float(timeout) # type: ignore
elif timeout is None:
timeout = 600.0
_cancel_batch_request = CancelBatchRequest(
batch_id=batch_id,
extra_headers=extra_headers,
extra_body=extra_body,
)
_is_async = kwargs.pop("acancel_batch", False) is True
api_base: Optional[str] = None
if custom_llm_provider == "openai":
api_base = (
optional_params.api_base
or litellm.api_base
or os.getenv("OPENAI_API_BASE")
or "https://api.openai.com/v1"
)
organization = (
optional_params.organization
or litellm.organization
or os.getenv("OPENAI_ORGANIZATION", None)
or None
)
api_key = (
optional_params.api_key
or litellm.api_key
or litellm.openai_key
or os.getenv("OPENAI_API_KEY")
)
response = openai_batches_instance.cancel_batch(
_is_async=_is_async,
cancel_batch_data=_cancel_batch_request,
api_base=api_base,
api_key=api_key,
organization=organization,
timeout=timeout,
max_retries=optional_params.max_retries,
)
elif custom_llm_provider == "azure":
api_base = (
optional_params.api_base
or litellm.api_base
or get_secret_str("AZURE_API_BASE")
)
api_version = (
optional_params.api_version
or litellm.api_version
or get_secret_str("AZURE_API_VERSION")
)
api_key = (
optional_params.api_key
or litellm.api_key
or litellm.azure_key
or get_secret_str("AZURE_OPENAI_API_KEY")
or get_secret_str("AZURE_API_KEY")
)
extra_body = optional_params.get("extra_body", {})
if extra_body is not None:
extra_body.pop("azure_ad_token", None)
else:
get_secret_str("AZURE_AD_TOKEN") # type: ignore
response = azure_batches_instance.cancel_batch(
_is_async=_is_async,
api_base=api_base,
api_key=api_key,
api_version=api_version,
timeout=timeout,
max_retries=optional_params.max_retries,
cancel_batch_data=_cancel_batch_request,
litellm_params=litellm_params,
)
else:
raise litellm.exceptions.BadRequestError(
message="LiteLLM doesn't support {} for 'cancel_batch'. Only 'openai' and 'azure' are supported.".format(
custom_llm_provider
),
model="n/a",
llm_provider=custom_llm_provider,
response=httpx.Response(
status_code=400,
content="Unsupported provider",
request=httpx.Request(method="cancel_batch", url="https://github.com/BerriAI/litellm"), # type: ignore
),
)
return response
except Exception as e:
raise e

View File

@@ -0,0 +1,230 @@
# +-----------------------------------------------+
# | |
# | NOT PROXY BUDGET MANAGER |
# | proxy budget manager is in proxy_server.py |
# | |
# +-----------------------------------------------+
#
# Thank you users! We ❤️ you! - Krrish & Ishaan
import json
import os
import threading
import time
from typing import Literal, Optional
import litellm
from litellm.constants import (
DAYS_IN_A_MONTH,
DAYS_IN_A_WEEK,
DAYS_IN_A_YEAR,
HOURS_IN_A_DAY,
)
from litellm.utils import ModelResponse
class BudgetManager:
def __init__(
self,
project_name: str,
client_type: str = "local",
api_base: Optional[str] = None,
headers: Optional[dict] = None,
):
self.client_type = client_type
self.project_name = project_name
self.api_base = api_base or "https://api.litellm.ai"
self.headers = headers or {"Content-Type": "application/json"}
## load the data or init the initial dictionaries
self.load_data()
def print_verbose(self, print_statement):
try:
if litellm.set_verbose:
import logging
logging.info(print_statement)
except Exception:
pass
def load_data(self):
if self.client_type == "local":
# Check if user dict file exists
if os.path.isfile("user_cost.json"):
# Load the user dict
with open("user_cost.json", "r") as json_file:
self.user_dict = json.load(json_file)
else:
self.print_verbose("User Dictionary not found!")
self.user_dict = {}
self.print_verbose(f"user dict from local: {self.user_dict}")
elif self.client_type == "hosted":
# Load the user_dict from hosted db
url = self.api_base + "/get_budget"
data = {"project_name": self.project_name}
response = litellm.module_level_client.post(
url, headers=self.headers, json=data
)
response = response.json()
if response["status"] == "error":
self.user_dict = (
{}
) # assume this means the user dict hasn't been stored yet
else:
self.user_dict = response["data"]
def create_budget(
self,
total_budget: float,
user: str,
duration: Optional[Literal["daily", "weekly", "monthly", "yearly"]] = None,
created_at: float = time.time(),
):
self.user_dict[user] = {"total_budget": total_budget}
if duration is None:
return self.user_dict[user]
if duration == "daily":
duration_in_days = 1
elif duration == "weekly":
duration_in_days = DAYS_IN_A_WEEK
elif duration == "monthly":
duration_in_days = DAYS_IN_A_MONTH
elif duration == "yearly":
duration_in_days = DAYS_IN_A_YEAR
else:
raise ValueError(
"""duration needs to be one of ["daily", "weekly", "monthly", "yearly"]"""
)
self.user_dict[user] = {
"total_budget": total_budget,
"duration": duration_in_days,
"created_at": created_at,
"last_updated_at": created_at,
}
self._save_data_thread() # [Non-Blocking] Update persistent storage without blocking execution
return self.user_dict[user]
def projected_cost(self, model: str, messages: list, user: str):
text = "".join(message["content"] for message in messages)
prompt_tokens = litellm.token_counter(model=model, text=text)
prompt_cost, _ = litellm.cost_per_token(
model=model, prompt_tokens=prompt_tokens, completion_tokens=0
)
current_cost = self.user_dict[user].get("current_cost", 0)
projected_cost = prompt_cost + current_cost
return projected_cost
def get_total_budget(self, user: str):
return self.user_dict[user]["total_budget"]
def update_cost(
self,
user: str,
completion_obj: Optional[ModelResponse] = None,
model: Optional[str] = None,
input_text: Optional[str] = None,
output_text: Optional[str] = None,
):
if model and input_text and output_text:
prompt_tokens = litellm.token_counter(
model=model, messages=[{"role": "user", "content": input_text}]
)
completion_tokens = litellm.token_counter(
model=model, messages=[{"role": "user", "content": output_text}]
)
(
prompt_tokens_cost_usd_dollar,
completion_tokens_cost_usd_dollar,
) = litellm.cost_per_token(
model=model,
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
)
cost = prompt_tokens_cost_usd_dollar + completion_tokens_cost_usd_dollar
elif completion_obj:
cost = litellm.completion_cost(completion_response=completion_obj)
model = completion_obj[
"model"
] # if this throws an error try, model = completion_obj['model']
else:
raise ValueError(
"Either a chat completion object or the text response needs to be passed in. Learn more - https://docs.litellm.ai/docs/budget_manager"
)
self.user_dict[user]["current_cost"] = cost + self.user_dict[user].get(
"current_cost", 0
)
if "model_cost" in self.user_dict[user]:
self.user_dict[user]["model_cost"][model] = cost + self.user_dict[user][
"model_cost"
].get(model, 0)
else:
self.user_dict[user]["model_cost"] = {model: cost}
self._save_data_thread() # [Non-Blocking] Update persistent storage without blocking execution
return {"user": self.user_dict[user]}
def get_current_cost(self, user):
return self.user_dict[user].get("current_cost", 0)
def get_model_cost(self, user):
return self.user_dict[user].get("model_cost", 0)
def is_valid_user(self, user: str) -> bool:
return user in self.user_dict
def get_users(self):
return list(self.user_dict.keys())
def reset_cost(self, user):
self.user_dict[user]["current_cost"] = 0
self.user_dict[user]["model_cost"] = {}
return {"user": self.user_dict[user]}
def reset_on_duration(self, user: str):
# Get current and creation time
last_updated_at = self.user_dict[user]["last_updated_at"]
current_time = time.time()
# Convert duration from days to seconds
duration_in_seconds = (
self.user_dict[user]["duration"] * HOURS_IN_A_DAY * 60 * 60
)
# Check if duration has elapsed
if current_time - last_updated_at >= duration_in_seconds:
# Reset cost if duration has elapsed and update the creation time
self.reset_cost(user)
self.user_dict[user]["last_updated_at"] = current_time
self._save_data_thread() # Save the data
def update_budget_all_users(self):
for user in self.get_users():
if "duration" in self.user_dict[user]:
self.reset_on_duration(user)
def _save_data_thread(self):
thread = threading.Thread(
target=self.save_data
) # [Non-Blocking]: saves data without blocking execution
thread.start()
def save_data(self):
if self.client_type == "local":
import json
# save the user dict
with open("user_cost.json", "w") as json_file:
json.dump(
self.user_dict, json_file, indent=4
) # Indent for pretty formatting
return {"status": "success"}
elif self.client_type == "hosted":
url = self.api_base + "/set_budget"
data = {"project_name": self.project_name, "user_dict": self.user_dict}
response = litellm.module_level_client.post(
url, headers=self.headers, json=data
)
response = response.json()
return response

View File

@@ -0,0 +1,40 @@
# Caching on LiteLLM
LiteLLM supports multiple caching mechanisms. This allows users to choose the most suitable caching solution for their use case.
The following caching mechanisms are supported:
1. **RedisCache**
2. **RedisSemanticCache**
3. **QdrantSemanticCache**
4. **InMemoryCache**
5. **DiskCache**
6. **S3Cache**
7. **DualCache** (updates both Redis and an in-memory cache simultaneously)
## Folder Structure
```
litellm/caching/
├── base_cache.py
├── caching.py
├── caching_handler.py
├── disk_cache.py
├── dual_cache.py
├── in_memory_cache.py
├── qdrant_semantic_cache.py
├── redis_cache.py
├── redis_semantic_cache.py
├── s3_cache.py
```
## Documentation
- [Caching on LiteLLM Gateway](https://docs.litellm.ai/docs/proxy/caching)
- [Caching on LiteLLM Python](https://docs.litellm.ai/docs/caching/all_caches)

View File

@@ -0,0 +1,9 @@
from .caching import Cache, LiteLLMCacheType
from .disk_cache import DiskCache
from .dual_cache import DualCache
from .in_memory_cache import InMemoryCache
from .qdrant_semantic_cache import QdrantSemanticCache
from .redis_cache import RedisCache
from .redis_cluster_cache import RedisClusterCache
from .redis_semantic_cache import RedisSemanticCache
from .s3_cache import S3Cache

View File

@@ -0,0 +1,30 @@
from functools import lru_cache
from typing import Callable, Optional, TypeVar
T = TypeVar("T")
def lru_cache_wrapper(
maxsize: Optional[int] = None,
) -> Callable[[Callable[..., T]], Callable[..., T]]:
"""
Wrapper for lru_cache that caches success and exceptions
"""
def decorator(f: Callable[..., T]) -> Callable[..., T]:
@lru_cache(maxsize=maxsize)
def wrapper(*args, **kwargs):
try:
return ("success", f(*args, **kwargs))
except Exception as e:
return ("error", e)
def wrapped(*args, **kwargs):
result = wrapper(*args, **kwargs)
if result[0] == "error":
raise result[1]
return result[1]
return wrapped
return decorator

View File

@@ -0,0 +1,55 @@
"""
Base Cache implementation. All cache implementations should inherit from this class.
Has 4 methods:
- set_cache
- get_cache
- async_set_cache
- async_get_cache
"""
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any, Optional, Union
if TYPE_CHECKING:
from opentelemetry.trace import Span as _Span
Span = Union[_Span, Any]
else:
Span = Any
class BaseCache(ABC):
def __init__(self, default_ttl: int = 60):
self.default_ttl = default_ttl
def get_ttl(self, **kwargs) -> Optional[int]:
kwargs_ttl: Optional[int] = kwargs.get("ttl")
if kwargs_ttl is not None:
try:
return int(kwargs_ttl)
except ValueError:
return self.default_ttl
return self.default_ttl
def set_cache(self, key, value, **kwargs):
raise NotImplementedError
async def async_set_cache(self, key, value, **kwargs):
raise NotImplementedError
@abstractmethod
async def async_set_cache_pipeline(self, cache_list, **kwargs):
pass
def get_cache(self, key, **kwargs):
raise NotImplementedError
async def async_get_cache(self, key, **kwargs):
raise NotImplementedError
async def batch_cache_write(self, key, value, **kwargs):
raise NotImplementedError
async def disconnect(self):
raise NotImplementedError

View File

@@ -0,0 +1,798 @@
# +-----------------------------------------------+
# | |
# | Give Feedback / Get Help |
# | https://github.com/BerriAI/litellm/issues/new |
# | |
# +-----------------------------------------------+
#
# Thank you users! We ❤️ you! - Krrish & Ishaan
import ast
import hashlib
import json
import time
import traceback
from enum import Enum
from typing import Any, Dict, List, Optional, Union
from pydantic import BaseModel
import litellm
from litellm._logging import verbose_logger
from litellm.constants import CACHED_STREAMING_CHUNK_DELAY
from litellm.litellm_core_utils.model_param_helper import ModelParamHelper
from litellm.types.caching import *
from litellm.types.utils import all_litellm_params
from .base_cache import BaseCache
from .disk_cache import DiskCache
from .dual_cache import DualCache # noqa
from .in_memory_cache import InMemoryCache
from .qdrant_semantic_cache import QdrantSemanticCache
from .redis_cache import RedisCache
from .redis_cluster_cache import RedisClusterCache
from .redis_semantic_cache import RedisSemanticCache
from .s3_cache import S3Cache
def print_verbose(print_statement):
try:
verbose_logger.debug(print_statement)
if litellm.set_verbose:
print(print_statement) # noqa
except Exception:
pass
class CacheMode(str, Enum):
default_on = "default_on"
default_off = "default_off"
#### LiteLLM.Completion / Embedding Cache ####
class Cache:
def __init__(
self,
type: Optional[LiteLLMCacheType] = LiteLLMCacheType.LOCAL,
mode: Optional[
CacheMode
] = CacheMode.default_on, # when default_on cache is always on, when default_off cache is opt in
host: Optional[str] = None,
port: Optional[str] = None,
password: Optional[str] = None,
namespace: Optional[str] = None,
ttl: Optional[float] = None,
default_in_memory_ttl: Optional[float] = None,
default_in_redis_ttl: Optional[float] = None,
similarity_threshold: Optional[float] = None,
supported_call_types: Optional[List[CachingSupportedCallTypes]] = [
"completion",
"acompletion",
"embedding",
"aembedding",
"atranscription",
"transcription",
"atext_completion",
"text_completion",
"arerank",
"rerank",
],
# s3 Bucket, boto3 configuration
s3_bucket_name: Optional[str] = None,
s3_region_name: Optional[str] = None,
s3_api_version: Optional[str] = None,
s3_use_ssl: Optional[bool] = True,
s3_verify: Optional[Union[bool, str]] = None,
s3_endpoint_url: Optional[str] = None,
s3_aws_access_key_id: Optional[str] = None,
s3_aws_secret_access_key: Optional[str] = None,
s3_aws_session_token: Optional[str] = None,
s3_config: Optional[Any] = None,
s3_path: Optional[str] = None,
redis_semantic_cache_embedding_model: str = "text-embedding-ada-002",
redis_semantic_cache_index_name: Optional[str] = None,
redis_flush_size: Optional[int] = None,
redis_startup_nodes: Optional[List] = None,
disk_cache_dir: Optional[str] = None,
qdrant_api_base: Optional[str] = None,
qdrant_api_key: Optional[str] = None,
qdrant_collection_name: Optional[str] = None,
qdrant_quantization_config: Optional[str] = None,
qdrant_semantic_cache_embedding_model: str = "text-embedding-ada-002",
**kwargs,
):
"""
Initializes the cache based on the given type.
Args:
type (str, optional): The type of cache to initialize. Can be "local", "redis", "redis-semantic", "qdrant-semantic", "s3" or "disk". Defaults to "local".
# Redis Cache Args
host (str, optional): The host address for the Redis cache. Required if type is "redis".
port (int, optional): The port number for the Redis cache. Required if type is "redis".
password (str, optional): The password for the Redis cache. Required if type is "redis".
namespace (str, optional): The namespace for the Redis cache. Required if type is "redis".
ttl (float, optional): The ttl for the Redis cache
redis_flush_size (int, optional): The number of keys to flush at a time. Defaults to 1000. Only used if batch redis set caching is used.
redis_startup_nodes (list, optional): The list of startup nodes for the Redis cache. Defaults to None.
# Qdrant Cache Args
qdrant_api_base (str, optional): The url for your qdrant cluster. Required if type is "qdrant-semantic".
qdrant_api_key (str, optional): The api_key for the local or cloud qdrant cluster.
qdrant_collection_name (str, optional): The name for your qdrant collection. Required if type is "qdrant-semantic".
similarity_threshold (float, optional): The similarity threshold for semantic-caching, Required if type is "redis-semantic" or "qdrant-semantic".
# Disk Cache Args
disk_cache_dir (str, optional): The directory for the disk cache. Defaults to None.
# S3 Cache Args
s3_bucket_name (str, optional): The bucket name for the s3 cache. Defaults to None.
s3_region_name (str, optional): The region name for the s3 cache. Defaults to None.
s3_api_version (str, optional): The api version for the s3 cache. Defaults to None.
s3_use_ssl (bool, optional): The use ssl for the s3 cache. Defaults to True.
s3_verify (bool, optional): The verify for the s3 cache. Defaults to None.
s3_endpoint_url (str, optional): The endpoint url for the s3 cache. Defaults to None.
s3_aws_access_key_id (str, optional): The aws access key id for the s3 cache. Defaults to None.
s3_aws_secret_access_key (str, optional): The aws secret access key for the s3 cache. Defaults to None.
s3_aws_session_token (str, optional): The aws session token for the s3 cache. Defaults to None.
s3_config (dict, optional): The config for the s3 cache. Defaults to None.
# Common Cache Args
supported_call_types (list, optional): List of call types to cache for. Defaults to cache == on for all call types.
**kwargs: Additional keyword arguments for redis.Redis() cache
Raises:
ValueError: If an invalid cache type is provided.
Returns:
None. Cache is set as a litellm param
"""
if type == LiteLLMCacheType.REDIS:
if redis_startup_nodes:
self.cache: BaseCache = RedisClusterCache(
host=host,
port=port,
password=password,
redis_flush_size=redis_flush_size,
startup_nodes=redis_startup_nodes,
**kwargs,
)
else:
self.cache = RedisCache(
host=host,
port=port,
password=password,
redis_flush_size=redis_flush_size,
**kwargs,
)
elif type == LiteLLMCacheType.REDIS_SEMANTIC:
self.cache = RedisSemanticCache(
host=host,
port=port,
password=password,
similarity_threshold=similarity_threshold,
embedding_model=redis_semantic_cache_embedding_model,
index_name=redis_semantic_cache_index_name,
**kwargs,
)
elif type == LiteLLMCacheType.QDRANT_SEMANTIC:
self.cache = QdrantSemanticCache(
qdrant_api_base=qdrant_api_base,
qdrant_api_key=qdrant_api_key,
collection_name=qdrant_collection_name,
similarity_threshold=similarity_threshold,
quantization_config=qdrant_quantization_config,
embedding_model=qdrant_semantic_cache_embedding_model,
)
elif type == LiteLLMCacheType.LOCAL:
self.cache = InMemoryCache()
elif type == LiteLLMCacheType.S3:
self.cache = S3Cache(
s3_bucket_name=s3_bucket_name,
s3_region_name=s3_region_name,
s3_api_version=s3_api_version,
s3_use_ssl=s3_use_ssl,
s3_verify=s3_verify,
s3_endpoint_url=s3_endpoint_url,
s3_aws_access_key_id=s3_aws_access_key_id,
s3_aws_secret_access_key=s3_aws_secret_access_key,
s3_aws_session_token=s3_aws_session_token,
s3_config=s3_config,
s3_path=s3_path,
**kwargs,
)
elif type == LiteLLMCacheType.DISK:
self.cache = DiskCache(disk_cache_dir=disk_cache_dir)
if "cache" not in litellm.input_callback:
litellm.input_callback.append("cache")
if "cache" not in litellm.success_callback:
litellm.logging_callback_manager.add_litellm_success_callback("cache")
if "cache" not in litellm._async_success_callback:
litellm.logging_callback_manager.add_litellm_async_success_callback("cache")
self.supported_call_types = supported_call_types # default to ["completion", "acompletion", "embedding", "aembedding"]
self.type = type
self.namespace = namespace
self.redis_flush_size = redis_flush_size
self.ttl = ttl
self.mode: CacheMode = mode or CacheMode.default_on
if self.type == LiteLLMCacheType.LOCAL and default_in_memory_ttl is not None:
self.ttl = default_in_memory_ttl
if (
self.type == LiteLLMCacheType.REDIS
or self.type == LiteLLMCacheType.REDIS_SEMANTIC
) and default_in_redis_ttl is not None:
self.ttl = default_in_redis_ttl
if self.namespace is not None and isinstance(self.cache, RedisCache):
self.cache.namespace = self.namespace
def get_cache_key(self, **kwargs) -> str:
"""
Get the cache key for the given arguments.
Args:
**kwargs: kwargs to litellm.completion() or embedding()
Returns:
str: The cache key generated from the arguments, or None if no cache key could be generated.
"""
cache_key = ""
# verbose_logger.debug("\nGetting Cache key. Kwargs: %s", kwargs)
preset_cache_key = self._get_preset_cache_key_from_kwargs(**kwargs)
if preset_cache_key is not None:
verbose_logger.debug("\nReturning preset cache key: %s", preset_cache_key)
return preset_cache_key
combined_kwargs = ModelParamHelper._get_all_llm_api_params()
litellm_param_kwargs = all_litellm_params
for param in kwargs:
if param in combined_kwargs:
param_value: Optional[str] = self._get_param_value(param, kwargs)
if param_value is not None:
cache_key += f"{str(param)}: {str(param_value)}"
elif (
param not in litellm_param_kwargs
): # check if user passed in optional param - e.g. top_k
if (
litellm.enable_caching_on_provider_specific_optional_params is True
): # feature flagged for now
if kwargs[param] is None:
continue # ignore None params
param_value = kwargs[param]
cache_key += f"{str(param)}: {str(param_value)}"
verbose_logger.debug("\nCreated cache key: %s", cache_key)
hashed_cache_key = Cache._get_hashed_cache_key(cache_key)
hashed_cache_key = self._add_namespace_to_cache_key(hashed_cache_key, **kwargs)
self._set_preset_cache_key_in_kwargs(
preset_cache_key=hashed_cache_key, **kwargs
)
return hashed_cache_key
def _get_param_value(
self,
param: str,
kwargs: dict,
) -> Optional[str]:
"""
Get the value for the given param from kwargs
"""
if param == "model":
return self._get_model_param_value(kwargs)
elif param == "file":
return self._get_file_param_value(kwargs)
return kwargs[param]
def _get_model_param_value(self, kwargs: dict) -> str:
"""
Handles getting the value for the 'model' param from kwargs
1. If caching groups are set, then return the caching group as the model https://docs.litellm.ai/docs/routing#caching-across-model-groups
2. Else if a model_group is set, then return the model_group as the model. This is used for all requests sent through the litellm.Router()
3. Else use the `model` passed in kwargs
"""
metadata: Dict = kwargs.get("metadata", {}) or {}
litellm_params: Dict = kwargs.get("litellm_params", {}) or {}
metadata_in_litellm_params: Dict = litellm_params.get("metadata", {}) or {}
model_group: Optional[str] = metadata.get(
"model_group"
) or metadata_in_litellm_params.get("model_group")
caching_group = self._get_caching_group(metadata, model_group)
return caching_group or model_group or kwargs["model"]
def _get_caching_group(
self, metadata: dict, model_group: Optional[str]
) -> Optional[str]:
caching_groups: Optional[List] = metadata.get("caching_groups", [])
if caching_groups:
for group in caching_groups:
if model_group in group:
return str(group)
return None
def _get_file_param_value(self, kwargs: dict) -> str:
"""
Handles getting the value for the 'file' param from kwargs. Used for `transcription` requests
"""
file = kwargs.get("file")
metadata = kwargs.get("metadata", {})
litellm_params = kwargs.get("litellm_params", {})
return (
metadata.get("file_checksum")
or getattr(file, "name", None)
or metadata.get("file_name")
or litellm_params.get("file_name")
)
def _get_preset_cache_key_from_kwargs(self, **kwargs) -> Optional[str]:
"""
Get the preset cache key from kwargs["litellm_params"]
We use _get_preset_cache_keys for two reasons
1. optional params like max_tokens, get transformed for bedrock -> max_new_tokens
2. avoid doing duplicate / repeated work
"""
if kwargs:
if "litellm_params" in kwargs:
return kwargs["litellm_params"].get("preset_cache_key", None)
return None
def _set_preset_cache_key_in_kwargs(self, preset_cache_key: str, **kwargs) -> None:
"""
Set the calculated cache key in kwargs
This is used to avoid doing duplicate / repeated work
Placed in kwargs["litellm_params"]
"""
if kwargs:
if "litellm_params" in kwargs:
kwargs["litellm_params"]["preset_cache_key"] = preset_cache_key
@staticmethod
def _get_hashed_cache_key(cache_key: str) -> str:
"""
Get the hashed cache key for the given cache key.
Use hashlib to create a sha256 hash of the cache key
Args:
cache_key (str): The cache key to hash.
Returns:
str: The hashed cache key.
"""
hash_object = hashlib.sha256(cache_key.encode())
# Hexadecimal representation of the hash
hash_hex = hash_object.hexdigest()
verbose_logger.debug("Hashed cache key (SHA-256): %s", hash_hex)
return hash_hex
def _add_namespace_to_cache_key(self, hash_hex: str, **kwargs) -> str:
"""
If a redis namespace is provided, add it to the cache key
Args:
hash_hex (str): The hashed cache key.
**kwargs: Additional keyword arguments.
Returns:
str: The final hashed cache key with the redis namespace.
"""
dynamic_cache_control: DynamicCacheControl = kwargs.get("cache", {})
namespace = (
dynamic_cache_control.get("namespace")
or kwargs.get("metadata", {}).get("redis_namespace")
or self.namespace
)
if namespace:
hash_hex = f"{namespace}:{hash_hex}"
verbose_logger.debug("Final hashed key: %s", hash_hex)
return hash_hex
def generate_streaming_content(self, content):
chunk_size = 5 # Adjust the chunk size as needed
for i in range(0, len(content), chunk_size):
yield {
"choices": [
{
"delta": {
"role": "assistant",
"content": content[i : i + chunk_size],
}
}
]
}
time.sleep(CACHED_STREAMING_CHUNK_DELAY)
def _get_cache_logic(
self,
cached_result: Optional[Any],
max_age: Optional[float],
):
"""
Common get cache logic across sync + async implementations
"""
# Check if a timestamp was stored with the cached response
if (
cached_result is not None
and isinstance(cached_result, dict)
and "timestamp" in cached_result
):
timestamp = cached_result["timestamp"]
current_time = time.time()
# Calculate age of the cached response
response_age = current_time - timestamp
# Check if the cached response is older than the max-age
if max_age is not None and response_age > max_age:
return None # Cached response is too old
# If the response is fresh, or there's no max-age requirement, return the cached response
# cached_response is in `b{} convert it to ModelResponse
cached_response = cached_result.get("response")
try:
if isinstance(cached_response, dict):
pass
else:
cached_response = json.loads(
cached_response # type: ignore
) # Convert string to dictionary
except Exception:
cached_response = ast.literal_eval(cached_response) # type: ignore
return cached_response
return cached_result
def get_cache(self, **kwargs):
"""
Retrieves the cached result for the given arguments.
Args:
*args: args to litellm.completion() or embedding()
**kwargs: kwargs to litellm.completion() or embedding()
Returns:
The cached result if it exists, otherwise None.
"""
try: # never block execution
if self.should_use_cache(**kwargs) is not True:
return
messages = kwargs.get("messages", [])
if "cache_key" in kwargs:
cache_key = kwargs["cache_key"]
else:
cache_key = self.get_cache_key(**kwargs)
if cache_key is not None:
cache_control_args: DynamicCacheControl = kwargs.get("cache", {})
max_age = (
cache_control_args.get("s-maxage")
or cache_control_args.get("s-max-age")
or float("inf")
)
cached_result = self.cache.get_cache(cache_key, messages=messages)
cached_result = self.cache.get_cache(cache_key, messages=messages)
return self._get_cache_logic(
cached_result=cached_result, max_age=max_age
)
except Exception:
print_verbose(f"An exception occurred: {traceback.format_exc()}")
return None
async def async_get_cache(self, **kwargs):
"""
Async get cache implementation.
Used for embedding calls in async wrapper
"""
try: # never block execution
if self.should_use_cache(**kwargs) is not True:
return
kwargs.get("messages", [])
if "cache_key" in kwargs:
cache_key = kwargs["cache_key"]
else:
cache_key = self.get_cache_key(**kwargs)
if cache_key is not None:
cache_control_args = kwargs.get("cache", {})
max_age = cache_control_args.get(
"s-max-age", cache_control_args.get("s-maxage", float("inf"))
)
cached_result = await self.cache.async_get_cache(cache_key, **kwargs)
return self._get_cache_logic(
cached_result=cached_result, max_age=max_age
)
except Exception:
print_verbose(f"An exception occurred: {traceback.format_exc()}")
return None
def _add_cache_logic(self, result, **kwargs):
"""
Common implementation across sync + async add_cache functions
"""
try:
if "cache_key" in kwargs:
cache_key = kwargs["cache_key"]
else:
cache_key = self.get_cache_key(**kwargs)
if cache_key is not None:
if isinstance(result, BaseModel):
result = result.model_dump_json()
## DEFAULT TTL ##
if self.ttl is not None:
kwargs["ttl"] = self.ttl
## Get Cache-Controls ##
_cache_kwargs = kwargs.get("cache", None)
if isinstance(_cache_kwargs, dict):
for k, v in _cache_kwargs.items():
if k == "ttl":
kwargs["ttl"] = v
cached_data = {"timestamp": time.time(), "response": result}
return cache_key, cached_data, kwargs
else:
raise Exception("cache key is None")
except Exception as e:
raise e
def add_cache(self, result, **kwargs):
"""
Adds a result to the cache.
Args:
*args: args to litellm.completion() or embedding()
**kwargs: kwargs to litellm.completion() or embedding()
Returns:
None
"""
try:
if self.should_use_cache(**kwargs) is not True:
return
cache_key, cached_data, kwargs = self._add_cache_logic(
result=result, **kwargs
)
self.cache.set_cache(cache_key, cached_data, **kwargs)
except Exception as e:
verbose_logger.exception(f"LiteLLM Cache: Excepton add_cache: {str(e)}")
async def async_add_cache(self, result, **kwargs):
"""
Async implementation of add_cache
"""
try:
if self.should_use_cache(**kwargs) is not True:
return
if self.type == "redis" and self.redis_flush_size is not None:
# high traffic - fill in results in memory and then flush
await self.batch_cache_write(result, **kwargs)
else:
cache_key, cached_data, kwargs = self._add_cache_logic(
result=result, **kwargs
)
await self.cache.async_set_cache(cache_key, cached_data, **kwargs)
except Exception as e:
verbose_logger.exception(f"LiteLLM Cache: Excepton add_cache: {str(e)}")
async def async_add_cache_pipeline(self, result, **kwargs):
"""
Async implementation of add_cache for Embedding calls
Does a bulk write, to prevent using too many clients
"""
try:
if self.should_use_cache(**kwargs) is not True:
return
# set default ttl if not set
if self.ttl is not None:
kwargs["ttl"] = self.ttl
cache_list = []
for idx, i in enumerate(kwargs["input"]):
preset_cache_key = self.get_cache_key(**{**kwargs, "input": i})
kwargs["cache_key"] = preset_cache_key
embedding_response = result.data[idx]
cache_key, cached_data, kwargs = self._add_cache_logic(
result=embedding_response,
**kwargs,
)
cache_list.append((cache_key, cached_data))
await self.cache.async_set_cache_pipeline(cache_list=cache_list, **kwargs)
# if async_set_cache_pipeline:
# await async_set_cache_pipeline(cache_list=cache_list, **kwargs)
# else:
# tasks = []
# for val in cache_list:
# tasks.append(self.cache.async_set_cache(val[0], val[1], **kwargs))
# await asyncio.gather(*tasks)
except Exception as e:
verbose_logger.exception(f"LiteLLM Cache: Excepton add_cache: {str(e)}")
def should_use_cache(self, **kwargs):
"""
Returns true if we should use the cache for LLM API calls
If cache is default_on then this is True
If cache is default_off then this is only true when user has opted in to use cache
"""
if self.mode == CacheMode.default_on:
return True
# when mode == default_off -> Cache is opt in only
_cache = kwargs.get("cache", None)
verbose_logger.debug("should_use_cache: kwargs: %s; _cache: %s", kwargs, _cache)
if _cache and isinstance(_cache, dict):
if _cache.get("use-cache", False) is True:
return True
return False
async def batch_cache_write(self, result, **kwargs):
cache_key, cached_data, kwargs = self._add_cache_logic(result=result, **kwargs)
await self.cache.batch_cache_write(cache_key, cached_data, **kwargs)
async def ping(self):
cache_ping = getattr(self.cache, "ping")
if cache_ping:
return await cache_ping()
return None
async def delete_cache_keys(self, keys):
cache_delete_cache_keys = getattr(self.cache, "delete_cache_keys")
if cache_delete_cache_keys:
return await cache_delete_cache_keys(keys)
return None
async def disconnect(self):
if hasattr(self.cache, "disconnect"):
await self.cache.disconnect()
def _supports_async(self) -> bool:
"""
Internal method to check if the cache type supports async get/set operations
Only S3 Cache Does NOT support async operations
"""
if self.type and self.type == LiteLLMCacheType.S3:
return False
return True
def enable_cache(
type: Optional[LiteLLMCacheType] = LiteLLMCacheType.LOCAL,
host: Optional[str] = None,
port: Optional[str] = None,
password: Optional[str] = None,
supported_call_types: Optional[List[CachingSupportedCallTypes]] = [
"completion",
"acompletion",
"embedding",
"aembedding",
"atranscription",
"transcription",
"atext_completion",
"text_completion",
"arerank",
"rerank",
],
**kwargs,
):
"""
Enable cache with the specified configuration.
Args:
type (Optional[Literal["local", "redis", "s3", "disk"]]): The type of cache to enable. Defaults to "local".
host (Optional[str]): The host address of the cache server. Defaults to None.
port (Optional[str]): The port number of the cache server. Defaults to None.
password (Optional[str]): The password for the cache server. Defaults to None.
supported_call_types (Optional[List[Literal["completion", "acompletion", "embedding", "aembedding"]]]):
The supported call types for the cache. Defaults to ["completion", "acompletion", "embedding", "aembedding"].
**kwargs: Additional keyword arguments.
Returns:
None
Raises:
None
"""
print_verbose("LiteLLM: Enabling Cache")
if "cache" not in litellm.input_callback:
litellm.input_callback.append("cache")
if "cache" not in litellm.success_callback:
litellm.logging_callback_manager.add_litellm_success_callback("cache")
if "cache" not in litellm._async_success_callback:
litellm.logging_callback_manager.add_litellm_async_success_callback("cache")
if litellm.cache is None:
litellm.cache = Cache(
type=type,
host=host,
port=port,
password=password,
supported_call_types=supported_call_types,
**kwargs,
)
print_verbose(f"LiteLLM: Cache enabled, litellm.cache={litellm.cache}")
print_verbose(f"LiteLLM Cache: {vars(litellm.cache)}")
def update_cache(
type: Optional[LiteLLMCacheType] = LiteLLMCacheType.LOCAL,
host: Optional[str] = None,
port: Optional[str] = None,
password: Optional[str] = None,
supported_call_types: Optional[List[CachingSupportedCallTypes]] = [
"completion",
"acompletion",
"embedding",
"aembedding",
"atranscription",
"transcription",
"atext_completion",
"text_completion",
"arerank",
"rerank",
],
**kwargs,
):
"""
Update the cache for LiteLLM.
Args:
type (Optional[Literal["local", "redis", "s3", "disk"]]): The type of cache. Defaults to "local".
host (Optional[str]): The host of the cache. Defaults to None.
port (Optional[str]): The port of the cache. Defaults to None.
password (Optional[str]): The password for the cache. Defaults to None.
supported_call_types (Optional[List[Literal["completion", "acompletion", "embedding", "aembedding"]]]):
The supported call types for the cache. Defaults to ["completion", "acompletion", "embedding", "aembedding"].
**kwargs: Additional keyword arguments for the cache.
Returns:
None
"""
print_verbose("LiteLLM: Updating Cache")
litellm.cache = Cache(
type=type,
host=host,
port=port,
password=password,
supported_call_types=supported_call_types,
**kwargs,
)
print_verbose(f"LiteLLM: Cache Updated, litellm.cache={litellm.cache}")
print_verbose(f"LiteLLM Cache: {vars(litellm.cache)}")
def disable_cache():
"""
Disable the cache used by LiteLLM.
This function disables the cache used by the LiteLLM module. It removes the cache-related callbacks from the input_callback, success_callback, and _async_success_callback lists. It also sets the litellm.cache attribute to None.
Parameters:
None
Returns:
None
"""
from contextlib import suppress
print_verbose("LiteLLM: Disabling Cache")
with suppress(ValueError):
litellm.input_callback.remove("cache")
litellm.success_callback.remove("cache")
litellm._async_success_callback.remove("cache")
litellm.cache = None
print_verbose(f"LiteLLM: Cache disabled, litellm.cache={litellm.cache}")

View File

@@ -0,0 +1,906 @@
"""
This contains LLMCachingHandler
This exposes two methods:
- async_get_cache
- async_set_cache
This file is a wrapper around caching.py
This class is used to handle caching logic specific for LLM API requests (completion / embedding / text_completion / transcription etc)
It utilizes the (RedisCache, s3Cache, RedisSemanticCache, QdrantSemanticCache, InMemoryCache, DiskCache) based on what the user has setup
In each method it will call the appropriate method from caching.py
"""
import asyncio
import datetime
import inspect
import threading
from typing import (
TYPE_CHECKING,
Any,
AsyncGenerator,
Callable,
Dict,
Generator,
List,
Optional,
Tuple,
Union,
)
from pydantic import BaseModel
import litellm
from litellm._logging import print_verbose, verbose_logger
from litellm.caching.caching import S3Cache
from litellm.litellm_core_utils.logging_utils import (
_assemble_complete_response_from_streaming_chunks,
)
from litellm.types.rerank import RerankResponse
from litellm.types.utils import (
CallTypes,
Embedding,
EmbeddingResponse,
ModelResponse,
TextCompletionResponse,
TranscriptionResponse,
)
if TYPE_CHECKING:
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
from litellm.utils import CustomStreamWrapper
else:
LiteLLMLoggingObj = Any
CustomStreamWrapper = Any
class CachingHandlerResponse(BaseModel):
"""
This is the response object for the caching handler. We need to separate embedding cached responses and (completion / text_completion / transcription) cached responses
For embeddings there can be a cache hit for some of the inputs in the list and a cache miss for others
"""
cached_result: Optional[Any] = None
final_embedding_cached_response: Optional[EmbeddingResponse] = None
embedding_all_elements_cache_hit: bool = False # this is set to True when all elements in the list have a cache hit in the embedding cache, if true return the final_embedding_cached_response no need to make an API call
class LLMCachingHandler:
def __init__(
self,
original_function: Callable,
request_kwargs: Dict[str, Any],
start_time: datetime.datetime,
):
self.async_streaming_chunks: List[ModelResponse] = []
self.sync_streaming_chunks: List[ModelResponse] = []
self.request_kwargs = request_kwargs
self.original_function = original_function
self.start_time = start_time
pass
async def _async_get_cache(
self,
model: str,
original_function: Callable,
logging_obj: LiteLLMLoggingObj,
start_time: datetime.datetime,
call_type: str,
kwargs: Dict[str, Any],
args: Optional[Tuple[Any, ...]] = None,
) -> CachingHandlerResponse:
"""
Internal method to get from the cache.
Handles different call types (embeddings, chat/completions, text_completion, transcription)
and accordingly returns the cached response
Args:
model: str:
original_function: Callable:
logging_obj: LiteLLMLoggingObj:
start_time: datetime.datetime:
call_type: str:
kwargs: Dict[str, Any]:
args: Optional[Tuple[Any, ...]] = None:
Returns:
CachingHandlerResponse:
Raises:
None
"""
from litellm.utils import CustomStreamWrapper
args = args or ()
final_embedding_cached_response: Optional[EmbeddingResponse] = None
embedding_all_elements_cache_hit: bool = False
cached_result: Optional[Any] = None
if (
(kwargs.get("caching", None) is None and litellm.cache is not None)
or kwargs.get("caching", False) is True
) and (
kwargs.get("cache", {}).get("no-cache", False) is not True
): # allow users to control returning cached responses from the completion function
if litellm.cache is not None and self._is_call_type_supported_by_cache(
original_function=original_function
):
verbose_logger.debug("Checking Cache")
cached_result = await self._retrieve_from_cache(
call_type=call_type,
kwargs=kwargs,
args=args,
)
if cached_result is not None and not isinstance(cached_result, list):
verbose_logger.debug("Cache Hit!")
cache_hit = True
end_time = datetime.datetime.now()
model, _, _, _ = litellm.get_llm_provider(
model=model,
custom_llm_provider=kwargs.get("custom_llm_provider", None),
api_base=kwargs.get("api_base", None),
api_key=kwargs.get("api_key", None),
)
self._update_litellm_logging_obj_environment(
logging_obj=logging_obj,
model=model,
kwargs=kwargs,
cached_result=cached_result,
is_async=True,
)
call_type = original_function.__name__
cached_result = self._convert_cached_result_to_model_response(
cached_result=cached_result,
call_type=call_type,
kwargs=kwargs,
logging_obj=logging_obj,
model=model,
custom_llm_provider=kwargs.get("custom_llm_provider", None),
args=args,
)
if kwargs.get("stream", False) is False:
# LOG SUCCESS
self._async_log_cache_hit_on_callbacks(
logging_obj=logging_obj,
cached_result=cached_result,
start_time=start_time,
end_time=end_time,
cache_hit=cache_hit,
)
cache_key = litellm.cache._get_preset_cache_key_from_kwargs(
**kwargs
)
if (
isinstance(cached_result, BaseModel)
or isinstance(cached_result, CustomStreamWrapper)
) and hasattr(cached_result, "_hidden_params"):
cached_result._hidden_params["cache_key"] = cache_key # type: ignore
return CachingHandlerResponse(cached_result=cached_result)
elif (
call_type == CallTypes.aembedding.value
and cached_result is not None
and isinstance(cached_result, list)
and litellm.cache is not None
and not isinstance(
litellm.cache.cache, S3Cache
) # s3 doesn't support bulk writing. Exclude.
):
(
final_embedding_cached_response,
embedding_all_elements_cache_hit,
) = self._process_async_embedding_cached_response(
final_embedding_cached_response=final_embedding_cached_response,
cached_result=cached_result,
kwargs=kwargs,
logging_obj=logging_obj,
start_time=start_time,
model=model,
)
return CachingHandlerResponse(
final_embedding_cached_response=final_embedding_cached_response,
embedding_all_elements_cache_hit=embedding_all_elements_cache_hit,
)
verbose_logger.debug(f"CACHE RESULT: {cached_result}")
return CachingHandlerResponse(
cached_result=cached_result,
final_embedding_cached_response=final_embedding_cached_response,
)
def _sync_get_cache(
self,
model: str,
original_function: Callable,
logging_obj: LiteLLMLoggingObj,
start_time: datetime.datetime,
call_type: str,
kwargs: Dict[str, Any],
args: Optional[Tuple[Any, ...]] = None,
) -> CachingHandlerResponse:
from litellm.utils import CustomStreamWrapper
args = args or ()
new_kwargs = kwargs.copy()
new_kwargs.update(
convert_args_to_kwargs(
self.original_function,
args,
)
)
cached_result: Optional[Any] = None
if litellm.cache is not None and self._is_call_type_supported_by_cache(
original_function=original_function
):
print_verbose("Checking Cache")
cached_result = litellm.cache.get_cache(**new_kwargs)
if cached_result is not None:
if "detail" in cached_result:
# implies an error occurred
pass
else:
call_type = original_function.__name__
cached_result = self._convert_cached_result_to_model_response(
cached_result=cached_result,
call_type=call_type,
kwargs=kwargs,
logging_obj=logging_obj,
model=model,
custom_llm_provider=kwargs.get("custom_llm_provider", None),
args=args,
)
# LOG SUCCESS
cache_hit = True
end_time = datetime.datetime.now()
(
model,
custom_llm_provider,
dynamic_api_key,
api_base,
) = litellm.get_llm_provider(
model=model or "",
custom_llm_provider=kwargs.get("custom_llm_provider", None),
api_base=kwargs.get("api_base", None),
api_key=kwargs.get("api_key", None),
)
self._update_litellm_logging_obj_environment(
logging_obj=logging_obj,
model=model,
kwargs=kwargs,
cached_result=cached_result,
is_async=False,
)
threading.Thread(
target=logging_obj.success_handler,
args=(cached_result, start_time, end_time, cache_hit),
).start()
cache_key = litellm.cache._get_preset_cache_key_from_kwargs(
**kwargs
)
if (
isinstance(cached_result, BaseModel)
or isinstance(cached_result, CustomStreamWrapper)
) and hasattr(cached_result, "_hidden_params"):
cached_result._hidden_params["cache_key"] = cache_key # type: ignore
return CachingHandlerResponse(cached_result=cached_result)
return CachingHandlerResponse(cached_result=cached_result)
def _process_async_embedding_cached_response(
self,
final_embedding_cached_response: Optional[EmbeddingResponse],
cached_result: List[Optional[Dict[str, Any]]],
kwargs: Dict[str, Any],
logging_obj: LiteLLMLoggingObj,
start_time: datetime.datetime,
model: str,
) -> Tuple[Optional[EmbeddingResponse], bool]:
"""
Returns the final embedding cached response and a boolean indicating if all elements in the list have a cache hit
For embedding responses, there can be a cache hit for some of the inputs in the list and a cache miss for others
This function processes the cached embedding responses and returns the final embedding cached response and a boolean indicating if all elements in the list have a cache hit
Args:
final_embedding_cached_response: Optional[EmbeddingResponse]:
cached_result: List[Optional[Dict[str, Any]]]:
kwargs: Dict[str, Any]:
logging_obj: LiteLLMLoggingObj:
start_time: datetime.datetime:
model: str:
Returns:
Tuple[Optional[EmbeddingResponse], bool]:
Returns the final embedding cached response and a boolean indicating if all elements in the list have a cache hit
"""
embedding_all_elements_cache_hit: bool = False
remaining_list = []
non_null_list = []
for idx, cr in enumerate(cached_result):
if cr is None:
remaining_list.append(kwargs["input"][idx])
else:
non_null_list.append((idx, cr))
original_kwargs_input = kwargs["input"]
kwargs["input"] = remaining_list
if len(non_null_list) > 0:
print_verbose(f"EMBEDDING CACHE HIT! - {len(non_null_list)}")
final_embedding_cached_response = EmbeddingResponse(
model=kwargs.get("model"),
data=[None] * len(original_kwargs_input),
)
final_embedding_cached_response._hidden_params["cache_hit"] = True
for val in non_null_list:
idx, cr = val # (idx, cr) tuple
if cr is not None:
final_embedding_cached_response.data[idx] = Embedding(
embedding=cr["embedding"],
index=idx,
object="embedding",
)
if len(remaining_list) == 0:
# LOG SUCCESS
cache_hit = True
embedding_all_elements_cache_hit = True
end_time = datetime.datetime.now()
(
model,
custom_llm_provider,
dynamic_api_key,
api_base,
) = litellm.get_llm_provider(
model=model,
custom_llm_provider=kwargs.get("custom_llm_provider", None),
api_base=kwargs.get("api_base", None),
api_key=kwargs.get("api_key", None),
)
self._update_litellm_logging_obj_environment(
logging_obj=logging_obj,
model=model,
kwargs=kwargs,
cached_result=final_embedding_cached_response,
is_async=True,
is_embedding=True,
)
self._async_log_cache_hit_on_callbacks(
logging_obj=logging_obj,
cached_result=final_embedding_cached_response,
start_time=start_time,
end_time=end_time,
cache_hit=cache_hit,
)
return final_embedding_cached_response, embedding_all_elements_cache_hit
return final_embedding_cached_response, embedding_all_elements_cache_hit
def _combine_cached_embedding_response_with_api_result(
self,
_caching_handler_response: CachingHandlerResponse,
embedding_response: EmbeddingResponse,
start_time: datetime.datetime,
end_time: datetime.datetime,
) -> EmbeddingResponse:
"""
Combines the cached embedding response with the API EmbeddingResponse
For caching there can be a cache hit for some of the inputs in the list and a cache miss for others
This function combines the cached embedding response with the API EmbeddingResponse
Args:
caching_handler_response: CachingHandlerResponse:
embedding_response: EmbeddingResponse:
Returns:
EmbeddingResponse:
"""
if _caching_handler_response.final_embedding_cached_response is None:
return embedding_response
idx = 0
final_data_list = []
for item in _caching_handler_response.final_embedding_cached_response.data:
if item is None and embedding_response.data is not None:
final_data_list.append(embedding_response.data[idx])
idx += 1
else:
final_data_list.append(item)
_caching_handler_response.final_embedding_cached_response.data = final_data_list
_caching_handler_response.final_embedding_cached_response._hidden_params[
"cache_hit"
] = True
_caching_handler_response.final_embedding_cached_response._response_ms = (
end_time - start_time
).total_seconds() * 1000
return _caching_handler_response.final_embedding_cached_response
def _async_log_cache_hit_on_callbacks(
self,
logging_obj: LiteLLMLoggingObj,
cached_result: Any,
start_time: datetime.datetime,
end_time: datetime.datetime,
cache_hit: bool,
):
"""
Helper function to log the success of a cached result on callbacks
Args:
logging_obj (LiteLLMLoggingObj): The logging object.
cached_result: The cached result.
start_time (datetime): The start time of the operation.
end_time (datetime): The end time of the operation.
cache_hit (bool): Whether it was a cache hit.
"""
asyncio.create_task(
logging_obj.async_success_handler(
cached_result, start_time, end_time, cache_hit
)
)
threading.Thread(
target=logging_obj.success_handler,
args=(cached_result, start_time, end_time, cache_hit),
).start()
async def _retrieve_from_cache(
self, call_type: str, kwargs: Dict[str, Any], args: Tuple[Any, ...]
) -> Optional[Any]:
"""
Internal method to
- get cache key
- check what type of cache is used - Redis, RedisSemantic, Qdrant, S3
- async get cache value
- return the cached value
Args:
call_type: str:
kwargs: Dict[str, Any]:
args: Optional[Tuple[Any, ...]] = None:
Returns:
Optional[Any]:
Raises:
None
"""
if litellm.cache is None:
return None
new_kwargs = kwargs.copy()
new_kwargs.update(
convert_args_to_kwargs(
self.original_function,
args,
)
)
cached_result: Optional[Any] = None
if call_type == CallTypes.aembedding.value and isinstance(
new_kwargs["input"], list
):
tasks = []
for idx, i in enumerate(new_kwargs["input"]):
preset_cache_key = litellm.cache.get_cache_key(
**{**new_kwargs, "input": i}
)
tasks.append(litellm.cache.async_get_cache(cache_key=preset_cache_key))
cached_result = await asyncio.gather(*tasks)
## check if cached result is None ##
if cached_result is not None and isinstance(cached_result, list):
# set cached_result to None if all elements are None
if all(result is None for result in cached_result):
cached_result = None
else:
if litellm.cache._supports_async() is True:
cached_result = await litellm.cache.async_get_cache(**new_kwargs)
else: # for s3 caching. [NOT RECOMMENDED IN PROD - this will slow down responses since boto3 is sync]
cached_result = litellm.cache.get_cache(**new_kwargs)
return cached_result
def _convert_cached_result_to_model_response(
self,
cached_result: Any,
call_type: str,
kwargs: Dict[str, Any],
logging_obj: LiteLLMLoggingObj,
model: str,
args: Tuple[Any, ...],
custom_llm_provider: Optional[str] = None,
) -> Optional[
Union[
ModelResponse,
TextCompletionResponse,
EmbeddingResponse,
RerankResponse,
TranscriptionResponse,
CustomStreamWrapper,
]
]:
"""
Internal method to process the cached result
Checks the call type and converts the cached result to the appropriate model response object
example if call type is text_completion -> returns TextCompletionResponse object
Args:
cached_result: Any:
call_type: str:
kwargs: Dict[str, Any]:
logging_obj: LiteLLMLoggingObj:
model: str:
custom_llm_provider: Optional[str] = None:
args: Optional[Tuple[Any, ...]] = None:
Returns:
Optional[Any]:
"""
from litellm.utils import convert_to_model_response_object
if (
call_type == CallTypes.acompletion.value
or call_type == CallTypes.completion.value
) and isinstance(cached_result, dict):
if kwargs.get("stream", False) is True:
cached_result = self._convert_cached_stream_response(
cached_result=cached_result,
call_type=call_type,
logging_obj=logging_obj,
model=model,
)
else:
cached_result = convert_to_model_response_object(
response_object=cached_result,
model_response_object=ModelResponse(),
)
if (
call_type == CallTypes.atext_completion.value
or call_type == CallTypes.text_completion.value
) and isinstance(cached_result, dict):
if kwargs.get("stream", False) is True:
cached_result = self._convert_cached_stream_response(
cached_result=cached_result,
call_type=call_type,
logging_obj=logging_obj,
model=model,
)
else:
cached_result = TextCompletionResponse(**cached_result)
elif (
call_type == CallTypes.aembedding.value
or call_type == CallTypes.embedding.value
) and isinstance(cached_result, dict):
cached_result = convert_to_model_response_object(
response_object=cached_result,
model_response_object=EmbeddingResponse(),
response_type="embedding",
)
elif (
call_type == CallTypes.arerank.value or call_type == CallTypes.rerank.value
) and isinstance(cached_result, dict):
cached_result = convert_to_model_response_object(
response_object=cached_result,
model_response_object=None,
response_type="rerank",
)
elif (
call_type == CallTypes.atranscription.value
or call_type == CallTypes.transcription.value
) and isinstance(cached_result, dict):
hidden_params = {
"model": "whisper-1",
"custom_llm_provider": custom_llm_provider,
"cache_hit": True,
}
cached_result = convert_to_model_response_object(
response_object=cached_result,
model_response_object=TranscriptionResponse(),
response_type="audio_transcription",
hidden_params=hidden_params,
)
if (
hasattr(cached_result, "_hidden_params")
and cached_result._hidden_params is not None
and isinstance(cached_result._hidden_params, dict)
):
cached_result._hidden_params["cache_hit"] = True
return cached_result
def _convert_cached_stream_response(
self,
cached_result: Any,
call_type: str,
logging_obj: LiteLLMLoggingObj,
model: str,
) -> CustomStreamWrapper:
from litellm.utils import (
CustomStreamWrapper,
convert_to_streaming_response,
convert_to_streaming_response_async,
)
_stream_cached_result: Union[AsyncGenerator, Generator]
if (
call_type == CallTypes.acompletion.value
or call_type == CallTypes.atext_completion.value
):
_stream_cached_result = convert_to_streaming_response_async(
response_object=cached_result,
)
else:
_stream_cached_result = convert_to_streaming_response(
response_object=cached_result,
)
return CustomStreamWrapper(
completion_stream=_stream_cached_result,
model=model,
custom_llm_provider="cached_response",
logging_obj=logging_obj,
)
async def async_set_cache(
self,
result: Any,
original_function: Callable,
kwargs: Dict[str, Any],
args: Optional[Tuple[Any, ...]] = None,
):
"""
Internal method to check the type of the result & cache used and adds the result to the cache accordingly
Args:
result: Any:
original_function: Callable:
kwargs: Dict[str, Any]:
args: Optional[Tuple[Any, ...]] = None:
Returns:
None
Raises:
None
"""
if litellm.cache is None:
return
new_kwargs = kwargs.copy()
new_kwargs.update(
convert_args_to_kwargs(
original_function,
args,
)
)
# [OPTIONAL] ADD TO CACHE
if self._should_store_result_in_cache(
original_function=original_function, kwargs=new_kwargs
):
if (
isinstance(result, litellm.ModelResponse)
or isinstance(result, litellm.EmbeddingResponse)
or isinstance(result, TranscriptionResponse)
or isinstance(result, RerankResponse)
):
if (
isinstance(result, EmbeddingResponse)
and isinstance(new_kwargs["input"], list)
and litellm.cache is not None
and not isinstance(
litellm.cache.cache, S3Cache
) # s3 doesn't support bulk writing. Exclude.
):
asyncio.create_task(
litellm.cache.async_add_cache_pipeline(result, **new_kwargs)
)
elif isinstance(litellm.cache.cache, S3Cache):
threading.Thread(
target=litellm.cache.add_cache,
args=(result,),
kwargs=new_kwargs,
).start()
else:
asyncio.create_task(
litellm.cache.async_add_cache(
result.model_dump_json(), **new_kwargs
)
)
else:
asyncio.create_task(litellm.cache.async_add_cache(result, **new_kwargs))
def sync_set_cache(
self,
result: Any,
kwargs: Dict[str, Any],
args: Optional[Tuple[Any, ...]] = None,
):
"""
Sync internal method to add the result to the cache
"""
new_kwargs = kwargs.copy()
new_kwargs.update(
convert_args_to_kwargs(
self.original_function,
args,
)
)
if litellm.cache is None:
return
if self._should_store_result_in_cache(
original_function=self.original_function, kwargs=new_kwargs
):
litellm.cache.add_cache(result, **new_kwargs)
return
def _should_store_result_in_cache(
self, original_function: Callable, kwargs: Dict[str, Any]
) -> bool:
"""
Helper function to determine if the result should be stored in the cache.
Returns:
bool: True if the result should be stored in the cache, False otherwise.
"""
return (
(litellm.cache is not None)
and litellm.cache.supported_call_types is not None
and (str(original_function.__name__) in litellm.cache.supported_call_types)
and (kwargs.get("cache", {}).get("no-store", False) is not True)
)
def _is_call_type_supported_by_cache(
self,
original_function: Callable,
) -> bool:
"""
Helper function to determine if the call type is supported by the cache.
call types are acompletion, aembedding, atext_completion, atranscription, arerank
Defined on `litellm.types.utils.CallTypes`
Returns:
bool: True if the call type is supported by the cache, False otherwise.
"""
if (
litellm.cache is not None
and litellm.cache.supported_call_types is not None
and str(original_function.__name__) in litellm.cache.supported_call_types
):
return True
return False
async def _add_streaming_response_to_cache(self, processed_chunk: ModelResponse):
"""
Internal method to add the streaming response to the cache
- If 'streaming_chunk' has a 'finish_reason' then assemble a litellm.ModelResponse object
- Else append the chunk to self.async_streaming_chunks
"""
complete_streaming_response: Optional[
Union[ModelResponse, TextCompletionResponse]
] = _assemble_complete_response_from_streaming_chunks(
result=processed_chunk,
start_time=self.start_time,
end_time=datetime.datetime.now(),
request_kwargs=self.request_kwargs,
streaming_chunks=self.async_streaming_chunks,
is_async=True,
)
# if a complete_streaming_response is assembled, add it to the cache
if complete_streaming_response is not None:
await self.async_set_cache(
result=complete_streaming_response,
original_function=self.original_function,
kwargs=self.request_kwargs,
)
def _sync_add_streaming_response_to_cache(self, processed_chunk: ModelResponse):
"""
Sync internal method to add the streaming response to the cache
"""
complete_streaming_response: Optional[
Union[ModelResponse, TextCompletionResponse]
] = _assemble_complete_response_from_streaming_chunks(
result=processed_chunk,
start_time=self.start_time,
end_time=datetime.datetime.now(),
request_kwargs=self.request_kwargs,
streaming_chunks=self.sync_streaming_chunks,
is_async=False,
)
# if a complete_streaming_response is assembled, add it to the cache
if complete_streaming_response is not None:
self.sync_set_cache(
result=complete_streaming_response,
kwargs=self.request_kwargs,
)
def _update_litellm_logging_obj_environment(
self,
logging_obj: LiteLLMLoggingObj,
model: str,
kwargs: Dict[str, Any],
cached_result: Any,
is_async: bool,
is_embedding: bool = False,
):
"""
Helper function to update the LiteLLMLoggingObj environment variables.
Args:
logging_obj (LiteLLMLoggingObj): The logging object to update.
model (str): The model being used.
kwargs (Dict[str, Any]): The keyword arguments from the original function call.
cached_result (Any): The cached result to log.
is_async (bool): Whether the call is asynchronous or not.
is_embedding (bool): Whether the call is for embeddings or not.
Returns:
None
"""
litellm_params = {
"logger_fn": kwargs.get("logger_fn", None),
"acompletion": is_async,
"api_base": kwargs.get("api_base", ""),
"metadata": kwargs.get("metadata", {}),
"model_info": kwargs.get("model_info", {}),
"proxy_server_request": kwargs.get("proxy_server_request", None),
"stream_response": kwargs.get("stream_response", {}),
}
if litellm.cache is not None:
litellm_params[
"preset_cache_key"
] = litellm.cache._get_preset_cache_key_from_kwargs(**kwargs)
else:
litellm_params["preset_cache_key"] = None
logging_obj.update_environment_variables(
model=model,
user=kwargs.get("user", None),
optional_params={},
litellm_params=litellm_params,
input=(
kwargs.get("messages", "")
if not is_embedding
else kwargs.get("input", "")
),
api_key=kwargs.get("api_key", None),
original_response=str(cached_result),
additional_args=None,
stream=kwargs.get("stream", False),
)
def convert_args_to_kwargs(
original_function: Callable,
args: Optional[Tuple[Any, ...]] = None,
) -> Dict[str, Any]:
# Get the signature of the original function
signature = inspect.signature(original_function)
# Get parameter names in the order they appear in the original function
param_names = list(signature.parameters.keys())
# Create a mapping of positional arguments to parameter names
args_to_kwargs = {}
if args:
for index, arg in enumerate(args):
if index < len(param_names):
param_name = param_names[index]
args_to_kwargs[param_name] = arg
return args_to_kwargs

View File

@@ -0,0 +1,88 @@
import json
from typing import TYPE_CHECKING, Any, Optional, Union
from .base_cache import BaseCache
if TYPE_CHECKING:
from opentelemetry.trace import Span as _Span
Span = Union[_Span, Any]
else:
Span = Any
class DiskCache(BaseCache):
def __init__(self, disk_cache_dir: Optional[str] = None):
import diskcache as dc
# if users don't provider one, use the default litellm cache
if disk_cache_dir is None:
self.disk_cache = dc.Cache(".litellm_cache")
else:
self.disk_cache = dc.Cache(disk_cache_dir)
def set_cache(self, key, value, **kwargs):
if "ttl" in kwargs:
self.disk_cache.set(key, value, expire=kwargs["ttl"])
else:
self.disk_cache.set(key, value)
async def async_set_cache(self, key, value, **kwargs):
self.set_cache(key=key, value=value, **kwargs)
async def async_set_cache_pipeline(self, cache_list, **kwargs):
for cache_key, cache_value in cache_list:
if "ttl" in kwargs:
self.set_cache(key=cache_key, value=cache_value, ttl=kwargs["ttl"])
else:
self.set_cache(key=cache_key, value=cache_value)
def get_cache(self, key, **kwargs):
original_cached_response = self.disk_cache.get(key)
if original_cached_response:
try:
cached_response = json.loads(original_cached_response) # type: ignore
except Exception:
cached_response = original_cached_response
return cached_response
return None
def batch_get_cache(self, keys: list, **kwargs):
return_val = []
for k in keys:
val = self.get_cache(key=k, **kwargs)
return_val.append(val)
return return_val
def increment_cache(self, key, value: int, **kwargs) -> int:
# get the value
init_value = self.get_cache(key=key) or 0
value = init_value + value # type: ignore
self.set_cache(key, value, **kwargs)
return value
async def async_get_cache(self, key, **kwargs):
return self.get_cache(key=key, **kwargs)
async def async_batch_get_cache(self, keys: list, **kwargs):
return_val = []
for k in keys:
val = self.get_cache(key=k, **kwargs)
return_val.append(val)
return return_val
async def async_increment(self, key, value: int, **kwargs) -> int:
# get the value
init_value = await self.async_get_cache(key=key) or 0
value = init_value + value # type: ignore
await self.async_set_cache(key, value, **kwargs)
return value
def flush_cache(self):
self.disk_cache.clear()
async def disconnect(self):
pass
def delete_cache(self, key):
self.disk_cache.pop(key)

View File

@@ -0,0 +1,434 @@
"""
Dual Cache implementation - Class to update both Redis and an in-memory cache simultaneously.
Has 4 primary methods:
- set_cache
- get_cache
- async_set_cache
- async_get_cache
"""
import asyncio
import time
import traceback
from concurrent.futures import ThreadPoolExecutor
from typing import TYPE_CHECKING, Any, List, Optional, Union
import litellm
from litellm._logging import print_verbose, verbose_logger
from .base_cache import BaseCache
from .in_memory_cache import InMemoryCache
from .redis_cache import RedisCache
if TYPE_CHECKING:
from opentelemetry.trace import Span as _Span
Span = Union[_Span, Any]
else:
Span = Any
from collections import OrderedDict
class LimitedSizeOrderedDict(OrderedDict):
def __init__(self, *args, max_size=100, **kwargs):
super().__init__(*args, **kwargs)
self.max_size = max_size
def __setitem__(self, key, value):
# If inserting a new key exceeds max size, remove the oldest item
if len(self) >= self.max_size:
self.popitem(last=False)
super().__setitem__(key, value)
class DualCache(BaseCache):
"""
DualCache is a cache implementation that updates both Redis and an in-memory cache simultaneously.
When data is updated or inserted, it is written to both the in-memory cache + Redis.
This ensures that even if Redis hasn't been updated yet, the in-memory cache reflects the most recent data.
"""
def __init__(
self,
in_memory_cache: Optional[InMemoryCache] = None,
redis_cache: Optional[RedisCache] = None,
default_in_memory_ttl: Optional[float] = None,
default_redis_ttl: Optional[float] = None,
default_redis_batch_cache_expiry: Optional[float] = None,
default_max_redis_batch_cache_size: int = 100,
) -> None:
super().__init__()
# If in_memory_cache is not provided, use the default InMemoryCache
self.in_memory_cache = in_memory_cache or InMemoryCache()
# If redis_cache is not provided, use the default RedisCache
self.redis_cache = redis_cache
self.last_redis_batch_access_time = LimitedSizeOrderedDict(
max_size=default_max_redis_batch_cache_size
)
self.redis_batch_cache_expiry = (
default_redis_batch_cache_expiry
or litellm.default_redis_batch_cache_expiry
or 10
)
self.default_in_memory_ttl = (
default_in_memory_ttl or litellm.default_in_memory_ttl
)
self.default_redis_ttl = default_redis_ttl or litellm.default_redis_ttl
def update_cache_ttl(
self, default_in_memory_ttl: Optional[float], default_redis_ttl: Optional[float]
):
if default_in_memory_ttl is not None:
self.default_in_memory_ttl = default_in_memory_ttl
if default_redis_ttl is not None:
self.default_redis_ttl = default_redis_ttl
def set_cache(self, key, value, local_only: bool = False, **kwargs):
# Update both Redis and in-memory cache
try:
if self.in_memory_cache is not None:
if "ttl" not in kwargs and self.default_in_memory_ttl is not None:
kwargs["ttl"] = self.default_in_memory_ttl
self.in_memory_cache.set_cache(key, value, **kwargs)
if self.redis_cache is not None and local_only is False:
self.redis_cache.set_cache(key, value, **kwargs)
except Exception as e:
print_verbose(e)
def increment_cache(
self, key, value: int, local_only: bool = False, **kwargs
) -> int:
"""
Key - the key in cache
Value - int - the value you want to increment by
Returns - int - the incremented value
"""
try:
result: int = value
if self.in_memory_cache is not None:
result = self.in_memory_cache.increment_cache(key, value, **kwargs)
if self.redis_cache is not None and local_only is False:
result = self.redis_cache.increment_cache(key, value, **kwargs)
return result
except Exception as e:
verbose_logger.error(f"LiteLLM Cache: Excepton async add_cache: {str(e)}")
raise e
def get_cache(
self,
key,
parent_otel_span: Optional[Span] = None,
local_only: bool = False,
**kwargs,
):
# Try to fetch from in-memory cache first
try:
result = None
if self.in_memory_cache is not None:
in_memory_result = self.in_memory_cache.get_cache(key, **kwargs)
if in_memory_result is not None:
result = in_memory_result
if result is None and self.redis_cache is not None and local_only is False:
# If not found in in-memory cache, try fetching from Redis
redis_result = self.redis_cache.get_cache(
key, parent_otel_span=parent_otel_span
)
if redis_result is not None:
# Update in-memory cache with the value from Redis
self.in_memory_cache.set_cache(key, redis_result, **kwargs)
result = redis_result
print_verbose(f"get cache: cache result: {result}")
return result
except Exception:
verbose_logger.error(traceback.format_exc())
def batch_get_cache(
self,
keys: list,
parent_otel_span: Optional[Span] = None,
local_only: bool = False,
**kwargs,
):
received_args = locals()
received_args.pop("self")
def run_in_new_loop():
"""Run the coroutine in a new event loop within this thread."""
new_loop = asyncio.new_event_loop()
try:
asyncio.set_event_loop(new_loop)
return new_loop.run_until_complete(
self.async_batch_get_cache(**received_args)
)
finally:
new_loop.close()
asyncio.set_event_loop(None)
try:
# First, try to get the current event loop
_ = asyncio.get_running_loop()
# If we're already in an event loop, run in a separate thread
# to avoid nested event loop issues
with ThreadPoolExecutor(max_workers=1) as executor:
future = executor.submit(run_in_new_loop)
return future.result()
except RuntimeError:
# No running event loop, we can safely run in this thread
return run_in_new_loop()
async def async_get_cache(
self,
key,
parent_otel_span: Optional[Span] = None,
local_only: bool = False,
**kwargs,
):
# Try to fetch from in-memory cache first
try:
print_verbose(
f"async get cache: cache key: {key}; local_only: {local_only}"
)
result = None
if self.in_memory_cache is not None:
in_memory_result = await self.in_memory_cache.async_get_cache(
key, **kwargs
)
print_verbose(f"in_memory_result: {in_memory_result}")
if in_memory_result is not None:
result = in_memory_result
if result is None and self.redis_cache is not None and local_only is False:
# If not found in in-memory cache, try fetching from Redis
redis_result = await self.redis_cache.async_get_cache(
key, parent_otel_span=parent_otel_span
)
if redis_result is not None:
# Update in-memory cache with the value from Redis
await self.in_memory_cache.async_set_cache(
key, redis_result, **kwargs
)
result = redis_result
print_verbose(f"get cache: cache result: {result}")
return result
except Exception:
verbose_logger.error(traceback.format_exc())
def get_redis_batch_keys(
self,
current_time: float,
keys: List[str],
result: List[Any],
) -> List[str]:
sublist_keys = []
for key, value in zip(keys, result):
if value is None:
if (
key not in self.last_redis_batch_access_time
or current_time - self.last_redis_batch_access_time[key]
>= self.redis_batch_cache_expiry
):
sublist_keys.append(key)
return sublist_keys
async def async_batch_get_cache(
self,
keys: list,
parent_otel_span: Optional[Span] = None,
local_only: bool = False,
**kwargs,
):
try:
result = [None for _ in range(len(keys))]
if self.in_memory_cache is not None:
in_memory_result = await self.in_memory_cache.async_batch_get_cache(
keys, **kwargs
)
if in_memory_result is not None:
result = in_memory_result
if None in result and self.redis_cache is not None and local_only is False:
"""
- for the none values in the result
- check the redis cache
"""
current_time = time.time()
sublist_keys = self.get_redis_batch_keys(current_time, keys, result)
# Only hit Redis if the last access time was more than 5 seconds ago
if len(sublist_keys) > 0:
# If not found in in-memory cache, try fetching from Redis
redis_result = await self.redis_cache.async_batch_get_cache(
sublist_keys, parent_otel_span=parent_otel_span
)
if redis_result is not None:
# Update in-memory cache with the value from Redis
for key, value in redis_result.items():
if value is not None:
await self.in_memory_cache.async_set_cache(
key, redis_result[key], **kwargs
)
# Update the last access time for each key fetched from Redis
self.last_redis_batch_access_time[key] = current_time
for key, value in redis_result.items():
index = keys.index(key)
result[index] = value
return result
except Exception:
verbose_logger.error(traceback.format_exc())
async def async_set_cache(self, key, value, local_only: bool = False, **kwargs):
print_verbose(
f"async set cache: cache key: {key}; local_only: {local_only}; value: {value}"
)
try:
if self.in_memory_cache is not None:
await self.in_memory_cache.async_set_cache(key, value, **kwargs)
if self.redis_cache is not None and local_only is False:
await self.redis_cache.async_set_cache(key, value, **kwargs)
except Exception as e:
verbose_logger.exception(
f"LiteLLM Cache: Excepton async add_cache: {str(e)}"
)
# async_batch_set_cache
async def async_set_cache_pipeline(
self, cache_list: list, local_only: bool = False, **kwargs
):
"""
Batch write values to the cache
"""
print_verbose(
f"async batch set cache: cache keys: {cache_list}; local_only: {local_only}"
)
try:
if self.in_memory_cache is not None:
await self.in_memory_cache.async_set_cache_pipeline(
cache_list=cache_list, **kwargs
)
if self.redis_cache is not None and local_only is False:
await self.redis_cache.async_set_cache_pipeline(
cache_list=cache_list, ttl=kwargs.pop("ttl", None), **kwargs
)
except Exception as e:
verbose_logger.exception(
f"LiteLLM Cache: Excepton async add_cache: {str(e)}"
)
async def async_increment_cache(
self,
key,
value: float,
parent_otel_span: Optional[Span] = None,
local_only: bool = False,
**kwargs,
) -> float:
"""
Key - the key in cache
Value - float - the value you want to increment by
Returns - float - the incremented value
"""
try:
result: float = value
if self.in_memory_cache is not None:
result = await self.in_memory_cache.async_increment(
key, value, **kwargs
)
if self.redis_cache is not None and local_only is False:
result = await self.redis_cache.async_increment(
key,
value,
parent_otel_span=parent_otel_span,
ttl=kwargs.get("ttl", None),
)
return result
except Exception as e:
raise e # don't log if exception is raised
async def async_set_cache_sadd(
self, key, value: List, local_only: bool = False, **kwargs
) -> None:
"""
Add value to a set
Key - the key in cache
Value - str - the value you want to add to the set
Returns - None
"""
try:
if self.in_memory_cache is not None:
_ = await self.in_memory_cache.async_set_cache_sadd(
key, value, ttl=kwargs.get("ttl", None)
)
if self.redis_cache is not None and local_only is False:
_ = await self.redis_cache.async_set_cache_sadd(
key, value, ttl=kwargs.get("ttl", None)
)
return None
except Exception as e:
raise e # don't log, if exception is raised
def flush_cache(self):
if self.in_memory_cache is not None:
self.in_memory_cache.flush_cache()
if self.redis_cache is not None:
self.redis_cache.flush_cache()
def delete_cache(self, key):
"""
Delete a key from the cache
"""
if self.in_memory_cache is not None:
self.in_memory_cache.delete_cache(key)
if self.redis_cache is not None:
self.redis_cache.delete_cache(key)
async def async_delete_cache(self, key: str):
"""
Delete a key from the cache
"""
if self.in_memory_cache is not None:
self.in_memory_cache.delete_cache(key)
if self.redis_cache is not None:
await self.redis_cache.async_delete_cache(key)
async def async_get_ttl(self, key: str) -> Optional[int]:
"""
Get the remaining TTL of a key in in-memory cache or redis
"""
ttl = await self.in_memory_cache.async_get_ttl(key)
if ttl is None and self.redis_cache is not None:
ttl = await self.redis_cache.async_get_ttl(key)
return ttl

View File

@@ -0,0 +1,204 @@
"""
In-Memory Cache implementation
Has 4 methods:
- set_cache
- get_cache
- async_set_cache
- async_get_cache
"""
import json
import sys
import time
from typing import Any, List, Optional
from pydantic import BaseModel
from litellm.constants import MAX_SIZE_PER_ITEM_IN_MEMORY_CACHE_IN_KB
from .base_cache import BaseCache
class InMemoryCache(BaseCache):
def __init__(
self,
max_size_in_memory: Optional[int] = 200,
default_ttl: Optional[
int
] = 600, # default ttl is 10 minutes. At maximum litellm rate limiting logic requires objects to be in memory for 1 minute
max_size_per_item: Optional[int] = 1024, # 1MB = 1024KB
):
"""
max_size_in_memory [int]: Maximum number of items in cache. done to prevent memory leaks. Use 200 items as a default
"""
self.max_size_in_memory = (
max_size_in_memory or 200
) # set an upper bound of 200 items in-memory
self.default_ttl = default_ttl or 600
self.max_size_per_item = (
max_size_per_item or MAX_SIZE_PER_ITEM_IN_MEMORY_CACHE_IN_KB
) # 1MB = 1024KB
# in-memory cache
self.cache_dict: dict = {}
self.ttl_dict: dict = {}
def check_value_size(self, value: Any):
"""
Check if value size exceeds max_size_per_item (1MB)
Returns True if value size is acceptable, False otherwise
"""
try:
# Fast path for common primitive types that are typically small
if (
isinstance(value, (bool, int, float, str))
and len(str(value))
< self.max_size_per_item * MAX_SIZE_PER_ITEM_IN_MEMORY_CACHE_IN_KB
): # Conservative estimate
return True
# Direct size check for bytes objects
if isinstance(value, bytes):
return sys.getsizeof(value) / 1024 <= self.max_size_per_item
# Handle special types without full conversion when possible
if hasattr(value, "__sizeof__"): # Use __sizeof__ if available
size = value.__sizeof__() / 1024
return size <= self.max_size_per_item
# Fallback for complex types
if isinstance(value, BaseModel) and hasattr(
value, "model_dump"
): # Pydantic v2
value = value.model_dump()
elif hasattr(value, "isoformat"): # datetime objects
return True # datetime strings are always small
# Only convert to JSON if absolutely necessary
if not isinstance(value, (str, bytes)):
value = json.dumps(value, default=str)
return sys.getsizeof(value) / 1024 <= self.max_size_per_item
except Exception:
return False
def evict_cache(self):
"""
Eviction policy:
- check if any items in ttl_dict are expired -> remove them from ttl_dict and cache_dict
This guarantees the following:
- 1. When item ttl not set: At minimumm each item will remain in memory for 5 minutes
- 2. When ttl is set: the item will remain in memory for at least that amount of time
- 3. the size of in-memory cache is bounded
"""
for key in list(self.ttl_dict.keys()):
if time.time() > self.ttl_dict[key]:
self.cache_dict.pop(key, None)
self.ttl_dict.pop(key, None)
# de-reference the removed item
# https://www.geeksforgeeks.org/diagnosing-and-fixing-memory-leaks-in-python/
# One of the most common causes of memory leaks in Python is the retention of objects that are no longer being used.
# This can occur when an object is referenced by another object, but the reference is never removed.
def set_cache(self, key, value, **kwargs):
if len(self.cache_dict) >= self.max_size_in_memory:
# only evict when cache is full
self.evict_cache()
if not self.check_value_size(value):
return
self.cache_dict[key] = value
if "ttl" in kwargs and kwargs["ttl"] is not None:
self.ttl_dict[key] = time.time() + kwargs["ttl"]
else:
self.ttl_dict[key] = time.time() + self.default_ttl
async def async_set_cache(self, key, value, **kwargs):
self.set_cache(key=key, value=value, **kwargs)
async def async_set_cache_pipeline(self, cache_list, ttl=None, **kwargs):
for cache_key, cache_value in cache_list:
if ttl is not None:
self.set_cache(key=cache_key, value=cache_value, ttl=ttl)
else:
self.set_cache(key=cache_key, value=cache_value)
async def async_set_cache_sadd(self, key, value: List, ttl: Optional[float]):
"""
Add value to set
"""
# get the value
init_value = self.get_cache(key=key) or set()
for val in value:
init_value.add(val)
self.set_cache(key, init_value, ttl=ttl)
return value
def get_cache(self, key, **kwargs):
if key in self.cache_dict:
if key in self.ttl_dict:
if time.time() > self.ttl_dict[key]:
self.cache_dict.pop(key, None)
return None
original_cached_response = self.cache_dict[key]
try:
cached_response = json.loads(original_cached_response)
except Exception:
cached_response = original_cached_response
return cached_response
return None
def batch_get_cache(self, keys: list, **kwargs):
return_val = []
for k in keys:
val = self.get_cache(key=k, **kwargs)
return_val.append(val)
return return_val
def increment_cache(self, key, value: int, **kwargs) -> int:
# get the value
init_value = self.get_cache(key=key) or 0
value = init_value + value
self.set_cache(key, value, **kwargs)
return value
async def async_get_cache(self, key, **kwargs):
return self.get_cache(key=key, **kwargs)
async def async_batch_get_cache(self, keys: list, **kwargs):
return_val = []
for k in keys:
val = self.get_cache(key=k, **kwargs)
return_val.append(val)
return return_val
async def async_increment(self, key, value: float, **kwargs) -> float:
# get the value
init_value = await self.async_get_cache(key=key) or 0
value = init_value + value
await self.async_set_cache(key, value, **kwargs)
return value
def flush_cache(self):
self.cache_dict.clear()
self.ttl_dict.clear()
async def disconnect(self):
pass
def delete_cache(self, key):
self.cache_dict.pop(key, None)
self.ttl_dict.pop(key, None)
async def async_get_ttl(self, key: str) -> Optional[int]:
"""
Get the remaining TTL of a key in in-memory cache
"""
return self.ttl_dict.get(key, None)

View File

@@ -0,0 +1,39 @@
"""
Add the event loop to the cache key, to prevent event loop closed errors.
"""
import asyncio
from .in_memory_cache import InMemoryCache
class LLMClientCache(InMemoryCache):
def update_cache_key_with_event_loop(self, key):
"""
Add the event loop to the cache key, to prevent event loop closed errors.
If none, use the key as is.
"""
try:
event_loop = asyncio.get_event_loop()
stringified_event_loop = str(id(event_loop))
return f"{key}-{stringified_event_loop}"
except Exception: # handle no current event loop
return key
def set_cache(self, key, value, **kwargs):
key = self.update_cache_key_with_event_loop(key)
return super().set_cache(key, value, **kwargs)
async def async_set_cache(self, key, value, **kwargs):
key = self.update_cache_key_with_event_loop(key)
return await super().async_set_cache(key, value, **kwargs)
def get_cache(self, key, **kwargs):
key = self.update_cache_key_with_event_loop(key)
return super().get_cache(key, **kwargs)
async def async_get_cache(self, key, **kwargs):
key = self.update_cache_key_with_event_loop(key)
return await super().async_get_cache(key, **kwargs)

View File

@@ -0,0 +1,442 @@
"""
Qdrant Semantic Cache implementation
Has 4 methods:
- set_cache
- get_cache
- async_set_cache
- async_get_cache
"""
import ast
import asyncio
import json
from typing import Any, cast
import litellm
from litellm._logging import print_verbose
from litellm.constants import QDRANT_SCALAR_QUANTILE, QDRANT_VECTOR_SIZE
from litellm.types.utils import EmbeddingResponse
from .base_cache import BaseCache
class QdrantSemanticCache(BaseCache):
def __init__( # noqa: PLR0915
self,
qdrant_api_base=None,
qdrant_api_key=None,
collection_name=None,
similarity_threshold=None,
quantization_config=None,
embedding_model="text-embedding-ada-002",
host_type=None,
):
import os
from litellm.llms.custom_httpx.http_handler import (
_get_httpx_client,
get_async_httpx_client,
httpxSpecialProvider,
)
from litellm.secret_managers.main import get_secret_str
if collection_name is None:
raise Exception("collection_name must be provided, passed None")
self.collection_name = collection_name
print_verbose(
f"qdrant semantic-cache initializing COLLECTION - {self.collection_name}"
)
if similarity_threshold is None:
raise Exception("similarity_threshold must be provided, passed None")
self.similarity_threshold = similarity_threshold
self.embedding_model = embedding_model
headers = {}
# check if defined as os.environ/ variable
if qdrant_api_base:
if isinstance(qdrant_api_base, str) and qdrant_api_base.startswith(
"os.environ/"
):
qdrant_api_base = get_secret_str(qdrant_api_base)
if qdrant_api_key:
if isinstance(qdrant_api_key, str) and qdrant_api_key.startswith(
"os.environ/"
):
qdrant_api_key = get_secret_str(qdrant_api_key)
qdrant_api_base = (
qdrant_api_base or os.getenv("QDRANT_URL") or os.getenv("QDRANT_API_BASE")
)
qdrant_api_key = qdrant_api_key or os.getenv("QDRANT_API_KEY")
headers = {"Content-Type": "application/json"}
if qdrant_api_key:
headers["api-key"] = qdrant_api_key
if qdrant_api_base is None:
raise ValueError("Qdrant url must be provided")
self.qdrant_api_base = qdrant_api_base
self.qdrant_api_key = qdrant_api_key
print_verbose(f"qdrant semantic-cache qdrant_api_base: {self.qdrant_api_base}")
self.headers = headers
self.sync_client = _get_httpx_client()
self.async_client = get_async_httpx_client(
llm_provider=httpxSpecialProvider.Caching
)
if quantization_config is None:
print_verbose(
"Quantization config is not provided. Default binary quantization will be used."
)
collection_exists = self.sync_client.get(
url=f"{self.qdrant_api_base}/collections/{self.collection_name}/exists",
headers=self.headers,
)
if collection_exists.status_code != 200:
raise ValueError(
f"Error from qdrant checking if /collections exist {collection_exists.text}"
)
if collection_exists.json()["result"]["exists"]:
collection_details = self.sync_client.get(
url=f"{self.qdrant_api_base}/collections/{self.collection_name}",
headers=self.headers,
)
self.collection_info = collection_details.json()
print_verbose(
f"Collection already exists.\nCollection details:{self.collection_info}"
)
else:
if quantization_config is None or quantization_config == "binary":
quantization_params = {
"binary": {
"always_ram": False,
}
}
elif quantization_config == "scalar":
quantization_params = {
"scalar": {
"type": "int8",
"quantile": QDRANT_SCALAR_QUANTILE,
"always_ram": False,
}
}
elif quantization_config == "product":
quantization_params = {
"product": {"compression": "x16", "always_ram": False}
}
else:
raise Exception(
"Quantization config must be one of 'scalar', 'binary' or 'product'"
)
new_collection_status = self.sync_client.put(
url=f"{self.qdrant_api_base}/collections/{self.collection_name}",
json={
"vectors": {"size": QDRANT_VECTOR_SIZE, "distance": "Cosine"},
"quantization_config": quantization_params,
},
headers=self.headers,
)
if new_collection_status.json()["result"]:
collection_details = self.sync_client.get(
url=f"{self.qdrant_api_base}/collections/{self.collection_name}",
headers=self.headers,
)
self.collection_info = collection_details.json()
print_verbose(
f"New collection created.\nCollection details:{self.collection_info}"
)
else:
raise Exception("Error while creating new collection")
def _get_cache_logic(self, cached_response: Any):
if cached_response is None:
return cached_response
try:
cached_response = json.loads(
cached_response
) # Convert string to dictionary
except Exception:
cached_response = ast.literal_eval(cached_response)
return cached_response
def set_cache(self, key, value, **kwargs):
print_verbose(f"qdrant semantic-cache set_cache, kwargs: {kwargs}")
import uuid
# get the prompt
messages = kwargs["messages"]
prompt = ""
for message in messages:
prompt += message["content"]
# create an embedding for prompt
embedding_response = cast(
EmbeddingResponse,
litellm.embedding(
model=self.embedding_model,
input=prompt,
cache={"no-store": True, "no-cache": True},
),
)
# get the embedding
embedding = embedding_response["data"][0]["embedding"]
value = str(value)
assert isinstance(value, str)
data = {
"points": [
{
"id": str(uuid.uuid4()),
"vector": embedding,
"payload": {
"text": prompt,
"response": value,
},
},
]
}
self.sync_client.put(
url=f"{self.qdrant_api_base}/collections/{self.collection_name}/points",
headers=self.headers,
json=data,
)
return
def get_cache(self, key, **kwargs):
print_verbose(f"sync qdrant semantic-cache get_cache, kwargs: {kwargs}")
# get the messages
messages = kwargs["messages"]
prompt = ""
for message in messages:
prompt += message["content"]
# convert to embedding
embedding_response = cast(
EmbeddingResponse,
litellm.embedding(
model=self.embedding_model,
input=prompt,
cache={"no-store": True, "no-cache": True},
),
)
# get the embedding
embedding = embedding_response["data"][0]["embedding"]
data = {
"vector": embedding,
"params": {
"quantization": {
"ignore": False,
"rescore": True,
"oversampling": 3.0,
}
},
"limit": 1,
"with_payload": True,
}
search_response = self.sync_client.post(
url=f"{self.qdrant_api_base}/collections/{self.collection_name}/points/search",
headers=self.headers,
json=data,
)
results = search_response.json()["result"]
if results is None:
return None
if isinstance(results, list):
if len(results) == 0:
return None
similarity = results[0]["score"]
cached_prompt = results[0]["payload"]["text"]
# check similarity, if more than self.similarity_threshold, return results
print_verbose(
f"semantic cache: similarity threshold: {self.similarity_threshold}, similarity: {similarity}, prompt: {prompt}, closest_cached_prompt: {cached_prompt}"
)
if similarity >= self.similarity_threshold:
# cache hit !
cached_value = results[0]["payload"]["response"]
print_verbose(
f"got a cache hit, similarity: {similarity}, Current prompt: {prompt}, cached_prompt: {cached_prompt}"
)
return self._get_cache_logic(cached_response=cached_value)
else:
# cache miss !
return None
pass
async def async_set_cache(self, key, value, **kwargs):
import uuid
from litellm.proxy.proxy_server import llm_model_list, llm_router
print_verbose(f"async qdrant semantic-cache set_cache, kwargs: {kwargs}")
# get the prompt
messages = kwargs["messages"]
prompt = ""
for message in messages:
prompt += message["content"]
# create an embedding for prompt
router_model_names = (
[m["model_name"] for m in llm_model_list]
if llm_model_list is not None
else []
)
if llm_router is not None and self.embedding_model in router_model_names:
user_api_key = kwargs.get("metadata", {}).get("user_api_key", "")
embedding_response = await llm_router.aembedding(
model=self.embedding_model,
input=prompt,
cache={"no-store": True, "no-cache": True},
metadata={
"user_api_key": user_api_key,
"semantic-cache-embedding": True,
"trace_id": kwargs.get("metadata", {}).get("trace_id", None),
},
)
else:
# convert to embedding
embedding_response = await litellm.aembedding(
model=self.embedding_model,
input=prompt,
cache={"no-store": True, "no-cache": True},
)
# get the embedding
embedding = embedding_response["data"][0]["embedding"]
value = str(value)
assert isinstance(value, str)
data = {
"points": [
{
"id": str(uuid.uuid4()),
"vector": embedding,
"payload": {
"text": prompt,
"response": value,
},
},
]
}
await self.async_client.put(
url=f"{self.qdrant_api_base}/collections/{self.collection_name}/points",
headers=self.headers,
json=data,
)
return
async def async_get_cache(self, key, **kwargs):
print_verbose(f"async qdrant semantic-cache get_cache, kwargs: {kwargs}")
from litellm.proxy.proxy_server import llm_model_list, llm_router
# get the messages
messages = kwargs["messages"]
prompt = ""
for message in messages:
prompt += message["content"]
router_model_names = (
[m["model_name"] for m in llm_model_list]
if llm_model_list is not None
else []
)
if llm_router is not None and self.embedding_model in router_model_names:
user_api_key = kwargs.get("metadata", {}).get("user_api_key", "")
embedding_response = await llm_router.aembedding(
model=self.embedding_model,
input=prompt,
cache={"no-store": True, "no-cache": True},
metadata={
"user_api_key": user_api_key,
"semantic-cache-embedding": True,
"trace_id": kwargs.get("metadata", {}).get("trace_id", None),
},
)
else:
# convert to embedding
embedding_response = await litellm.aembedding(
model=self.embedding_model,
input=prompt,
cache={"no-store": True, "no-cache": True},
)
# get the embedding
embedding = embedding_response["data"][0]["embedding"]
data = {
"vector": embedding,
"params": {
"quantization": {
"ignore": False,
"rescore": True,
"oversampling": 3.0,
}
},
"limit": 1,
"with_payload": True,
}
search_response = await self.async_client.post(
url=f"{self.qdrant_api_base}/collections/{self.collection_name}/points/search",
headers=self.headers,
json=data,
)
results = search_response.json()["result"]
if results is None:
kwargs.setdefault("metadata", {})["semantic-similarity"] = 0.0
return None
if isinstance(results, list):
if len(results) == 0:
kwargs.setdefault("metadata", {})["semantic-similarity"] = 0.0
return None
similarity = results[0]["score"]
cached_prompt = results[0]["payload"]["text"]
# check similarity, if more than self.similarity_threshold, return results
print_verbose(
f"semantic cache: similarity threshold: {self.similarity_threshold}, similarity: {similarity}, prompt: {prompt}, closest_cached_prompt: {cached_prompt}"
)
# update kwargs["metadata"] with similarity, don't rewrite the original metadata
kwargs.setdefault("metadata", {})["semantic-similarity"] = similarity
if similarity >= self.similarity_threshold:
# cache hit !
cached_value = results[0]["payload"]["response"]
print_verbose(
f"got a cache hit, similarity: {similarity}, Current prompt: {prompt}, cached_prompt: {cached_prompt}"
)
return self._get_cache_logic(cached_response=cached_value)
else:
# cache miss !
return None
pass
async def _collection_info(self):
return self.collection_info
async def async_set_cache_pipeline(self, cache_list, **kwargs):
tasks = []
for val in cache_list:
tasks.append(self.async_set_cache(val[0], val[1], **kwargs))
await asyncio.gather(*tasks)

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,59 @@
"""
Redis Cluster Cache implementation
Key differences:
- RedisClient NEEDs to be re-used across requests, adds 3000ms latency if it's re-created
"""
from typing import TYPE_CHECKING, Any, List, Optional, Union
from litellm.caching.redis_cache import RedisCache
if TYPE_CHECKING:
from opentelemetry.trace import Span as _Span
from redis.asyncio import Redis, RedisCluster
from redis.asyncio.client import Pipeline
pipeline = Pipeline
async_redis_client = Redis
Span = Union[_Span, Any]
else:
pipeline = Any
async_redis_client = Any
Span = Any
class RedisClusterCache(RedisCache):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.redis_async_redis_cluster_client: Optional[RedisCluster] = None
self.redis_sync_redis_cluster_client: Optional[RedisCluster] = None
def init_async_client(self):
from redis.asyncio import RedisCluster
from .._redis import get_redis_async_client
if self.redis_async_redis_cluster_client:
return self.redis_async_redis_cluster_client
_redis_client = get_redis_async_client(
connection_pool=self.async_redis_conn_pool, **self.redis_kwargs
)
if isinstance(_redis_client, RedisCluster):
self.redis_async_redis_cluster_client = _redis_client
return _redis_client
def _run_redis_mget_operation(self, keys: List[str]) -> List[Any]:
"""
Overrides `_run_redis_mget_operation` in redis_cache.py
"""
return self.redis_client.mget_nonatomic(keys=keys) # type: ignore
async def _async_run_redis_mget_operation(self, keys: List[str]) -> List[Any]:
"""
Overrides `_async_run_redis_mget_operation` in redis_cache.py
"""
async_redis_cluster_client = self.init_async_client()
return await async_redis_cluster_client.mget_nonatomic(keys=keys) # type: ignore

View File

@@ -0,0 +1,450 @@
"""
Redis Semantic Cache implementation for LiteLLM
The RedisSemanticCache provides semantic caching functionality using Redis as a backend.
This cache stores responses based on the semantic similarity of prompts rather than
exact matching, allowing for more flexible caching of LLM responses.
This implementation uses RedisVL's SemanticCache to find semantically similar prompts
and their cached responses.
"""
import ast
import asyncio
import json
import os
from typing import Any, Dict, List, Optional, Tuple, cast
import litellm
from litellm._logging import print_verbose
from litellm.litellm_core_utils.prompt_templates.common_utils import (
get_str_from_messages,
)
from litellm.types.utils import EmbeddingResponse
from .base_cache import BaseCache
class RedisSemanticCache(BaseCache):
"""
Redis-backed semantic cache for LLM responses.
This cache uses vector similarity to find semantically similar prompts that have been
previously sent to the LLM, allowing for cache hits even when prompts are not identical
but carry similar meaning.
"""
DEFAULT_REDIS_INDEX_NAME: str = "litellm_semantic_cache_index"
def __init__(
self,
host: Optional[str] = None,
port: Optional[str] = None,
password: Optional[str] = None,
redis_url: Optional[str] = None,
similarity_threshold: Optional[float] = None,
embedding_model: str = "text-embedding-ada-002",
index_name: Optional[str] = None,
**kwargs,
):
"""
Initialize the Redis Semantic Cache.
Args:
host: Redis host address
port: Redis port
password: Redis password
redis_url: Full Redis URL (alternative to separate host/port/password)
similarity_threshold: Threshold for semantic similarity (0.0 to 1.0)
where 1.0 requires exact matches and 0.0 accepts any match
embedding_model: Model to use for generating embeddings
index_name: Name for the Redis index
ttl: Default time-to-live for cache entries in seconds
**kwargs: Additional arguments passed to the Redis client
Raises:
Exception: If similarity_threshold is not provided or required Redis
connection information is missing
"""
from redisvl.extensions.llmcache import SemanticCache
from redisvl.utils.vectorize import CustomTextVectorizer
if index_name is None:
index_name = self.DEFAULT_REDIS_INDEX_NAME
print_verbose(f"Redis semantic-cache initializing index - {index_name}")
# Validate similarity threshold
if similarity_threshold is None:
raise ValueError("similarity_threshold must be provided, passed None")
# Store configuration
self.similarity_threshold = similarity_threshold
# Convert similarity threshold [0,1] to distance threshold [0,2]
# For cosine distance: 0 = most similar, 2 = least similar
# While similarity: 1 = most similar, 0 = least similar
self.distance_threshold = 1 - similarity_threshold
self.embedding_model = embedding_model
# Set up Redis connection
if redis_url is None:
try:
# Attempt to use provided parameters or fallback to environment variables
host = host or os.environ["REDIS_HOST"]
port = port or os.environ["REDIS_PORT"]
password = password or os.environ["REDIS_PASSWORD"]
except KeyError as e:
# Raise a more informative exception if any of the required keys are missing
missing_var = e.args[0]
raise ValueError(
f"Missing required Redis configuration: {missing_var}. "
f"Provide {missing_var} or redis_url."
) from e
redis_url = f"redis://:{password}@{host}:{port}"
print_verbose(f"Redis semantic-cache redis_url: {redis_url}")
# Initialize the Redis vectorizer and cache
cache_vectorizer = CustomTextVectorizer(self._get_embedding)
self.llmcache = SemanticCache(
name=index_name,
redis_url=redis_url,
vectorizer=cache_vectorizer,
distance_threshold=self.distance_threshold,
overwrite=False,
)
def _get_ttl(self, **kwargs) -> Optional[int]:
"""
Get the TTL (time-to-live) value for cache entries.
Args:
**kwargs: Keyword arguments that may contain a custom TTL
Returns:
Optional[int]: The TTL value in seconds, or None if no TTL should be applied
"""
ttl = kwargs.get("ttl")
if ttl is not None:
ttl = int(ttl)
return ttl
def _get_embedding(self, prompt: str) -> List[float]:
"""
Generate an embedding vector for the given prompt using the configured embedding model.
Args:
prompt: The text to generate an embedding for
Returns:
List[float]: The embedding vector
"""
# Create an embedding from prompt
embedding_response = cast(
EmbeddingResponse,
litellm.embedding(
model=self.embedding_model,
input=prompt,
cache={"no-store": True, "no-cache": True},
),
)
embedding = embedding_response["data"][0]["embedding"]
return embedding
def _get_cache_logic(self, cached_response: Any) -> Any:
"""
Process the cached response to prepare it for use.
Args:
cached_response: The raw cached response
Returns:
The processed cache response, or None if input was None
"""
if cached_response is None:
return cached_response
# Convert bytes to string if needed
if isinstance(cached_response, bytes):
cached_response = cached_response.decode("utf-8")
# Convert string representation to Python object
try:
cached_response = json.loads(cached_response)
except json.JSONDecodeError:
try:
cached_response = ast.literal_eval(cached_response)
except (ValueError, SyntaxError) as e:
print_verbose(f"Error parsing cached response: {str(e)}")
return None
return cached_response
def set_cache(self, key: str, value: Any, **kwargs) -> None:
"""
Store a value in the semantic cache.
Args:
key: The cache key (not directly used in semantic caching)
value: The response value to cache
**kwargs: Additional arguments including 'messages' for the prompt
and optional 'ttl' for time-to-live
"""
print_verbose(f"Redis semantic-cache set_cache, kwargs: {kwargs}")
value_str: Optional[str] = None
try:
# Extract the prompt from messages
messages = kwargs.get("messages", [])
if not messages:
print_verbose("No messages provided for semantic caching")
return
prompt = get_str_from_messages(messages)
value_str = str(value)
# Get TTL and store in Redis semantic cache
ttl = self._get_ttl(**kwargs)
if ttl is not None:
self.llmcache.store(prompt, value_str, ttl=int(ttl))
else:
self.llmcache.store(prompt, value_str)
except Exception as e:
print_verbose(
f"Error setting {value_str or value} in the Redis semantic cache: {str(e)}"
)
def get_cache(self, key: str, **kwargs) -> Any:
"""
Retrieve a semantically similar cached response.
Args:
key: The cache key (not directly used in semantic caching)
**kwargs: Additional arguments including 'messages' for the prompt
Returns:
The cached response if a semantically similar prompt is found, else None
"""
print_verbose(f"Redis semantic-cache get_cache, kwargs: {kwargs}")
try:
# Extract the prompt from messages
messages = kwargs.get("messages", [])
if not messages:
print_verbose("No messages provided for semantic cache lookup")
return None
prompt = get_str_from_messages(messages)
# Check the cache for semantically similar prompts
results = self.llmcache.check(prompt=prompt)
# Return None if no similar prompts found
if not results:
return None
# Process the best matching result
cache_hit = results[0]
vector_distance = float(cache_hit["vector_distance"])
# Convert vector distance back to similarity score
# For cosine distance: 0 = most similar, 2 = least similar
# While similarity: 1 = most similar, 0 = least similar
similarity = 1 - vector_distance
cached_prompt = cache_hit["prompt"]
cached_response = cache_hit["response"]
print_verbose(
f"Cache hit: similarity threshold: {self.similarity_threshold}, "
f"actual similarity: {similarity}, "
f"current prompt: {prompt}, "
f"cached prompt: {cached_prompt}"
)
return self._get_cache_logic(cached_response=cached_response)
except Exception as e:
print_verbose(f"Error retrieving from Redis semantic cache: {str(e)}")
async def _get_async_embedding(self, prompt: str, **kwargs) -> List[float]:
"""
Asynchronously generate an embedding for the given prompt.
Args:
prompt: The text to generate an embedding for
**kwargs: Additional arguments that may contain metadata
Returns:
List[float]: The embedding vector
"""
from litellm.proxy.proxy_server import llm_model_list, llm_router
# Route the embedding request through the proxy if appropriate
router_model_names = (
[m["model_name"] for m in llm_model_list]
if llm_model_list is not None
else []
)
try:
if llm_router is not None and self.embedding_model in router_model_names:
# Use the router for embedding generation
user_api_key = kwargs.get("metadata", {}).get("user_api_key", "")
embedding_response = await llm_router.aembedding(
model=self.embedding_model,
input=prompt,
cache={"no-store": True, "no-cache": True},
metadata={
"user_api_key": user_api_key,
"semantic-cache-embedding": True,
"trace_id": kwargs.get("metadata", {}).get("trace_id", None),
},
)
else:
# Generate embedding directly
embedding_response = await litellm.aembedding(
model=self.embedding_model,
input=prompt,
cache={"no-store": True, "no-cache": True},
)
# Extract and return the embedding vector
return embedding_response["data"][0]["embedding"]
except Exception as e:
print_verbose(f"Error generating async embedding: {str(e)}")
raise ValueError(f"Failed to generate embedding: {str(e)}") from e
async def async_set_cache(self, key: str, value: Any, **kwargs) -> None:
"""
Asynchronously store a value in the semantic cache.
Args:
key: The cache key (not directly used in semantic caching)
value: The response value to cache
**kwargs: Additional arguments including 'messages' for the prompt
and optional 'ttl' for time-to-live
"""
print_verbose(f"Async Redis semantic-cache set_cache, kwargs: {kwargs}")
try:
# Extract the prompt from messages
messages = kwargs.get("messages", [])
if not messages:
print_verbose("No messages provided for semantic caching")
return
prompt = get_str_from_messages(messages)
value_str = str(value)
# Generate embedding for the value (response) to cache
prompt_embedding = await self._get_async_embedding(prompt, **kwargs)
# Get TTL and store in Redis semantic cache
ttl = self._get_ttl(**kwargs)
if ttl is not None:
await self.llmcache.astore(
prompt,
value_str,
vector=prompt_embedding, # Pass through custom embedding
ttl=ttl,
)
else:
await self.llmcache.astore(
prompt,
value_str,
vector=prompt_embedding, # Pass through custom embedding
)
except Exception as e:
print_verbose(f"Error in async_set_cache: {str(e)}")
async def async_get_cache(self, key: str, **kwargs) -> Any:
"""
Asynchronously retrieve a semantically similar cached response.
Args:
key: The cache key (not directly used in semantic caching)
**kwargs: Additional arguments including 'messages' for the prompt
Returns:
The cached response if a semantically similar prompt is found, else None
"""
print_verbose(f"Async Redis semantic-cache get_cache, kwargs: {kwargs}")
try:
# Extract the prompt from messages
messages = kwargs.get("messages", [])
if not messages:
print_verbose("No messages provided for semantic cache lookup")
kwargs.setdefault("metadata", {})["semantic-similarity"] = 0.0
return None
prompt = get_str_from_messages(messages)
# Generate embedding for the prompt
prompt_embedding = await self._get_async_embedding(prompt, **kwargs)
# Check the cache for semantically similar prompts
results = await self.llmcache.acheck(prompt=prompt, vector=prompt_embedding)
# handle results / cache hit
if not results:
kwargs.setdefault("metadata", {})[
"semantic-similarity"
] = 0.0 # TODO why here but not above??
return None
cache_hit = results[0]
vector_distance = float(cache_hit["vector_distance"])
# Convert vector distance back to similarity
# For cosine distance: 0 = most similar, 2 = least similar
# While similarity: 1 = most similar, 0 = least similar
similarity = 1 - vector_distance
cached_prompt = cache_hit["prompt"]
cached_response = cache_hit["response"]
# update kwargs["metadata"] with similarity, don't rewrite the original metadata
kwargs.setdefault("metadata", {})["semantic-similarity"] = similarity
print_verbose(
f"Cache hit: similarity threshold: {self.similarity_threshold}, "
f"actual similarity: {similarity}, "
f"current prompt: {prompt}, "
f"cached prompt: {cached_prompt}"
)
return self._get_cache_logic(cached_response=cached_response)
except Exception as e:
print_verbose(f"Error in async_get_cache: {str(e)}")
kwargs.setdefault("metadata", {})["semantic-similarity"] = 0.0
async def _index_info(self) -> Dict[str, Any]:
"""
Get information about the Redis index.
Returns:
Dict[str, Any]: Information about the Redis index
"""
aindex = await self.llmcache._get_async_index()
return await aindex.info()
async def async_set_cache_pipeline(
self, cache_list: List[Tuple[str, Any]], **kwargs
) -> None:
"""
Asynchronously store multiple values in the semantic cache.
Args:
cache_list: List of (key, value) tuples to cache
**kwargs: Additional arguments
"""
try:
tasks = []
for val in cache_list:
tasks.append(self.async_set_cache(val[0], val[1], **kwargs))
await asyncio.gather(*tasks)
except Exception as e:
print_verbose(f"Error in async_set_cache_pipeline: {str(e)}")

View File

@@ -0,0 +1,159 @@
"""
S3 Cache implementation
WARNING: DO NOT USE THIS IN PRODUCTION - This is not ASYNC
Has 4 methods:
- set_cache
- get_cache
- async_set_cache
- async_get_cache
"""
import ast
import asyncio
import json
from typing import Optional
from litellm._logging import print_verbose, verbose_logger
from .base_cache import BaseCache
class S3Cache(BaseCache):
def __init__(
self,
s3_bucket_name,
s3_region_name=None,
s3_api_version=None,
s3_use_ssl: Optional[bool] = True,
s3_verify=None,
s3_endpoint_url=None,
s3_aws_access_key_id=None,
s3_aws_secret_access_key=None,
s3_aws_session_token=None,
s3_config=None,
s3_path=None,
**kwargs,
):
import boto3
self.bucket_name = s3_bucket_name
self.key_prefix = s3_path.rstrip("/") + "/" if s3_path else ""
# Create an S3 client with custom endpoint URL
self.s3_client = boto3.client(
"s3",
region_name=s3_region_name,
endpoint_url=s3_endpoint_url,
api_version=s3_api_version,
use_ssl=s3_use_ssl,
verify=s3_verify,
aws_access_key_id=s3_aws_access_key_id,
aws_secret_access_key=s3_aws_secret_access_key,
aws_session_token=s3_aws_session_token,
config=s3_config,
**kwargs,
)
def set_cache(self, key, value, **kwargs):
try:
print_verbose(f"LiteLLM SET Cache - S3. Key={key}. Value={value}")
ttl = kwargs.get("ttl", None)
# Convert value to JSON before storing in S3
serialized_value = json.dumps(value)
key = self.key_prefix + key
if ttl is not None:
cache_control = f"immutable, max-age={ttl}, s-maxage={ttl}"
import datetime
# Calculate expiration time
expiration_time = datetime.datetime.now() + ttl
# Upload the data to S3 with the calculated expiration time
self.s3_client.put_object(
Bucket=self.bucket_name,
Key=key,
Body=serialized_value,
Expires=expiration_time,
CacheControl=cache_control,
ContentType="application/json",
ContentLanguage="en",
ContentDisposition=f'inline; filename="{key}.json"',
)
else:
cache_control = "immutable, max-age=31536000, s-maxage=31536000"
# Upload the data to S3 without specifying Expires
self.s3_client.put_object(
Bucket=self.bucket_name,
Key=key,
Body=serialized_value,
CacheControl=cache_control,
ContentType="application/json",
ContentLanguage="en",
ContentDisposition=f'inline; filename="{key}.json"',
)
except Exception as e:
# NON blocking - notify users S3 is throwing an exception
print_verbose(f"S3 Caching: set_cache() - Got exception from S3: {e}")
async def async_set_cache(self, key, value, **kwargs):
self.set_cache(key=key, value=value, **kwargs)
def get_cache(self, key, **kwargs):
import botocore
try:
key = self.key_prefix + key
print_verbose(f"Get S3 Cache: key: {key}")
# Download the data from S3
cached_response = self.s3_client.get_object(
Bucket=self.bucket_name, Key=key
)
if cached_response is not None:
# cached_response is in `b{} convert it to ModelResponse
cached_response = (
cached_response["Body"].read().decode("utf-8")
) # Convert bytes to string
try:
cached_response = json.loads(
cached_response
) # Convert string to dictionary
except Exception:
cached_response = ast.literal_eval(cached_response)
if not isinstance(cached_response, dict):
cached_response = dict(cached_response)
verbose_logger.debug(
f"Got S3 Cache: key: {key}, cached_response {cached_response}. Type Response {type(cached_response)}"
)
return cached_response
except botocore.exceptions.ClientError as e: # type: ignore
if e.response["Error"]["Code"] == "NoSuchKey":
verbose_logger.debug(
f"S3 Cache: The specified key '{key}' does not exist in the S3 bucket."
)
return None
except Exception as e:
# NON blocking - notify users S3 is throwing an exception
verbose_logger.error(
f"S3 Caching: get_cache() - Got exception from S3: {e}"
)
async def async_get_cache(self, key, **kwargs):
return self.get_cache(key=key, **kwargs)
def flush_cache(self):
pass
async def disconnect(self):
pass
async def async_set_cache_pipeline(self, cache_list, **kwargs):
tasks = []
for val in cache_list:
tasks.append(self.async_set_cache(val[0], val[1], **kwargs))
await asyncio.gather(*tasks)

View File

@@ -0,0 +1,539 @@
from typing import List, Literal
ROUTER_MAX_FALLBACKS = 5
DEFAULT_BATCH_SIZE = 512
DEFAULT_FLUSH_INTERVAL_SECONDS = 5
DEFAULT_MAX_RETRIES = 2
DEFAULT_MAX_RECURSE_DEPTH = 10
DEFAULT_FAILURE_THRESHOLD_PERCENT = (
0.5 # default cooldown a deployment if 50% of requests fail in a given minute
)
DEFAULT_MAX_TOKENS = 4096
DEFAULT_ALLOWED_FAILS = 3
DEFAULT_REDIS_SYNC_INTERVAL = 1
DEFAULT_COOLDOWN_TIME_SECONDS = 5
DEFAULT_REPLICATE_POLLING_RETRIES = 5
DEFAULT_REPLICATE_POLLING_DELAY_SECONDS = 1
DEFAULT_IMAGE_TOKEN_COUNT = 250
DEFAULT_IMAGE_WIDTH = 300
DEFAULT_IMAGE_HEIGHT = 300
DEFAULT_MAX_TOKENS = 256 # used when providers need a default
MAX_SIZE_PER_ITEM_IN_MEMORY_CACHE_IN_KB = 1024 # 1MB = 1024KB
SINGLE_DEPLOYMENT_TRAFFIC_FAILURE_THRESHOLD = 1000 # Minimum number of requests to consider "reasonable traffic". Used for single-deployment cooldown logic.
DEFAULT_REASONING_EFFORT_LOW_THINKING_BUDGET = 1024
DEFAULT_REASONING_EFFORT_MEDIUM_THINKING_BUDGET = 2048
DEFAULT_REASONING_EFFORT_HIGH_THINKING_BUDGET = 4096
########## Networking constants ##############################################################
_DEFAULT_TTL_FOR_HTTPX_CLIENTS = 3600 # 1 hour, re-use the same httpx client for 1 hour
########### v2 Architecture constants for managing writing updates to the database ###########
REDIS_UPDATE_BUFFER_KEY = "litellm_spend_update_buffer"
REDIS_DAILY_SPEND_UPDATE_BUFFER_KEY = "litellm_daily_spend_update_buffer"
REDIS_DAILY_TEAM_SPEND_UPDATE_BUFFER_KEY = "litellm_daily_team_spend_update_buffer"
REDIS_DAILY_TAG_SPEND_UPDATE_BUFFER_KEY = "litellm_daily_tag_spend_update_buffer"
MAX_REDIS_BUFFER_DEQUEUE_COUNT = 100
MAX_SIZE_IN_MEMORY_QUEUE = 10000
MAX_IN_MEMORY_QUEUE_FLUSH_COUNT = 1000
###############################################################################################
MINIMUM_PROMPT_CACHE_TOKEN_COUNT = (
1024 # minimum number of tokens to cache a prompt by Anthropic
)
DEFAULT_TRIM_RATIO = 0.75 # default ratio of tokens to trim from the end of a prompt
HOURS_IN_A_DAY = 24
DAYS_IN_A_WEEK = 7
DAYS_IN_A_MONTH = 28
DAYS_IN_A_YEAR = 365
REPLICATE_MODEL_NAME_WITH_ID_LENGTH = 64
#### TOKEN COUNTING ####
FUNCTION_DEFINITION_TOKEN_COUNT = 9
SYSTEM_MESSAGE_TOKEN_COUNT = 4
TOOL_CHOICE_OBJECT_TOKEN_COUNT = 4
DEFAULT_MOCK_RESPONSE_PROMPT_TOKEN_COUNT = 10
DEFAULT_MOCK_RESPONSE_COMPLETION_TOKEN_COUNT = 20
MAX_SHORT_SIDE_FOR_IMAGE_HIGH_RES = 768
MAX_LONG_SIDE_FOR_IMAGE_HIGH_RES = 2000
MAX_TILE_WIDTH = 512
MAX_TILE_HEIGHT = 512
OPENAI_FILE_SEARCH_COST_PER_1K_CALLS = 2.5 / 1000
MIN_NON_ZERO_TEMPERATURE = 0.0001
#### RELIABILITY ####
REPEATED_STREAMING_CHUNK_LIMIT = 100 # catch if model starts looping the same chunk while streaming. Uses high default to prevent false positives.
DEFAULT_MAX_LRU_CACHE_SIZE = 16
INITIAL_RETRY_DELAY = 0.5
MAX_RETRY_DELAY = 8.0
JITTER = 0.75
DEFAULT_IN_MEMORY_TTL = 5 # default time to live for the in-memory cache
DEFAULT_POLLING_INTERVAL = 0.03 # default polling interval for the scheduler
AZURE_OPERATION_POLLING_TIMEOUT = 120
REDIS_SOCKET_TIMEOUT = 0.1
REDIS_CONNECTION_POOL_TIMEOUT = 5
NON_LLM_CONNECTION_TIMEOUT = 15 # timeout for adjacent services (e.g. jwt auth)
MAX_EXCEPTION_MESSAGE_LENGTH = 2000
BEDROCK_MAX_POLICY_SIZE = 75
REPLICATE_POLLING_DELAY_SECONDS = 0.5
DEFAULT_ANTHROPIC_CHAT_MAX_TOKENS = 4096
TOGETHER_AI_4_B = 4
TOGETHER_AI_8_B = 8
TOGETHER_AI_21_B = 21
TOGETHER_AI_41_B = 41
TOGETHER_AI_80_B = 80
TOGETHER_AI_110_B = 110
TOGETHER_AI_EMBEDDING_150_M = 150
TOGETHER_AI_EMBEDDING_350_M = 350
QDRANT_SCALAR_QUANTILE = 0.99
QDRANT_VECTOR_SIZE = 1536
CACHED_STREAMING_CHUNK_DELAY = 0.02
MAX_SIZE_PER_ITEM_IN_MEMORY_CACHE_IN_KB = 512
DEFAULT_MAX_TOKENS_FOR_TRITON = 2000
#### Networking settings ####
request_timeout: float = 6000 # time in seconds
STREAM_SSE_DONE_STRING: str = "[DONE]"
### SPEND TRACKING ###
DEFAULT_REPLICATE_GPU_PRICE_PER_SECOND = 0.001400 # price per second for a100 80GB
FIREWORKS_AI_56_B_MOE = 56
FIREWORKS_AI_176_B_MOE = 176
FIREWORKS_AI_16_B = 16
FIREWORKS_AI_80_B = 80
LITELLM_CHAT_PROVIDERS = [
"openai",
"openai_like",
"xai",
"custom_openai",
"text-completion-openai",
"cohere",
"cohere_chat",
"clarifai",
"anthropic",
"anthropic_text",
"replicate",
"huggingface",
"together_ai",
"openrouter",
"vertex_ai",
"vertex_ai_beta",
"gemini",
"ai21",
"baseten",
"azure",
"azure_text",
"azure_ai",
"sagemaker",
"sagemaker_chat",
"bedrock",
"vllm",
"nlp_cloud",
"petals",
"oobabooga",
"ollama",
"ollama_chat",
"deepinfra",
"perplexity",
"mistral",
"groq",
"nvidia_nim",
"cerebras",
"ai21_chat",
"volcengine",
"codestral",
"text-completion-codestral",
"deepseek",
"sambanova",
"maritalk",
"cloudflare",
"fireworks_ai",
"friendliai",
"watsonx",
"watsonx_text",
"triton",
"predibase",
"databricks",
"empower",
"github",
"custom",
"litellm_proxy",
"hosted_vllm",
"lm_studio",
"galadriel",
]
OPENAI_CHAT_COMPLETION_PARAMS = [
"functions",
"function_call",
"temperature",
"temperature",
"top_p",
"n",
"stream",
"stream_options",
"stop",
"max_completion_tokens",
"modalities",
"prediction",
"audio",
"max_tokens",
"presence_penalty",
"frequency_penalty",
"logit_bias",
"user",
"request_timeout",
"api_base",
"api_version",
"api_key",
"deployment_id",
"organization",
"base_url",
"default_headers",
"timeout",
"response_format",
"seed",
"tools",
"tool_choice",
"max_retries",
"parallel_tool_calls",
"logprobs",
"top_logprobs",
"reasoning_effort",
"extra_headers",
"thinking",
]
openai_compatible_endpoints: List = [
"api.perplexity.ai",
"api.endpoints.anyscale.com/v1",
"api.deepinfra.com/v1/openai",
"api.mistral.ai/v1",
"codestral.mistral.ai/v1/chat/completions",
"codestral.mistral.ai/v1/fim/completions",
"api.groq.com/openai/v1",
"https://integrate.api.nvidia.com/v1",
"api.deepseek.com/v1",
"api.together.xyz/v1",
"app.empower.dev/api/v1",
"https://api.friendli.ai/serverless/v1",
"api.sambanova.ai/v1",
"api.x.ai/v1",
"api.galadriel.ai/v1",
]
openai_compatible_providers: List = [
"anyscale",
"mistral",
"groq",
"nvidia_nim",
"cerebras",
"sambanova",
"ai21_chat",
"ai21",
"volcengine",
"codestral",
"deepseek",
"deepinfra",
"perplexity",
"xinference",
"xai",
"together_ai",
"fireworks_ai",
"empower",
"friendliai",
"azure_ai",
"github",
"litellm_proxy",
"hosted_vllm",
"lm_studio",
"galadriel",
]
openai_text_completion_compatible_providers: List = (
[ # providers that support `/v1/completions`
"together_ai",
"fireworks_ai",
"hosted_vllm",
]
)
_openai_like_providers: List = [
"predibase",
"databricks",
"watsonx",
] # private helper. similar to openai but require some custom auth / endpoint handling, so can't use the openai sdk
# well supported replicate llms
replicate_models: List = [
# llama replicate supported LLMs
"replicate/llama-2-70b-chat:2796ee9483c3fd7aa2e171d38f4ca12251a30609463dcfd4cd76703f22e96cdf",
"a16z-infra/llama-2-13b-chat:2a7f981751ec7fdf87b5b91ad4db53683a98082e9ff7bfd12c8cd5ea85980a52",
"meta/codellama-13b:1c914d844307b0588599b8393480a3ba917b660c7e9dfae681542b5325f228db",
# Vicuna
"replicate/vicuna-13b:6282abe6a492de4145d7bb601023762212f9ddbbe78278bd6771c8b3b2f2a13b",
"joehoover/instructblip-vicuna13b:c4c54e3c8c97cd50c2d2fec9be3b6065563ccf7d43787fb99f84151b867178fe",
# Flan T-5
"daanelson/flan-t5-large:ce962b3f6792a57074a601d3979db5839697add2e4e02696b3ced4c022d4767f",
# Others
"replicate/dolly-v2-12b:ef0e1aefc61f8e096ebe4db6b2bacc297daf2ef6899f0f7e001ec445893500e5",
"replit/replit-code-v1-3b:b84f4c074b807211cd75e3e8b1589b6399052125b4c27106e43d47189e8415ad",
]
clarifai_models: List = [
"clarifai/meta.Llama-3.Llama-3-8B-Instruct",
"clarifai/gcp.generate.gemma-1_1-7b-it",
"clarifai/mistralai.completion.mixtral-8x22B",
"clarifai/cohere.generate.command-r-plus",
"clarifai/databricks.drbx.dbrx-instruct",
"clarifai/mistralai.completion.mistral-large",
"clarifai/mistralai.completion.mistral-medium",
"clarifai/mistralai.completion.mistral-small",
"clarifai/mistralai.completion.mixtral-8x7B-Instruct-v0_1",
"clarifai/gcp.generate.gemma-2b-it",
"clarifai/gcp.generate.gemma-7b-it",
"clarifai/deci.decilm.deciLM-7B-instruct",
"clarifai/mistralai.completion.mistral-7B-Instruct",
"clarifai/gcp.generate.gemini-pro",
"clarifai/anthropic.completion.claude-v1",
"clarifai/anthropic.completion.claude-instant-1_2",
"clarifai/anthropic.completion.claude-instant",
"clarifai/anthropic.completion.claude-v2",
"clarifai/anthropic.completion.claude-2_1",
"clarifai/meta.Llama-2.codeLlama-70b-Python",
"clarifai/meta.Llama-2.codeLlama-70b-Instruct",
"clarifai/openai.completion.gpt-3_5-turbo-instruct",
"clarifai/meta.Llama-2.llama2-7b-chat",
"clarifai/meta.Llama-2.llama2-13b-chat",
"clarifai/meta.Llama-2.llama2-70b-chat",
"clarifai/openai.chat-completion.gpt-4-turbo",
"clarifai/microsoft.text-generation.phi-2",
"clarifai/meta.Llama-2.llama2-7b-chat-vllm",
"clarifai/upstage.solar.solar-10_7b-instruct",
"clarifai/openchat.openchat.openchat-3_5-1210",
"clarifai/togethercomputer.stripedHyena.stripedHyena-Nous-7B",
"clarifai/gcp.generate.text-bison",
"clarifai/meta.Llama-2.llamaGuard-7b",
"clarifai/fblgit.una-cybertron.una-cybertron-7b-v2",
"clarifai/openai.chat-completion.GPT-4",
"clarifai/openai.chat-completion.GPT-3_5-turbo",
"clarifai/ai21.complete.Jurassic2-Grande",
"clarifai/ai21.complete.Jurassic2-Grande-Instruct",
"clarifai/ai21.complete.Jurassic2-Jumbo-Instruct",
"clarifai/ai21.complete.Jurassic2-Jumbo",
"clarifai/ai21.complete.Jurassic2-Large",
"clarifai/cohere.generate.cohere-generate-command",
"clarifai/wizardlm.generate.wizardCoder-Python-34B",
"clarifai/wizardlm.generate.wizardLM-70B",
"clarifai/tiiuae.falcon.falcon-40b-instruct",
"clarifai/togethercomputer.RedPajama.RedPajama-INCITE-7B-Chat",
"clarifai/gcp.generate.code-gecko",
"clarifai/gcp.generate.code-bison",
"clarifai/mistralai.completion.mistral-7B-OpenOrca",
"clarifai/mistralai.completion.openHermes-2-mistral-7B",
"clarifai/wizardlm.generate.wizardLM-13B",
"clarifai/huggingface-research.zephyr.zephyr-7B-alpha",
"clarifai/wizardlm.generate.wizardCoder-15B",
"clarifai/microsoft.text-generation.phi-1_5",
"clarifai/databricks.Dolly-v2.dolly-v2-12b",
"clarifai/bigcode.code.StarCoder",
"clarifai/salesforce.xgen.xgen-7b-8k-instruct",
"clarifai/mosaicml.mpt.mpt-7b-instruct",
"clarifai/anthropic.completion.claude-3-opus",
"clarifai/anthropic.completion.claude-3-sonnet",
"clarifai/gcp.generate.gemini-1_5-pro",
"clarifai/gcp.generate.imagen-2",
"clarifai/salesforce.blip.general-english-image-caption-blip-2",
]
huggingface_models: List = [
"meta-llama/Llama-2-7b-hf",
"meta-llama/Llama-2-7b-chat-hf",
"meta-llama/Llama-2-13b-hf",
"meta-llama/Llama-2-13b-chat-hf",
"meta-llama/Llama-2-70b-hf",
"meta-llama/Llama-2-70b-chat-hf",
"meta-llama/Llama-2-7b",
"meta-llama/Llama-2-7b-chat",
"meta-llama/Llama-2-13b",
"meta-llama/Llama-2-13b-chat",
"meta-llama/Llama-2-70b",
"meta-llama/Llama-2-70b-chat",
] # these have been tested on extensively. But by default all text2text-generation and text-generation models are supported by liteLLM. - https://docs.litellm.ai/docs/providers
empower_models = [
"empower/empower-functions",
"empower/empower-functions-small",
]
together_ai_models: List = [
# llama llms - chat
"togethercomputer/llama-2-70b-chat",
# llama llms - language / instruct
"togethercomputer/llama-2-70b",
"togethercomputer/LLaMA-2-7B-32K",
"togethercomputer/Llama-2-7B-32K-Instruct",
"togethercomputer/llama-2-7b",
# falcon llms
"togethercomputer/falcon-40b-instruct",
"togethercomputer/falcon-7b-instruct",
# alpaca
"togethercomputer/alpaca-7b",
# chat llms
"HuggingFaceH4/starchat-alpha",
# code llms
"togethercomputer/CodeLlama-34b",
"togethercomputer/CodeLlama-34b-Instruct",
"togethercomputer/CodeLlama-34b-Python",
"defog/sqlcoder",
"NumbersStation/nsql-llama-2-7B",
"WizardLM/WizardCoder-15B-V1.0",
"WizardLM/WizardCoder-Python-34B-V1.0",
# language llms
"NousResearch/Nous-Hermes-Llama2-13b",
"Austism/chronos-hermes-13b",
"upstage/SOLAR-0-70b-16bit",
"WizardLM/WizardLM-70B-V1.0",
] # supports all together ai models, just pass in the model id e.g. completion(model="together_computer/replit_code_3b",...)
baseten_models: List = [
"qvv0xeq",
"q841o8w",
"31dxrj3",
] # FALCON 7B # WizardLM # Mosaic ML
BEDROCK_INVOKE_PROVIDERS_LITERAL = Literal[
"cohere",
"anthropic",
"mistral",
"amazon",
"meta",
"llama",
"ai21",
"nova",
"deepseek_r1",
]
open_ai_embedding_models: List = ["text-embedding-ada-002"]
cohere_embedding_models: List = [
"embed-english-v3.0",
"embed-english-light-v3.0",
"embed-multilingual-v3.0",
"embed-english-v2.0",
"embed-english-light-v2.0",
"embed-multilingual-v2.0",
]
bedrock_embedding_models: List = [
"amazon.titan-embed-text-v1",
"cohere.embed-english-v3",
"cohere.embed-multilingual-v3",
]
known_tokenizer_config = {
"mistralai/Mistral-7B-Instruct-v0.1": {
"tokenizer": {
"chat_template": "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token + ' ' }}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}",
"bos_token": "<s>",
"eos_token": "</s>",
},
"status": "success",
},
"meta-llama/Meta-Llama-3-8B-Instruct": {
"tokenizer": {
"chat_template": "{% set loop_messages = messages %}{% for message in loop_messages %}{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' %}{% if loop.index0 == 0 %}{% set content = bos_token + content %}{% endif %}{{ content }}{% endfor %}{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}",
"bos_token": "<|begin_of_text|>",
"eos_token": "",
},
"status": "success",
},
"deepseek-r1/deepseek-r1-7b-instruct": {
"tokenizer": {
"add_bos_token": True,
"add_eos_token": False,
"bos_token": {
"__type": "AddedToken",
"content": "<begin▁of▁sentence>",
"lstrip": False,
"normalized": True,
"rstrip": False,
"single_word": False,
},
"clean_up_tokenization_spaces": False,
"eos_token": {
"__type": "AddedToken",
"content": "<end▁of▁sentence>",
"lstrip": False,
"normalized": True,
"rstrip": False,
"single_word": False,
},
"legacy": True,
"model_max_length": 16384,
"pad_token": {
"__type": "AddedToken",
"content": "<end▁of▁sentence>",
"lstrip": False,
"normalized": True,
"rstrip": False,
"single_word": False,
},
"sp_model_kwargs": {},
"unk_token": None,
"tokenizer_class": "LlamaTokenizerFast",
"chat_template": "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% set ns = namespace(is_first=false, is_tool=false, is_output_first=true, system_prompt='') %}{%- for message in messages %}{%- if message['role'] == 'system' %}{% set ns.system_prompt = message['content'] %}{%- endif %}{%- endfor %}{{bos_token}}{{ns.system_prompt}}{%- for message in messages %}{%- if message['role'] == 'user' %}{%- set ns.is_tool = false -%}{{'<User>' + message['content']}}{%- endif %}{%- if message['role'] == 'assistant' and message['content'] is none %}{%- set ns.is_tool = false -%}{%- for tool in message['tool_calls']%}{%- if not ns.is_first %}{{'<Assistant><tool▁calls▁begin><tool▁call▁begin>' + tool['type'] + '<tool▁sep>' + tool['function']['name'] + '\\n' + '```json' + '\\n' + tool['function']['arguments'] + '\\n' + '```' + '<tool▁call▁end>'}}{%- set ns.is_first = true -%}{%- else %}{{'\\n' + '<tool▁call▁begin>' + tool['type'] + '<tool▁sep>' + tool['function']['name'] + '\\n' + '```json' + '\\n' + tool['function']['arguments'] + '\\n' + '```' + '<tool▁call▁end>'}}{{'<tool▁calls▁end><end▁of▁sentence>'}}{%- endif %}{%- endfor %}{%- endif %}{%- if message['role'] == 'assistant' and message['content'] is not none %}{%- if ns.is_tool %}{{'<tool▁outputs▁end>' + message['content'] + '<end▁of▁sentence>'}}{%- set ns.is_tool = false -%}{%- else %}{% set content = message['content'] %}{% if '</think>' in content %}{% set content = content.split('</think>')[-1] %}{% endif %}{{'<Assistant>' + content + '<end▁of▁sentence>'}}{%- endif %}{%- endif %}{%- if message['role'] == 'tool' %}{%- set ns.is_tool = true -%}{%- if ns.is_output_first %}{{'<tool▁outputs▁begin><tool▁output▁begin>' + message['content'] + '<tool▁output▁end>'}}{%- set ns.is_output_first = false %}{%- else %}{{'\\n<tool▁output▁begin>' + message['content'] + '<tool▁output▁end>'}}{%- endif %}{%- endif %}{%- endfor -%}{% if ns.is_tool %}{{'<tool▁outputs▁end>'}}{% endif %}{% if add_generation_prompt and not ns.is_tool %}{{'<Assistant><think>\\n'}}{% endif %}",
},
"status": "success",
},
}
OPENAI_FINISH_REASONS = ["stop", "length", "function_call", "content_filter", "null"]
HUMANLOOP_PROMPT_CACHE_TTL_SECONDS = 60 # 1 minute
RESPONSE_FORMAT_TOOL_NAME = "json_tool_call" # default tool name used when converting response format to tool call
########################### Logging Callback Constants ###########################
AZURE_STORAGE_MSFT_VERSION = "2019-07-07"
PROMETHEUS_BUDGET_METRICS_REFRESH_INTERVAL_MINUTES = 5
MCP_TOOL_NAME_PREFIX = "mcp_tool"
########################### LiteLLM Proxy Specific Constants ###########################
########################################################################################
MAX_SPENDLOG_ROWS_TO_QUERY = (
1_000_000 # if spendLogs has more than 1M rows, do not query the DB
)
DEFAULT_SOFT_BUDGET = (
50.0 # by default all litellm proxy keys have a soft budget of 50.0
)
# makes it clear this is a rate limit error for a litellm virtual key
RATE_LIMIT_ERROR_MESSAGE_FOR_VIRTUAL_KEY = "LiteLLM Virtual Key user_api_key_hash"
# pass through route constansts
BEDROCK_AGENT_RUNTIME_PASS_THROUGH_ROUTES = [
"agents/",
"knowledgebases/",
"flows/",
"retrieveAndGenerate/",
"rerank/",
"generateQuery/",
"optimize-prompt/",
]
BATCH_STATUS_POLL_INTERVAL_SECONDS = 3600 # 1 hour
BATCH_STATUS_POLL_MAX_ATTEMPTS = 24 # for 24 hours
HEALTH_CHECK_TIMEOUT_SECONDS = 60 # 60 seconds
UI_SESSION_TOKEN_TEAM_ID = "litellm-dashboard"
LITELLM_PROXY_ADMIN_NAME = "default_user_id"
########################### DB CRON JOB NAMES ###########################
DB_SPEND_UPDATE_JOB_NAME = "db_spend_update_job"
PROMETHEUS_EMIT_BUDGET_METRICS_JOB_NAME = "prometheus_emit_budget_metrics_job"
DEFAULT_CRON_JOB_LOCK_TTL_SECONDS = 60 # 1 minute
PROXY_BUDGET_RESCHEDULER_MIN_TIME = 597
PROXY_BUDGET_RESCHEDULER_MAX_TIME = 605
PROXY_BATCH_WRITE_AT = 10 # in seconds
DEFAULT_HEALTH_CHECK_INTERVAL = 300 # 5 minutes
PROMETHEUS_FALLBACK_STATS_SEND_TIME_HOURS = 9
DEFAULT_MODEL_CREATED_AT_TIME = 1677610602 # returns on `/models` endpoint
DEFAULT_SLACK_ALERTING_THRESHOLD = 300
MAX_TEAM_LIST_LIMIT = 20
DEFAULT_PROMPT_INJECTION_SIMILARITY_THRESHOLD = 0.7
LENGTH_OF_LITELLM_GENERATED_KEY = 16
SECRET_MANAGER_REFRESH_INTERVAL = 86400

View File

@@ -0,0 +1,5 @@
{
"gpt-3.5-turbo-0613": 0.00015000000000000001,
"claude-2": 0.00016454,
"gpt-4-0613": 0.015408
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,808 @@
# +-----------------------------------------------+
# | |
# | Give Feedback / Get Help |
# | https://github.com/BerriAI/litellm/issues/new |
# | |
# +-----------------------------------------------+
#
# Thank you users! We ❤️ you! - Krrish & Ishaan
## LiteLLM versions of the OpenAI Exception Types
from typing import Optional
import httpx
import openai
from litellm.types.utils import LiteLLMCommonStrings
class AuthenticationError(openai.AuthenticationError): # type: ignore
def __init__(
self,
message,
llm_provider,
model,
response: Optional[httpx.Response] = None,
litellm_debug_info: Optional[str] = None,
max_retries: Optional[int] = None,
num_retries: Optional[int] = None,
):
self.status_code = 401
self.message = "litellm.AuthenticationError: {}".format(message)
self.llm_provider = llm_provider
self.model = model
self.litellm_debug_info = litellm_debug_info
self.max_retries = max_retries
self.num_retries = num_retries
self.response = response or httpx.Response(
status_code=self.status_code,
request=httpx.Request(
method="GET", url="https://litellm.ai"
), # mock request object
)
super().__init__(
self.message, response=self.response, body=None
) # Call the base class constructor with the parameters it needs
def __str__(self):
_message = self.message
if self.num_retries:
_message += f" LiteLLM Retried: {self.num_retries} times"
if self.max_retries:
_message += f", LiteLLM Max Retries: {self.max_retries}"
return _message
def __repr__(self):
_message = self.message
if self.num_retries:
_message += f" LiteLLM Retried: {self.num_retries} times"
if self.max_retries:
_message += f", LiteLLM Max Retries: {self.max_retries}"
return _message
# raise when invalid models passed, example gpt-8
class NotFoundError(openai.NotFoundError): # type: ignore
def __init__(
self,
message,
model,
llm_provider,
response: Optional[httpx.Response] = None,
litellm_debug_info: Optional[str] = None,
max_retries: Optional[int] = None,
num_retries: Optional[int] = None,
):
self.status_code = 404
self.message = "litellm.NotFoundError: {}".format(message)
self.model = model
self.llm_provider = llm_provider
self.litellm_debug_info = litellm_debug_info
self.max_retries = max_retries
self.num_retries = num_retries
self.response = response or httpx.Response(
status_code=self.status_code,
request=httpx.Request(
method="GET", url="https://litellm.ai"
), # mock request object
)
super().__init__(
self.message, response=self.response, body=None
) # Call the base class constructor with the parameters it needs
def __str__(self):
_message = self.message
if self.num_retries:
_message += f" LiteLLM Retried: {self.num_retries} times"
if self.max_retries:
_message += f", LiteLLM Max Retries: {self.max_retries}"
return _message
def __repr__(self):
_message = self.message
if self.num_retries:
_message += f" LiteLLM Retried: {self.num_retries} times"
if self.max_retries:
_message += f", LiteLLM Max Retries: {self.max_retries}"
return _message
class BadRequestError(openai.BadRequestError): # type: ignore
def __init__(
self,
message,
model,
llm_provider,
response: Optional[httpx.Response] = None,
litellm_debug_info: Optional[str] = None,
max_retries: Optional[int] = None,
num_retries: Optional[int] = None,
body: Optional[dict] = None,
):
self.status_code = 400
self.message = "litellm.BadRequestError: {}".format(message)
self.model = model
self.llm_provider = llm_provider
self.litellm_debug_info = litellm_debug_info
response = httpx.Response(
status_code=self.status_code,
request=httpx.Request(
method="GET", url="https://litellm.ai"
), # mock request object
)
self.max_retries = max_retries
self.num_retries = num_retries
super().__init__(
self.message, response=response, body=body
) # Call the base class constructor with the parameters it needs
def __str__(self):
_message = self.message
if self.num_retries:
_message += f" LiteLLM Retried: {self.num_retries} times"
if self.max_retries:
_message += f", LiteLLM Max Retries: {self.max_retries}"
return _message
def __repr__(self):
_message = self.message
if self.num_retries:
_message += f" LiteLLM Retried: {self.num_retries} times"
if self.max_retries:
_message += f", LiteLLM Max Retries: {self.max_retries}"
return _message
class UnprocessableEntityError(openai.UnprocessableEntityError): # type: ignore
def __init__(
self,
message,
model,
llm_provider,
response: httpx.Response,
litellm_debug_info: Optional[str] = None,
max_retries: Optional[int] = None,
num_retries: Optional[int] = None,
):
self.status_code = 422
self.message = "litellm.UnprocessableEntityError: {}".format(message)
self.model = model
self.llm_provider = llm_provider
self.litellm_debug_info = litellm_debug_info
self.max_retries = max_retries
self.num_retries = num_retries
super().__init__(
self.message, response=response, body=None
) # Call the base class constructor with the parameters it needs
def __str__(self):
_message = self.message
if self.num_retries:
_message += f" LiteLLM Retried: {self.num_retries} times"
if self.max_retries:
_message += f", LiteLLM Max Retries: {self.max_retries}"
return _message
def __repr__(self):
_message = self.message
if self.num_retries:
_message += f" LiteLLM Retried: {self.num_retries} times"
if self.max_retries:
_message += f", LiteLLM Max Retries: {self.max_retries}"
return _message
class Timeout(openai.APITimeoutError): # type: ignore
def __init__(
self,
message,
model,
llm_provider,
litellm_debug_info: Optional[str] = None,
max_retries: Optional[int] = None,
num_retries: Optional[int] = None,
headers: Optional[dict] = None,
):
request = httpx.Request(
method="POST",
url="https://api.openai.com/v1",
)
super().__init__(
request=request
) # Call the base class constructor with the parameters it needs
self.status_code = 408
self.message = "litellm.Timeout: {}".format(message)
self.model = model
self.llm_provider = llm_provider
self.litellm_debug_info = litellm_debug_info
self.max_retries = max_retries
self.num_retries = num_retries
self.headers = headers
# custom function to convert to str
def __str__(self):
_message = self.message
if self.num_retries:
_message += f" LiteLLM Retried: {self.num_retries} times"
if self.max_retries:
_message += f", LiteLLM Max Retries: {self.max_retries}"
return _message
def __repr__(self):
_message = self.message
if self.num_retries:
_message += f" LiteLLM Retried: {self.num_retries} times"
if self.max_retries:
_message += f", LiteLLM Max Retries: {self.max_retries}"
return _message
class PermissionDeniedError(openai.PermissionDeniedError): # type:ignore
def __init__(
self,
message,
llm_provider,
model,
response: httpx.Response,
litellm_debug_info: Optional[str] = None,
max_retries: Optional[int] = None,
num_retries: Optional[int] = None,
):
self.status_code = 403
self.message = "litellm.PermissionDeniedError: {}".format(message)
self.llm_provider = llm_provider
self.model = model
self.litellm_debug_info = litellm_debug_info
self.max_retries = max_retries
self.num_retries = num_retries
super().__init__(
self.message, response=response, body=None
) # Call the base class constructor with the parameters it needs
def __str__(self):
_message = self.message
if self.num_retries:
_message += f" LiteLLM Retried: {self.num_retries} times"
if self.max_retries:
_message += f", LiteLLM Max Retries: {self.max_retries}"
return _message
def __repr__(self):
_message = self.message
if self.num_retries:
_message += f" LiteLLM Retried: {self.num_retries} times"
if self.max_retries:
_message += f", LiteLLM Max Retries: {self.max_retries}"
return _message
class RateLimitError(openai.RateLimitError): # type: ignore
def __init__(
self,
message,
llm_provider,
model,
response: Optional[httpx.Response] = None,
litellm_debug_info: Optional[str] = None,
max_retries: Optional[int] = None,
num_retries: Optional[int] = None,
):
self.status_code = 429
self.message = "litellm.RateLimitError: {}".format(message)
self.llm_provider = llm_provider
self.model = model
self.litellm_debug_info = litellm_debug_info
self.max_retries = max_retries
self.num_retries = num_retries
_response_headers = (
getattr(response, "headers", None) if response is not None else None
)
self.response = httpx.Response(
status_code=429,
headers=_response_headers,
request=httpx.Request(
method="POST",
url=" https://cloud.google.com/vertex-ai/",
),
)
super().__init__(
self.message, response=self.response, body=None
) # Call the base class constructor with the parameters it needs
self.code = "429"
self.type = "throttling_error"
def __str__(self):
_message = self.message
if self.num_retries:
_message += f" LiteLLM Retried: {self.num_retries} times"
if self.max_retries:
_message += f", LiteLLM Max Retries: {self.max_retries}"
return _message
def __repr__(self):
_message = self.message
if self.num_retries:
_message += f" LiteLLM Retried: {self.num_retries} times"
if self.max_retries:
_message += f", LiteLLM Max Retries: {self.max_retries}"
return _message
# sub class of rate limit error - meant to give more granularity for error handling context window exceeded errors
class ContextWindowExceededError(BadRequestError): # type: ignore
def __init__(
self,
message,
model,
llm_provider,
response: Optional[httpx.Response] = None,
litellm_debug_info: Optional[str] = None,
):
self.status_code = 400
self.model = model
self.llm_provider = llm_provider
self.litellm_debug_info = litellm_debug_info
request = httpx.Request(method="POST", url="https://api.openai.com/v1")
self.response = httpx.Response(status_code=400, request=request)
super().__init__(
message=message,
model=self.model, # type: ignore
llm_provider=self.llm_provider, # type: ignore
response=self.response,
litellm_debug_info=self.litellm_debug_info,
) # Call the base class constructor with the parameters it needs
# set after, to make it clear the raised error is a context window exceeded error
self.message = "litellm.ContextWindowExceededError: {}".format(self.message)
def __str__(self):
_message = self.message
if self.num_retries:
_message += f" LiteLLM Retried: {self.num_retries} times"
if self.max_retries:
_message += f", LiteLLM Max Retries: {self.max_retries}"
return _message
def __repr__(self):
_message = self.message
if self.num_retries:
_message += f" LiteLLM Retried: {self.num_retries} times"
if self.max_retries:
_message += f", LiteLLM Max Retries: {self.max_retries}"
return _message
# sub class of bad request error - meant to help us catch guardrails-related errors on proxy.
class RejectedRequestError(BadRequestError): # type: ignore
def __init__(
self,
message,
model,
llm_provider,
request_data: dict,
litellm_debug_info: Optional[str] = None,
):
self.status_code = 400
self.message = "litellm.RejectedRequestError: {}".format(message)
self.model = model
self.llm_provider = llm_provider
self.litellm_debug_info = litellm_debug_info
self.request_data = request_data
request = httpx.Request(method="POST", url="https://api.openai.com/v1")
response = httpx.Response(status_code=400, request=request)
super().__init__(
message=self.message,
model=self.model, # type: ignore
llm_provider=self.llm_provider, # type: ignore
response=response,
litellm_debug_info=self.litellm_debug_info,
) # Call the base class constructor with the parameters it needs
def __str__(self):
_message = self.message
if self.num_retries:
_message += f" LiteLLM Retried: {self.num_retries} times"
if self.max_retries:
_message += f", LiteLLM Max Retries: {self.max_retries}"
return _message
def __repr__(self):
_message = self.message
if self.num_retries:
_message += f" LiteLLM Retried: {self.num_retries} times"
if self.max_retries:
_message += f", LiteLLM Max Retries: {self.max_retries}"
return _message
class ContentPolicyViolationError(BadRequestError): # type: ignore
# Error code: 400 - {'error': {'code': 'content_policy_violation', 'message': 'Your request was rejected as a result of our safety system. Image descriptions generated from your prompt may contain text that is not allowed by our safety system. If you believe this was done in error, your request may succeed if retried, or by adjusting your prompt.', 'param': None, 'type': 'invalid_request_error'}}
def __init__(
self,
message,
model,
llm_provider,
response: Optional[httpx.Response] = None,
litellm_debug_info: Optional[str] = None,
):
self.status_code = 400
self.message = "litellm.ContentPolicyViolationError: {}".format(message)
self.model = model
self.llm_provider = llm_provider
self.litellm_debug_info = litellm_debug_info
request = httpx.Request(method="POST", url="https://api.openai.com/v1")
self.response = httpx.Response(status_code=400, request=request)
super().__init__(
message=self.message,
model=self.model, # type: ignore
llm_provider=self.llm_provider, # type: ignore
response=self.response,
litellm_debug_info=self.litellm_debug_info,
) # Call the base class constructor with the parameters it needs
def __str__(self):
_message = self.message
if self.num_retries:
_message += f" LiteLLM Retried: {self.num_retries} times"
if self.max_retries:
_message += f", LiteLLM Max Retries: {self.max_retries}"
return _message
def __repr__(self):
_message = self.message
if self.num_retries:
_message += f" LiteLLM Retried: {self.num_retries} times"
if self.max_retries:
_message += f", LiteLLM Max Retries: {self.max_retries}"
return _message
class ServiceUnavailableError(openai.APIStatusError): # type: ignore
def __init__(
self,
message,
llm_provider,
model,
response: Optional[httpx.Response] = None,
litellm_debug_info: Optional[str] = None,
max_retries: Optional[int] = None,
num_retries: Optional[int] = None,
):
self.status_code = 503
self.message = "litellm.ServiceUnavailableError: {}".format(message)
self.llm_provider = llm_provider
self.model = model
self.litellm_debug_info = litellm_debug_info
self.max_retries = max_retries
self.num_retries = num_retries
self.response = httpx.Response(
status_code=self.status_code,
request=httpx.Request(
method="POST",
url=" https://cloud.google.com/vertex-ai/",
),
)
super().__init__(
self.message, response=self.response, body=None
) # Call the base class constructor with the parameters it needs
def __str__(self):
_message = self.message
if self.num_retries:
_message += f" LiteLLM Retried: {self.num_retries} times"
if self.max_retries:
_message += f", LiteLLM Max Retries: {self.max_retries}"
return _message
def __repr__(self):
_message = self.message
if self.num_retries:
_message += f" LiteLLM Retried: {self.num_retries} times"
if self.max_retries:
_message += f", LiteLLM Max Retries: {self.max_retries}"
return _message
class InternalServerError(openai.InternalServerError): # type: ignore
def __init__(
self,
message,
llm_provider,
model,
response: Optional[httpx.Response] = None,
litellm_debug_info: Optional[str] = None,
max_retries: Optional[int] = None,
num_retries: Optional[int] = None,
):
self.status_code = 500
self.message = "litellm.InternalServerError: {}".format(message)
self.llm_provider = llm_provider
self.model = model
self.litellm_debug_info = litellm_debug_info
self.max_retries = max_retries
self.num_retries = num_retries
self.response = httpx.Response(
status_code=self.status_code,
request=httpx.Request(
method="POST",
url=" https://cloud.google.com/vertex-ai/",
),
)
super().__init__(
self.message, response=self.response, body=None
) # Call the base class constructor with the parameters it needs
def __str__(self):
_message = self.message
if self.num_retries:
_message += f" LiteLLM Retried: {self.num_retries} times"
if self.max_retries:
_message += f", LiteLLM Max Retries: {self.max_retries}"
return _message
def __repr__(self):
_message = self.message
if self.num_retries:
_message += f" LiteLLM Retried: {self.num_retries} times"
if self.max_retries:
_message += f", LiteLLM Max Retries: {self.max_retries}"
return _message
# raise this when the API returns an invalid response object - https://github.com/openai/openai-python/blob/1be14ee34a0f8e42d3f9aa5451aa4cb161f1781f/openai/api_requestor.py#L401
class APIError(openai.APIError): # type: ignore
def __init__(
self,
status_code: int,
message,
llm_provider,
model,
request: Optional[httpx.Request] = None,
litellm_debug_info: Optional[str] = None,
max_retries: Optional[int] = None,
num_retries: Optional[int] = None,
):
self.status_code = status_code
self.message = "litellm.APIError: {}".format(message)
self.llm_provider = llm_provider
self.model = model
self.litellm_debug_info = litellm_debug_info
self.max_retries = max_retries
self.num_retries = num_retries
if request is None:
request = httpx.Request(method="POST", url="https://api.openai.com/v1")
super().__init__(self.message, request=request, body=None) # type: ignore
def __str__(self):
_message = self.message
if self.num_retries:
_message += f" LiteLLM Retried: {self.num_retries} times"
if self.max_retries:
_message += f", LiteLLM Max Retries: {self.max_retries}"
return _message
def __repr__(self):
_message = self.message
if self.num_retries:
_message += f" LiteLLM Retried: {self.num_retries} times"
if self.max_retries:
_message += f", LiteLLM Max Retries: {self.max_retries}"
return _message
# raised if an invalid request (not get, delete, put, post) is made
class APIConnectionError(openai.APIConnectionError): # type: ignore
def __init__(
self,
message,
llm_provider,
model,
request: Optional[httpx.Request] = None,
litellm_debug_info: Optional[str] = None,
max_retries: Optional[int] = None,
num_retries: Optional[int] = None,
):
self.message = "litellm.APIConnectionError: {}".format(message)
self.llm_provider = llm_provider
self.model = model
self.status_code = 500
self.litellm_debug_info = litellm_debug_info
self.request = httpx.Request(method="POST", url="https://api.openai.com/v1")
self.max_retries = max_retries
self.num_retries = num_retries
super().__init__(message=self.message, request=self.request)
def __str__(self):
_message = self.message
if self.num_retries:
_message += f" LiteLLM Retried: {self.num_retries} times"
if self.max_retries:
_message += f", LiteLLM Max Retries: {self.max_retries}"
return _message
def __repr__(self):
_message = self.message
if self.num_retries:
_message += f" LiteLLM Retried: {self.num_retries} times"
if self.max_retries:
_message += f", LiteLLM Max Retries: {self.max_retries}"
return _message
# raised if an invalid request (not get, delete, put, post) is made
class APIResponseValidationError(openai.APIResponseValidationError): # type: ignore
def __init__(
self,
message,
llm_provider,
model,
litellm_debug_info: Optional[str] = None,
max_retries: Optional[int] = None,
num_retries: Optional[int] = None,
):
self.message = "litellm.APIResponseValidationError: {}".format(message)
self.llm_provider = llm_provider
self.model = model
request = httpx.Request(method="POST", url="https://api.openai.com/v1")
response = httpx.Response(status_code=500, request=request)
self.litellm_debug_info = litellm_debug_info
self.max_retries = max_retries
self.num_retries = num_retries
super().__init__(response=response, body=None, message=message)
def __str__(self):
_message = self.message
if self.num_retries:
_message += f" LiteLLM Retried: {self.num_retries} times"
if self.max_retries:
_message += f", LiteLLM Max Retries: {self.max_retries}"
return _message
def __repr__(self):
_message = self.message
if self.num_retries:
_message += f" LiteLLM Retried: {self.num_retries} times"
if self.max_retries:
_message += f", LiteLLM Max Retries: {self.max_retries}"
return _message
class JSONSchemaValidationError(APIResponseValidationError):
def __init__(
self, model: str, llm_provider: str, raw_response: str, schema: str
) -> None:
self.raw_response = raw_response
self.schema = schema
self.model = model
message = "litellm.JSONSchemaValidationError: model={}, returned an invalid response={}, for schema={}.\nAccess raw response with `e.raw_response`".format(
model, raw_response, schema
)
self.message = message
super().__init__(model=model, message=message, llm_provider=llm_provider)
class OpenAIError(openai.OpenAIError): # type: ignore
def __init__(self, original_exception=None):
super().__init__()
self.llm_provider = "openai"
class UnsupportedParamsError(BadRequestError):
def __init__(
self,
message,
llm_provider: Optional[str] = None,
model: Optional[str] = None,
status_code: int = 400,
response: Optional[httpx.Response] = None,
litellm_debug_info: Optional[str] = None,
max_retries: Optional[int] = None,
num_retries: Optional[int] = None,
):
self.status_code = 400
self.message = "litellm.UnsupportedParamsError: {}".format(message)
self.model = model
self.llm_provider = llm_provider
self.litellm_debug_info = litellm_debug_info
response = response or httpx.Response(
status_code=self.status_code,
request=httpx.Request(
method="GET", url="https://litellm.ai"
), # mock request object
)
self.max_retries = max_retries
self.num_retries = num_retries
LITELLM_EXCEPTION_TYPES = [
AuthenticationError,
NotFoundError,
BadRequestError,
UnprocessableEntityError,
UnsupportedParamsError,
Timeout,
PermissionDeniedError,
RateLimitError,
ContextWindowExceededError,
RejectedRequestError,
ContentPolicyViolationError,
InternalServerError,
ServiceUnavailableError,
APIError,
APIConnectionError,
APIResponseValidationError,
OpenAIError,
InternalServerError,
JSONSchemaValidationError,
]
class BudgetExceededError(Exception):
def __init__(
self, current_cost: float, max_budget: float, message: Optional[str] = None
):
self.current_cost = current_cost
self.max_budget = max_budget
message = (
message
or f"Budget has been exceeded! Current cost: {current_cost}, Max budget: {max_budget}"
)
self.message = message
super().__init__(message)
## DEPRECATED ##
class InvalidRequestError(openai.BadRequestError): # type: ignore
def __init__(self, message, model, llm_provider):
self.status_code = 400
self.message = message
self.model = model
self.llm_provider = llm_provider
self.response = httpx.Response(
status_code=400,
request=httpx.Request(
method="GET", url="https://litellm.ai"
), # mock request object
)
super().__init__(
message=self.message, response=self.response, body=None
) # Call the base class constructor with the parameters it needs
class MockException(openai.APIError):
# used for testing
def __init__(
self,
status_code: int,
message,
llm_provider,
model,
request: Optional[httpx.Request] = None,
litellm_debug_info: Optional[str] = None,
max_retries: Optional[int] = None,
num_retries: Optional[int] = None,
):
self.status_code = status_code
self.message = "litellm.MockException: {}".format(message)
self.llm_provider = llm_provider
self.model = model
self.litellm_debug_info = litellm_debug_info
self.max_retries = max_retries
self.num_retries = num_retries
if request is None:
request = httpx.Request(method="POST", url="https://api.openai.com/v1")
super().__init__(self.message, request=request, body=None) # type: ignore
class LiteLLMUnknownProvider(BadRequestError):
def __init__(self, model: str, custom_llm_provider: Optional[str] = None):
self.message = LiteLLMCommonStrings.llm_provider_not_provided.value.format(
model=model, custom_llm_provider=custom_llm_provider
)
super().__init__(
self.message, model=model, llm_provider=custom_llm_provider, response=None
)
def __str__(self):
return self.message

View File

@@ -0,0 +1,6 @@
# LiteLLM MCP Client
LiteLLM MCP Client is a client that allows you to use MCP tools with LiteLLM.

View File

@@ -0,0 +1,3 @@
from .tools import call_openai_tool, load_mcp_tools
__all__ = ["load_mcp_tools", "call_openai_tool"]

View File

@@ -0,0 +1,111 @@
import json
from typing import Dict, List, Literal, Union
from mcp import ClientSession
from mcp.types import CallToolRequestParams as MCPCallToolRequestParams
from mcp.types import CallToolResult as MCPCallToolResult
from mcp.types import Tool as MCPTool
from openai.types.chat import ChatCompletionToolParam
from openai.types.shared_params.function_definition import FunctionDefinition
from litellm.types.utils import ChatCompletionMessageToolCall
########################################################
# List MCP Tool functions
########################################################
def transform_mcp_tool_to_openai_tool(mcp_tool: MCPTool) -> ChatCompletionToolParam:
"""Convert an MCP tool to an OpenAI tool."""
return ChatCompletionToolParam(
type="function",
function=FunctionDefinition(
name=mcp_tool.name,
description=mcp_tool.description or "",
parameters=mcp_tool.inputSchema,
strict=False,
),
)
async def load_mcp_tools(
session: ClientSession, format: Literal["mcp", "openai"] = "mcp"
) -> Union[List[MCPTool], List[ChatCompletionToolParam]]:
"""
Load all available MCP tools
Args:
session: The MCP session to use
format: The format to convert the tools to
By default, the tools are returned in MCP format.
If format is set to "openai", the tools are converted to OpenAI API compatible tools.
"""
tools = await session.list_tools()
if format == "openai":
return [
transform_mcp_tool_to_openai_tool(mcp_tool=tool) for tool in tools.tools
]
return tools.tools
########################################################
# Call MCP Tool functions
########################################################
async def call_mcp_tool(
session: ClientSession,
call_tool_request_params: MCPCallToolRequestParams,
) -> MCPCallToolResult:
"""Call an MCP tool."""
tool_result = await session.call_tool(
name=call_tool_request_params.name,
arguments=call_tool_request_params.arguments,
)
return tool_result
def _get_function_arguments(function: FunctionDefinition) -> dict:
"""Helper to safely get and parse function arguments."""
arguments = function.get("arguments", {})
if isinstance(arguments, str):
try:
arguments = json.loads(arguments)
except json.JSONDecodeError:
arguments = {}
return arguments if isinstance(arguments, dict) else {}
def transform_openai_tool_call_request_to_mcp_tool_call_request(
openai_tool: Union[ChatCompletionMessageToolCall, Dict],
) -> MCPCallToolRequestParams:
"""Convert an OpenAI ChatCompletionMessageToolCall to an MCP CallToolRequestParams."""
function = openai_tool["function"]
return MCPCallToolRequestParams(
name=function["name"],
arguments=_get_function_arguments(function),
)
async def call_openai_tool(
session: ClientSession,
openai_tool: ChatCompletionMessageToolCall,
) -> MCPCallToolResult:
"""
Call an OpenAI tool using MCP client.
Args:
session: The MCP session to use
openai_tool: The OpenAI tool to call. You can get this from the `choices[0].message.tool_calls[0]` of the response from the OpenAI API.
Returns:
The result of the MCP tool call.
"""
mcp_tool_call_request_params = (
transform_openai_tool_call_request_to_mcp_tool_call_request(
openai_tool=openai_tool,
)
)
return await call_mcp_tool(
session=session,
call_tool_request_params=mcp_tool_call_request_params,
)

View File

@@ -0,0 +1,886 @@
"""
Main File for Files API implementation
https://platform.openai.com/docs/api-reference/files
"""
import asyncio
import contextvars
import os
from functools import partial
from typing import Any, Coroutine, Dict, Literal, Optional, Union, cast
import httpx
import litellm
from litellm import get_secret_str
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
from litellm.llms.azure.files.handler import AzureOpenAIFilesAPI
from litellm.llms.custom_httpx.llm_http_handler import BaseLLMHTTPHandler
from litellm.llms.openai.openai import FileDeleted, FileObject, OpenAIFilesAPI
from litellm.llms.vertex_ai.files.handler import VertexAIFilesHandler
from litellm.types.llms.openai import (
CreateFileRequest,
FileContentRequest,
FileTypes,
HttpxBinaryResponseContent,
OpenAIFileObject,
)
from litellm.types.router import *
from litellm.types.utils import LlmProviders
from litellm.utils import (
ProviderConfigManager,
client,
get_litellm_params,
supports_httpx_timeout,
)
base_llm_http_handler = BaseLLMHTTPHandler()
####### ENVIRONMENT VARIABLES ###################
openai_files_instance = OpenAIFilesAPI()
azure_files_instance = AzureOpenAIFilesAPI()
vertex_ai_files_instance = VertexAIFilesHandler()
#################################################
@client
async def acreate_file(
file: FileTypes,
purpose: Literal["assistants", "batch", "fine-tune"],
custom_llm_provider: Literal["openai", "azure", "vertex_ai"] = "openai",
extra_headers: Optional[Dict[str, str]] = None,
extra_body: Optional[Dict[str, str]] = None,
**kwargs,
) -> OpenAIFileObject:
"""
Async: Files are used to upload documents that can be used with features like Assistants, Fine-tuning, and Batch API.
LiteLLM Equivalent of POST: POST https://api.openai.com/v1/files
"""
try:
loop = asyncio.get_event_loop()
kwargs["acreate_file"] = True
call_args = {
"file": file,
"purpose": purpose,
"custom_llm_provider": custom_llm_provider,
"extra_headers": extra_headers,
"extra_body": extra_body,
**kwargs,
}
# Use a partial function to pass your keyword arguments
func = partial(create_file, **call_args)
# Add the context to the function
ctx = contextvars.copy_context()
func_with_context = partial(ctx.run, func)
init_response = await loop.run_in_executor(None, func_with_context)
if asyncio.iscoroutine(init_response):
response = await init_response
else:
response = init_response # type: ignore
return response
except Exception as e:
raise e
@client
def create_file(
file: FileTypes,
purpose: Literal["assistants", "batch", "fine-tune"],
custom_llm_provider: Optional[Literal["openai", "azure", "vertex_ai"]] = None,
extra_headers: Optional[Dict[str, str]] = None,
extra_body: Optional[Dict[str, str]] = None,
**kwargs,
) -> Union[OpenAIFileObject, Coroutine[Any, Any, OpenAIFileObject]]:
"""
Files are used to upload documents that can be used with features like Assistants, Fine-tuning, and Batch API.
LiteLLM Equivalent of POST: POST https://api.openai.com/v1/files
Specify either provider_list or custom_llm_provider.
"""
try:
_is_async = kwargs.pop("acreate_file", False) is True
optional_params = GenericLiteLLMParams(**kwargs)
litellm_params_dict = get_litellm_params(**kwargs)
logging_obj = cast(
Optional[LiteLLMLoggingObj], kwargs.get("litellm_logging_obj")
)
if logging_obj is None:
raise ValueError("logging_obj is required")
client = kwargs.get("client")
### TIMEOUT LOGIC ###
timeout = optional_params.timeout or kwargs.get("request_timeout", 600) or 600
# set timeout for 10 minutes by default
if (
timeout is not None
and isinstance(timeout, httpx.Timeout)
and supports_httpx_timeout(cast(str, custom_llm_provider)) is False
):
read_timeout = timeout.read or 600
timeout = read_timeout # default 10 min timeout
elif timeout is not None and not isinstance(timeout, httpx.Timeout):
timeout = float(timeout) # type: ignore
elif timeout is None:
timeout = 600.0
_create_file_request = CreateFileRequest(
file=file,
purpose=purpose,
extra_headers=extra_headers,
extra_body=extra_body,
)
provider_config = ProviderConfigManager.get_provider_files_config(
model="",
provider=LlmProviders(custom_llm_provider),
)
if provider_config is not None:
response = base_llm_http_handler.create_file(
provider_config=provider_config,
litellm_params=litellm_params_dict,
create_file_data=_create_file_request,
headers=extra_headers or {},
api_base=optional_params.api_base,
api_key=optional_params.api_key,
logging_obj=logging_obj,
_is_async=_is_async,
client=client
if client is not None
and isinstance(client, (HTTPHandler, AsyncHTTPHandler))
else None,
timeout=timeout,
)
elif custom_llm_provider == "openai":
# for deepinfra/perplexity/anyscale/groq we check in get_llm_provider and pass in the api base from there
api_base = (
optional_params.api_base
or litellm.api_base
or os.getenv("OPENAI_API_BASE")
or "https://api.openai.com/v1"
)
organization = (
optional_params.organization
or litellm.organization
or os.getenv("OPENAI_ORGANIZATION", None)
or None # default - https://github.com/openai/openai-python/blob/284c1799070c723c6a553337134148a7ab088dd8/openai/util.py#L105
)
# set API KEY
api_key = (
optional_params.api_key
or litellm.api_key # for deepinfra/perplexity/anyscale we check in get_llm_provider and pass in the api key from there
or litellm.openai_key
or os.getenv("OPENAI_API_KEY")
)
response = openai_files_instance.create_file(
_is_async=_is_async,
api_base=api_base,
api_key=api_key,
timeout=timeout,
max_retries=optional_params.max_retries,
organization=organization,
create_file_data=_create_file_request,
)
elif custom_llm_provider == "azure":
api_base = optional_params.api_base or litellm.api_base or get_secret_str("AZURE_API_BASE") # type: ignore
api_version = (
optional_params.api_version
or litellm.api_version
or get_secret_str("AZURE_API_VERSION")
) # type: ignore
api_key = (
optional_params.api_key
or litellm.api_key
or litellm.azure_key
or get_secret_str("AZURE_OPENAI_API_KEY")
or get_secret_str("AZURE_API_KEY")
) # type: ignore
extra_body = optional_params.get("extra_body", {})
if extra_body is not None:
extra_body.pop("azure_ad_token", None)
else:
get_secret_str("AZURE_AD_TOKEN") # type: ignore
response = azure_files_instance.create_file(
_is_async=_is_async,
api_base=api_base,
api_key=api_key,
api_version=api_version,
timeout=timeout,
max_retries=optional_params.max_retries,
create_file_data=_create_file_request,
litellm_params=litellm_params_dict,
)
elif custom_llm_provider == "vertex_ai":
api_base = optional_params.api_base or ""
vertex_ai_project = (
optional_params.vertex_project
or litellm.vertex_project
or get_secret_str("VERTEXAI_PROJECT")
)
vertex_ai_location = (
optional_params.vertex_location
or litellm.vertex_location
or get_secret_str("VERTEXAI_LOCATION")
)
vertex_credentials = optional_params.vertex_credentials or get_secret_str(
"VERTEXAI_CREDENTIALS"
)
response = vertex_ai_files_instance.create_file(
_is_async=_is_async,
api_base=api_base,
vertex_project=vertex_ai_project,
vertex_location=vertex_ai_location,
vertex_credentials=vertex_credentials,
timeout=timeout,
max_retries=optional_params.max_retries,
create_file_data=_create_file_request,
)
else:
raise litellm.exceptions.BadRequestError(
message="LiteLLM doesn't support {} for 'create_file'. Only ['openai', 'azure', 'vertex_ai'] are supported.".format(
custom_llm_provider
),
model="n/a",
llm_provider=custom_llm_provider,
response=httpx.Response(
status_code=400,
content="Unsupported provider",
request=httpx.Request(method="create_file", url="https://github.com/BerriAI/litellm"), # type: ignore
),
)
return response
except Exception as e:
raise e
async def afile_retrieve(
file_id: str,
custom_llm_provider: Literal["openai", "azure"] = "openai",
extra_headers: Optional[Dict[str, str]] = None,
extra_body: Optional[Dict[str, str]] = None,
**kwargs,
):
"""
Async: Get file contents
LiteLLM Equivalent of GET https://api.openai.com/v1/files
"""
try:
loop = asyncio.get_event_loop()
kwargs["is_async"] = True
# Use a partial function to pass your keyword arguments
func = partial(
file_retrieve,
file_id,
custom_llm_provider,
extra_headers,
extra_body,
**kwargs,
)
# Add the context to the function
ctx = contextvars.copy_context()
func_with_context = partial(ctx.run, func)
init_response = await loop.run_in_executor(None, func_with_context)
if asyncio.iscoroutine(init_response):
response = await init_response
else:
response = init_response
return response
except Exception as e:
raise e
def file_retrieve(
file_id: str,
custom_llm_provider: Literal["openai", "azure"] = "openai",
extra_headers: Optional[Dict[str, str]] = None,
extra_body: Optional[Dict[str, str]] = None,
**kwargs,
) -> FileObject:
"""
Returns the contents of the specified file.
LiteLLM Equivalent of POST: POST https://api.openai.com/v1/files
"""
try:
optional_params = GenericLiteLLMParams(**kwargs)
### TIMEOUT LOGIC ###
timeout = optional_params.timeout or kwargs.get("request_timeout", 600) or 600
# set timeout for 10 minutes by default
if (
timeout is not None
and isinstance(timeout, httpx.Timeout)
and supports_httpx_timeout(custom_llm_provider) is False
):
read_timeout = timeout.read or 600
timeout = read_timeout # default 10 min timeout
elif timeout is not None and not isinstance(timeout, httpx.Timeout):
timeout = float(timeout) # type: ignore
elif timeout is None:
timeout = 600.0
_is_async = kwargs.pop("is_async", False) is True
if custom_llm_provider == "openai":
# for deepinfra/perplexity/anyscale/groq we check in get_llm_provider and pass in the api base from there
api_base = (
optional_params.api_base
or litellm.api_base
or os.getenv("OPENAI_API_BASE")
or "https://api.openai.com/v1"
)
organization = (
optional_params.organization
or litellm.organization
or os.getenv("OPENAI_ORGANIZATION", None)
or None # default - https://github.com/openai/openai-python/blob/284c1799070c723c6a553337134148a7ab088dd8/openai/util.py#L105
)
# set API KEY
api_key = (
optional_params.api_key
or litellm.api_key # for deepinfra/perplexity/anyscale we check in get_llm_provider and pass in the api key from there
or litellm.openai_key
or os.getenv("OPENAI_API_KEY")
)
response = openai_files_instance.retrieve_file(
file_id=file_id,
_is_async=_is_async,
api_base=api_base,
api_key=api_key,
timeout=timeout,
max_retries=optional_params.max_retries,
organization=organization,
)
elif custom_llm_provider == "azure":
api_base = optional_params.api_base or litellm.api_base or get_secret_str("AZURE_API_BASE") # type: ignore
api_version = (
optional_params.api_version
or litellm.api_version
or get_secret_str("AZURE_API_VERSION")
) # type: ignore
api_key = (
optional_params.api_key
or litellm.api_key
or litellm.azure_key
or get_secret_str("AZURE_OPENAI_API_KEY")
or get_secret_str("AZURE_API_KEY")
) # type: ignore
extra_body = optional_params.get("extra_body", {})
if extra_body is not None:
extra_body.pop("azure_ad_token", None)
else:
get_secret_str("AZURE_AD_TOKEN") # type: ignore
response = azure_files_instance.retrieve_file(
_is_async=_is_async,
api_base=api_base,
api_key=api_key,
api_version=api_version,
timeout=timeout,
max_retries=optional_params.max_retries,
file_id=file_id,
)
else:
raise litellm.exceptions.BadRequestError(
message="LiteLLM doesn't support {} for 'file_retrieve'. Only 'openai' and 'azure' are supported.".format(
custom_llm_provider
),
model="n/a",
llm_provider=custom_llm_provider,
response=httpx.Response(
status_code=400,
content="Unsupported provider",
request=httpx.Request(method="create_thread", url="https://github.com/BerriAI/litellm"), # type: ignore
),
)
return cast(FileObject, response)
except Exception as e:
raise e
# Delete file
async def afile_delete(
file_id: str,
custom_llm_provider: Literal["openai", "azure"] = "openai",
extra_headers: Optional[Dict[str, str]] = None,
extra_body: Optional[Dict[str, str]] = None,
**kwargs,
) -> Coroutine[Any, Any, FileObject]:
"""
Async: Delete file
LiteLLM Equivalent of DELETE https://api.openai.com/v1/files
"""
try:
loop = asyncio.get_event_loop()
kwargs["is_async"] = True
# Use a partial function to pass your keyword arguments
func = partial(
file_delete,
file_id,
custom_llm_provider,
extra_headers,
extra_body,
**kwargs,
)
# Add the context to the function
ctx = contextvars.copy_context()
func_with_context = partial(ctx.run, func)
init_response = await loop.run_in_executor(None, func_with_context)
if asyncio.iscoroutine(init_response):
response = await init_response
else:
response = init_response # type: ignore
return cast(FileDeleted, response) # type: ignore
except Exception as e:
raise e
def file_delete(
file_id: str,
custom_llm_provider: Literal["openai", "azure"] = "openai",
extra_headers: Optional[Dict[str, str]] = None,
extra_body: Optional[Dict[str, str]] = None,
**kwargs,
) -> FileDeleted:
"""
Delete file
LiteLLM Equivalent of DELETE https://api.openai.com/v1/files
"""
try:
optional_params = GenericLiteLLMParams(**kwargs)
litellm_params_dict = get_litellm_params(**kwargs)
### TIMEOUT LOGIC ###
timeout = optional_params.timeout or kwargs.get("request_timeout", 600) or 600
# set timeout for 10 minutes by default
client = kwargs.get("client")
if (
timeout is not None
and isinstance(timeout, httpx.Timeout)
and supports_httpx_timeout(custom_llm_provider) is False
):
read_timeout = timeout.read or 600
timeout = read_timeout # default 10 min timeout
elif timeout is not None and not isinstance(timeout, httpx.Timeout):
timeout = float(timeout) # type: ignore
elif timeout is None:
timeout = 600.0
_is_async = kwargs.pop("is_async", False) is True
if custom_llm_provider == "openai":
# for deepinfra/perplexity/anyscale/groq we check in get_llm_provider and pass in the api base from there
api_base = (
optional_params.api_base
or litellm.api_base
or os.getenv("OPENAI_API_BASE")
or "https://api.openai.com/v1"
)
organization = (
optional_params.organization
or litellm.organization
or os.getenv("OPENAI_ORGANIZATION", None)
or None # default - https://github.com/openai/openai-python/blob/284c1799070c723c6a553337134148a7ab088dd8/openai/util.py#L105
)
# set API KEY
api_key = (
optional_params.api_key
or litellm.api_key # for deepinfra/perplexity/anyscale we check in get_llm_provider and pass in the api key from there
or litellm.openai_key
or os.getenv("OPENAI_API_KEY")
)
response = openai_files_instance.delete_file(
file_id=file_id,
_is_async=_is_async,
api_base=api_base,
api_key=api_key,
timeout=timeout,
max_retries=optional_params.max_retries,
organization=organization,
)
elif custom_llm_provider == "azure":
api_base = optional_params.api_base or litellm.api_base or get_secret_str("AZURE_API_BASE") # type: ignore
api_version = (
optional_params.api_version
or litellm.api_version
or get_secret_str("AZURE_API_VERSION")
) # type: ignore
api_key = (
optional_params.api_key
or litellm.api_key
or litellm.azure_key
or get_secret_str("AZURE_OPENAI_API_KEY")
or get_secret_str("AZURE_API_KEY")
) # type: ignore
extra_body = optional_params.get("extra_body", {})
if extra_body is not None:
extra_body.pop("azure_ad_token", None)
else:
get_secret_str("AZURE_AD_TOKEN") # type: ignore
response = azure_files_instance.delete_file(
_is_async=_is_async,
api_base=api_base,
api_key=api_key,
api_version=api_version,
timeout=timeout,
max_retries=optional_params.max_retries,
file_id=file_id,
client=client,
litellm_params=litellm_params_dict,
)
else:
raise litellm.exceptions.BadRequestError(
message="LiteLLM doesn't support {} for 'create_batch'. Only 'openai' is supported.".format(
custom_llm_provider
),
model="n/a",
llm_provider=custom_llm_provider,
response=httpx.Response(
status_code=400,
content="Unsupported provider",
request=httpx.Request(method="create_thread", url="https://github.com/BerriAI/litellm"), # type: ignore
),
)
return cast(FileDeleted, response)
except Exception as e:
raise e
# List files
async def afile_list(
custom_llm_provider: Literal["openai", "azure"] = "openai",
purpose: Optional[str] = None,
extra_headers: Optional[Dict[str, str]] = None,
extra_body: Optional[Dict[str, str]] = None,
**kwargs,
):
"""
Async: List files
LiteLLM Equivalent of GET https://api.openai.com/v1/files
"""
try:
loop = asyncio.get_event_loop()
kwargs["is_async"] = True
# Use a partial function to pass your keyword arguments
func = partial(
file_list,
custom_llm_provider,
purpose,
extra_headers,
extra_body,
**kwargs,
)
# Add the context to the function
ctx = contextvars.copy_context()
func_with_context = partial(ctx.run, func)
init_response = await loop.run_in_executor(None, func_with_context)
if asyncio.iscoroutine(init_response):
response = await init_response
else:
response = init_response # type: ignore
return response
except Exception as e:
raise e
def file_list(
custom_llm_provider: Literal["openai", "azure"] = "openai",
purpose: Optional[str] = None,
extra_headers: Optional[Dict[str, str]] = None,
extra_body: Optional[Dict[str, str]] = None,
**kwargs,
):
"""
List files
LiteLLM Equivalent of GET https://api.openai.com/v1/files
"""
try:
optional_params = GenericLiteLLMParams(**kwargs)
### TIMEOUT LOGIC ###
timeout = optional_params.timeout or kwargs.get("request_timeout", 600) or 600
# set timeout for 10 minutes by default
if (
timeout is not None
and isinstance(timeout, httpx.Timeout)
and supports_httpx_timeout(custom_llm_provider) is False
):
read_timeout = timeout.read or 600
timeout = read_timeout # default 10 min timeout
elif timeout is not None and not isinstance(timeout, httpx.Timeout):
timeout = float(timeout) # type: ignore
elif timeout is None:
timeout = 600.0
_is_async = kwargs.pop("is_async", False) is True
if custom_llm_provider == "openai":
# for deepinfra/perplexity/anyscale/groq we check in get_llm_provider and pass in the api base from there
api_base = (
optional_params.api_base
or litellm.api_base
or os.getenv("OPENAI_API_BASE")
or "https://api.openai.com/v1"
)
organization = (
optional_params.organization
or litellm.organization
or os.getenv("OPENAI_ORGANIZATION", None)
or None # default - https://github.com/openai/openai-python/blob/284c1799070c723c6a553337134148a7ab088dd8/openai/util.py#L105
)
# set API KEY
api_key = (
optional_params.api_key
or litellm.api_key # for deepinfra/perplexity/anyscale we check in get_llm_provider and pass in the api key from there
or litellm.openai_key
or os.getenv("OPENAI_API_KEY")
)
response = openai_files_instance.list_files(
purpose=purpose,
_is_async=_is_async,
api_base=api_base,
api_key=api_key,
timeout=timeout,
max_retries=optional_params.max_retries,
organization=organization,
)
elif custom_llm_provider == "azure":
api_base = optional_params.api_base or litellm.api_base or get_secret_str("AZURE_API_BASE") # type: ignore
api_version = (
optional_params.api_version
or litellm.api_version
or get_secret_str("AZURE_API_VERSION")
) # type: ignore
api_key = (
optional_params.api_key
or litellm.api_key
or litellm.azure_key
or get_secret_str("AZURE_OPENAI_API_KEY")
or get_secret_str("AZURE_API_KEY")
) # type: ignore
extra_body = optional_params.get("extra_body", {})
if extra_body is not None:
extra_body.pop("azure_ad_token", None)
else:
get_secret_str("AZURE_AD_TOKEN") # type: ignore
response = azure_files_instance.list_files(
_is_async=_is_async,
api_base=api_base,
api_key=api_key,
api_version=api_version,
timeout=timeout,
max_retries=optional_params.max_retries,
purpose=purpose,
)
else:
raise litellm.exceptions.BadRequestError(
message="LiteLLM doesn't support {} for 'file_list'. Only 'openai' and 'azure' are supported.".format(
custom_llm_provider
),
model="n/a",
llm_provider=custom_llm_provider,
response=httpx.Response(
status_code=400,
content="Unsupported provider",
request=httpx.Request(method="file_list", url="https://github.com/BerriAI/litellm"), # type: ignore
),
)
return response
except Exception as e:
raise e
async def afile_content(
file_id: str,
custom_llm_provider: Literal["openai", "azure"] = "openai",
extra_headers: Optional[Dict[str, str]] = None,
extra_body: Optional[Dict[str, str]] = None,
**kwargs,
) -> HttpxBinaryResponseContent:
"""
Async: Get file contents
LiteLLM Equivalent of GET https://api.openai.com/v1/files
"""
try:
loop = asyncio.get_event_loop()
kwargs["afile_content"] = True
# Use a partial function to pass your keyword arguments
func = partial(
file_content,
file_id,
custom_llm_provider,
extra_headers,
extra_body,
**kwargs,
)
# Add the context to the function
ctx = contextvars.copy_context()
func_with_context = partial(ctx.run, func)
init_response = await loop.run_in_executor(None, func_with_context)
if asyncio.iscoroutine(init_response):
response = await init_response
else:
response = init_response # type: ignore
return response
except Exception as e:
raise e
def file_content(
file_id: str,
custom_llm_provider: Literal["openai", "azure"] = "openai",
extra_headers: Optional[Dict[str, str]] = None,
extra_body: Optional[Dict[str, str]] = None,
**kwargs,
) -> Union[HttpxBinaryResponseContent, Coroutine[Any, Any, HttpxBinaryResponseContent]]:
"""
Returns the contents of the specified file.
LiteLLM Equivalent of POST: POST https://api.openai.com/v1/files
"""
try:
optional_params = GenericLiteLLMParams(**kwargs)
litellm_params_dict = get_litellm_params(**kwargs)
### TIMEOUT LOGIC ###
timeout = optional_params.timeout or kwargs.get("request_timeout", 600) or 600
client = kwargs.get("client")
# set timeout for 10 minutes by default
if (
timeout is not None
and isinstance(timeout, httpx.Timeout)
and supports_httpx_timeout(custom_llm_provider) is False
):
read_timeout = timeout.read or 600
timeout = read_timeout # default 10 min timeout
elif timeout is not None and not isinstance(timeout, httpx.Timeout):
timeout = float(timeout) # type: ignore
elif timeout is None:
timeout = 600.0
_file_content_request = FileContentRequest(
file_id=file_id,
extra_headers=extra_headers,
extra_body=extra_body,
)
_is_async = kwargs.pop("afile_content", False) is True
if custom_llm_provider == "openai":
# for deepinfra/perplexity/anyscale/groq we check in get_llm_provider and pass in the api base from there
api_base = (
optional_params.api_base
or litellm.api_base
or os.getenv("OPENAI_API_BASE")
or "https://api.openai.com/v1"
)
organization = (
optional_params.organization
or litellm.organization
or os.getenv("OPENAI_ORGANIZATION", None)
or None # default - https://github.com/openai/openai-python/blob/284c1799070c723c6a553337134148a7ab088dd8/openai/util.py#L105
)
# set API KEY
api_key = (
optional_params.api_key
or litellm.api_key # for deepinfra/perplexity/anyscale we check in get_llm_provider and pass in the api key from there
or litellm.openai_key
or os.getenv("OPENAI_API_KEY")
)
response = openai_files_instance.file_content(
_is_async=_is_async,
file_content_request=_file_content_request,
api_base=api_base,
api_key=api_key,
timeout=timeout,
max_retries=optional_params.max_retries,
organization=organization,
)
elif custom_llm_provider == "azure":
api_base = optional_params.api_base or litellm.api_base or get_secret_str("AZURE_API_BASE") # type: ignore
api_version = (
optional_params.api_version
or litellm.api_version
or get_secret_str("AZURE_API_VERSION")
) # type: ignore
api_key = (
optional_params.api_key
or litellm.api_key
or litellm.azure_key
or get_secret_str("AZURE_OPENAI_API_KEY")
or get_secret_str("AZURE_API_KEY")
) # type: ignore
extra_body = optional_params.get("extra_body", {})
if extra_body is not None:
extra_body.pop("azure_ad_token", None)
else:
get_secret_str("AZURE_AD_TOKEN") # type: ignore
response = azure_files_instance.file_content(
_is_async=_is_async,
api_base=api_base,
api_key=api_key,
api_version=api_version,
timeout=timeout,
max_retries=optional_params.max_retries,
file_content_request=_file_content_request,
client=client,
litellm_params=litellm_params_dict,
)
else:
raise litellm.exceptions.BadRequestError(
message="LiteLLM doesn't support {} for 'custom_llm_provider'. Supported providers are 'openai', 'azure', 'vertex_ai'.".format(
custom_llm_provider
),
model="n/a",
llm_provider=custom_llm_provider,
response=httpx.Response(
status_code=400,
content="Unsupported provider",
request=httpx.Request(method="create_thread", url="https://github.com/BerriAI/litellm"), # type: ignore
),
)
return response
except Exception as e:
raise e

View File

@@ -0,0 +1,757 @@
"""
Main File for Fine Tuning API implementation
https://platform.openai.com/docs/api-reference/fine-tuning
- fine_tuning.jobs.create()
- fine_tuning.jobs.list()
- client.fine_tuning.jobs.list_events()
"""
import asyncio
import contextvars
import os
from functools import partial
from typing import Any, Coroutine, Dict, Literal, Optional, Union
import httpx
import litellm
from litellm._logging import verbose_logger
from litellm.llms.azure.fine_tuning.handler import AzureOpenAIFineTuningAPI
from litellm.llms.openai.fine_tuning.handler import OpenAIFineTuningAPI
from litellm.llms.vertex_ai.fine_tuning.handler import VertexFineTuningAPI
from litellm.secret_managers.main import get_secret_str
from litellm.types.llms.openai import (
FineTuningJob,
FineTuningJobCreate,
Hyperparameters,
)
from litellm.types.router import *
from litellm.utils import client, supports_httpx_timeout
####### ENVIRONMENT VARIABLES ###################
openai_fine_tuning_apis_instance = OpenAIFineTuningAPI()
azure_fine_tuning_apis_instance = AzureOpenAIFineTuningAPI()
vertex_fine_tuning_apis_instance = VertexFineTuningAPI()
#################################################
@client
async def acreate_fine_tuning_job(
model: str,
training_file: str,
hyperparameters: Optional[dict] = {},
suffix: Optional[str] = None,
validation_file: Optional[str] = None,
integrations: Optional[List[str]] = None,
seed: Optional[int] = None,
custom_llm_provider: Literal["openai", "azure", "vertex_ai"] = "openai",
extra_headers: Optional[Dict[str, str]] = None,
extra_body: Optional[Dict[str, str]] = None,
**kwargs,
) -> FineTuningJob:
"""
Async: Creates and executes a batch from an uploaded file of request
"""
verbose_logger.debug(
"inside acreate_fine_tuning_job model=%s and kwargs=%s", model, kwargs
)
try:
loop = asyncio.get_event_loop()
kwargs["acreate_fine_tuning_job"] = True
# Use a partial function to pass your keyword arguments
func = partial(
create_fine_tuning_job,
model,
training_file,
hyperparameters,
suffix,
validation_file,
integrations,
seed,
custom_llm_provider,
extra_headers,
extra_body,
**kwargs,
)
# Add the context to the function
ctx = contextvars.copy_context()
func_with_context = partial(ctx.run, func)
init_response = await loop.run_in_executor(None, func_with_context)
if asyncio.iscoroutine(init_response):
response = await init_response
else:
response = init_response # type: ignore
return response
except Exception as e:
raise e
@client
def create_fine_tuning_job(
model: str,
training_file: str,
hyperparameters: Optional[dict] = {},
suffix: Optional[str] = None,
validation_file: Optional[str] = None,
integrations: Optional[List[str]] = None,
seed: Optional[int] = None,
custom_llm_provider: Literal["openai", "azure", "vertex_ai"] = "openai",
extra_headers: Optional[Dict[str, str]] = None,
extra_body: Optional[Dict[str, str]] = None,
**kwargs,
) -> Union[FineTuningJob, Coroutine[Any, Any, FineTuningJob]]:
"""
Creates a fine-tuning job which begins the process of creating a new model from a given dataset.
Response includes details of the enqueued job including job status and the name of the fine-tuned models once complete
"""
try:
_is_async = kwargs.pop("acreate_fine_tuning_job", False) is True
optional_params = GenericLiteLLMParams(**kwargs)
# handle hyperparameters
hyperparameters = hyperparameters or {} # original hyperparameters
_oai_hyperparameters: Hyperparameters = Hyperparameters(
**hyperparameters
) # Typed Hyperparameters for OpenAI Spec
### TIMEOUT LOGIC ###
timeout = optional_params.timeout or kwargs.get("request_timeout", 600) or 600
# set timeout for 10 minutes by default
if (
timeout is not None
and isinstance(timeout, httpx.Timeout)
and supports_httpx_timeout(custom_llm_provider) is False
):
read_timeout = timeout.read or 600
timeout = read_timeout # default 10 min timeout
elif timeout is not None and not isinstance(timeout, httpx.Timeout):
timeout = float(timeout) # type: ignore
elif timeout is None:
timeout = 600.0
# OpenAI
if custom_llm_provider == "openai":
# for deepinfra/perplexity/anyscale/groq we check in get_llm_provider and pass in the api base from there
api_base = (
optional_params.api_base
or litellm.api_base
or os.getenv("OPENAI_API_BASE")
or "https://api.openai.com/v1"
)
organization = (
optional_params.organization
or litellm.organization
or os.getenv("OPENAI_ORGANIZATION", None)
or None # default - https://github.com/openai/openai-python/blob/284c1799070c723c6a553337134148a7ab088dd8/openai/util.py#L105
)
# set API KEY
api_key = (
optional_params.api_key
or litellm.api_key # for deepinfra/perplexity/anyscale we check in get_llm_provider and pass in the api key from there
or litellm.openai_key
or os.getenv("OPENAI_API_KEY")
)
create_fine_tuning_job_data = FineTuningJobCreate(
model=model,
training_file=training_file,
hyperparameters=_oai_hyperparameters,
suffix=suffix,
validation_file=validation_file,
integrations=integrations,
seed=seed,
)
create_fine_tuning_job_data_dict = create_fine_tuning_job_data.model_dump(
exclude_none=True
)
response = openai_fine_tuning_apis_instance.create_fine_tuning_job(
api_base=api_base,
api_key=api_key,
api_version=optional_params.api_version,
organization=organization,
create_fine_tuning_job_data=create_fine_tuning_job_data_dict,
timeout=timeout,
max_retries=optional_params.max_retries,
_is_async=_is_async,
client=kwargs.get(
"client", None
), # note, when we add this to `GenericLiteLLMParams` it impacts a lot of other tests + linting
)
# Azure OpenAI
elif custom_llm_provider == "azure":
api_base = optional_params.api_base or litellm.api_base or get_secret_str("AZURE_API_BASE") # type: ignore
api_version = (
optional_params.api_version
or litellm.api_version
or get_secret_str("AZURE_API_VERSION")
) # type: ignore
api_key = (
optional_params.api_key
or litellm.api_key
or litellm.azure_key
or get_secret_str("AZURE_OPENAI_API_KEY")
or get_secret_str("AZURE_API_KEY")
) # type: ignore
extra_body = optional_params.get("extra_body", {})
if extra_body is not None:
extra_body.pop("azure_ad_token", None)
else:
get_secret_str("AZURE_AD_TOKEN") # type: ignore
create_fine_tuning_job_data = FineTuningJobCreate(
model=model,
training_file=training_file,
hyperparameters=_oai_hyperparameters,
suffix=suffix,
validation_file=validation_file,
integrations=integrations,
seed=seed,
)
create_fine_tuning_job_data_dict = create_fine_tuning_job_data.model_dump(
exclude_none=True
)
response = azure_fine_tuning_apis_instance.create_fine_tuning_job(
api_base=api_base,
api_key=api_key,
api_version=api_version,
create_fine_tuning_job_data=create_fine_tuning_job_data_dict,
timeout=timeout,
max_retries=optional_params.max_retries,
_is_async=_is_async,
organization=optional_params.organization,
)
elif custom_llm_provider == "vertex_ai":
api_base = optional_params.api_base or ""
vertex_ai_project = (
optional_params.vertex_project
or litellm.vertex_project
or get_secret_str("VERTEXAI_PROJECT")
)
vertex_ai_location = (
optional_params.vertex_location
or litellm.vertex_location
or get_secret_str("VERTEXAI_LOCATION")
)
vertex_credentials = optional_params.vertex_credentials or get_secret_str(
"VERTEXAI_CREDENTIALS"
)
create_fine_tuning_job_data = FineTuningJobCreate(
model=model,
training_file=training_file,
hyperparameters=_oai_hyperparameters,
suffix=suffix,
validation_file=validation_file,
integrations=integrations,
seed=seed,
)
response = vertex_fine_tuning_apis_instance.create_fine_tuning_job(
_is_async=_is_async,
create_fine_tuning_job_data=create_fine_tuning_job_data,
vertex_credentials=vertex_credentials,
vertex_project=vertex_ai_project,
vertex_location=vertex_ai_location,
timeout=timeout,
api_base=api_base,
kwargs=kwargs,
original_hyperparameters=hyperparameters,
)
else:
raise litellm.exceptions.BadRequestError(
message="LiteLLM doesn't support {} for 'create_batch'. Only 'openai' is supported.".format(
custom_llm_provider
),
model="n/a",
llm_provider=custom_llm_provider,
response=httpx.Response(
status_code=400,
content="Unsupported provider",
request=httpx.Request(method="create_thread", url="https://github.com/BerriAI/litellm"), # type: ignore
),
)
return response
except Exception as e:
verbose_logger.error("got exception in create_fine_tuning_job=%s", str(e))
raise e
async def acancel_fine_tuning_job(
fine_tuning_job_id: str,
custom_llm_provider: Literal["openai", "azure", "vertex_ai"] = "openai",
extra_headers: Optional[Dict[str, str]] = None,
extra_body: Optional[Dict[str, str]] = None,
**kwargs,
) -> FineTuningJob:
"""
Async: Immediately cancel a fine-tune job.
"""
try:
loop = asyncio.get_event_loop()
kwargs["acancel_fine_tuning_job"] = True
# Use a partial function to pass your keyword arguments
func = partial(
cancel_fine_tuning_job,
fine_tuning_job_id,
custom_llm_provider,
extra_headers,
extra_body,
**kwargs,
)
# Add the context to the function
ctx = contextvars.copy_context()
func_with_context = partial(ctx.run, func)
init_response = await loop.run_in_executor(None, func_with_context)
if asyncio.iscoroutine(init_response):
response = await init_response
else:
response = init_response # type: ignore
return response
except Exception as e:
raise e
def cancel_fine_tuning_job(
fine_tuning_job_id: str,
custom_llm_provider: Literal["openai", "azure", "vertex_ai"] = "openai",
extra_headers: Optional[Dict[str, str]] = None,
extra_body: Optional[Dict[str, str]] = None,
**kwargs,
) -> Union[FineTuningJob, Coroutine[Any, Any, FineTuningJob]]:
"""
Immediately cancel a fine-tune job.
Response includes details of the enqueued job including job status and the name of the fine-tuned models once complete
"""
try:
optional_params = GenericLiteLLMParams(**kwargs)
### TIMEOUT LOGIC ###
timeout = optional_params.timeout or kwargs.get("request_timeout", 600) or 600
# set timeout for 10 minutes by default
if (
timeout is not None
and isinstance(timeout, httpx.Timeout)
and supports_httpx_timeout(custom_llm_provider) is False
):
read_timeout = timeout.read or 600
timeout = read_timeout # default 10 min timeout
elif timeout is not None and not isinstance(timeout, httpx.Timeout):
timeout = float(timeout) # type: ignore
elif timeout is None:
timeout = 600.0
_is_async = kwargs.pop("acancel_fine_tuning_job", False) is True
# OpenAI
if custom_llm_provider == "openai":
# for deepinfra/perplexity/anyscale/groq we check in get_llm_provider and pass in the api base from there
api_base = (
optional_params.api_base
or litellm.api_base
or os.getenv("OPENAI_API_BASE")
or "https://api.openai.com/v1"
)
organization = (
optional_params.organization
or litellm.organization
or os.getenv("OPENAI_ORGANIZATION", None)
or None # default - https://github.com/openai/openai-python/blob/284c1799070c723c6a553337134148a7ab088dd8/openai/util.py#L105
)
# set API KEY
api_key = (
optional_params.api_key
or litellm.api_key # for deepinfra/perplexity/anyscale we check in get_llm_provider and pass in the api key from there
or litellm.openai_key
or os.getenv("OPENAI_API_KEY")
)
response = openai_fine_tuning_apis_instance.cancel_fine_tuning_job(
api_base=api_base,
api_key=api_key,
api_version=optional_params.api_version,
organization=organization,
fine_tuning_job_id=fine_tuning_job_id,
timeout=timeout,
max_retries=optional_params.max_retries,
_is_async=_is_async,
client=kwargs.get("client", None),
)
# Azure OpenAI
elif custom_llm_provider == "azure":
api_base = optional_params.api_base or litellm.api_base or get_secret("AZURE_API_BASE") # type: ignore
api_version = (
optional_params.api_version
or litellm.api_version
or get_secret_str("AZURE_API_VERSION")
) # type: ignore
api_key = (
optional_params.api_key
or litellm.api_key
or litellm.azure_key
or get_secret_str("AZURE_OPENAI_API_KEY")
or get_secret_str("AZURE_API_KEY")
) # type: ignore
extra_body = optional_params.get("extra_body", {})
if extra_body is not None:
extra_body.pop("azure_ad_token", None)
else:
get_secret_str("AZURE_AD_TOKEN") # type: ignore
response = azure_fine_tuning_apis_instance.cancel_fine_tuning_job(
api_base=api_base,
api_key=api_key,
api_version=api_version,
fine_tuning_job_id=fine_tuning_job_id,
timeout=timeout,
max_retries=optional_params.max_retries,
_is_async=_is_async,
organization=optional_params.organization,
)
else:
raise litellm.exceptions.BadRequestError(
message="LiteLLM doesn't support {} for 'create_batch'. Only 'openai' is supported.".format(
custom_llm_provider
),
model="n/a",
llm_provider=custom_llm_provider,
response=httpx.Response(
status_code=400,
content="Unsupported provider",
request=httpx.Request(method="create_thread", url="https://github.com/BerriAI/litellm"), # type: ignore
),
)
return response
except Exception as e:
raise e
async def alist_fine_tuning_jobs(
after: Optional[str] = None,
limit: Optional[int] = None,
custom_llm_provider: Literal["openai", "azure", "vertex_ai"] = "openai",
extra_headers: Optional[Dict[str, str]] = None,
extra_body: Optional[Dict[str, str]] = None,
**kwargs,
):
"""
Async: List your organization's fine-tuning jobs
"""
try:
loop = asyncio.get_event_loop()
kwargs["alist_fine_tuning_jobs"] = True
# Use a partial function to pass your keyword arguments
func = partial(
list_fine_tuning_jobs,
after,
limit,
custom_llm_provider,
extra_headers,
extra_body,
**kwargs,
)
# Add the context to the function
ctx = contextvars.copy_context()
func_with_context = partial(ctx.run, func)
init_response = await loop.run_in_executor(None, func_with_context)
if asyncio.iscoroutine(init_response):
response = await init_response
else:
response = init_response # type: ignore
return response
except Exception as e:
raise e
def list_fine_tuning_jobs(
after: Optional[str] = None,
limit: Optional[int] = None,
custom_llm_provider: Literal["openai", "azure", "vertex_ai"] = "openai",
extra_headers: Optional[Dict[str, str]] = None,
extra_body: Optional[Dict[str, str]] = None,
**kwargs,
):
"""
List your organization's fine-tuning jobs
Params:
- after: Optional[str] = None, Identifier for the last job from the previous pagination request.
- limit: Optional[int] = None, Number of fine-tuning jobs to retrieve. Defaults to 20
"""
try:
optional_params = GenericLiteLLMParams(**kwargs)
### TIMEOUT LOGIC ###
timeout = optional_params.timeout or kwargs.get("request_timeout", 600) or 600
# set timeout for 10 minutes by default
if (
timeout is not None
and isinstance(timeout, httpx.Timeout)
and supports_httpx_timeout(custom_llm_provider) is False
):
read_timeout = timeout.read or 600
timeout = read_timeout # default 10 min timeout
elif timeout is not None and not isinstance(timeout, httpx.Timeout):
timeout = float(timeout) # type: ignore
elif timeout is None:
timeout = 600.0
_is_async = kwargs.pop("alist_fine_tuning_jobs", False) is True
# OpenAI
if custom_llm_provider == "openai":
# for deepinfra/perplexity/anyscale/groq we check in get_llm_provider and pass in the api base from there
api_base = (
optional_params.api_base
or litellm.api_base
or os.getenv("OPENAI_API_BASE")
or "https://api.openai.com/v1"
)
organization = (
optional_params.organization
or litellm.organization
or os.getenv("OPENAI_ORGANIZATION", None)
or None # default - https://github.com/openai/openai-python/blob/284c1799070c723c6a553337134148a7ab088dd8/openai/util.py#L105
)
# set API KEY
api_key = (
optional_params.api_key
or litellm.api_key # for deepinfra/perplexity/anyscale we check in get_llm_provider and pass in the api key from there
or litellm.openai_key
or os.getenv("OPENAI_API_KEY")
)
response = openai_fine_tuning_apis_instance.list_fine_tuning_jobs(
api_base=api_base,
api_key=api_key,
api_version=optional_params.api_version,
organization=organization,
after=after,
limit=limit,
timeout=timeout,
max_retries=optional_params.max_retries,
_is_async=_is_async,
client=kwargs.get("client", None),
)
# Azure OpenAI
elif custom_llm_provider == "azure":
api_base = optional_params.api_base or litellm.api_base or get_secret_str("AZURE_API_BASE") # type: ignore
api_version = (
optional_params.api_version
or litellm.api_version
or get_secret_str("AZURE_API_VERSION")
) # type: ignore
api_key = (
optional_params.api_key
or litellm.api_key
or litellm.azure_key
or get_secret_str("AZURE_OPENAI_API_KEY")
or get_secret_str("AZURE_API_KEY")
) # type: ignore
extra_body = optional_params.get("extra_body", {})
if extra_body is not None:
extra_body.pop("azure_ad_token", None)
else:
get_secret("AZURE_AD_TOKEN") # type: ignore
response = azure_fine_tuning_apis_instance.list_fine_tuning_jobs(
api_base=api_base,
api_key=api_key,
api_version=api_version,
after=after,
limit=limit,
timeout=timeout,
max_retries=optional_params.max_retries,
_is_async=_is_async,
organization=optional_params.organization,
)
else:
raise litellm.exceptions.BadRequestError(
message="LiteLLM doesn't support {} for 'create_batch'. Only 'openai' is supported.".format(
custom_llm_provider
),
model="n/a",
llm_provider=custom_llm_provider,
response=httpx.Response(
status_code=400,
content="Unsupported provider",
request=httpx.Request(method="create_thread", url="https://github.com/BerriAI/litellm"), # type: ignore
),
)
return response
except Exception as e:
raise e
async def aretrieve_fine_tuning_job(
fine_tuning_job_id: str,
custom_llm_provider: Literal["openai", "azure", "vertex_ai"] = "openai",
extra_headers: Optional[Dict[str, str]] = None,
extra_body: Optional[Dict[str, str]] = None,
**kwargs,
) -> FineTuningJob:
"""
Async: Get info about a fine-tuning job.
"""
try:
loop = asyncio.get_event_loop()
kwargs["aretrieve_fine_tuning_job"] = True
# Use a partial function to pass your keyword arguments
func = partial(
retrieve_fine_tuning_job,
fine_tuning_job_id,
custom_llm_provider,
extra_headers,
extra_body,
**kwargs,
)
# Add the context to the function
ctx = contextvars.copy_context()
func_with_context = partial(ctx.run, func)
init_response = await loop.run_in_executor(None, func_with_context)
if asyncio.iscoroutine(init_response):
response = await init_response
else:
response = init_response # type: ignore
return response
except Exception as e:
raise e
def retrieve_fine_tuning_job(
fine_tuning_job_id: str,
custom_llm_provider: Literal["openai", "azure", "vertex_ai"] = "openai",
extra_headers: Optional[Dict[str, str]] = None,
extra_body: Optional[Dict[str, str]] = None,
**kwargs,
) -> Union[FineTuningJob, Coroutine[Any, Any, FineTuningJob]]:
"""
Get info about a fine-tuning job.
"""
try:
optional_params = GenericLiteLLMParams(**kwargs)
### TIMEOUT LOGIC ###
timeout = optional_params.timeout or kwargs.get("request_timeout", 600) or 600
# set timeout for 10 minutes by default
if (
timeout is not None
and isinstance(timeout, httpx.Timeout)
and supports_httpx_timeout(custom_llm_provider) is False
):
read_timeout = timeout.read or 600
timeout = read_timeout # default 10 min timeout
elif timeout is not None and not isinstance(timeout, httpx.Timeout):
timeout = float(timeout) # type: ignore
elif timeout is None:
timeout = 600.0
_is_async = kwargs.pop("aretrieve_fine_tuning_job", False) is True
# OpenAI
if custom_llm_provider == "openai":
api_base = (
optional_params.api_base
or litellm.api_base
or os.getenv("OPENAI_API_BASE")
or "https://api.openai.com/v1"
)
organization = (
optional_params.organization
or litellm.organization
or os.getenv("OPENAI_ORGANIZATION", None)
or None
)
api_key = (
optional_params.api_key
or litellm.api_key
or litellm.openai_key
or os.getenv("OPENAI_API_KEY")
)
response = openai_fine_tuning_apis_instance.retrieve_fine_tuning_job(
api_base=api_base,
api_key=api_key,
api_version=optional_params.api_version,
organization=organization,
fine_tuning_job_id=fine_tuning_job_id,
timeout=timeout,
max_retries=optional_params.max_retries,
_is_async=_is_async,
client=kwargs.get("client", None),
)
# Azure OpenAI
elif custom_llm_provider == "azure":
api_base = optional_params.api_base or litellm.api_base or get_secret_str("AZURE_API_BASE") # type: ignore
api_version = (
optional_params.api_version
or litellm.api_version
or get_secret_str("AZURE_API_VERSION")
) # type: ignore
api_key = (
optional_params.api_key
or litellm.api_key
or litellm.azure_key
or get_secret_str("AZURE_OPENAI_API_KEY")
or get_secret_str("AZURE_API_KEY")
) # type: ignore
extra_body = optional_params.get("extra_body", {})
if extra_body is not None:
extra_body.pop("azure_ad_token", None)
else:
get_secret_str("AZURE_AD_TOKEN") # type: ignore
response = azure_fine_tuning_apis_instance.retrieve_fine_tuning_job(
api_base=api_base,
api_key=api_key,
api_version=api_version,
fine_tuning_job_id=fine_tuning_job_id,
timeout=timeout,
max_retries=optional_params.max_retries,
_is_async=_is_async,
organization=optional_params.organization,
)
else:
raise litellm.exceptions.BadRequestError(
message="LiteLLM doesn't support {} for 'retrieve_fine_tuning_job'. Only 'openai' and 'azure' are supported.".format(
custom_llm_provider
),
model="n/a",
llm_provider=custom_llm_provider,
response=httpx.Response(
status_code=400,
content="Unsupported provider",
request=httpx.Request(method="retrieve_fine_tuning_job", url="https://github.com/BerriAI/litellm"), # type: ignore
),
)
return response
except Exception as e:
raise e

View File

@@ -0,0 +1,5 @@
# Integrations
This folder contains logging integrations for litellm
eg. logging to Datadog, Langfuse, Prometheus, s3, GCS Bucket, etc.

View File

@@ -0,0 +1,13 @@
# Slack Alerting on LiteLLM Gateway
This folder contains the Slack Alerting integration for LiteLLM Gateway.
## Folder Structure
- `slack_alerting.py`: This is the main file that handles sending different types of alerts
- `batching_handler.py`: Handles Batching + sending Httpx Post requests to slack. Slack alerts are sent every 10s or when events are greater than X events. Done to ensure litellm has good performance under high traffic
- `types.py`: This file contains the AlertType enum which is used to define the different types of alerts that can be sent to Slack.
- `utils.py`: This file contains common utils used specifically for slack alerting
## Further Reading
- [Doc setting up Alerting on LiteLLM Proxy (Gateway)](https://docs.litellm.ai/docs/proxy/alerting)

View File

@@ -0,0 +1,81 @@
"""
Handles Batching + sending Httpx Post requests to slack
Slack alerts are sent every 10s or when events are greater than X events
see custom_batch_logger.py for more details / defaults
"""
from typing import TYPE_CHECKING, Any
from litellm._logging import verbose_proxy_logger
if TYPE_CHECKING:
from .slack_alerting import SlackAlerting as _SlackAlerting
SlackAlertingType = _SlackAlerting
else:
SlackAlertingType = Any
def squash_payloads(queue):
squashed = {}
if len(queue) == 0:
return squashed
if len(queue) == 1:
return {"key": {"item": queue[0], "count": 1}}
for item in queue:
url = item["url"]
alert_type = item["alert_type"]
_key = (url, alert_type)
if _key in squashed:
squashed[_key]["count"] += 1
# Merge the payloads
else:
squashed[_key] = {"item": item, "count": 1}
return squashed
def _print_alerting_payload_warning(
payload: dict, slackAlertingInstance: SlackAlertingType
):
"""
Print the payload to the console when
slackAlertingInstance.alerting_args.log_to_console is True
Relevant issue: https://github.com/BerriAI/litellm/issues/7372
"""
if slackAlertingInstance.alerting_args.log_to_console is True:
verbose_proxy_logger.warning(payload)
async def send_to_webhook(slackAlertingInstance: SlackAlertingType, item, count):
"""
Send a single slack alert to the webhook
"""
import json
payload = item.get("payload", {})
try:
if count > 1:
payload["text"] = f"[Num Alerts: {count}]\n\n{payload['text']}"
response = await slackAlertingInstance.async_http_handler.post(
url=item["url"],
headers=item["headers"],
data=json.dumps(payload),
)
if response.status_code != 200:
verbose_proxy_logger.debug(
f"Error sending slack alert to url={item['url']}. Error={response.text}"
)
except Exception as e:
verbose_proxy_logger.debug(f"Error sending slack alert: {str(e)}")
finally:
_print_alerting_payload_warning(
payload, slackAlertingInstance=slackAlertingInstance
)

View File

@@ -0,0 +1,92 @@
"""
Utils used for slack alerting
"""
import asyncio
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
from litellm.proxy._types import AlertType
from litellm.secret_managers.main import get_secret
if TYPE_CHECKING:
from litellm.litellm_core_utils.litellm_logging import Logging as _Logging
Logging = _Logging
else:
Logging = Any
def process_slack_alerting_variables(
alert_to_webhook_url: Optional[Dict[AlertType, Union[List[str], str]]]
) -> Optional[Dict[AlertType, Union[List[str], str]]]:
"""
process alert_to_webhook_url
- check if any urls are set as os.environ/SLACK_WEBHOOK_URL_1 read env var and set the correct value
"""
if alert_to_webhook_url is None:
return None
for alert_type, webhook_urls in alert_to_webhook_url.items():
if isinstance(webhook_urls, list):
_webhook_values: List[str] = []
for webhook_url in webhook_urls:
if "os.environ/" in webhook_url:
_env_value = get_secret(secret_name=webhook_url)
if not isinstance(_env_value, str):
raise ValueError(
f"Invalid webhook url value for: {webhook_url}. Got type={type(_env_value)}"
)
_webhook_values.append(_env_value)
else:
_webhook_values.append(webhook_url)
alert_to_webhook_url[alert_type] = _webhook_values
else:
_webhook_value_str: str = webhook_urls
if "os.environ/" in webhook_urls:
_env_value = get_secret(secret_name=webhook_urls)
if not isinstance(_env_value, str):
raise ValueError(
f"Invalid webhook url value for: {webhook_urls}. Got type={type(_env_value)}"
)
_webhook_value_str = _env_value
else:
_webhook_value_str = webhook_urls
alert_to_webhook_url[alert_type] = _webhook_value_str
return alert_to_webhook_url
async def _add_langfuse_trace_id_to_alert(
request_data: Optional[dict] = None,
) -> Optional[str]:
"""
Returns langfuse trace url
- check:
-> existing_trace_id
-> trace_id
-> litellm_call_id
"""
# do nothing for now
if (
request_data is not None
and request_data.get("litellm_logging_obj", None) is not None
):
trace_id: Optional[str] = None
litellm_logging_obj: Logging = request_data["litellm_logging_obj"]
for _ in range(3):
trace_id = litellm_logging_obj._get_trace_id(service_name="langfuse")
if trace_id is not None:
break
await asyncio.sleep(3) # wait 3s before retrying for trace id
_langfuse_object = litellm_logging_obj._get_callback_object(
service_name="langfuse"
)
if _langfuse_object is not None:
base_url = _langfuse_object.Langfuse.base_url
return f"{base_url}/trace/{trace_id}"
return None

View File

@@ -0,0 +1 @@
from . import *

Some files were not shown because too many files have changed in this diff Show More