structure saas with tools
This commit is contained in:
1082
.venv/lib/python3.10/site-packages/litellm/__init__.py
Normal file
1082
.venv/lib/python3.10/site-packages/litellm/__init__.py
Normal file
File diff suppressed because it is too large
Load Diff
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
167
.venv/lib/python3.10/site-packages/litellm/_logging.py
Normal file
167
.venv/lib/python3.10/site-packages/litellm/_logging.py
Normal file
@@ -0,0 +1,167 @@
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
from datetime import datetime
|
||||
from logging import Formatter
|
||||
|
||||
set_verbose = False
|
||||
|
||||
if set_verbose is True:
|
||||
logging.warning(
|
||||
"`litellm.set_verbose` is deprecated. Please set `os.environ['LITELLM_LOG'] = 'DEBUG'` for debug logs."
|
||||
)
|
||||
json_logs = bool(os.getenv("JSON_LOGS", False))
|
||||
# Create a handler for the logger (you may need to adapt this based on your needs)
|
||||
log_level = os.getenv("LITELLM_LOG", "DEBUG")
|
||||
numeric_level: str = getattr(logging, log_level.upper())
|
||||
handler = logging.StreamHandler()
|
||||
handler.setLevel(numeric_level)
|
||||
|
||||
|
||||
class JsonFormatter(Formatter):
|
||||
def __init__(self):
|
||||
super(JsonFormatter, self).__init__()
|
||||
|
||||
def formatTime(self, record, datefmt=None):
|
||||
# Use datetime to format the timestamp in ISO 8601 format
|
||||
dt = datetime.fromtimestamp(record.created)
|
||||
return dt.isoformat()
|
||||
|
||||
def format(self, record):
|
||||
json_record = {
|
||||
"message": record.getMessage(),
|
||||
"level": record.levelname,
|
||||
"timestamp": self.formatTime(record),
|
||||
}
|
||||
|
||||
if record.exc_info:
|
||||
json_record["stacktrace"] = self.formatException(record.exc_info)
|
||||
|
||||
return json.dumps(json_record)
|
||||
|
||||
|
||||
# Function to set up exception handlers for JSON logging
|
||||
def _setup_json_exception_handlers(formatter):
|
||||
# Create a handler with JSON formatting for exceptions
|
||||
error_handler = logging.StreamHandler()
|
||||
error_handler.setFormatter(formatter)
|
||||
|
||||
# Setup excepthook for uncaught exceptions
|
||||
def json_excepthook(exc_type, exc_value, exc_traceback):
|
||||
record = logging.LogRecord(
|
||||
name="LiteLLM",
|
||||
level=logging.ERROR,
|
||||
pathname="",
|
||||
lineno=0,
|
||||
msg=str(exc_value),
|
||||
args=(),
|
||||
exc_info=(exc_type, exc_value, exc_traceback),
|
||||
)
|
||||
error_handler.handle(record)
|
||||
|
||||
sys.excepthook = json_excepthook
|
||||
|
||||
# Configure asyncio exception handler if possible
|
||||
try:
|
||||
import asyncio
|
||||
|
||||
def async_json_exception_handler(loop, context):
|
||||
exception = context.get("exception")
|
||||
if exception:
|
||||
record = logging.LogRecord(
|
||||
name="LiteLLM",
|
||||
level=logging.ERROR,
|
||||
pathname="",
|
||||
lineno=0,
|
||||
msg=str(exception),
|
||||
args=(),
|
||||
exc_info=None,
|
||||
)
|
||||
error_handler.handle(record)
|
||||
else:
|
||||
loop.default_exception_handler(context)
|
||||
|
||||
asyncio.get_event_loop().set_exception_handler(async_json_exception_handler)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
# Create a formatter and set it for the handler
|
||||
if json_logs:
|
||||
handler.setFormatter(JsonFormatter())
|
||||
_setup_json_exception_handlers(JsonFormatter())
|
||||
else:
|
||||
formatter = logging.Formatter(
|
||||
"\033[92m%(asctime)s - %(name)s:%(levelname)s\033[0m: %(filename)s:%(lineno)s - %(message)s",
|
||||
datefmt="%H:%M:%S",
|
||||
)
|
||||
|
||||
handler.setFormatter(formatter)
|
||||
|
||||
verbose_proxy_logger = logging.getLogger("LiteLLM Proxy")
|
||||
verbose_router_logger = logging.getLogger("LiteLLM Router")
|
||||
verbose_logger = logging.getLogger("LiteLLM")
|
||||
|
||||
# Add the handler to the logger
|
||||
verbose_router_logger.addHandler(handler)
|
||||
verbose_proxy_logger.addHandler(handler)
|
||||
verbose_logger.addHandler(handler)
|
||||
|
||||
|
||||
def _turn_on_json():
|
||||
handler = logging.StreamHandler()
|
||||
handler.setFormatter(JsonFormatter())
|
||||
|
||||
# Define all loggers to update, including root logger
|
||||
loggers = [logging.getLogger()] + [
|
||||
verbose_router_logger,
|
||||
verbose_proxy_logger,
|
||||
verbose_logger,
|
||||
]
|
||||
|
||||
# Iterate through each logger and update its handlers
|
||||
for logger in loggers:
|
||||
# Remove all existing handlers
|
||||
for h in logger.handlers[:]:
|
||||
logger.removeHandler(h)
|
||||
# Add the new handler
|
||||
logger.addHandler(handler)
|
||||
|
||||
# Set up exception handlers
|
||||
_setup_json_exception_handlers(JsonFormatter())
|
||||
|
||||
|
||||
def _turn_on_debug():
|
||||
verbose_logger.setLevel(level=logging.DEBUG) # set package log to debug
|
||||
verbose_router_logger.setLevel(level=logging.DEBUG) # set router logs to debug
|
||||
verbose_proxy_logger.setLevel(level=logging.DEBUG) # set proxy logs to debug
|
||||
|
||||
|
||||
def _disable_debugging():
|
||||
verbose_logger.disabled = True
|
||||
verbose_router_logger.disabled = True
|
||||
verbose_proxy_logger.disabled = True
|
||||
|
||||
|
||||
def _enable_debugging():
|
||||
verbose_logger.disabled = False
|
||||
verbose_router_logger.disabled = False
|
||||
verbose_proxy_logger.disabled = False
|
||||
|
||||
|
||||
def print_verbose(print_statement):
|
||||
try:
|
||||
if set_verbose:
|
||||
print(print_statement) # noqa
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
def _is_debugging_on() -> bool:
|
||||
"""
|
||||
Returns True if debugging is on
|
||||
"""
|
||||
if verbose_logger.isEnabledFor(logging.DEBUG) or set_verbose is True:
|
||||
return True
|
||||
return False
|
||||
333
.venv/lib/python3.10/site-packages/litellm/_redis.py
Normal file
333
.venv/lib/python3.10/site-packages/litellm/_redis.py
Normal file
@@ -0,0 +1,333 @@
|
||||
# +-----------------------------------------------+
|
||||
# | |
|
||||
# | Give Feedback / Get Help |
|
||||
# | https://github.com/BerriAI/litellm/issues/new |
|
||||
# | |
|
||||
# +-----------------------------------------------+
|
||||
#
|
||||
# Thank you users! We ❤️ you! - Krrish & Ishaan
|
||||
|
||||
import inspect
|
||||
import json
|
||||
|
||||
# s/o [@Frank Colson](https://www.linkedin.com/in/frank-colson-422b9b183/) for this redis implementation
|
||||
import os
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import redis # type: ignore
|
||||
import redis.asyncio as async_redis # type: ignore
|
||||
|
||||
from litellm import get_secret, get_secret_str
|
||||
from litellm.constants import REDIS_CONNECTION_POOL_TIMEOUT, REDIS_SOCKET_TIMEOUT
|
||||
|
||||
from ._logging import verbose_logger
|
||||
|
||||
|
||||
def _get_redis_kwargs():
|
||||
arg_spec = inspect.getfullargspec(redis.Redis)
|
||||
|
||||
# Only allow primitive arguments
|
||||
exclude_args = {
|
||||
"self",
|
||||
"connection_pool",
|
||||
"retry",
|
||||
}
|
||||
|
||||
include_args = ["url"]
|
||||
|
||||
available_args = [x for x in arg_spec.args if x not in exclude_args] + include_args
|
||||
|
||||
return available_args
|
||||
|
||||
|
||||
def _get_redis_url_kwargs(client=None):
|
||||
if client is None:
|
||||
client = redis.Redis.from_url
|
||||
arg_spec = inspect.getfullargspec(redis.Redis.from_url)
|
||||
|
||||
# Only allow primitive arguments
|
||||
exclude_args = {
|
||||
"self",
|
||||
"connection_pool",
|
||||
"retry",
|
||||
}
|
||||
|
||||
include_args = ["url"]
|
||||
|
||||
available_args = [x for x in arg_spec.args if x not in exclude_args] + include_args
|
||||
|
||||
return available_args
|
||||
|
||||
|
||||
def _get_redis_cluster_kwargs(client=None):
|
||||
if client is None:
|
||||
client = redis.Redis.from_url
|
||||
arg_spec = inspect.getfullargspec(redis.RedisCluster)
|
||||
|
||||
# Only allow primitive arguments
|
||||
exclude_args = {"self", "connection_pool", "retry", "host", "port", "startup_nodes"}
|
||||
|
||||
available_args = [x for x in arg_spec.args if x not in exclude_args]
|
||||
available_args.append("password")
|
||||
available_args.append("username")
|
||||
available_args.append("ssl")
|
||||
|
||||
return available_args
|
||||
|
||||
|
||||
def _get_redis_env_kwarg_mapping():
|
||||
PREFIX = "REDIS_"
|
||||
|
||||
return {f"{PREFIX}{x.upper()}": x for x in _get_redis_kwargs()}
|
||||
|
||||
|
||||
def _redis_kwargs_from_environment():
|
||||
mapping = _get_redis_env_kwarg_mapping()
|
||||
|
||||
return_dict = {}
|
||||
for k, v in mapping.items():
|
||||
value = get_secret(k, default_value=None) # type: ignore
|
||||
if value is not None:
|
||||
return_dict[v] = value
|
||||
return return_dict
|
||||
|
||||
|
||||
def get_redis_url_from_environment():
|
||||
if "REDIS_URL" in os.environ:
|
||||
return os.environ["REDIS_URL"]
|
||||
|
||||
if "REDIS_HOST" not in os.environ or "REDIS_PORT" not in os.environ:
|
||||
raise ValueError(
|
||||
"Either 'REDIS_URL' or both 'REDIS_HOST' and 'REDIS_PORT' must be specified for Redis."
|
||||
)
|
||||
|
||||
if "REDIS_PASSWORD" in os.environ:
|
||||
redis_password = f":{os.environ['REDIS_PASSWORD']}@"
|
||||
else:
|
||||
redis_password = ""
|
||||
|
||||
return (
|
||||
f"redis://{redis_password}{os.environ['REDIS_HOST']}:{os.environ['REDIS_PORT']}"
|
||||
)
|
||||
|
||||
|
||||
def _get_redis_client_logic(**env_overrides):
|
||||
"""
|
||||
Common functionality across sync + async redis client implementations
|
||||
"""
|
||||
### check if "os.environ/<key-name>" passed in
|
||||
for k, v in env_overrides.items():
|
||||
if isinstance(v, str) and v.startswith("os.environ/"):
|
||||
v = v.replace("os.environ/", "")
|
||||
value = get_secret(v) # type: ignore
|
||||
env_overrides[k] = value
|
||||
|
||||
redis_kwargs = {
|
||||
**_redis_kwargs_from_environment(),
|
||||
**env_overrides,
|
||||
}
|
||||
|
||||
_startup_nodes: Optional[Union[str, list]] = redis_kwargs.get("startup_nodes", None) or get_secret( # type: ignore
|
||||
"REDIS_CLUSTER_NODES"
|
||||
)
|
||||
|
||||
if _startup_nodes is not None and isinstance(_startup_nodes, str):
|
||||
redis_kwargs["startup_nodes"] = json.loads(_startup_nodes)
|
||||
|
||||
_sentinel_nodes: Optional[Union[str, list]] = redis_kwargs.get("sentinel_nodes", None) or get_secret( # type: ignore
|
||||
"REDIS_SENTINEL_NODES"
|
||||
)
|
||||
|
||||
if _sentinel_nodes is not None and isinstance(_sentinel_nodes, str):
|
||||
redis_kwargs["sentinel_nodes"] = json.loads(_sentinel_nodes)
|
||||
|
||||
_sentinel_password: Optional[str] = redis_kwargs.get(
|
||||
"sentinel_password", None
|
||||
) or get_secret_str("REDIS_SENTINEL_PASSWORD")
|
||||
|
||||
if _sentinel_password is not None:
|
||||
redis_kwargs["sentinel_password"] = _sentinel_password
|
||||
|
||||
_service_name: Optional[str] = redis_kwargs.get("service_name", None) or get_secret( # type: ignore
|
||||
"REDIS_SERVICE_NAME"
|
||||
)
|
||||
|
||||
if _service_name is not None:
|
||||
redis_kwargs["service_name"] = _service_name
|
||||
|
||||
if "url" in redis_kwargs and redis_kwargs["url"] is not None:
|
||||
redis_kwargs.pop("host", None)
|
||||
redis_kwargs.pop("port", None)
|
||||
redis_kwargs.pop("db", None)
|
||||
redis_kwargs.pop("password", None)
|
||||
elif "startup_nodes" in redis_kwargs and redis_kwargs["startup_nodes"] is not None:
|
||||
pass
|
||||
elif (
|
||||
"sentinel_nodes" in redis_kwargs and redis_kwargs["sentinel_nodes"] is not None
|
||||
):
|
||||
pass
|
||||
elif "host" not in redis_kwargs or redis_kwargs["host"] is None:
|
||||
raise ValueError("Either 'host' or 'url' must be specified for redis.")
|
||||
|
||||
# litellm.print_verbose(f"redis_kwargs: {redis_kwargs}")
|
||||
return redis_kwargs
|
||||
|
||||
|
||||
def init_redis_cluster(redis_kwargs) -> redis.RedisCluster:
|
||||
_redis_cluster_nodes_in_env: Optional[str] = get_secret("REDIS_CLUSTER_NODES") # type: ignore
|
||||
if _redis_cluster_nodes_in_env is not None:
|
||||
try:
|
||||
redis_kwargs["startup_nodes"] = json.loads(_redis_cluster_nodes_in_env)
|
||||
except json.JSONDecodeError:
|
||||
raise ValueError(
|
||||
"REDIS_CLUSTER_NODES environment variable is not valid JSON. Please ensure it's properly formatted."
|
||||
)
|
||||
|
||||
verbose_logger.debug("init_redis_cluster: startup nodes are being initialized.")
|
||||
from redis.cluster import ClusterNode
|
||||
|
||||
args = _get_redis_cluster_kwargs()
|
||||
cluster_kwargs = {}
|
||||
for arg in redis_kwargs:
|
||||
if arg in args:
|
||||
cluster_kwargs[arg] = redis_kwargs[arg]
|
||||
|
||||
new_startup_nodes: List[ClusterNode] = []
|
||||
|
||||
for item in redis_kwargs["startup_nodes"]:
|
||||
new_startup_nodes.append(ClusterNode(**item))
|
||||
|
||||
redis_kwargs.pop("startup_nodes")
|
||||
return redis.RedisCluster(startup_nodes=new_startup_nodes, **cluster_kwargs) # type: ignore
|
||||
|
||||
|
||||
def _init_redis_sentinel(redis_kwargs) -> redis.Redis:
|
||||
sentinel_nodes = redis_kwargs.get("sentinel_nodes")
|
||||
sentinel_password = redis_kwargs.get("sentinel_password")
|
||||
service_name = redis_kwargs.get("service_name")
|
||||
|
||||
if not sentinel_nodes or not service_name:
|
||||
raise ValueError(
|
||||
"Both 'sentinel_nodes' and 'service_name' are required for Redis Sentinel."
|
||||
)
|
||||
|
||||
verbose_logger.debug("init_redis_sentinel: sentinel nodes are being initialized.")
|
||||
|
||||
# Set up the Sentinel client
|
||||
sentinel = redis.Sentinel(
|
||||
sentinel_nodes,
|
||||
socket_timeout=REDIS_SOCKET_TIMEOUT,
|
||||
password=sentinel_password,
|
||||
)
|
||||
|
||||
# Return the master instance for the given service
|
||||
|
||||
return sentinel.master_for(service_name)
|
||||
|
||||
|
||||
def _init_async_redis_sentinel(redis_kwargs) -> async_redis.Redis:
|
||||
sentinel_nodes = redis_kwargs.get("sentinel_nodes")
|
||||
sentinel_password = redis_kwargs.get("sentinel_password")
|
||||
service_name = redis_kwargs.get("service_name")
|
||||
|
||||
if not sentinel_nodes or not service_name:
|
||||
raise ValueError(
|
||||
"Both 'sentinel_nodes' and 'service_name' are required for Redis Sentinel."
|
||||
)
|
||||
|
||||
verbose_logger.debug("init_redis_sentinel: sentinel nodes are being initialized.")
|
||||
|
||||
# Set up the Sentinel client
|
||||
sentinel = async_redis.Sentinel(
|
||||
sentinel_nodes,
|
||||
socket_timeout=REDIS_SOCKET_TIMEOUT,
|
||||
password=sentinel_password,
|
||||
)
|
||||
|
||||
# Return the master instance for the given service
|
||||
|
||||
return sentinel.master_for(service_name)
|
||||
|
||||
|
||||
def get_redis_client(**env_overrides):
|
||||
redis_kwargs = _get_redis_client_logic(**env_overrides)
|
||||
if "url" in redis_kwargs and redis_kwargs["url"] is not None:
|
||||
args = _get_redis_url_kwargs()
|
||||
url_kwargs = {}
|
||||
for arg in redis_kwargs:
|
||||
if arg in args:
|
||||
url_kwargs[arg] = redis_kwargs[arg]
|
||||
|
||||
return redis.Redis.from_url(**url_kwargs)
|
||||
|
||||
if "startup_nodes" in redis_kwargs or get_secret("REDIS_CLUSTER_NODES") is not None: # type: ignore
|
||||
return init_redis_cluster(redis_kwargs)
|
||||
|
||||
# Check for Redis Sentinel
|
||||
if "sentinel_nodes" in redis_kwargs and "service_name" in redis_kwargs:
|
||||
return _init_redis_sentinel(redis_kwargs)
|
||||
|
||||
return redis.Redis(**redis_kwargs)
|
||||
|
||||
|
||||
def get_redis_async_client(
|
||||
**env_overrides,
|
||||
) -> async_redis.Redis:
|
||||
redis_kwargs = _get_redis_client_logic(**env_overrides)
|
||||
if "url" in redis_kwargs and redis_kwargs["url"] is not None:
|
||||
args = _get_redis_url_kwargs(client=async_redis.Redis.from_url)
|
||||
url_kwargs = {}
|
||||
for arg in redis_kwargs:
|
||||
if arg in args:
|
||||
url_kwargs[arg] = redis_kwargs[arg]
|
||||
else:
|
||||
verbose_logger.debug(
|
||||
"REDIS: ignoring argument: {}. Not an allowed async_redis.Redis.from_url arg.".format(
|
||||
arg
|
||||
)
|
||||
)
|
||||
return async_redis.Redis.from_url(**url_kwargs)
|
||||
|
||||
if "startup_nodes" in redis_kwargs:
|
||||
from redis.cluster import ClusterNode
|
||||
|
||||
args = _get_redis_cluster_kwargs()
|
||||
cluster_kwargs = {}
|
||||
for arg in redis_kwargs:
|
||||
if arg in args:
|
||||
cluster_kwargs[arg] = redis_kwargs[arg]
|
||||
|
||||
new_startup_nodes: List[ClusterNode] = []
|
||||
|
||||
for item in redis_kwargs["startup_nodes"]:
|
||||
new_startup_nodes.append(ClusterNode(**item))
|
||||
redis_kwargs.pop("startup_nodes")
|
||||
return async_redis.RedisCluster(
|
||||
startup_nodes=new_startup_nodes, **cluster_kwargs # type: ignore
|
||||
)
|
||||
|
||||
# Check for Redis Sentinel
|
||||
if "sentinel_nodes" in redis_kwargs and "service_name" in redis_kwargs:
|
||||
return _init_async_redis_sentinel(redis_kwargs)
|
||||
|
||||
return async_redis.Redis(
|
||||
**redis_kwargs,
|
||||
)
|
||||
|
||||
|
||||
def get_redis_connection_pool(**env_overrides):
|
||||
redis_kwargs = _get_redis_client_logic(**env_overrides)
|
||||
verbose_logger.debug("get_redis_connection_pool: redis_kwargs", redis_kwargs)
|
||||
if "url" in redis_kwargs and redis_kwargs["url"] is not None:
|
||||
return async_redis.BlockingConnectionPool.from_url(
|
||||
timeout=REDIS_CONNECTION_POOL_TIMEOUT, url=redis_kwargs["url"]
|
||||
)
|
||||
connection_class = async_redis.Connection
|
||||
if "ssl" in redis_kwargs:
|
||||
connection_class = async_redis.SSLConnection
|
||||
redis_kwargs.pop("ssl", None)
|
||||
redis_kwargs["connection_class"] = connection_class
|
||||
redis_kwargs.pop("startup_nodes", None)
|
||||
return async_redis.BlockingConnectionPool(
|
||||
timeout=REDIS_CONNECTION_POOL_TIMEOUT, **redis_kwargs
|
||||
)
|
||||
311
.venv/lib/python3.10/site-packages/litellm/_service_logger.py
Normal file
311
.venv/lib/python3.10/site-packages/litellm/_service_logger.py
Normal file
@@ -0,0 +1,311 @@
|
||||
import asyncio
|
||||
from datetime import datetime, timedelta
|
||||
from typing import TYPE_CHECKING, Any, Optional, Union
|
||||
|
||||
import litellm
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.proxy._types import UserAPIKeyAuth
|
||||
|
||||
from .integrations.custom_logger import CustomLogger
|
||||
from .integrations.datadog.datadog import DataDogLogger
|
||||
from .integrations.opentelemetry import OpenTelemetry
|
||||
from .integrations.prometheus_services import PrometheusServicesLogger
|
||||
from .types.services import ServiceLoggerPayload, ServiceTypes
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from opentelemetry.trace import Span as _Span
|
||||
|
||||
Span = Union[_Span, Any]
|
||||
OTELClass = OpenTelemetry
|
||||
else:
|
||||
Span = Any
|
||||
OTELClass = Any
|
||||
|
||||
|
||||
class ServiceLogging(CustomLogger):
|
||||
"""
|
||||
Separate class used for monitoring health of litellm-adjacent services (redis/postgres).
|
||||
"""
|
||||
|
||||
def __init__(self, mock_testing: bool = False) -> None:
|
||||
self.mock_testing = mock_testing
|
||||
self.mock_testing_sync_success_hook = 0
|
||||
self.mock_testing_async_success_hook = 0
|
||||
self.mock_testing_sync_failure_hook = 0
|
||||
self.mock_testing_async_failure_hook = 0
|
||||
if "prometheus_system" in litellm.service_callback:
|
||||
self.prometheusServicesLogger = PrometheusServicesLogger()
|
||||
|
||||
def service_success_hook(
|
||||
self,
|
||||
service: ServiceTypes,
|
||||
duration: float,
|
||||
call_type: str,
|
||||
parent_otel_span: Optional[Span] = None,
|
||||
start_time: Optional[Union[datetime, float]] = None,
|
||||
end_time: Optional[Union[float, datetime]] = None,
|
||||
):
|
||||
"""
|
||||
Handles both sync and async monitoring by checking for existing event loop.
|
||||
"""
|
||||
|
||||
if self.mock_testing:
|
||||
self.mock_testing_sync_success_hook += 1
|
||||
|
||||
try:
|
||||
# Try to get the current event loop
|
||||
loop = asyncio.get_event_loop()
|
||||
# Check if the loop is running
|
||||
if loop.is_running():
|
||||
# If we're in a running loop, create a task
|
||||
loop.create_task(
|
||||
self.async_service_success_hook(
|
||||
service=service,
|
||||
duration=duration,
|
||||
call_type=call_type,
|
||||
parent_otel_span=parent_otel_span,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
)
|
||||
)
|
||||
else:
|
||||
# Loop exists but not running, we can use run_until_complete
|
||||
loop.run_until_complete(
|
||||
self.async_service_success_hook(
|
||||
service=service,
|
||||
duration=duration,
|
||||
call_type=call_type,
|
||||
parent_otel_span=parent_otel_span,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
)
|
||||
)
|
||||
except RuntimeError:
|
||||
# No event loop exists, create a new one and run
|
||||
asyncio.run(
|
||||
self.async_service_success_hook(
|
||||
service=service,
|
||||
duration=duration,
|
||||
call_type=call_type,
|
||||
parent_otel_span=parent_otel_span,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
)
|
||||
)
|
||||
|
||||
def service_failure_hook(
|
||||
self, service: ServiceTypes, duration: float, error: Exception, call_type: str
|
||||
):
|
||||
"""
|
||||
[TODO] Not implemented for sync calls yet. V0 is focused on async monitoring (used by proxy).
|
||||
"""
|
||||
if self.mock_testing:
|
||||
self.mock_testing_sync_failure_hook += 1
|
||||
|
||||
async def async_service_success_hook(
|
||||
self,
|
||||
service: ServiceTypes,
|
||||
call_type: str,
|
||||
duration: float,
|
||||
parent_otel_span: Optional[Span] = None,
|
||||
start_time: Optional[Union[datetime, float]] = None,
|
||||
end_time: Optional[Union[datetime, float]] = None,
|
||||
event_metadata: Optional[dict] = None,
|
||||
):
|
||||
"""
|
||||
- For counting if the redis, postgres call is successful
|
||||
"""
|
||||
if self.mock_testing:
|
||||
self.mock_testing_async_success_hook += 1
|
||||
|
||||
payload = ServiceLoggerPayload(
|
||||
is_error=False,
|
||||
error=None,
|
||||
service=service,
|
||||
duration=duration,
|
||||
call_type=call_type,
|
||||
event_metadata=event_metadata,
|
||||
)
|
||||
|
||||
for callback in litellm.service_callback:
|
||||
if callback == "prometheus_system":
|
||||
await self.init_prometheus_services_logger_if_none()
|
||||
await self.prometheusServicesLogger.async_service_success_hook(
|
||||
payload=payload
|
||||
)
|
||||
elif callback == "datadog" or isinstance(callback, DataDogLogger):
|
||||
await self.init_datadog_logger_if_none()
|
||||
await self.dd_logger.async_service_success_hook(
|
||||
payload=payload,
|
||||
parent_otel_span=parent_otel_span,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
event_metadata=event_metadata,
|
||||
)
|
||||
elif callback == "otel" or isinstance(callback, OpenTelemetry):
|
||||
from litellm.proxy.proxy_server import open_telemetry_logger
|
||||
|
||||
await self.init_otel_logger_if_none()
|
||||
|
||||
if (
|
||||
parent_otel_span is not None
|
||||
and open_telemetry_logger is not None
|
||||
and isinstance(open_telemetry_logger, OpenTelemetry)
|
||||
):
|
||||
await self.otel_logger.async_service_success_hook(
|
||||
payload=payload,
|
||||
parent_otel_span=parent_otel_span,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
event_metadata=event_metadata,
|
||||
)
|
||||
|
||||
async def init_prometheus_services_logger_if_none(self):
|
||||
"""
|
||||
initializes prometheusServicesLogger if it is None or no attribute exists on ServiceLogging Object
|
||||
|
||||
"""
|
||||
if not hasattr(self, "prometheusServicesLogger"):
|
||||
self.prometheusServicesLogger = PrometheusServicesLogger()
|
||||
elif self.prometheusServicesLogger is None:
|
||||
self.prometheusServicesLogger = self.prometheusServicesLogger()
|
||||
return
|
||||
|
||||
async def init_datadog_logger_if_none(self):
|
||||
"""
|
||||
initializes dd_logger if it is None or no attribute exists on ServiceLogging Object
|
||||
|
||||
"""
|
||||
from litellm.integrations.datadog.datadog import DataDogLogger
|
||||
|
||||
if not hasattr(self, "dd_logger"):
|
||||
self.dd_logger: DataDogLogger = DataDogLogger()
|
||||
|
||||
return
|
||||
|
||||
async def init_otel_logger_if_none(self):
|
||||
"""
|
||||
initializes otel_logger if it is None or no attribute exists on ServiceLogging Object
|
||||
|
||||
"""
|
||||
from litellm.proxy.proxy_server import open_telemetry_logger
|
||||
|
||||
if not hasattr(self, "otel_logger"):
|
||||
if open_telemetry_logger is not None and isinstance(
|
||||
open_telemetry_logger, OpenTelemetry
|
||||
):
|
||||
self.otel_logger: OpenTelemetry = open_telemetry_logger
|
||||
else:
|
||||
verbose_logger.warning(
|
||||
"ServiceLogger: open_telemetry_logger is None or not an instance of OpenTelemetry"
|
||||
)
|
||||
return
|
||||
|
||||
async def async_service_failure_hook(
|
||||
self,
|
||||
service: ServiceTypes,
|
||||
duration: float,
|
||||
error: Union[str, Exception],
|
||||
call_type: str,
|
||||
parent_otel_span: Optional[Span] = None,
|
||||
start_time: Optional[Union[datetime, float]] = None,
|
||||
end_time: Optional[Union[float, datetime]] = None,
|
||||
event_metadata: Optional[dict] = None,
|
||||
):
|
||||
"""
|
||||
- For counting if the redis, postgres call is unsuccessful
|
||||
"""
|
||||
if self.mock_testing:
|
||||
self.mock_testing_async_failure_hook += 1
|
||||
|
||||
error_message = ""
|
||||
if isinstance(error, Exception):
|
||||
error_message = str(error)
|
||||
elif isinstance(error, str):
|
||||
error_message = error
|
||||
|
||||
payload = ServiceLoggerPayload(
|
||||
is_error=True,
|
||||
error=error_message,
|
||||
service=service,
|
||||
duration=duration,
|
||||
call_type=call_type,
|
||||
event_metadata=event_metadata,
|
||||
)
|
||||
|
||||
for callback in litellm.service_callback:
|
||||
if callback == "prometheus_system":
|
||||
await self.init_prometheus_services_logger_if_none()
|
||||
await self.prometheusServicesLogger.async_service_failure_hook(
|
||||
payload=payload,
|
||||
error=error,
|
||||
)
|
||||
elif callback == "datadog" or isinstance(callback, DataDogLogger):
|
||||
await self.init_datadog_logger_if_none()
|
||||
await self.dd_logger.async_service_failure_hook(
|
||||
payload=payload,
|
||||
error=error_message,
|
||||
parent_otel_span=parent_otel_span,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
event_metadata=event_metadata,
|
||||
)
|
||||
elif callback == "otel" or isinstance(callback, OpenTelemetry):
|
||||
from litellm.proxy.proxy_server import open_telemetry_logger
|
||||
|
||||
await self.init_otel_logger_if_none()
|
||||
|
||||
if not isinstance(error, str):
|
||||
error = str(error)
|
||||
|
||||
if (
|
||||
parent_otel_span is not None
|
||||
and open_telemetry_logger is not None
|
||||
and isinstance(open_telemetry_logger, OpenTelemetry)
|
||||
):
|
||||
await self.otel_logger.async_service_success_hook(
|
||||
payload=payload,
|
||||
parent_otel_span=parent_otel_span,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
event_metadata=event_metadata,
|
||||
)
|
||||
|
||||
async def async_post_call_failure_hook(
|
||||
self,
|
||||
request_data: dict,
|
||||
original_exception: Exception,
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
):
|
||||
"""
|
||||
Hook to track failed litellm-service calls
|
||||
"""
|
||||
return await super().async_post_call_failure_hook(
|
||||
request_data,
|
||||
original_exception,
|
||||
user_api_key_dict,
|
||||
)
|
||||
|
||||
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
|
||||
"""
|
||||
Hook to track latency for litellm proxy llm api calls
|
||||
"""
|
||||
try:
|
||||
_duration = end_time - start_time
|
||||
if isinstance(_duration, timedelta):
|
||||
_duration = _duration.total_seconds()
|
||||
elif isinstance(_duration, float):
|
||||
pass
|
||||
else:
|
||||
raise Exception(
|
||||
"Duration={} is not a float or timedelta object. type={}".format(
|
||||
_duration, type(_duration)
|
||||
)
|
||||
) # invalid _duration value
|
||||
await self.async_service_success_hook(
|
||||
service=ServiceTypes.LITELLM,
|
||||
duration=_duration,
|
||||
call_type=kwargs["call_type"],
|
||||
)
|
||||
except Exception as e:
|
||||
raise e
|
||||
6
.venv/lib/python3.10/site-packages/litellm/_version.py
Normal file
6
.venv/lib/python3.10/site-packages/litellm/_version.py
Normal file
@@ -0,0 +1,6 @@
|
||||
import importlib_metadata
|
||||
|
||||
try:
|
||||
version = importlib_metadata.version("litellm")
|
||||
except Exception:
|
||||
version = "unknown"
|
||||
@@ -0,0 +1,6 @@
|
||||
"""
|
||||
Anthropic module for LiteLLM
|
||||
"""
|
||||
from .messages import acreate, create
|
||||
|
||||
__all__ = ["acreate", "create"]
|
||||
Binary file not shown.
@@ -0,0 +1,117 @@
|
||||
"""
|
||||
Interface for Anthropic's messages API
|
||||
|
||||
Use this to call LLMs in Anthropic /messages Request/Response format
|
||||
|
||||
This is an __init__.py file to allow the following interface
|
||||
|
||||
- litellm.messages.acreate
|
||||
- litellm.messages.create
|
||||
|
||||
"""
|
||||
|
||||
from typing import AsyncIterator, Dict, Iterator, List, Optional, Union
|
||||
|
||||
from litellm.llms.anthropic.experimental_pass_through.messages.handler import (
|
||||
anthropic_messages as _async_anthropic_messages,
|
||||
)
|
||||
from litellm.types.llms.anthropic_messages.anthropic_response import (
|
||||
AnthropicMessagesResponse,
|
||||
)
|
||||
|
||||
|
||||
async def acreate(
|
||||
max_tokens: int,
|
||||
messages: List[Dict],
|
||||
model: str,
|
||||
metadata: Optional[Dict] = None,
|
||||
stop_sequences: Optional[List[str]] = None,
|
||||
stream: Optional[bool] = False,
|
||||
system: Optional[str] = None,
|
||||
temperature: Optional[float] = 1.0,
|
||||
thinking: Optional[Dict] = None,
|
||||
tool_choice: Optional[Dict] = None,
|
||||
tools: Optional[List[Dict]] = None,
|
||||
top_k: Optional[int] = None,
|
||||
top_p: Optional[float] = None,
|
||||
**kwargs
|
||||
) -> Union[AnthropicMessagesResponse, AsyncIterator]:
|
||||
"""
|
||||
Async wrapper for Anthropic's messages API
|
||||
|
||||
Args:
|
||||
max_tokens (int): Maximum tokens to generate (required)
|
||||
messages (List[Dict]): List of message objects with role and content (required)
|
||||
model (str): Model name to use (required)
|
||||
metadata (Dict, optional): Request metadata
|
||||
stop_sequences (List[str], optional): Custom stop sequences
|
||||
stream (bool, optional): Whether to stream the response
|
||||
system (str, optional): System prompt
|
||||
temperature (float, optional): Sampling temperature (0.0 to 1.0)
|
||||
thinking (Dict, optional): Extended thinking configuration
|
||||
tool_choice (Dict, optional): Tool choice configuration
|
||||
tools (List[Dict], optional): List of tool definitions
|
||||
top_k (int, optional): Top K sampling parameter
|
||||
top_p (float, optional): Nucleus sampling parameter
|
||||
**kwargs: Additional arguments
|
||||
|
||||
Returns:
|
||||
Dict: Response from the API
|
||||
"""
|
||||
return await _async_anthropic_messages(
|
||||
max_tokens=max_tokens,
|
||||
messages=messages,
|
||||
model=model,
|
||||
metadata=metadata,
|
||||
stop_sequences=stop_sequences,
|
||||
stream=stream,
|
||||
system=system,
|
||||
temperature=temperature,
|
||||
thinking=thinking,
|
||||
tool_choice=tool_choice,
|
||||
tools=tools,
|
||||
top_k=top_k,
|
||||
top_p=top_p,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
async def create(
|
||||
max_tokens: int,
|
||||
messages: List[Dict],
|
||||
model: str,
|
||||
metadata: Optional[Dict] = None,
|
||||
stop_sequences: Optional[List[str]] = None,
|
||||
stream: Optional[bool] = False,
|
||||
system: Optional[str] = None,
|
||||
temperature: Optional[float] = 1.0,
|
||||
thinking: Optional[Dict] = None,
|
||||
tool_choice: Optional[Dict] = None,
|
||||
tools: Optional[List[Dict]] = None,
|
||||
top_k: Optional[int] = None,
|
||||
top_p: Optional[float] = None,
|
||||
**kwargs
|
||||
) -> Union[AnthropicMessagesResponse, Iterator]:
|
||||
"""
|
||||
Async wrapper for Anthropic's messages API
|
||||
|
||||
Args:
|
||||
max_tokens (int): Maximum tokens to generate (required)
|
||||
messages (List[Dict]): List of message objects with role and content (required)
|
||||
model (str): Model name to use (required)
|
||||
metadata (Dict, optional): Request metadata
|
||||
stop_sequences (List[str], optional): Custom stop sequences
|
||||
stream (bool, optional): Whether to stream the response
|
||||
system (str, optional): System prompt
|
||||
temperature (float, optional): Sampling temperature (0.0 to 1.0)
|
||||
thinking (Dict, optional): Extended thinking configuration
|
||||
tool_choice (Dict, optional): Tool choice configuration
|
||||
tools (List[Dict], optional): List of tool definitions
|
||||
top_k (int, optional): Top K sampling parameter
|
||||
top_p (float, optional): Nucleus sampling parameter
|
||||
**kwargs: Additional arguments
|
||||
|
||||
Returns:
|
||||
Dict: Response from the API
|
||||
"""
|
||||
raise NotImplementedError("This function is not implemented")
|
||||
Binary file not shown.
@@ -0,0 +1,116 @@
|
||||
## Use LLM API endpoints in Anthropic Interface
|
||||
|
||||
Note: This is called `anthropic_interface` because `anthropic` is a known python package and was failing mypy type checking.
|
||||
|
||||
|
||||
## Usage
|
||||
---
|
||||
|
||||
### LiteLLM Python SDK
|
||||
|
||||
#### Non-streaming example
|
||||
```python showLineNumbers title="Example using LiteLLM Python SDK"
|
||||
import litellm
|
||||
response = await litellm.anthropic.messages.acreate(
|
||||
messages=[{"role": "user", "content": "Hello, can you tell me a short joke?"}],
|
||||
api_key=api_key,
|
||||
model="anthropic/claude-3-haiku-20240307",
|
||||
max_tokens=100,
|
||||
)
|
||||
```
|
||||
|
||||
Example response:
|
||||
```json
|
||||
{
|
||||
"content": [
|
||||
{
|
||||
"text": "Hi! this is a very short joke",
|
||||
"type": "text"
|
||||
}
|
||||
],
|
||||
"id": "msg_013Zva2CMHLNnXjNJJKqJ2EF",
|
||||
"model": "claude-3-7-sonnet-20250219",
|
||||
"role": "assistant",
|
||||
"stop_reason": "end_turn",
|
||||
"stop_sequence": null,
|
||||
"type": "message",
|
||||
"usage": {
|
||||
"input_tokens": 2095,
|
||||
"output_tokens": 503,
|
||||
"cache_creation_input_tokens": 2095,
|
||||
"cache_read_input_tokens": 0
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
#### Streaming example
|
||||
```python showLineNumbers title="Example using LiteLLM Python SDK"
|
||||
import litellm
|
||||
response = await litellm.anthropic.messages.acreate(
|
||||
messages=[{"role": "user", "content": "Hello, can you tell me a short joke?"}],
|
||||
api_key=api_key,
|
||||
model="anthropic/claude-3-haiku-20240307",
|
||||
max_tokens=100,
|
||||
stream=True,
|
||||
)
|
||||
async for chunk in response:
|
||||
print(chunk)
|
||||
```
|
||||
|
||||
### LiteLLM Proxy Server
|
||||
|
||||
|
||||
1. Setup config.yaml
|
||||
|
||||
```yaml
|
||||
model_list:
|
||||
- model_name: anthropic-claude
|
||||
litellm_params:
|
||||
model: claude-3-7-sonnet-latest
|
||||
```
|
||||
|
||||
2. Start proxy
|
||||
|
||||
```bash
|
||||
litellm --config /path/to/config.yaml
|
||||
```
|
||||
|
||||
3. Test it!
|
||||
|
||||
<Tabs>
|
||||
<TabItem label="Anthropic Python SDK" value="python">
|
||||
|
||||
```python showLineNumbers title="Example using LiteLLM Proxy Server"
|
||||
import anthropic
|
||||
|
||||
# point anthropic sdk to litellm proxy
|
||||
client = anthropic.Anthropic(
|
||||
base_url="http://0.0.0.0:4000",
|
||||
api_key="sk-1234",
|
||||
)
|
||||
|
||||
response = client.messages.create(
|
||||
messages=[{"role": "user", "content": "Hello, can you tell me a short joke?"}],
|
||||
model="anthropic/claude-3-haiku-20240307",
|
||||
max_tokens=100,
|
||||
)
|
||||
```
|
||||
</TabItem>
|
||||
<TabItem label="curl" value="curl">
|
||||
|
||||
```bash showLineNumbers title="Example using LiteLLM Proxy Server"
|
||||
curl -L -X POST 'http://0.0.0.0:4000/v1/messages' \
|
||||
-H 'content-type: application/json' \
|
||||
-H 'x-api-key: $LITELLM_API_KEY' \
|
||||
-H 'anthropic-version: 2023-06-01' \
|
||||
-d '{
|
||||
"model": "anthropic-claude",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Hello, can you tell me a short joke?"
|
||||
}
|
||||
],
|
||||
"max_tokens": 100
|
||||
}'
|
||||
```
|
||||
Binary file not shown.
Binary file not shown.
1476
.venv/lib/python3.10/site-packages/litellm/assistants/main.py
Normal file
1476
.venv/lib/python3.10/site-packages/litellm/assistants/main.py
Normal file
File diff suppressed because it is too large
Load Diff
161
.venv/lib/python3.10/site-packages/litellm/assistants/utils.py
Normal file
161
.venv/lib/python3.10/site-packages/litellm/assistants/utils.py
Normal file
@@ -0,0 +1,161 @@
|
||||
from typing import Optional, Union
|
||||
|
||||
import litellm
|
||||
|
||||
from ..exceptions import UnsupportedParamsError
|
||||
from ..types.llms.openai import *
|
||||
|
||||
|
||||
def get_optional_params_add_message(
|
||||
role: Optional[str],
|
||||
content: Optional[
|
||||
Union[
|
||||
str,
|
||||
List[
|
||||
Union[
|
||||
MessageContentTextObject,
|
||||
MessageContentImageFileObject,
|
||||
MessageContentImageURLObject,
|
||||
]
|
||||
],
|
||||
]
|
||||
],
|
||||
attachments: Optional[List[Attachment]],
|
||||
metadata: Optional[dict],
|
||||
custom_llm_provider: str,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Azure doesn't support 'attachments' for creating a message
|
||||
|
||||
Reference - https://learn.microsoft.com/en-us/azure/ai-services/openai/assistants-reference-messages?tabs=python#create-message
|
||||
"""
|
||||
passed_params = locals()
|
||||
custom_llm_provider = passed_params.pop("custom_llm_provider")
|
||||
special_params = passed_params.pop("kwargs")
|
||||
for k, v in special_params.items():
|
||||
passed_params[k] = v
|
||||
|
||||
default_params = {
|
||||
"role": None,
|
||||
"content": None,
|
||||
"attachments": None,
|
||||
"metadata": None,
|
||||
}
|
||||
|
||||
non_default_params = {
|
||||
k: v
|
||||
for k, v in passed_params.items()
|
||||
if (k in default_params and v != default_params[k])
|
||||
}
|
||||
optional_params = {}
|
||||
|
||||
## raise exception if non-default value passed for non-openai/azure embedding calls
|
||||
def _check_valid_arg(supported_params):
|
||||
if len(non_default_params.keys()) > 0:
|
||||
keys = list(non_default_params.keys())
|
||||
for k in keys:
|
||||
if (
|
||||
litellm.drop_params is True and k not in supported_params
|
||||
): # drop the unsupported non-default values
|
||||
non_default_params.pop(k, None)
|
||||
elif k not in supported_params:
|
||||
raise litellm.utils.UnsupportedParamsError(
|
||||
status_code=500,
|
||||
message="k={}, not supported by {}. Supported params={}. To drop it from the call, set `litellm.drop_params = True`.".format(
|
||||
k, custom_llm_provider, supported_params
|
||||
),
|
||||
)
|
||||
return non_default_params
|
||||
|
||||
if custom_llm_provider == "openai":
|
||||
optional_params = non_default_params
|
||||
elif custom_llm_provider == "azure":
|
||||
supported_params = (
|
||||
litellm.AzureOpenAIAssistantsAPIConfig().get_supported_openai_create_message_params()
|
||||
)
|
||||
_check_valid_arg(supported_params=supported_params)
|
||||
optional_params = litellm.AzureOpenAIAssistantsAPIConfig().map_openai_params_create_message_params(
|
||||
non_default_params=non_default_params, optional_params=optional_params
|
||||
)
|
||||
for k in passed_params.keys():
|
||||
if k not in default_params.keys():
|
||||
optional_params[k] = passed_params[k]
|
||||
return optional_params
|
||||
|
||||
|
||||
def get_optional_params_image_gen(
|
||||
n: Optional[int] = None,
|
||||
quality: Optional[str] = None,
|
||||
response_format: Optional[str] = None,
|
||||
size: Optional[str] = None,
|
||||
style: Optional[str] = None,
|
||||
user: Optional[str] = None,
|
||||
custom_llm_provider: Optional[str] = None,
|
||||
**kwargs,
|
||||
):
|
||||
# retrieve all parameters passed to the function
|
||||
passed_params = locals()
|
||||
custom_llm_provider = passed_params.pop("custom_llm_provider")
|
||||
special_params = passed_params.pop("kwargs")
|
||||
for k, v in special_params.items():
|
||||
passed_params[k] = v
|
||||
|
||||
default_params = {
|
||||
"n": None,
|
||||
"quality": None,
|
||||
"response_format": None,
|
||||
"size": None,
|
||||
"style": None,
|
||||
"user": None,
|
||||
}
|
||||
|
||||
non_default_params = {
|
||||
k: v
|
||||
for k, v in passed_params.items()
|
||||
if (k in default_params and v != default_params[k])
|
||||
}
|
||||
optional_params = {}
|
||||
|
||||
## raise exception if non-default value passed for non-openai/azure embedding calls
|
||||
def _check_valid_arg(supported_params):
|
||||
if len(non_default_params.keys()) > 0:
|
||||
keys = list(non_default_params.keys())
|
||||
for k in keys:
|
||||
if (
|
||||
litellm.drop_params is True and k not in supported_params
|
||||
): # drop the unsupported non-default values
|
||||
non_default_params.pop(k, None)
|
||||
elif k not in supported_params:
|
||||
raise UnsupportedParamsError(
|
||||
status_code=500,
|
||||
message=f"Setting user/encoding format is not supported by {custom_llm_provider}. To drop it from the call, set `litellm.drop_params = True`.",
|
||||
)
|
||||
return non_default_params
|
||||
|
||||
if (
|
||||
custom_llm_provider == "openai"
|
||||
or custom_llm_provider == "azure"
|
||||
or custom_llm_provider in litellm.openai_compatible_providers
|
||||
):
|
||||
optional_params = non_default_params
|
||||
elif custom_llm_provider == "bedrock":
|
||||
supported_params = ["size"]
|
||||
_check_valid_arg(supported_params=supported_params)
|
||||
if size is not None:
|
||||
width, height = size.split("x")
|
||||
optional_params["width"] = int(width)
|
||||
optional_params["height"] = int(height)
|
||||
elif custom_llm_provider == "vertex_ai":
|
||||
supported_params = ["n"]
|
||||
"""
|
||||
All params here: https://console.cloud.google.com/vertex-ai/publishers/google/model-garden/imagegeneration?project=adroit-crow-413218
|
||||
"""
|
||||
_check_valid_arg(supported_params=supported_params)
|
||||
if n is not None:
|
||||
optional_params["sampleCount"] = int(n)
|
||||
|
||||
for k in passed_params.keys():
|
||||
if k not in default_params.keys():
|
||||
optional_params[k] = passed_params[k]
|
||||
return optional_params
|
||||
@@ -0,0 +1,11 @@
|
||||
# Implementation of `litellm.batch_completion`, `litellm.batch_completion_models`, `litellm.batch_completion_models_all_responses`
|
||||
|
||||
Doc: https://docs.litellm.ai/docs/completion/batching
|
||||
|
||||
|
||||
LiteLLM Python SDK allows you to:
|
||||
1. `litellm.batch_completion` Batch litellm.completion function for a given model.
|
||||
2. `litellm.batch_completion_models` Send a request to multiple language models concurrently and return the response
|
||||
as soon as one of the models responds.
|
||||
3. `litellm.batch_completion_models_all_responses` Send a request to multiple language models concurrently and return a list of responses
|
||||
from all models that respond.
|
||||
Binary file not shown.
@@ -0,0 +1,253 @@
|
||||
from concurrent.futures import FIRST_COMPLETED, ThreadPoolExecutor, wait
|
||||
from typing import List, Optional
|
||||
|
||||
import litellm
|
||||
from litellm._logging import print_verbose
|
||||
from litellm.utils import get_optional_params
|
||||
|
||||
from ..llms.vllm.completion import handler as vllm_handler
|
||||
|
||||
|
||||
def batch_completion(
|
||||
model: str,
|
||||
# Optional OpenAI params: see https://platform.openai.com/docs/api-reference/chat/create
|
||||
messages: List = [],
|
||||
functions: Optional[List] = None,
|
||||
function_call: Optional[str] = None,
|
||||
temperature: Optional[float] = None,
|
||||
top_p: Optional[float] = None,
|
||||
n: Optional[int] = None,
|
||||
stream: Optional[bool] = None,
|
||||
stop=None,
|
||||
max_tokens: Optional[int] = None,
|
||||
presence_penalty: Optional[float] = None,
|
||||
frequency_penalty: Optional[float] = None,
|
||||
logit_bias: Optional[dict] = None,
|
||||
user: Optional[str] = None,
|
||||
deployment_id=None,
|
||||
request_timeout: Optional[int] = None,
|
||||
timeout: Optional[int] = 600,
|
||||
max_workers: Optional[int] = 100,
|
||||
# Optional liteLLM function params
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Batch litellm.completion function for a given model.
|
||||
|
||||
Args:
|
||||
model (str): The model to use for generating completions.
|
||||
messages (List, optional): List of messages to use as input for generating completions. Defaults to [].
|
||||
functions (List, optional): List of functions to use as input for generating completions. Defaults to [].
|
||||
function_call (str, optional): The function call to use as input for generating completions. Defaults to "".
|
||||
temperature (float, optional): The temperature parameter for generating completions. Defaults to None.
|
||||
top_p (float, optional): The top-p parameter for generating completions. Defaults to None.
|
||||
n (int, optional): The number of completions to generate. Defaults to None.
|
||||
stream (bool, optional): Whether to stream completions or not. Defaults to None.
|
||||
stop (optional): The stop parameter for generating completions. Defaults to None.
|
||||
max_tokens (float, optional): The maximum number of tokens to generate. Defaults to None.
|
||||
presence_penalty (float, optional): The presence penalty for generating completions. Defaults to None.
|
||||
frequency_penalty (float, optional): The frequency penalty for generating completions. Defaults to None.
|
||||
logit_bias (dict, optional): The logit bias for generating completions. Defaults to {}.
|
||||
user (str, optional): The user string for generating completions. Defaults to "".
|
||||
deployment_id (optional): The deployment ID for generating completions. Defaults to None.
|
||||
request_timeout (int, optional): The request timeout for generating completions. Defaults to None.
|
||||
max_workers (int,optional): The maximum number of threads to use for parallel processing.
|
||||
|
||||
Returns:
|
||||
list: A list of completion results.
|
||||
"""
|
||||
args = locals()
|
||||
|
||||
batch_messages = messages
|
||||
completions = []
|
||||
model = model
|
||||
custom_llm_provider = None
|
||||
if model.split("/", 1)[0] in litellm.provider_list:
|
||||
custom_llm_provider = model.split("/", 1)[0]
|
||||
model = model.split("/", 1)[1]
|
||||
if custom_llm_provider == "vllm":
|
||||
optional_params = get_optional_params(
|
||||
functions=functions,
|
||||
function_call=function_call,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
n=n,
|
||||
stream=stream or False,
|
||||
stop=stop,
|
||||
max_tokens=max_tokens,
|
||||
presence_penalty=presence_penalty,
|
||||
frequency_penalty=frequency_penalty,
|
||||
logit_bias=logit_bias,
|
||||
user=user,
|
||||
# params to identify the model
|
||||
model=model,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
)
|
||||
results = vllm_handler.batch_completions(
|
||||
model=model,
|
||||
messages=batch_messages,
|
||||
custom_prompt_dict=litellm.custom_prompt_dict,
|
||||
optional_params=optional_params,
|
||||
)
|
||||
# all non VLLM models for batch completion models
|
||||
else:
|
||||
|
||||
def chunks(lst, n):
|
||||
"""Yield successive n-sized chunks from lst."""
|
||||
for i in range(0, len(lst), n):
|
||||
yield lst[i : i + n]
|
||||
|
||||
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
||||
for sub_batch in chunks(batch_messages, 100):
|
||||
for message_list in sub_batch:
|
||||
kwargs_modified = args.copy()
|
||||
kwargs_modified.pop("max_workers")
|
||||
kwargs_modified["messages"] = message_list
|
||||
original_kwargs = {}
|
||||
if "kwargs" in kwargs_modified:
|
||||
original_kwargs = kwargs_modified.pop("kwargs")
|
||||
future = executor.submit(
|
||||
litellm.completion, **kwargs_modified, **original_kwargs
|
||||
)
|
||||
completions.append(future)
|
||||
|
||||
# Retrieve the results from the futures
|
||||
# results = [future.result() for future in completions]
|
||||
# return exceptions if any
|
||||
results = []
|
||||
for future in completions:
|
||||
try:
|
||||
results.append(future.result())
|
||||
except Exception as exc:
|
||||
results.append(exc)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
# send one request to multiple models
|
||||
# return as soon as one of the llms responds
|
||||
def batch_completion_models(*args, **kwargs):
|
||||
"""
|
||||
Send a request to multiple language models concurrently and return the response
|
||||
as soon as one of the models responds.
|
||||
|
||||
Args:
|
||||
*args: Variable-length positional arguments passed to the completion function.
|
||||
**kwargs: Additional keyword arguments:
|
||||
- models (str or list of str): The language models to send requests to.
|
||||
- Other keyword arguments to be passed to the completion function.
|
||||
|
||||
Returns:
|
||||
str or None: The response from one of the language models, or None if no response is received.
|
||||
|
||||
Note:
|
||||
This function utilizes a ThreadPoolExecutor to parallelize requests to multiple models.
|
||||
It sends requests concurrently and returns the response from the first model that responds.
|
||||
"""
|
||||
|
||||
if "model" in kwargs:
|
||||
kwargs.pop("model")
|
||||
if "models" in kwargs:
|
||||
models = kwargs["models"]
|
||||
kwargs.pop("models")
|
||||
futures = {}
|
||||
with ThreadPoolExecutor(max_workers=len(models)) as executor:
|
||||
for model in models:
|
||||
futures[model] = executor.submit(
|
||||
litellm.completion, *args, model=model, **kwargs
|
||||
)
|
||||
|
||||
for model, future in sorted(
|
||||
futures.items(), key=lambda x: models.index(x[0])
|
||||
):
|
||||
if future.result() is not None:
|
||||
return future.result()
|
||||
elif "deployments" in kwargs:
|
||||
deployments = kwargs["deployments"]
|
||||
kwargs.pop("deployments")
|
||||
kwargs.pop("model_list")
|
||||
nested_kwargs = kwargs.pop("kwargs", {})
|
||||
futures = {}
|
||||
with ThreadPoolExecutor(max_workers=len(deployments)) as executor:
|
||||
for deployment in deployments:
|
||||
for key in kwargs.keys():
|
||||
if (
|
||||
key not in deployment
|
||||
): # don't override deployment values e.g. model name, api base, etc.
|
||||
deployment[key] = kwargs[key]
|
||||
kwargs = {**deployment, **nested_kwargs}
|
||||
futures[deployment["model"]] = executor.submit(
|
||||
litellm.completion, **kwargs
|
||||
)
|
||||
|
||||
while futures:
|
||||
# wait for the first returned future
|
||||
print_verbose("\n\n waiting for next result\n\n")
|
||||
done, _ = wait(futures.values(), return_when=FIRST_COMPLETED)
|
||||
print_verbose(f"done list\n{done}")
|
||||
for future in done:
|
||||
try:
|
||||
result = future.result()
|
||||
return result
|
||||
except Exception:
|
||||
# if model 1 fails, continue with response from model 2, model3
|
||||
print_verbose(
|
||||
"\n\ngot an exception, ignoring, removing from futures"
|
||||
)
|
||||
print_verbose(futures)
|
||||
new_futures = {}
|
||||
for key, value in futures.items():
|
||||
if future == value:
|
||||
print_verbose(f"removing key{key}")
|
||||
continue
|
||||
else:
|
||||
new_futures[key] = value
|
||||
futures = new_futures
|
||||
print_verbose(f"new futures{futures}")
|
||||
continue
|
||||
|
||||
print_verbose("\n\ndone looping through futures\n\n")
|
||||
print_verbose(futures)
|
||||
|
||||
return None # If no response is received from any model
|
||||
|
||||
|
||||
def batch_completion_models_all_responses(*args, **kwargs):
|
||||
"""
|
||||
Send a request to multiple language models concurrently and return a list of responses
|
||||
from all models that respond.
|
||||
|
||||
Args:
|
||||
*args: Variable-length positional arguments passed to the completion function.
|
||||
**kwargs: Additional keyword arguments:
|
||||
- models (str or list of str): The language models to send requests to.
|
||||
- Other keyword arguments to be passed to the completion function.
|
||||
|
||||
Returns:
|
||||
list: A list of responses from the language models that responded.
|
||||
|
||||
Note:
|
||||
This function utilizes a ThreadPoolExecutor to parallelize requests to multiple models.
|
||||
It sends requests concurrently and collects responses from all models that respond.
|
||||
"""
|
||||
import concurrent.futures
|
||||
|
||||
# ANSI escape codes for colored output
|
||||
|
||||
if "model" in kwargs:
|
||||
kwargs.pop("model")
|
||||
if "models" in kwargs:
|
||||
models = kwargs["models"]
|
||||
kwargs.pop("models")
|
||||
else:
|
||||
raise Exception("'models' param not in kwargs")
|
||||
|
||||
responses = []
|
||||
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=len(models)) as executor:
|
||||
for idx, model in enumerate(models):
|
||||
future = executor.submit(litellm.completion, *args, model=model, **kwargs)
|
||||
if future.result() is not None:
|
||||
responses.append(future.result())
|
||||
|
||||
return responses
|
||||
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,182 @@
|
||||
import json
|
||||
from typing import Any, List, Literal, Tuple
|
||||
|
||||
import litellm
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.types.llms.openai import Batch
|
||||
from litellm.types.utils import CallTypes, Usage
|
||||
|
||||
|
||||
async def _handle_completed_batch(
|
||||
batch: Batch,
|
||||
custom_llm_provider: Literal["openai", "azure", "vertex_ai"],
|
||||
) -> Tuple[float, Usage, List[str]]:
|
||||
"""Helper function to process a completed batch and handle logging"""
|
||||
# Get batch results
|
||||
file_content_dictionary = await _get_batch_output_file_content_as_dictionary(
|
||||
batch, custom_llm_provider
|
||||
)
|
||||
|
||||
# Calculate costs and usage
|
||||
batch_cost = await _batch_cost_calculator(
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
file_content_dictionary=file_content_dictionary,
|
||||
)
|
||||
batch_usage = _get_batch_job_total_usage_from_file_content(
|
||||
file_content_dictionary=file_content_dictionary,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
)
|
||||
|
||||
batch_models = _get_batch_models_from_file_content(file_content_dictionary)
|
||||
|
||||
return batch_cost, batch_usage, batch_models
|
||||
|
||||
|
||||
def _get_batch_models_from_file_content(
|
||||
file_content_dictionary: List[dict],
|
||||
) -> List[str]:
|
||||
"""
|
||||
Get the models from the file content
|
||||
"""
|
||||
batch_models = []
|
||||
for _item in file_content_dictionary:
|
||||
if _batch_response_was_successful(_item):
|
||||
_response_body = _get_response_from_batch_job_output_file(_item)
|
||||
_model = _response_body.get("model")
|
||||
if _model:
|
||||
batch_models.append(_model)
|
||||
return batch_models
|
||||
|
||||
|
||||
async def _batch_cost_calculator(
|
||||
file_content_dictionary: List[dict],
|
||||
custom_llm_provider: Literal["openai", "azure", "vertex_ai"] = "openai",
|
||||
) -> float:
|
||||
"""
|
||||
Calculate the cost of a batch based on the output file id
|
||||
"""
|
||||
if custom_llm_provider == "vertex_ai":
|
||||
raise ValueError("Vertex AI does not support file content retrieval")
|
||||
total_cost = _get_batch_job_cost_from_file_content(
|
||||
file_content_dictionary=file_content_dictionary,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
)
|
||||
verbose_logger.debug("total_cost=%s", total_cost)
|
||||
return total_cost
|
||||
|
||||
|
||||
async def _get_batch_output_file_content_as_dictionary(
|
||||
batch: Batch,
|
||||
custom_llm_provider: Literal["openai", "azure", "vertex_ai"] = "openai",
|
||||
) -> List[dict]:
|
||||
"""
|
||||
Get the batch output file content as a list of dictionaries
|
||||
"""
|
||||
from litellm.files.main import afile_content
|
||||
|
||||
if custom_llm_provider == "vertex_ai":
|
||||
raise ValueError("Vertex AI does not support file content retrieval")
|
||||
|
||||
if batch.output_file_id is None:
|
||||
raise ValueError("Output file id is None cannot retrieve file content")
|
||||
|
||||
_file_content = await afile_content(
|
||||
file_id=batch.output_file_id,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
)
|
||||
return _get_file_content_as_dictionary(_file_content.content)
|
||||
|
||||
|
||||
def _get_file_content_as_dictionary(file_content: bytes) -> List[dict]:
|
||||
"""
|
||||
Get the file content as a list of dictionaries from JSON Lines format
|
||||
"""
|
||||
try:
|
||||
_file_content_str = file_content.decode("utf-8")
|
||||
# Split by newlines and parse each line as a separate JSON object
|
||||
json_objects = []
|
||||
for line in _file_content_str.strip().split("\n"):
|
||||
if line: # Skip empty lines
|
||||
json_objects.append(json.loads(line))
|
||||
verbose_logger.debug("json_objects=%s", json.dumps(json_objects, indent=4))
|
||||
return json_objects
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
|
||||
def _get_batch_job_cost_from_file_content(
|
||||
file_content_dictionary: List[dict],
|
||||
custom_llm_provider: Literal["openai", "azure", "vertex_ai"] = "openai",
|
||||
) -> float:
|
||||
"""
|
||||
Get the cost of a batch job from the file content
|
||||
"""
|
||||
try:
|
||||
total_cost: float = 0.0
|
||||
# parse the file content as json
|
||||
verbose_logger.debug(
|
||||
"file_content_dictionary=%s", json.dumps(file_content_dictionary, indent=4)
|
||||
)
|
||||
for _item in file_content_dictionary:
|
||||
if _batch_response_was_successful(_item):
|
||||
_response_body = _get_response_from_batch_job_output_file(_item)
|
||||
total_cost += litellm.completion_cost(
|
||||
completion_response=_response_body,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
call_type=CallTypes.aretrieve_batch.value,
|
||||
)
|
||||
verbose_logger.debug("total_cost=%s", total_cost)
|
||||
return total_cost
|
||||
except Exception as e:
|
||||
verbose_logger.error("error in _get_batch_job_cost_from_file_content", e)
|
||||
raise e
|
||||
|
||||
|
||||
def _get_batch_job_total_usage_from_file_content(
|
||||
file_content_dictionary: List[dict],
|
||||
custom_llm_provider: Literal["openai", "azure", "vertex_ai"] = "openai",
|
||||
) -> Usage:
|
||||
"""
|
||||
Get the tokens of a batch job from the file content
|
||||
"""
|
||||
total_tokens: int = 0
|
||||
prompt_tokens: int = 0
|
||||
completion_tokens: int = 0
|
||||
for _item in file_content_dictionary:
|
||||
if _batch_response_was_successful(_item):
|
||||
_response_body = _get_response_from_batch_job_output_file(_item)
|
||||
usage: Usage = _get_batch_job_usage_from_response_body(_response_body)
|
||||
total_tokens += usage.total_tokens
|
||||
prompt_tokens += usage.prompt_tokens
|
||||
completion_tokens += usage.completion_tokens
|
||||
return Usage(
|
||||
total_tokens=total_tokens,
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
)
|
||||
|
||||
|
||||
def _get_batch_job_usage_from_response_body(response_body: dict) -> Usage:
|
||||
"""
|
||||
Get the tokens of a batch job from the response body
|
||||
"""
|
||||
_usage_dict = response_body.get("usage", None) or {}
|
||||
usage: Usage = Usage(**_usage_dict)
|
||||
return usage
|
||||
|
||||
|
||||
def _get_response_from_batch_job_output_file(batch_job_output_file: dict) -> Any:
|
||||
"""
|
||||
Get the response from the batch job output file
|
||||
"""
|
||||
_response: dict = batch_job_output_file.get("response", None) or {}
|
||||
_response_body = _response.get("body", None) or {}
|
||||
return _response_body
|
||||
|
||||
|
||||
def _batch_response_was_successful(batch_job_output_file: dict) -> bool:
|
||||
"""
|
||||
Check if the batch job response status == 200
|
||||
"""
|
||||
_response: dict = batch_job_output_file.get("response", None) or {}
|
||||
return _response.get("status_code", None) == 200
|
||||
792
.venv/lib/python3.10/site-packages/litellm/batches/main.py
Normal file
792
.venv/lib/python3.10/site-packages/litellm/batches/main.py
Normal file
@@ -0,0 +1,792 @@
|
||||
"""
|
||||
Main File for Batches API implementation
|
||||
|
||||
https://platform.openai.com/docs/api-reference/batch
|
||||
|
||||
- create_batch()
|
||||
- retrieve_batch()
|
||||
- cancel_batch()
|
||||
- list_batch()
|
||||
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import contextvars
|
||||
import os
|
||||
from functools import partial
|
||||
from typing import Any, Coroutine, Dict, Literal, Optional, Union
|
||||
|
||||
import httpx
|
||||
|
||||
import litellm
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
|
||||
from litellm.llms.azure.batches.handler import AzureBatchesAPI
|
||||
from litellm.llms.openai.openai import OpenAIBatchesAPI
|
||||
from litellm.llms.vertex_ai.batches.handler import VertexAIBatchPrediction
|
||||
from litellm.secret_managers.main import get_secret_str
|
||||
from litellm.types.llms.openai import (
|
||||
Batch,
|
||||
CancelBatchRequest,
|
||||
CreateBatchRequest,
|
||||
RetrieveBatchRequest,
|
||||
)
|
||||
from litellm.types.router import GenericLiteLLMParams
|
||||
from litellm.types.utils import LiteLLMBatch
|
||||
from litellm.utils import client, get_litellm_params, supports_httpx_timeout
|
||||
|
||||
####### ENVIRONMENT VARIABLES ###################
|
||||
openai_batches_instance = OpenAIBatchesAPI()
|
||||
azure_batches_instance = AzureBatchesAPI()
|
||||
vertex_ai_batches_instance = VertexAIBatchPrediction(gcs_bucket_name="")
|
||||
#################################################
|
||||
|
||||
|
||||
@client
|
||||
async def acreate_batch(
|
||||
completion_window: Literal["24h"],
|
||||
endpoint: Literal["/v1/chat/completions", "/v1/embeddings", "/v1/completions"],
|
||||
input_file_id: str,
|
||||
custom_llm_provider: Literal["openai", "azure", "vertex_ai"] = "openai",
|
||||
metadata: Optional[Dict[str, str]] = None,
|
||||
extra_headers: Optional[Dict[str, str]] = None,
|
||||
extra_body: Optional[Dict[str, str]] = None,
|
||||
**kwargs,
|
||||
) -> Batch:
|
||||
"""
|
||||
Async: Creates and executes a batch from an uploaded file of request
|
||||
|
||||
LiteLLM Equivalent of POST: https://api.openai.com/v1/batches
|
||||
"""
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
kwargs["acreate_batch"] = True
|
||||
|
||||
# Use a partial function to pass your keyword arguments
|
||||
func = partial(
|
||||
create_batch,
|
||||
completion_window,
|
||||
endpoint,
|
||||
input_file_id,
|
||||
custom_llm_provider,
|
||||
metadata,
|
||||
extra_headers,
|
||||
extra_body,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# Add the context to the function
|
||||
ctx = contextvars.copy_context()
|
||||
func_with_context = partial(ctx.run, func)
|
||||
init_response = await loop.run_in_executor(None, func_with_context)
|
||||
|
||||
if asyncio.iscoroutine(init_response):
|
||||
response = await init_response
|
||||
else:
|
||||
response = init_response
|
||||
|
||||
return response
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
|
||||
@client
|
||||
def create_batch(
|
||||
completion_window: Literal["24h"],
|
||||
endpoint: Literal["/v1/chat/completions", "/v1/embeddings", "/v1/completions"],
|
||||
input_file_id: str,
|
||||
custom_llm_provider: Literal["openai", "azure", "vertex_ai"] = "openai",
|
||||
metadata: Optional[Dict[str, str]] = None,
|
||||
extra_headers: Optional[Dict[str, str]] = None,
|
||||
extra_body: Optional[Dict[str, str]] = None,
|
||||
**kwargs,
|
||||
) -> Union[LiteLLMBatch, Coroutine[Any, Any, LiteLLMBatch]]:
|
||||
"""
|
||||
Creates and executes a batch from an uploaded file of request
|
||||
|
||||
LiteLLM Equivalent of POST: https://api.openai.com/v1/batches
|
||||
"""
|
||||
try:
|
||||
optional_params = GenericLiteLLMParams(**kwargs)
|
||||
litellm_call_id = kwargs.get("litellm_call_id", None)
|
||||
proxy_server_request = kwargs.get("proxy_server_request", None)
|
||||
model_info = kwargs.get("model_info", None)
|
||||
_is_async = kwargs.pop("acreate_batch", False) is True
|
||||
litellm_params = get_litellm_params(**kwargs)
|
||||
litellm_logging_obj: LiteLLMLoggingObj = kwargs.get("litellm_logging_obj", None)
|
||||
### TIMEOUT LOGIC ###
|
||||
timeout = optional_params.timeout or kwargs.get("request_timeout", 600) or 600
|
||||
litellm_logging_obj.update_environment_variables(
|
||||
model=None,
|
||||
user=None,
|
||||
optional_params=optional_params.model_dump(),
|
||||
litellm_params={
|
||||
"litellm_call_id": litellm_call_id,
|
||||
"proxy_server_request": proxy_server_request,
|
||||
"model_info": model_info,
|
||||
"metadata": metadata,
|
||||
"preset_cache_key": None,
|
||||
"stream_response": {},
|
||||
**optional_params.model_dump(exclude_unset=True),
|
||||
},
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
)
|
||||
|
||||
if (
|
||||
timeout is not None
|
||||
and isinstance(timeout, httpx.Timeout)
|
||||
and supports_httpx_timeout(custom_llm_provider) is False
|
||||
):
|
||||
read_timeout = timeout.read or 600
|
||||
timeout = read_timeout # default 10 min timeout
|
||||
elif timeout is not None and not isinstance(timeout, httpx.Timeout):
|
||||
timeout = float(timeout) # type: ignore
|
||||
elif timeout is None:
|
||||
timeout = 600.0
|
||||
|
||||
_create_batch_request = CreateBatchRequest(
|
||||
completion_window=completion_window,
|
||||
endpoint=endpoint,
|
||||
input_file_id=input_file_id,
|
||||
metadata=metadata,
|
||||
extra_headers=extra_headers,
|
||||
extra_body=extra_body,
|
||||
)
|
||||
api_base: Optional[str] = None
|
||||
if custom_llm_provider == "openai":
|
||||
# for deepinfra/perplexity/anyscale/groq we check in get_llm_provider and pass in the api base from there
|
||||
api_base = (
|
||||
optional_params.api_base
|
||||
or litellm.api_base
|
||||
or os.getenv("OPENAI_API_BASE")
|
||||
or "https://api.openai.com/v1"
|
||||
)
|
||||
organization = (
|
||||
optional_params.organization
|
||||
or litellm.organization
|
||||
or os.getenv("OPENAI_ORGANIZATION", None)
|
||||
or None # default - https://github.com/openai/openai-python/blob/284c1799070c723c6a553337134148a7ab088dd8/openai/util.py#L105
|
||||
)
|
||||
# set API KEY
|
||||
api_key = (
|
||||
optional_params.api_key
|
||||
or litellm.api_key # for deepinfra/perplexity/anyscale we check in get_llm_provider and pass in the api key from there
|
||||
or litellm.openai_key
|
||||
or os.getenv("OPENAI_API_KEY")
|
||||
)
|
||||
|
||||
response = openai_batches_instance.create_batch(
|
||||
api_base=api_base,
|
||||
api_key=api_key,
|
||||
organization=organization,
|
||||
create_batch_data=_create_batch_request,
|
||||
timeout=timeout,
|
||||
max_retries=optional_params.max_retries,
|
||||
_is_async=_is_async,
|
||||
)
|
||||
elif custom_llm_provider == "azure":
|
||||
api_base = (
|
||||
optional_params.api_base
|
||||
or litellm.api_base
|
||||
or get_secret_str("AZURE_API_BASE")
|
||||
)
|
||||
api_version = (
|
||||
optional_params.api_version
|
||||
or litellm.api_version
|
||||
or get_secret_str("AZURE_API_VERSION")
|
||||
)
|
||||
|
||||
api_key = (
|
||||
optional_params.api_key
|
||||
or litellm.api_key
|
||||
or litellm.azure_key
|
||||
or get_secret_str("AZURE_OPENAI_API_KEY")
|
||||
or get_secret_str("AZURE_API_KEY")
|
||||
)
|
||||
|
||||
extra_body = optional_params.get("extra_body", {})
|
||||
if extra_body is not None:
|
||||
extra_body.pop("azure_ad_token", None)
|
||||
else:
|
||||
get_secret_str("AZURE_AD_TOKEN") # type: ignore
|
||||
|
||||
response = azure_batches_instance.create_batch(
|
||||
_is_async=_is_async,
|
||||
api_base=api_base,
|
||||
api_key=api_key,
|
||||
api_version=api_version,
|
||||
timeout=timeout,
|
||||
max_retries=optional_params.max_retries,
|
||||
create_batch_data=_create_batch_request,
|
||||
litellm_params=litellm_params,
|
||||
)
|
||||
elif custom_llm_provider == "vertex_ai":
|
||||
api_base = optional_params.api_base or ""
|
||||
vertex_ai_project = (
|
||||
optional_params.vertex_project
|
||||
or litellm.vertex_project
|
||||
or get_secret_str("VERTEXAI_PROJECT")
|
||||
)
|
||||
vertex_ai_location = (
|
||||
optional_params.vertex_location
|
||||
or litellm.vertex_location
|
||||
or get_secret_str("VERTEXAI_LOCATION")
|
||||
)
|
||||
vertex_credentials = optional_params.vertex_credentials or get_secret_str(
|
||||
"VERTEXAI_CREDENTIALS"
|
||||
)
|
||||
|
||||
response = vertex_ai_batches_instance.create_batch(
|
||||
_is_async=_is_async,
|
||||
api_base=api_base,
|
||||
vertex_project=vertex_ai_project,
|
||||
vertex_location=vertex_ai_location,
|
||||
vertex_credentials=vertex_credentials,
|
||||
timeout=timeout,
|
||||
max_retries=optional_params.max_retries,
|
||||
create_batch_data=_create_batch_request,
|
||||
)
|
||||
else:
|
||||
raise litellm.exceptions.BadRequestError(
|
||||
message="LiteLLM doesn't support custom_llm_provider={} for 'create_batch'".format(
|
||||
custom_llm_provider
|
||||
),
|
||||
model="n/a",
|
||||
llm_provider=custom_llm_provider,
|
||||
response=httpx.Response(
|
||||
status_code=400,
|
||||
content="Unsupported provider",
|
||||
request=httpx.Request(method="create_batch", url="https://github.com/BerriAI/litellm"), # type: ignore
|
||||
),
|
||||
)
|
||||
return response
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
|
||||
@client
|
||||
async def aretrieve_batch(
|
||||
batch_id: str,
|
||||
custom_llm_provider: Literal["openai", "azure", "vertex_ai"] = "openai",
|
||||
metadata: Optional[Dict[str, str]] = None,
|
||||
extra_headers: Optional[Dict[str, str]] = None,
|
||||
extra_body: Optional[Dict[str, str]] = None,
|
||||
**kwargs,
|
||||
) -> LiteLLMBatch:
|
||||
"""
|
||||
Async: Retrieves a batch.
|
||||
|
||||
LiteLLM Equivalent of GET https://api.openai.com/v1/batches/{batch_id}
|
||||
"""
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
kwargs["aretrieve_batch"] = True
|
||||
|
||||
# Use a partial function to pass your keyword arguments
|
||||
func = partial(
|
||||
retrieve_batch,
|
||||
batch_id,
|
||||
custom_llm_provider,
|
||||
metadata,
|
||||
extra_headers,
|
||||
extra_body,
|
||||
**kwargs,
|
||||
)
|
||||
# Add the context to the function
|
||||
ctx = contextvars.copy_context()
|
||||
func_with_context = partial(ctx.run, func)
|
||||
init_response = await loop.run_in_executor(None, func_with_context)
|
||||
if asyncio.iscoroutine(init_response):
|
||||
response = await init_response
|
||||
else:
|
||||
response = init_response # type: ignore
|
||||
|
||||
return response
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
|
||||
@client
|
||||
def retrieve_batch(
|
||||
batch_id: str,
|
||||
custom_llm_provider: Literal["openai", "azure", "vertex_ai"] = "openai",
|
||||
metadata: Optional[Dict[str, str]] = None,
|
||||
extra_headers: Optional[Dict[str, str]] = None,
|
||||
extra_body: Optional[Dict[str, str]] = None,
|
||||
**kwargs,
|
||||
) -> Union[LiteLLMBatch, Coroutine[Any, Any, LiteLLMBatch]]:
|
||||
"""
|
||||
Retrieves a batch.
|
||||
|
||||
LiteLLM Equivalent of GET https://api.openai.com/v1/batches/{batch_id}
|
||||
"""
|
||||
try:
|
||||
optional_params = GenericLiteLLMParams(**kwargs)
|
||||
litellm_logging_obj: LiteLLMLoggingObj = kwargs.get("litellm_logging_obj", None)
|
||||
### TIMEOUT LOGIC ###
|
||||
timeout = optional_params.timeout or kwargs.get("request_timeout", 600) or 600
|
||||
litellm_params = get_litellm_params(
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
**kwargs,
|
||||
)
|
||||
litellm_logging_obj.update_environment_variables(
|
||||
model=None,
|
||||
user=None,
|
||||
optional_params=optional_params.model_dump(),
|
||||
litellm_params=litellm_params,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
)
|
||||
|
||||
if (
|
||||
timeout is not None
|
||||
and isinstance(timeout, httpx.Timeout)
|
||||
and supports_httpx_timeout(custom_llm_provider) is False
|
||||
):
|
||||
read_timeout = timeout.read or 600
|
||||
timeout = read_timeout # default 10 min timeout
|
||||
elif timeout is not None and not isinstance(timeout, httpx.Timeout):
|
||||
timeout = float(timeout) # type: ignore
|
||||
elif timeout is None:
|
||||
timeout = 600.0
|
||||
|
||||
_retrieve_batch_request = RetrieveBatchRequest(
|
||||
batch_id=batch_id,
|
||||
extra_headers=extra_headers,
|
||||
extra_body=extra_body,
|
||||
)
|
||||
|
||||
_is_async = kwargs.pop("aretrieve_batch", False) is True
|
||||
api_base: Optional[str] = None
|
||||
if custom_llm_provider == "openai":
|
||||
# for deepinfra/perplexity/anyscale/groq we check in get_llm_provider and pass in the api base from there
|
||||
api_base = (
|
||||
optional_params.api_base
|
||||
or litellm.api_base
|
||||
or os.getenv("OPENAI_API_BASE")
|
||||
or "https://api.openai.com/v1"
|
||||
)
|
||||
organization = (
|
||||
optional_params.organization
|
||||
or litellm.organization
|
||||
or os.getenv("OPENAI_ORGANIZATION", None)
|
||||
or None # default - https://github.com/openai/openai-python/blob/284c1799070c723c6a553337134148a7ab088dd8/openai/util.py#L105
|
||||
)
|
||||
# set API KEY
|
||||
api_key = (
|
||||
optional_params.api_key
|
||||
or litellm.api_key # for deepinfra/perplexity/anyscale we check in get_llm_provider and pass in the api key from there
|
||||
or litellm.openai_key
|
||||
or os.getenv("OPENAI_API_KEY")
|
||||
)
|
||||
|
||||
response = openai_batches_instance.retrieve_batch(
|
||||
_is_async=_is_async,
|
||||
retrieve_batch_data=_retrieve_batch_request,
|
||||
api_base=api_base,
|
||||
api_key=api_key,
|
||||
organization=organization,
|
||||
timeout=timeout,
|
||||
max_retries=optional_params.max_retries,
|
||||
)
|
||||
elif custom_llm_provider == "azure":
|
||||
api_base = (
|
||||
optional_params.api_base
|
||||
or litellm.api_base
|
||||
or get_secret_str("AZURE_API_BASE")
|
||||
)
|
||||
api_version = (
|
||||
optional_params.api_version
|
||||
or litellm.api_version
|
||||
or get_secret_str("AZURE_API_VERSION")
|
||||
)
|
||||
|
||||
api_key = (
|
||||
optional_params.api_key
|
||||
or litellm.api_key
|
||||
or litellm.azure_key
|
||||
or get_secret_str("AZURE_OPENAI_API_KEY")
|
||||
or get_secret_str("AZURE_API_KEY")
|
||||
)
|
||||
|
||||
extra_body = optional_params.get("extra_body", {})
|
||||
if extra_body is not None:
|
||||
extra_body.pop("azure_ad_token", None)
|
||||
else:
|
||||
get_secret_str("AZURE_AD_TOKEN") # type: ignore
|
||||
|
||||
response = azure_batches_instance.retrieve_batch(
|
||||
_is_async=_is_async,
|
||||
api_base=api_base,
|
||||
api_key=api_key,
|
||||
api_version=api_version,
|
||||
timeout=timeout,
|
||||
max_retries=optional_params.max_retries,
|
||||
retrieve_batch_data=_retrieve_batch_request,
|
||||
litellm_params=litellm_params,
|
||||
)
|
||||
elif custom_llm_provider == "vertex_ai":
|
||||
api_base = optional_params.api_base or ""
|
||||
vertex_ai_project = (
|
||||
optional_params.vertex_project
|
||||
or litellm.vertex_project
|
||||
or get_secret_str("VERTEXAI_PROJECT")
|
||||
)
|
||||
vertex_ai_location = (
|
||||
optional_params.vertex_location
|
||||
or litellm.vertex_location
|
||||
or get_secret_str("VERTEXAI_LOCATION")
|
||||
)
|
||||
vertex_credentials = optional_params.vertex_credentials or get_secret_str(
|
||||
"VERTEXAI_CREDENTIALS"
|
||||
)
|
||||
|
||||
response = vertex_ai_batches_instance.retrieve_batch(
|
||||
_is_async=_is_async,
|
||||
batch_id=batch_id,
|
||||
api_base=api_base,
|
||||
vertex_project=vertex_ai_project,
|
||||
vertex_location=vertex_ai_location,
|
||||
vertex_credentials=vertex_credentials,
|
||||
timeout=timeout,
|
||||
max_retries=optional_params.max_retries,
|
||||
)
|
||||
else:
|
||||
raise litellm.exceptions.BadRequestError(
|
||||
message="LiteLLM doesn't support {} for 'create_batch'. Only 'openai' is supported.".format(
|
||||
custom_llm_provider
|
||||
),
|
||||
model="n/a",
|
||||
llm_provider=custom_llm_provider,
|
||||
response=httpx.Response(
|
||||
status_code=400,
|
||||
content="Unsupported provider",
|
||||
request=httpx.Request(method="create_thread", url="https://github.com/BerriAI/litellm"), # type: ignore
|
||||
),
|
||||
)
|
||||
return response
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
|
||||
async def alist_batches(
|
||||
after: Optional[str] = None,
|
||||
limit: Optional[int] = None,
|
||||
custom_llm_provider: Literal["openai", "azure"] = "openai",
|
||||
metadata: Optional[Dict[str, str]] = None,
|
||||
extra_headers: Optional[Dict[str, str]] = None,
|
||||
extra_body: Optional[Dict[str, str]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Async: List your organization's batches.
|
||||
"""
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
kwargs["alist_batches"] = True
|
||||
|
||||
# Use a partial function to pass your keyword arguments
|
||||
func = partial(
|
||||
list_batches,
|
||||
after,
|
||||
limit,
|
||||
custom_llm_provider,
|
||||
extra_headers,
|
||||
extra_body,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# Add the context to the function
|
||||
ctx = contextvars.copy_context()
|
||||
func_with_context = partial(ctx.run, func)
|
||||
init_response = await loop.run_in_executor(None, func_with_context)
|
||||
if asyncio.iscoroutine(init_response):
|
||||
response = await init_response
|
||||
else:
|
||||
response = init_response # type: ignore
|
||||
|
||||
return response
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
|
||||
def list_batches(
|
||||
after: Optional[str] = None,
|
||||
limit: Optional[int] = None,
|
||||
custom_llm_provider: Literal["openai", "azure"] = "openai",
|
||||
extra_headers: Optional[Dict[str, str]] = None,
|
||||
extra_body: Optional[Dict[str, str]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Lists batches
|
||||
|
||||
List your organization's batches.
|
||||
"""
|
||||
try:
|
||||
# set API KEY
|
||||
optional_params = GenericLiteLLMParams(**kwargs)
|
||||
litellm_params = get_litellm_params(
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
**kwargs,
|
||||
)
|
||||
api_key = (
|
||||
optional_params.api_key
|
||||
or litellm.api_key # for deepinfra/perplexity/anyscale we check in get_llm_provider and pass in the api key from there
|
||||
or litellm.openai_key
|
||||
or os.getenv("OPENAI_API_KEY")
|
||||
)
|
||||
### TIMEOUT LOGIC ###
|
||||
timeout = optional_params.timeout or kwargs.get("request_timeout", 600) or 600
|
||||
# set timeout for 10 minutes by default
|
||||
|
||||
if (
|
||||
timeout is not None
|
||||
and isinstance(timeout, httpx.Timeout)
|
||||
and supports_httpx_timeout(custom_llm_provider) is False
|
||||
):
|
||||
read_timeout = timeout.read or 600
|
||||
timeout = read_timeout # default 10 min timeout
|
||||
elif timeout is not None and not isinstance(timeout, httpx.Timeout):
|
||||
timeout = float(timeout) # type: ignore
|
||||
elif timeout is None:
|
||||
timeout = 600.0
|
||||
|
||||
_is_async = kwargs.pop("alist_batches", False) is True
|
||||
if custom_llm_provider == "openai":
|
||||
# for deepinfra/perplexity/anyscale/groq we check in get_llm_provider and pass in the api base from there
|
||||
api_base = (
|
||||
optional_params.api_base
|
||||
or litellm.api_base
|
||||
or os.getenv("OPENAI_API_BASE")
|
||||
or "https://api.openai.com/v1"
|
||||
)
|
||||
organization = (
|
||||
optional_params.organization
|
||||
or litellm.organization
|
||||
or os.getenv("OPENAI_ORGANIZATION", None)
|
||||
or None # default - https://github.com/openai/openai-python/blob/284c1799070c723c6a553337134148a7ab088dd8/openai/util.py#L105
|
||||
)
|
||||
|
||||
response = openai_batches_instance.list_batches(
|
||||
_is_async=_is_async,
|
||||
after=after,
|
||||
limit=limit,
|
||||
api_base=api_base,
|
||||
api_key=api_key,
|
||||
organization=organization,
|
||||
timeout=timeout,
|
||||
max_retries=optional_params.max_retries,
|
||||
)
|
||||
elif custom_llm_provider == "azure":
|
||||
api_base = optional_params.api_base or litellm.api_base or get_secret_str("AZURE_API_BASE") # type: ignore
|
||||
api_version = (
|
||||
optional_params.api_version
|
||||
or litellm.api_version
|
||||
or get_secret_str("AZURE_API_VERSION")
|
||||
)
|
||||
|
||||
api_key = (
|
||||
optional_params.api_key
|
||||
or litellm.api_key
|
||||
or litellm.azure_key
|
||||
or get_secret_str("AZURE_OPENAI_API_KEY")
|
||||
or get_secret_str("AZURE_API_KEY")
|
||||
)
|
||||
|
||||
extra_body = optional_params.get("extra_body", {})
|
||||
if extra_body is not None:
|
||||
extra_body.pop("azure_ad_token", None)
|
||||
else:
|
||||
get_secret_str("AZURE_AD_TOKEN") # type: ignore
|
||||
|
||||
response = azure_batches_instance.list_batches(
|
||||
_is_async=_is_async,
|
||||
api_base=api_base,
|
||||
api_key=api_key,
|
||||
api_version=api_version,
|
||||
timeout=timeout,
|
||||
max_retries=optional_params.max_retries,
|
||||
litellm_params=litellm_params,
|
||||
)
|
||||
else:
|
||||
raise litellm.exceptions.BadRequestError(
|
||||
message="LiteLLM doesn't support {} for 'list_batch'. Only 'openai' is supported.".format(
|
||||
custom_llm_provider
|
||||
),
|
||||
model="n/a",
|
||||
llm_provider=custom_llm_provider,
|
||||
response=httpx.Response(
|
||||
status_code=400,
|
||||
content="Unsupported provider",
|
||||
request=httpx.Request(method="create_thread", url="https://github.com/BerriAI/litellm"), # type: ignore
|
||||
),
|
||||
)
|
||||
return response
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
|
||||
async def acancel_batch(
|
||||
batch_id: str,
|
||||
custom_llm_provider: Literal["openai", "azure"] = "openai",
|
||||
metadata: Optional[Dict[str, str]] = None,
|
||||
extra_headers: Optional[Dict[str, str]] = None,
|
||||
extra_body: Optional[Dict[str, str]] = None,
|
||||
**kwargs,
|
||||
) -> Batch:
|
||||
"""
|
||||
Async: Cancels a batch.
|
||||
|
||||
LiteLLM Equivalent of POST https://api.openai.com/v1/batches/{batch_id}/cancel
|
||||
"""
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
kwargs["acancel_batch"] = True
|
||||
|
||||
# Use a partial function to pass your keyword arguments
|
||||
func = partial(
|
||||
cancel_batch,
|
||||
batch_id,
|
||||
custom_llm_provider,
|
||||
metadata,
|
||||
extra_headers,
|
||||
extra_body,
|
||||
**kwargs,
|
||||
)
|
||||
# Add the context to the function
|
||||
ctx = contextvars.copy_context()
|
||||
func_with_context = partial(ctx.run, func)
|
||||
init_response = await loop.run_in_executor(None, func_with_context)
|
||||
if asyncio.iscoroutine(init_response):
|
||||
response = await init_response
|
||||
else:
|
||||
response = init_response
|
||||
|
||||
return response
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
|
||||
def cancel_batch(
|
||||
batch_id: str,
|
||||
custom_llm_provider: Literal["openai", "azure"] = "openai",
|
||||
metadata: Optional[Dict[str, str]] = None,
|
||||
extra_headers: Optional[Dict[str, str]] = None,
|
||||
extra_body: Optional[Dict[str, str]] = None,
|
||||
**kwargs,
|
||||
) -> Union[Batch, Coroutine[Any, Any, Batch]]:
|
||||
"""
|
||||
Cancels a batch.
|
||||
|
||||
LiteLLM Equivalent of POST https://api.openai.com/v1/batches/{batch_id}/cancel
|
||||
"""
|
||||
try:
|
||||
optional_params = GenericLiteLLMParams(**kwargs)
|
||||
litellm_params = get_litellm_params(
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
**kwargs,
|
||||
)
|
||||
### TIMEOUT LOGIC ###
|
||||
timeout = optional_params.timeout or kwargs.get("request_timeout", 600) or 600
|
||||
# set timeout for 10 minutes by default
|
||||
|
||||
if (
|
||||
timeout is not None
|
||||
and isinstance(timeout, httpx.Timeout)
|
||||
and supports_httpx_timeout(custom_llm_provider) is False
|
||||
):
|
||||
read_timeout = timeout.read or 600
|
||||
timeout = read_timeout # default 10 min timeout
|
||||
elif timeout is not None and not isinstance(timeout, httpx.Timeout):
|
||||
timeout = float(timeout) # type: ignore
|
||||
elif timeout is None:
|
||||
timeout = 600.0
|
||||
|
||||
_cancel_batch_request = CancelBatchRequest(
|
||||
batch_id=batch_id,
|
||||
extra_headers=extra_headers,
|
||||
extra_body=extra_body,
|
||||
)
|
||||
|
||||
_is_async = kwargs.pop("acancel_batch", False) is True
|
||||
api_base: Optional[str] = None
|
||||
if custom_llm_provider == "openai":
|
||||
api_base = (
|
||||
optional_params.api_base
|
||||
or litellm.api_base
|
||||
or os.getenv("OPENAI_API_BASE")
|
||||
or "https://api.openai.com/v1"
|
||||
)
|
||||
organization = (
|
||||
optional_params.organization
|
||||
or litellm.organization
|
||||
or os.getenv("OPENAI_ORGANIZATION", None)
|
||||
or None
|
||||
)
|
||||
api_key = (
|
||||
optional_params.api_key
|
||||
or litellm.api_key
|
||||
or litellm.openai_key
|
||||
or os.getenv("OPENAI_API_KEY")
|
||||
)
|
||||
|
||||
response = openai_batches_instance.cancel_batch(
|
||||
_is_async=_is_async,
|
||||
cancel_batch_data=_cancel_batch_request,
|
||||
api_base=api_base,
|
||||
api_key=api_key,
|
||||
organization=organization,
|
||||
timeout=timeout,
|
||||
max_retries=optional_params.max_retries,
|
||||
)
|
||||
elif custom_llm_provider == "azure":
|
||||
api_base = (
|
||||
optional_params.api_base
|
||||
or litellm.api_base
|
||||
or get_secret_str("AZURE_API_BASE")
|
||||
)
|
||||
api_version = (
|
||||
optional_params.api_version
|
||||
or litellm.api_version
|
||||
or get_secret_str("AZURE_API_VERSION")
|
||||
)
|
||||
|
||||
api_key = (
|
||||
optional_params.api_key
|
||||
or litellm.api_key
|
||||
or litellm.azure_key
|
||||
or get_secret_str("AZURE_OPENAI_API_KEY")
|
||||
or get_secret_str("AZURE_API_KEY")
|
||||
)
|
||||
|
||||
extra_body = optional_params.get("extra_body", {})
|
||||
if extra_body is not None:
|
||||
extra_body.pop("azure_ad_token", None)
|
||||
else:
|
||||
get_secret_str("AZURE_AD_TOKEN") # type: ignore
|
||||
|
||||
response = azure_batches_instance.cancel_batch(
|
||||
_is_async=_is_async,
|
||||
api_base=api_base,
|
||||
api_key=api_key,
|
||||
api_version=api_version,
|
||||
timeout=timeout,
|
||||
max_retries=optional_params.max_retries,
|
||||
cancel_batch_data=_cancel_batch_request,
|
||||
litellm_params=litellm_params,
|
||||
)
|
||||
else:
|
||||
raise litellm.exceptions.BadRequestError(
|
||||
message="LiteLLM doesn't support {} for 'cancel_batch'. Only 'openai' and 'azure' are supported.".format(
|
||||
custom_llm_provider
|
||||
),
|
||||
model="n/a",
|
||||
llm_provider=custom_llm_provider,
|
||||
response=httpx.Response(
|
||||
status_code=400,
|
||||
content="Unsupported provider",
|
||||
request=httpx.Request(method="cancel_batch", url="https://github.com/BerriAI/litellm"), # type: ignore
|
||||
),
|
||||
)
|
||||
return response
|
||||
except Exception as e:
|
||||
raise e
|
||||
230
.venv/lib/python3.10/site-packages/litellm/budget_manager.py
Normal file
230
.venv/lib/python3.10/site-packages/litellm/budget_manager.py
Normal file
@@ -0,0 +1,230 @@
|
||||
# +-----------------------------------------------+
|
||||
# | |
|
||||
# | NOT PROXY BUDGET MANAGER |
|
||||
# | proxy budget manager is in proxy_server.py |
|
||||
# | |
|
||||
# +-----------------------------------------------+
|
||||
#
|
||||
# Thank you users! We ❤️ you! - Krrish & Ishaan
|
||||
|
||||
import json
|
||||
import os
|
||||
import threading
|
||||
import time
|
||||
from typing import Literal, Optional
|
||||
|
||||
import litellm
|
||||
from litellm.constants import (
|
||||
DAYS_IN_A_MONTH,
|
||||
DAYS_IN_A_WEEK,
|
||||
DAYS_IN_A_YEAR,
|
||||
HOURS_IN_A_DAY,
|
||||
)
|
||||
from litellm.utils import ModelResponse
|
||||
|
||||
|
||||
class BudgetManager:
|
||||
def __init__(
|
||||
self,
|
||||
project_name: str,
|
||||
client_type: str = "local",
|
||||
api_base: Optional[str] = None,
|
||||
headers: Optional[dict] = None,
|
||||
):
|
||||
self.client_type = client_type
|
||||
self.project_name = project_name
|
||||
self.api_base = api_base or "https://api.litellm.ai"
|
||||
self.headers = headers or {"Content-Type": "application/json"}
|
||||
## load the data or init the initial dictionaries
|
||||
self.load_data()
|
||||
|
||||
def print_verbose(self, print_statement):
|
||||
try:
|
||||
if litellm.set_verbose:
|
||||
import logging
|
||||
|
||||
logging.info(print_statement)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def load_data(self):
|
||||
if self.client_type == "local":
|
||||
# Check if user dict file exists
|
||||
if os.path.isfile("user_cost.json"):
|
||||
# Load the user dict
|
||||
with open("user_cost.json", "r") as json_file:
|
||||
self.user_dict = json.load(json_file)
|
||||
else:
|
||||
self.print_verbose("User Dictionary not found!")
|
||||
self.user_dict = {}
|
||||
self.print_verbose(f"user dict from local: {self.user_dict}")
|
||||
elif self.client_type == "hosted":
|
||||
# Load the user_dict from hosted db
|
||||
url = self.api_base + "/get_budget"
|
||||
data = {"project_name": self.project_name}
|
||||
response = litellm.module_level_client.post(
|
||||
url, headers=self.headers, json=data
|
||||
)
|
||||
response = response.json()
|
||||
if response["status"] == "error":
|
||||
self.user_dict = (
|
||||
{}
|
||||
) # assume this means the user dict hasn't been stored yet
|
||||
else:
|
||||
self.user_dict = response["data"]
|
||||
|
||||
def create_budget(
|
||||
self,
|
||||
total_budget: float,
|
||||
user: str,
|
||||
duration: Optional[Literal["daily", "weekly", "monthly", "yearly"]] = None,
|
||||
created_at: float = time.time(),
|
||||
):
|
||||
self.user_dict[user] = {"total_budget": total_budget}
|
||||
if duration is None:
|
||||
return self.user_dict[user]
|
||||
|
||||
if duration == "daily":
|
||||
duration_in_days = 1
|
||||
elif duration == "weekly":
|
||||
duration_in_days = DAYS_IN_A_WEEK
|
||||
elif duration == "monthly":
|
||||
duration_in_days = DAYS_IN_A_MONTH
|
||||
elif duration == "yearly":
|
||||
duration_in_days = DAYS_IN_A_YEAR
|
||||
else:
|
||||
raise ValueError(
|
||||
"""duration needs to be one of ["daily", "weekly", "monthly", "yearly"]"""
|
||||
)
|
||||
self.user_dict[user] = {
|
||||
"total_budget": total_budget,
|
||||
"duration": duration_in_days,
|
||||
"created_at": created_at,
|
||||
"last_updated_at": created_at,
|
||||
}
|
||||
self._save_data_thread() # [Non-Blocking] Update persistent storage without blocking execution
|
||||
return self.user_dict[user]
|
||||
|
||||
def projected_cost(self, model: str, messages: list, user: str):
|
||||
text = "".join(message["content"] for message in messages)
|
||||
prompt_tokens = litellm.token_counter(model=model, text=text)
|
||||
prompt_cost, _ = litellm.cost_per_token(
|
||||
model=model, prompt_tokens=prompt_tokens, completion_tokens=0
|
||||
)
|
||||
current_cost = self.user_dict[user].get("current_cost", 0)
|
||||
projected_cost = prompt_cost + current_cost
|
||||
return projected_cost
|
||||
|
||||
def get_total_budget(self, user: str):
|
||||
return self.user_dict[user]["total_budget"]
|
||||
|
||||
def update_cost(
|
||||
self,
|
||||
user: str,
|
||||
completion_obj: Optional[ModelResponse] = None,
|
||||
model: Optional[str] = None,
|
||||
input_text: Optional[str] = None,
|
||||
output_text: Optional[str] = None,
|
||||
):
|
||||
if model and input_text and output_text:
|
||||
prompt_tokens = litellm.token_counter(
|
||||
model=model, messages=[{"role": "user", "content": input_text}]
|
||||
)
|
||||
completion_tokens = litellm.token_counter(
|
||||
model=model, messages=[{"role": "user", "content": output_text}]
|
||||
)
|
||||
(
|
||||
prompt_tokens_cost_usd_dollar,
|
||||
completion_tokens_cost_usd_dollar,
|
||||
) = litellm.cost_per_token(
|
||||
model=model,
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
)
|
||||
cost = prompt_tokens_cost_usd_dollar + completion_tokens_cost_usd_dollar
|
||||
elif completion_obj:
|
||||
cost = litellm.completion_cost(completion_response=completion_obj)
|
||||
model = completion_obj[
|
||||
"model"
|
||||
] # if this throws an error try, model = completion_obj['model']
|
||||
else:
|
||||
raise ValueError(
|
||||
"Either a chat completion object or the text response needs to be passed in. Learn more - https://docs.litellm.ai/docs/budget_manager"
|
||||
)
|
||||
|
||||
self.user_dict[user]["current_cost"] = cost + self.user_dict[user].get(
|
||||
"current_cost", 0
|
||||
)
|
||||
if "model_cost" in self.user_dict[user]:
|
||||
self.user_dict[user]["model_cost"][model] = cost + self.user_dict[user][
|
||||
"model_cost"
|
||||
].get(model, 0)
|
||||
else:
|
||||
self.user_dict[user]["model_cost"] = {model: cost}
|
||||
|
||||
self._save_data_thread() # [Non-Blocking] Update persistent storage without blocking execution
|
||||
return {"user": self.user_dict[user]}
|
||||
|
||||
def get_current_cost(self, user):
|
||||
return self.user_dict[user].get("current_cost", 0)
|
||||
|
||||
def get_model_cost(self, user):
|
||||
return self.user_dict[user].get("model_cost", 0)
|
||||
|
||||
def is_valid_user(self, user: str) -> bool:
|
||||
return user in self.user_dict
|
||||
|
||||
def get_users(self):
|
||||
return list(self.user_dict.keys())
|
||||
|
||||
def reset_cost(self, user):
|
||||
self.user_dict[user]["current_cost"] = 0
|
||||
self.user_dict[user]["model_cost"] = {}
|
||||
return {"user": self.user_dict[user]}
|
||||
|
||||
def reset_on_duration(self, user: str):
|
||||
# Get current and creation time
|
||||
last_updated_at = self.user_dict[user]["last_updated_at"]
|
||||
current_time = time.time()
|
||||
|
||||
# Convert duration from days to seconds
|
||||
duration_in_seconds = (
|
||||
self.user_dict[user]["duration"] * HOURS_IN_A_DAY * 60 * 60
|
||||
)
|
||||
|
||||
# Check if duration has elapsed
|
||||
if current_time - last_updated_at >= duration_in_seconds:
|
||||
# Reset cost if duration has elapsed and update the creation time
|
||||
self.reset_cost(user)
|
||||
self.user_dict[user]["last_updated_at"] = current_time
|
||||
self._save_data_thread() # Save the data
|
||||
|
||||
def update_budget_all_users(self):
|
||||
for user in self.get_users():
|
||||
if "duration" in self.user_dict[user]:
|
||||
self.reset_on_duration(user)
|
||||
|
||||
def _save_data_thread(self):
|
||||
thread = threading.Thread(
|
||||
target=self.save_data
|
||||
) # [Non-Blocking]: saves data without blocking execution
|
||||
thread.start()
|
||||
|
||||
def save_data(self):
|
||||
if self.client_type == "local":
|
||||
import json
|
||||
|
||||
# save the user dict
|
||||
with open("user_cost.json", "w") as json_file:
|
||||
json.dump(
|
||||
self.user_dict, json_file, indent=4
|
||||
) # Indent for pretty formatting
|
||||
return {"status": "success"}
|
||||
elif self.client_type == "hosted":
|
||||
url = self.api_base + "/set_budget"
|
||||
data = {"project_name": self.project_name, "user_dict": self.user_dict}
|
||||
response = litellm.module_level_client.post(
|
||||
url, headers=self.headers, json=data
|
||||
)
|
||||
response = response.json()
|
||||
return response
|
||||
40
.venv/lib/python3.10/site-packages/litellm/caching/Readme.md
Normal file
40
.venv/lib/python3.10/site-packages/litellm/caching/Readme.md
Normal file
@@ -0,0 +1,40 @@
|
||||
# Caching on LiteLLM
|
||||
|
||||
LiteLLM supports multiple caching mechanisms. This allows users to choose the most suitable caching solution for their use case.
|
||||
|
||||
The following caching mechanisms are supported:
|
||||
|
||||
1. **RedisCache**
|
||||
2. **RedisSemanticCache**
|
||||
3. **QdrantSemanticCache**
|
||||
4. **InMemoryCache**
|
||||
5. **DiskCache**
|
||||
6. **S3Cache**
|
||||
7. **DualCache** (updates both Redis and an in-memory cache simultaneously)
|
||||
|
||||
## Folder Structure
|
||||
|
||||
```
|
||||
litellm/caching/
|
||||
├── base_cache.py
|
||||
├── caching.py
|
||||
├── caching_handler.py
|
||||
├── disk_cache.py
|
||||
├── dual_cache.py
|
||||
├── in_memory_cache.py
|
||||
├── qdrant_semantic_cache.py
|
||||
├── redis_cache.py
|
||||
├── redis_semantic_cache.py
|
||||
├── s3_cache.py
|
||||
```
|
||||
|
||||
## Documentation
|
||||
- [Caching on LiteLLM Gateway](https://docs.litellm.ai/docs/proxy/caching)
|
||||
- [Caching on LiteLLM Python](https://docs.litellm.ai/docs/caching/all_caches)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -0,0 +1,9 @@
|
||||
from .caching import Cache, LiteLLMCacheType
|
||||
from .disk_cache import DiskCache
|
||||
from .dual_cache import DualCache
|
||||
from .in_memory_cache import InMemoryCache
|
||||
from .qdrant_semantic_cache import QdrantSemanticCache
|
||||
from .redis_cache import RedisCache
|
||||
from .redis_cluster_cache import RedisClusterCache
|
||||
from .redis_semantic_cache import RedisSemanticCache
|
||||
from .s3_cache import S3Cache
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,30 @@
|
||||
from functools import lru_cache
|
||||
from typing import Callable, Optional, TypeVar
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
def lru_cache_wrapper(
|
||||
maxsize: Optional[int] = None,
|
||||
) -> Callable[[Callable[..., T]], Callable[..., T]]:
|
||||
"""
|
||||
Wrapper for lru_cache that caches success and exceptions
|
||||
"""
|
||||
|
||||
def decorator(f: Callable[..., T]) -> Callable[..., T]:
|
||||
@lru_cache(maxsize=maxsize)
|
||||
def wrapper(*args, **kwargs):
|
||||
try:
|
||||
return ("success", f(*args, **kwargs))
|
||||
except Exception as e:
|
||||
return ("error", e)
|
||||
|
||||
def wrapped(*args, **kwargs):
|
||||
result = wrapper(*args, **kwargs)
|
||||
if result[0] == "error":
|
||||
raise result[1]
|
||||
return result[1]
|
||||
|
||||
return wrapped
|
||||
|
||||
return decorator
|
||||
@@ -0,0 +1,55 @@
|
||||
"""
|
||||
Base Cache implementation. All cache implementations should inherit from this class.
|
||||
|
||||
Has 4 methods:
|
||||
- set_cache
|
||||
- get_cache
|
||||
- async_set_cache
|
||||
- async_get_cache
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import TYPE_CHECKING, Any, Optional, Union
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from opentelemetry.trace import Span as _Span
|
||||
|
||||
Span = Union[_Span, Any]
|
||||
else:
|
||||
Span = Any
|
||||
|
||||
|
||||
class BaseCache(ABC):
|
||||
def __init__(self, default_ttl: int = 60):
|
||||
self.default_ttl = default_ttl
|
||||
|
||||
def get_ttl(self, **kwargs) -> Optional[int]:
|
||||
kwargs_ttl: Optional[int] = kwargs.get("ttl")
|
||||
if kwargs_ttl is not None:
|
||||
try:
|
||||
return int(kwargs_ttl)
|
||||
except ValueError:
|
||||
return self.default_ttl
|
||||
return self.default_ttl
|
||||
|
||||
def set_cache(self, key, value, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
async def async_set_cache(self, key, value, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
async def async_set_cache_pipeline(self, cache_list, **kwargs):
|
||||
pass
|
||||
|
||||
def get_cache(self, key, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
async def async_get_cache(self, key, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
async def batch_cache_write(self, key, value, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
async def disconnect(self):
|
||||
raise NotImplementedError
|
||||
798
.venv/lib/python3.10/site-packages/litellm/caching/caching.py
Normal file
798
.venv/lib/python3.10/site-packages/litellm/caching/caching.py
Normal file
@@ -0,0 +1,798 @@
|
||||
# +-----------------------------------------------+
|
||||
# | |
|
||||
# | Give Feedback / Get Help |
|
||||
# | https://github.com/BerriAI/litellm/issues/new |
|
||||
# | |
|
||||
# +-----------------------------------------------+
|
||||
#
|
||||
# Thank you users! We ❤️ you! - Krrish & Ishaan
|
||||
|
||||
import ast
|
||||
import hashlib
|
||||
import json
|
||||
import time
|
||||
import traceback
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
import litellm
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.constants import CACHED_STREAMING_CHUNK_DELAY
|
||||
from litellm.litellm_core_utils.model_param_helper import ModelParamHelper
|
||||
from litellm.types.caching import *
|
||||
from litellm.types.utils import all_litellm_params
|
||||
|
||||
from .base_cache import BaseCache
|
||||
from .disk_cache import DiskCache
|
||||
from .dual_cache import DualCache # noqa
|
||||
from .in_memory_cache import InMemoryCache
|
||||
from .qdrant_semantic_cache import QdrantSemanticCache
|
||||
from .redis_cache import RedisCache
|
||||
from .redis_cluster_cache import RedisClusterCache
|
||||
from .redis_semantic_cache import RedisSemanticCache
|
||||
from .s3_cache import S3Cache
|
||||
|
||||
|
||||
def print_verbose(print_statement):
|
||||
try:
|
||||
verbose_logger.debug(print_statement)
|
||||
if litellm.set_verbose:
|
||||
print(print_statement) # noqa
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
class CacheMode(str, Enum):
|
||||
default_on = "default_on"
|
||||
default_off = "default_off"
|
||||
|
||||
|
||||
#### LiteLLM.Completion / Embedding Cache ####
|
||||
class Cache:
|
||||
def __init__(
|
||||
self,
|
||||
type: Optional[LiteLLMCacheType] = LiteLLMCacheType.LOCAL,
|
||||
mode: Optional[
|
||||
CacheMode
|
||||
] = CacheMode.default_on, # when default_on cache is always on, when default_off cache is opt in
|
||||
host: Optional[str] = None,
|
||||
port: Optional[str] = None,
|
||||
password: Optional[str] = None,
|
||||
namespace: Optional[str] = None,
|
||||
ttl: Optional[float] = None,
|
||||
default_in_memory_ttl: Optional[float] = None,
|
||||
default_in_redis_ttl: Optional[float] = None,
|
||||
similarity_threshold: Optional[float] = None,
|
||||
supported_call_types: Optional[List[CachingSupportedCallTypes]] = [
|
||||
"completion",
|
||||
"acompletion",
|
||||
"embedding",
|
||||
"aembedding",
|
||||
"atranscription",
|
||||
"transcription",
|
||||
"atext_completion",
|
||||
"text_completion",
|
||||
"arerank",
|
||||
"rerank",
|
||||
],
|
||||
# s3 Bucket, boto3 configuration
|
||||
s3_bucket_name: Optional[str] = None,
|
||||
s3_region_name: Optional[str] = None,
|
||||
s3_api_version: Optional[str] = None,
|
||||
s3_use_ssl: Optional[bool] = True,
|
||||
s3_verify: Optional[Union[bool, str]] = None,
|
||||
s3_endpoint_url: Optional[str] = None,
|
||||
s3_aws_access_key_id: Optional[str] = None,
|
||||
s3_aws_secret_access_key: Optional[str] = None,
|
||||
s3_aws_session_token: Optional[str] = None,
|
||||
s3_config: Optional[Any] = None,
|
||||
s3_path: Optional[str] = None,
|
||||
redis_semantic_cache_embedding_model: str = "text-embedding-ada-002",
|
||||
redis_semantic_cache_index_name: Optional[str] = None,
|
||||
redis_flush_size: Optional[int] = None,
|
||||
redis_startup_nodes: Optional[List] = None,
|
||||
disk_cache_dir: Optional[str] = None,
|
||||
qdrant_api_base: Optional[str] = None,
|
||||
qdrant_api_key: Optional[str] = None,
|
||||
qdrant_collection_name: Optional[str] = None,
|
||||
qdrant_quantization_config: Optional[str] = None,
|
||||
qdrant_semantic_cache_embedding_model: str = "text-embedding-ada-002",
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Initializes the cache based on the given type.
|
||||
|
||||
Args:
|
||||
type (str, optional): The type of cache to initialize. Can be "local", "redis", "redis-semantic", "qdrant-semantic", "s3" or "disk". Defaults to "local".
|
||||
|
||||
# Redis Cache Args
|
||||
host (str, optional): The host address for the Redis cache. Required if type is "redis".
|
||||
port (int, optional): The port number for the Redis cache. Required if type is "redis".
|
||||
password (str, optional): The password for the Redis cache. Required if type is "redis".
|
||||
namespace (str, optional): The namespace for the Redis cache. Required if type is "redis".
|
||||
ttl (float, optional): The ttl for the Redis cache
|
||||
redis_flush_size (int, optional): The number of keys to flush at a time. Defaults to 1000. Only used if batch redis set caching is used.
|
||||
redis_startup_nodes (list, optional): The list of startup nodes for the Redis cache. Defaults to None.
|
||||
|
||||
# Qdrant Cache Args
|
||||
qdrant_api_base (str, optional): The url for your qdrant cluster. Required if type is "qdrant-semantic".
|
||||
qdrant_api_key (str, optional): The api_key for the local or cloud qdrant cluster.
|
||||
qdrant_collection_name (str, optional): The name for your qdrant collection. Required if type is "qdrant-semantic".
|
||||
similarity_threshold (float, optional): The similarity threshold for semantic-caching, Required if type is "redis-semantic" or "qdrant-semantic".
|
||||
|
||||
# Disk Cache Args
|
||||
disk_cache_dir (str, optional): The directory for the disk cache. Defaults to None.
|
||||
|
||||
# S3 Cache Args
|
||||
s3_bucket_name (str, optional): The bucket name for the s3 cache. Defaults to None.
|
||||
s3_region_name (str, optional): The region name for the s3 cache. Defaults to None.
|
||||
s3_api_version (str, optional): The api version for the s3 cache. Defaults to None.
|
||||
s3_use_ssl (bool, optional): The use ssl for the s3 cache. Defaults to True.
|
||||
s3_verify (bool, optional): The verify for the s3 cache. Defaults to None.
|
||||
s3_endpoint_url (str, optional): The endpoint url for the s3 cache. Defaults to None.
|
||||
s3_aws_access_key_id (str, optional): The aws access key id for the s3 cache. Defaults to None.
|
||||
s3_aws_secret_access_key (str, optional): The aws secret access key for the s3 cache. Defaults to None.
|
||||
s3_aws_session_token (str, optional): The aws session token for the s3 cache. Defaults to None.
|
||||
s3_config (dict, optional): The config for the s3 cache. Defaults to None.
|
||||
|
||||
# Common Cache Args
|
||||
supported_call_types (list, optional): List of call types to cache for. Defaults to cache == on for all call types.
|
||||
**kwargs: Additional keyword arguments for redis.Redis() cache
|
||||
|
||||
Raises:
|
||||
ValueError: If an invalid cache type is provided.
|
||||
|
||||
Returns:
|
||||
None. Cache is set as a litellm param
|
||||
"""
|
||||
if type == LiteLLMCacheType.REDIS:
|
||||
if redis_startup_nodes:
|
||||
self.cache: BaseCache = RedisClusterCache(
|
||||
host=host,
|
||||
port=port,
|
||||
password=password,
|
||||
redis_flush_size=redis_flush_size,
|
||||
startup_nodes=redis_startup_nodes,
|
||||
**kwargs,
|
||||
)
|
||||
else:
|
||||
self.cache = RedisCache(
|
||||
host=host,
|
||||
port=port,
|
||||
password=password,
|
||||
redis_flush_size=redis_flush_size,
|
||||
**kwargs,
|
||||
)
|
||||
elif type == LiteLLMCacheType.REDIS_SEMANTIC:
|
||||
self.cache = RedisSemanticCache(
|
||||
host=host,
|
||||
port=port,
|
||||
password=password,
|
||||
similarity_threshold=similarity_threshold,
|
||||
embedding_model=redis_semantic_cache_embedding_model,
|
||||
index_name=redis_semantic_cache_index_name,
|
||||
**kwargs,
|
||||
)
|
||||
elif type == LiteLLMCacheType.QDRANT_SEMANTIC:
|
||||
self.cache = QdrantSemanticCache(
|
||||
qdrant_api_base=qdrant_api_base,
|
||||
qdrant_api_key=qdrant_api_key,
|
||||
collection_name=qdrant_collection_name,
|
||||
similarity_threshold=similarity_threshold,
|
||||
quantization_config=qdrant_quantization_config,
|
||||
embedding_model=qdrant_semantic_cache_embedding_model,
|
||||
)
|
||||
elif type == LiteLLMCacheType.LOCAL:
|
||||
self.cache = InMemoryCache()
|
||||
elif type == LiteLLMCacheType.S3:
|
||||
self.cache = S3Cache(
|
||||
s3_bucket_name=s3_bucket_name,
|
||||
s3_region_name=s3_region_name,
|
||||
s3_api_version=s3_api_version,
|
||||
s3_use_ssl=s3_use_ssl,
|
||||
s3_verify=s3_verify,
|
||||
s3_endpoint_url=s3_endpoint_url,
|
||||
s3_aws_access_key_id=s3_aws_access_key_id,
|
||||
s3_aws_secret_access_key=s3_aws_secret_access_key,
|
||||
s3_aws_session_token=s3_aws_session_token,
|
||||
s3_config=s3_config,
|
||||
s3_path=s3_path,
|
||||
**kwargs,
|
||||
)
|
||||
elif type == LiteLLMCacheType.DISK:
|
||||
self.cache = DiskCache(disk_cache_dir=disk_cache_dir)
|
||||
if "cache" not in litellm.input_callback:
|
||||
litellm.input_callback.append("cache")
|
||||
if "cache" not in litellm.success_callback:
|
||||
litellm.logging_callback_manager.add_litellm_success_callback("cache")
|
||||
if "cache" not in litellm._async_success_callback:
|
||||
litellm.logging_callback_manager.add_litellm_async_success_callback("cache")
|
||||
self.supported_call_types = supported_call_types # default to ["completion", "acompletion", "embedding", "aembedding"]
|
||||
self.type = type
|
||||
self.namespace = namespace
|
||||
self.redis_flush_size = redis_flush_size
|
||||
self.ttl = ttl
|
||||
self.mode: CacheMode = mode or CacheMode.default_on
|
||||
|
||||
if self.type == LiteLLMCacheType.LOCAL and default_in_memory_ttl is not None:
|
||||
self.ttl = default_in_memory_ttl
|
||||
|
||||
if (
|
||||
self.type == LiteLLMCacheType.REDIS
|
||||
or self.type == LiteLLMCacheType.REDIS_SEMANTIC
|
||||
) and default_in_redis_ttl is not None:
|
||||
self.ttl = default_in_redis_ttl
|
||||
|
||||
if self.namespace is not None and isinstance(self.cache, RedisCache):
|
||||
self.cache.namespace = self.namespace
|
||||
|
||||
def get_cache_key(self, **kwargs) -> str:
|
||||
"""
|
||||
Get the cache key for the given arguments.
|
||||
|
||||
Args:
|
||||
**kwargs: kwargs to litellm.completion() or embedding()
|
||||
|
||||
Returns:
|
||||
str: The cache key generated from the arguments, or None if no cache key could be generated.
|
||||
"""
|
||||
cache_key = ""
|
||||
# verbose_logger.debug("\nGetting Cache key. Kwargs: %s", kwargs)
|
||||
|
||||
preset_cache_key = self._get_preset_cache_key_from_kwargs(**kwargs)
|
||||
if preset_cache_key is not None:
|
||||
verbose_logger.debug("\nReturning preset cache key: %s", preset_cache_key)
|
||||
return preset_cache_key
|
||||
|
||||
combined_kwargs = ModelParamHelper._get_all_llm_api_params()
|
||||
litellm_param_kwargs = all_litellm_params
|
||||
for param in kwargs:
|
||||
if param in combined_kwargs:
|
||||
param_value: Optional[str] = self._get_param_value(param, kwargs)
|
||||
if param_value is not None:
|
||||
cache_key += f"{str(param)}: {str(param_value)}"
|
||||
elif (
|
||||
param not in litellm_param_kwargs
|
||||
): # check if user passed in optional param - e.g. top_k
|
||||
if (
|
||||
litellm.enable_caching_on_provider_specific_optional_params is True
|
||||
): # feature flagged for now
|
||||
if kwargs[param] is None:
|
||||
continue # ignore None params
|
||||
param_value = kwargs[param]
|
||||
cache_key += f"{str(param)}: {str(param_value)}"
|
||||
|
||||
verbose_logger.debug("\nCreated cache key: %s", cache_key)
|
||||
hashed_cache_key = Cache._get_hashed_cache_key(cache_key)
|
||||
hashed_cache_key = self._add_namespace_to_cache_key(hashed_cache_key, **kwargs)
|
||||
self._set_preset_cache_key_in_kwargs(
|
||||
preset_cache_key=hashed_cache_key, **kwargs
|
||||
)
|
||||
return hashed_cache_key
|
||||
|
||||
def _get_param_value(
|
||||
self,
|
||||
param: str,
|
||||
kwargs: dict,
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
Get the value for the given param from kwargs
|
||||
"""
|
||||
if param == "model":
|
||||
return self._get_model_param_value(kwargs)
|
||||
elif param == "file":
|
||||
return self._get_file_param_value(kwargs)
|
||||
return kwargs[param]
|
||||
|
||||
def _get_model_param_value(self, kwargs: dict) -> str:
|
||||
"""
|
||||
Handles getting the value for the 'model' param from kwargs
|
||||
|
||||
1. If caching groups are set, then return the caching group as the model https://docs.litellm.ai/docs/routing#caching-across-model-groups
|
||||
2. Else if a model_group is set, then return the model_group as the model. This is used for all requests sent through the litellm.Router()
|
||||
3. Else use the `model` passed in kwargs
|
||||
"""
|
||||
metadata: Dict = kwargs.get("metadata", {}) or {}
|
||||
litellm_params: Dict = kwargs.get("litellm_params", {}) or {}
|
||||
metadata_in_litellm_params: Dict = litellm_params.get("metadata", {}) or {}
|
||||
model_group: Optional[str] = metadata.get(
|
||||
"model_group"
|
||||
) or metadata_in_litellm_params.get("model_group")
|
||||
caching_group = self._get_caching_group(metadata, model_group)
|
||||
return caching_group or model_group or kwargs["model"]
|
||||
|
||||
def _get_caching_group(
|
||||
self, metadata: dict, model_group: Optional[str]
|
||||
) -> Optional[str]:
|
||||
caching_groups: Optional[List] = metadata.get("caching_groups", [])
|
||||
if caching_groups:
|
||||
for group in caching_groups:
|
||||
if model_group in group:
|
||||
return str(group)
|
||||
return None
|
||||
|
||||
def _get_file_param_value(self, kwargs: dict) -> str:
|
||||
"""
|
||||
Handles getting the value for the 'file' param from kwargs. Used for `transcription` requests
|
||||
"""
|
||||
file = kwargs.get("file")
|
||||
metadata = kwargs.get("metadata", {})
|
||||
litellm_params = kwargs.get("litellm_params", {})
|
||||
return (
|
||||
metadata.get("file_checksum")
|
||||
or getattr(file, "name", None)
|
||||
or metadata.get("file_name")
|
||||
or litellm_params.get("file_name")
|
||||
)
|
||||
|
||||
def _get_preset_cache_key_from_kwargs(self, **kwargs) -> Optional[str]:
|
||||
"""
|
||||
Get the preset cache key from kwargs["litellm_params"]
|
||||
|
||||
We use _get_preset_cache_keys for two reasons
|
||||
|
||||
1. optional params like max_tokens, get transformed for bedrock -> max_new_tokens
|
||||
2. avoid doing duplicate / repeated work
|
||||
"""
|
||||
if kwargs:
|
||||
if "litellm_params" in kwargs:
|
||||
return kwargs["litellm_params"].get("preset_cache_key", None)
|
||||
return None
|
||||
|
||||
def _set_preset_cache_key_in_kwargs(self, preset_cache_key: str, **kwargs) -> None:
|
||||
"""
|
||||
Set the calculated cache key in kwargs
|
||||
|
||||
This is used to avoid doing duplicate / repeated work
|
||||
|
||||
Placed in kwargs["litellm_params"]
|
||||
"""
|
||||
if kwargs:
|
||||
if "litellm_params" in kwargs:
|
||||
kwargs["litellm_params"]["preset_cache_key"] = preset_cache_key
|
||||
|
||||
@staticmethod
|
||||
def _get_hashed_cache_key(cache_key: str) -> str:
|
||||
"""
|
||||
Get the hashed cache key for the given cache key.
|
||||
|
||||
Use hashlib to create a sha256 hash of the cache key
|
||||
|
||||
Args:
|
||||
cache_key (str): The cache key to hash.
|
||||
|
||||
Returns:
|
||||
str: The hashed cache key.
|
||||
"""
|
||||
hash_object = hashlib.sha256(cache_key.encode())
|
||||
# Hexadecimal representation of the hash
|
||||
hash_hex = hash_object.hexdigest()
|
||||
verbose_logger.debug("Hashed cache key (SHA-256): %s", hash_hex)
|
||||
return hash_hex
|
||||
|
||||
def _add_namespace_to_cache_key(self, hash_hex: str, **kwargs) -> str:
|
||||
"""
|
||||
If a redis namespace is provided, add it to the cache key
|
||||
|
||||
Args:
|
||||
hash_hex (str): The hashed cache key.
|
||||
**kwargs: Additional keyword arguments.
|
||||
|
||||
Returns:
|
||||
str: The final hashed cache key with the redis namespace.
|
||||
"""
|
||||
dynamic_cache_control: DynamicCacheControl = kwargs.get("cache", {})
|
||||
namespace = (
|
||||
dynamic_cache_control.get("namespace")
|
||||
or kwargs.get("metadata", {}).get("redis_namespace")
|
||||
or self.namespace
|
||||
)
|
||||
if namespace:
|
||||
hash_hex = f"{namespace}:{hash_hex}"
|
||||
verbose_logger.debug("Final hashed key: %s", hash_hex)
|
||||
return hash_hex
|
||||
|
||||
def generate_streaming_content(self, content):
|
||||
chunk_size = 5 # Adjust the chunk size as needed
|
||||
for i in range(0, len(content), chunk_size):
|
||||
yield {
|
||||
"choices": [
|
||||
{
|
||||
"delta": {
|
||||
"role": "assistant",
|
||||
"content": content[i : i + chunk_size],
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
time.sleep(CACHED_STREAMING_CHUNK_DELAY)
|
||||
|
||||
def _get_cache_logic(
|
||||
self,
|
||||
cached_result: Optional[Any],
|
||||
max_age: Optional[float],
|
||||
):
|
||||
"""
|
||||
Common get cache logic across sync + async implementations
|
||||
"""
|
||||
# Check if a timestamp was stored with the cached response
|
||||
if (
|
||||
cached_result is not None
|
||||
and isinstance(cached_result, dict)
|
||||
and "timestamp" in cached_result
|
||||
):
|
||||
timestamp = cached_result["timestamp"]
|
||||
current_time = time.time()
|
||||
|
||||
# Calculate age of the cached response
|
||||
response_age = current_time - timestamp
|
||||
|
||||
# Check if the cached response is older than the max-age
|
||||
if max_age is not None and response_age > max_age:
|
||||
return None # Cached response is too old
|
||||
|
||||
# If the response is fresh, or there's no max-age requirement, return the cached response
|
||||
# cached_response is in `b{} convert it to ModelResponse
|
||||
cached_response = cached_result.get("response")
|
||||
try:
|
||||
if isinstance(cached_response, dict):
|
||||
pass
|
||||
else:
|
||||
cached_response = json.loads(
|
||||
cached_response # type: ignore
|
||||
) # Convert string to dictionary
|
||||
except Exception:
|
||||
cached_response = ast.literal_eval(cached_response) # type: ignore
|
||||
return cached_response
|
||||
return cached_result
|
||||
|
||||
def get_cache(self, **kwargs):
|
||||
"""
|
||||
Retrieves the cached result for the given arguments.
|
||||
|
||||
Args:
|
||||
*args: args to litellm.completion() or embedding()
|
||||
**kwargs: kwargs to litellm.completion() or embedding()
|
||||
|
||||
Returns:
|
||||
The cached result if it exists, otherwise None.
|
||||
"""
|
||||
try: # never block execution
|
||||
if self.should_use_cache(**kwargs) is not True:
|
||||
return
|
||||
messages = kwargs.get("messages", [])
|
||||
if "cache_key" in kwargs:
|
||||
cache_key = kwargs["cache_key"]
|
||||
else:
|
||||
cache_key = self.get_cache_key(**kwargs)
|
||||
if cache_key is not None:
|
||||
cache_control_args: DynamicCacheControl = kwargs.get("cache", {})
|
||||
max_age = (
|
||||
cache_control_args.get("s-maxage")
|
||||
or cache_control_args.get("s-max-age")
|
||||
or float("inf")
|
||||
)
|
||||
cached_result = self.cache.get_cache(cache_key, messages=messages)
|
||||
cached_result = self.cache.get_cache(cache_key, messages=messages)
|
||||
return self._get_cache_logic(
|
||||
cached_result=cached_result, max_age=max_age
|
||||
)
|
||||
except Exception:
|
||||
print_verbose(f"An exception occurred: {traceback.format_exc()}")
|
||||
return None
|
||||
|
||||
async def async_get_cache(self, **kwargs):
|
||||
"""
|
||||
Async get cache implementation.
|
||||
|
||||
Used for embedding calls in async wrapper
|
||||
"""
|
||||
|
||||
try: # never block execution
|
||||
if self.should_use_cache(**kwargs) is not True:
|
||||
return
|
||||
|
||||
kwargs.get("messages", [])
|
||||
if "cache_key" in kwargs:
|
||||
cache_key = kwargs["cache_key"]
|
||||
else:
|
||||
cache_key = self.get_cache_key(**kwargs)
|
||||
if cache_key is not None:
|
||||
cache_control_args = kwargs.get("cache", {})
|
||||
max_age = cache_control_args.get(
|
||||
"s-max-age", cache_control_args.get("s-maxage", float("inf"))
|
||||
)
|
||||
cached_result = await self.cache.async_get_cache(cache_key, **kwargs)
|
||||
return self._get_cache_logic(
|
||||
cached_result=cached_result, max_age=max_age
|
||||
)
|
||||
except Exception:
|
||||
print_verbose(f"An exception occurred: {traceback.format_exc()}")
|
||||
return None
|
||||
|
||||
def _add_cache_logic(self, result, **kwargs):
|
||||
"""
|
||||
Common implementation across sync + async add_cache functions
|
||||
"""
|
||||
try:
|
||||
if "cache_key" in kwargs:
|
||||
cache_key = kwargs["cache_key"]
|
||||
else:
|
||||
cache_key = self.get_cache_key(**kwargs)
|
||||
if cache_key is not None:
|
||||
if isinstance(result, BaseModel):
|
||||
result = result.model_dump_json()
|
||||
|
||||
## DEFAULT TTL ##
|
||||
if self.ttl is not None:
|
||||
kwargs["ttl"] = self.ttl
|
||||
## Get Cache-Controls ##
|
||||
_cache_kwargs = kwargs.get("cache", None)
|
||||
if isinstance(_cache_kwargs, dict):
|
||||
for k, v in _cache_kwargs.items():
|
||||
if k == "ttl":
|
||||
kwargs["ttl"] = v
|
||||
|
||||
cached_data = {"timestamp": time.time(), "response": result}
|
||||
return cache_key, cached_data, kwargs
|
||||
else:
|
||||
raise Exception("cache key is None")
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
def add_cache(self, result, **kwargs):
|
||||
"""
|
||||
Adds a result to the cache.
|
||||
|
||||
Args:
|
||||
*args: args to litellm.completion() or embedding()
|
||||
**kwargs: kwargs to litellm.completion() or embedding()
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
try:
|
||||
if self.should_use_cache(**kwargs) is not True:
|
||||
return
|
||||
cache_key, cached_data, kwargs = self._add_cache_logic(
|
||||
result=result, **kwargs
|
||||
)
|
||||
self.cache.set_cache(cache_key, cached_data, **kwargs)
|
||||
except Exception as e:
|
||||
verbose_logger.exception(f"LiteLLM Cache: Excepton add_cache: {str(e)}")
|
||||
|
||||
async def async_add_cache(self, result, **kwargs):
|
||||
"""
|
||||
Async implementation of add_cache
|
||||
"""
|
||||
try:
|
||||
if self.should_use_cache(**kwargs) is not True:
|
||||
return
|
||||
if self.type == "redis" and self.redis_flush_size is not None:
|
||||
# high traffic - fill in results in memory and then flush
|
||||
await self.batch_cache_write(result, **kwargs)
|
||||
else:
|
||||
cache_key, cached_data, kwargs = self._add_cache_logic(
|
||||
result=result, **kwargs
|
||||
)
|
||||
|
||||
await self.cache.async_set_cache(cache_key, cached_data, **kwargs)
|
||||
except Exception as e:
|
||||
verbose_logger.exception(f"LiteLLM Cache: Excepton add_cache: {str(e)}")
|
||||
|
||||
async def async_add_cache_pipeline(self, result, **kwargs):
|
||||
"""
|
||||
Async implementation of add_cache for Embedding calls
|
||||
|
||||
Does a bulk write, to prevent using too many clients
|
||||
"""
|
||||
try:
|
||||
if self.should_use_cache(**kwargs) is not True:
|
||||
return
|
||||
|
||||
# set default ttl if not set
|
||||
if self.ttl is not None:
|
||||
kwargs["ttl"] = self.ttl
|
||||
|
||||
cache_list = []
|
||||
for idx, i in enumerate(kwargs["input"]):
|
||||
preset_cache_key = self.get_cache_key(**{**kwargs, "input": i})
|
||||
kwargs["cache_key"] = preset_cache_key
|
||||
embedding_response = result.data[idx]
|
||||
cache_key, cached_data, kwargs = self._add_cache_logic(
|
||||
result=embedding_response,
|
||||
**kwargs,
|
||||
)
|
||||
cache_list.append((cache_key, cached_data))
|
||||
|
||||
await self.cache.async_set_cache_pipeline(cache_list=cache_list, **kwargs)
|
||||
# if async_set_cache_pipeline:
|
||||
# await async_set_cache_pipeline(cache_list=cache_list, **kwargs)
|
||||
# else:
|
||||
# tasks = []
|
||||
# for val in cache_list:
|
||||
# tasks.append(self.cache.async_set_cache(val[0], val[1], **kwargs))
|
||||
# await asyncio.gather(*tasks)
|
||||
except Exception as e:
|
||||
verbose_logger.exception(f"LiteLLM Cache: Excepton add_cache: {str(e)}")
|
||||
|
||||
def should_use_cache(self, **kwargs):
|
||||
"""
|
||||
Returns true if we should use the cache for LLM API calls
|
||||
|
||||
If cache is default_on then this is True
|
||||
If cache is default_off then this is only true when user has opted in to use cache
|
||||
"""
|
||||
if self.mode == CacheMode.default_on:
|
||||
return True
|
||||
|
||||
# when mode == default_off -> Cache is opt in only
|
||||
_cache = kwargs.get("cache", None)
|
||||
verbose_logger.debug("should_use_cache: kwargs: %s; _cache: %s", kwargs, _cache)
|
||||
if _cache and isinstance(_cache, dict):
|
||||
if _cache.get("use-cache", False) is True:
|
||||
return True
|
||||
return False
|
||||
|
||||
async def batch_cache_write(self, result, **kwargs):
|
||||
cache_key, cached_data, kwargs = self._add_cache_logic(result=result, **kwargs)
|
||||
await self.cache.batch_cache_write(cache_key, cached_data, **kwargs)
|
||||
|
||||
async def ping(self):
|
||||
cache_ping = getattr(self.cache, "ping")
|
||||
if cache_ping:
|
||||
return await cache_ping()
|
||||
return None
|
||||
|
||||
async def delete_cache_keys(self, keys):
|
||||
cache_delete_cache_keys = getattr(self.cache, "delete_cache_keys")
|
||||
if cache_delete_cache_keys:
|
||||
return await cache_delete_cache_keys(keys)
|
||||
return None
|
||||
|
||||
async def disconnect(self):
|
||||
if hasattr(self.cache, "disconnect"):
|
||||
await self.cache.disconnect()
|
||||
|
||||
def _supports_async(self) -> bool:
|
||||
"""
|
||||
Internal method to check if the cache type supports async get/set operations
|
||||
|
||||
Only S3 Cache Does NOT support async operations
|
||||
|
||||
"""
|
||||
if self.type and self.type == LiteLLMCacheType.S3:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def enable_cache(
|
||||
type: Optional[LiteLLMCacheType] = LiteLLMCacheType.LOCAL,
|
||||
host: Optional[str] = None,
|
||||
port: Optional[str] = None,
|
||||
password: Optional[str] = None,
|
||||
supported_call_types: Optional[List[CachingSupportedCallTypes]] = [
|
||||
"completion",
|
||||
"acompletion",
|
||||
"embedding",
|
||||
"aembedding",
|
||||
"atranscription",
|
||||
"transcription",
|
||||
"atext_completion",
|
||||
"text_completion",
|
||||
"arerank",
|
||||
"rerank",
|
||||
],
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Enable cache with the specified configuration.
|
||||
|
||||
Args:
|
||||
type (Optional[Literal["local", "redis", "s3", "disk"]]): The type of cache to enable. Defaults to "local".
|
||||
host (Optional[str]): The host address of the cache server. Defaults to None.
|
||||
port (Optional[str]): The port number of the cache server. Defaults to None.
|
||||
password (Optional[str]): The password for the cache server. Defaults to None.
|
||||
supported_call_types (Optional[List[Literal["completion", "acompletion", "embedding", "aembedding"]]]):
|
||||
The supported call types for the cache. Defaults to ["completion", "acompletion", "embedding", "aembedding"].
|
||||
**kwargs: Additional keyword arguments.
|
||||
|
||||
Returns:
|
||||
None
|
||||
|
||||
Raises:
|
||||
None
|
||||
"""
|
||||
print_verbose("LiteLLM: Enabling Cache")
|
||||
if "cache" not in litellm.input_callback:
|
||||
litellm.input_callback.append("cache")
|
||||
if "cache" not in litellm.success_callback:
|
||||
litellm.logging_callback_manager.add_litellm_success_callback("cache")
|
||||
if "cache" not in litellm._async_success_callback:
|
||||
litellm.logging_callback_manager.add_litellm_async_success_callback("cache")
|
||||
|
||||
if litellm.cache is None:
|
||||
litellm.cache = Cache(
|
||||
type=type,
|
||||
host=host,
|
||||
port=port,
|
||||
password=password,
|
||||
supported_call_types=supported_call_types,
|
||||
**kwargs,
|
||||
)
|
||||
print_verbose(f"LiteLLM: Cache enabled, litellm.cache={litellm.cache}")
|
||||
print_verbose(f"LiteLLM Cache: {vars(litellm.cache)}")
|
||||
|
||||
|
||||
def update_cache(
|
||||
type: Optional[LiteLLMCacheType] = LiteLLMCacheType.LOCAL,
|
||||
host: Optional[str] = None,
|
||||
port: Optional[str] = None,
|
||||
password: Optional[str] = None,
|
||||
supported_call_types: Optional[List[CachingSupportedCallTypes]] = [
|
||||
"completion",
|
||||
"acompletion",
|
||||
"embedding",
|
||||
"aembedding",
|
||||
"atranscription",
|
||||
"transcription",
|
||||
"atext_completion",
|
||||
"text_completion",
|
||||
"arerank",
|
||||
"rerank",
|
||||
],
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Update the cache for LiteLLM.
|
||||
|
||||
Args:
|
||||
type (Optional[Literal["local", "redis", "s3", "disk"]]): The type of cache. Defaults to "local".
|
||||
host (Optional[str]): The host of the cache. Defaults to None.
|
||||
port (Optional[str]): The port of the cache. Defaults to None.
|
||||
password (Optional[str]): The password for the cache. Defaults to None.
|
||||
supported_call_types (Optional[List[Literal["completion", "acompletion", "embedding", "aembedding"]]]):
|
||||
The supported call types for the cache. Defaults to ["completion", "acompletion", "embedding", "aembedding"].
|
||||
**kwargs: Additional keyword arguments for the cache.
|
||||
|
||||
Returns:
|
||||
None
|
||||
|
||||
"""
|
||||
print_verbose("LiteLLM: Updating Cache")
|
||||
litellm.cache = Cache(
|
||||
type=type,
|
||||
host=host,
|
||||
port=port,
|
||||
password=password,
|
||||
supported_call_types=supported_call_types,
|
||||
**kwargs,
|
||||
)
|
||||
print_verbose(f"LiteLLM: Cache Updated, litellm.cache={litellm.cache}")
|
||||
print_verbose(f"LiteLLM Cache: {vars(litellm.cache)}")
|
||||
|
||||
|
||||
def disable_cache():
|
||||
"""
|
||||
Disable the cache used by LiteLLM.
|
||||
|
||||
This function disables the cache used by the LiteLLM module. It removes the cache-related callbacks from the input_callback, success_callback, and _async_success_callback lists. It also sets the litellm.cache attribute to None.
|
||||
|
||||
Parameters:
|
||||
None
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
from contextlib import suppress
|
||||
|
||||
print_verbose("LiteLLM: Disabling Cache")
|
||||
with suppress(ValueError):
|
||||
litellm.input_callback.remove("cache")
|
||||
litellm.success_callback.remove("cache")
|
||||
litellm._async_success_callback.remove("cache")
|
||||
|
||||
litellm.cache = None
|
||||
print_verbose(f"LiteLLM: Cache disabled, litellm.cache={litellm.cache}")
|
||||
@@ -0,0 +1,906 @@
|
||||
"""
|
||||
This contains LLMCachingHandler
|
||||
|
||||
This exposes two methods:
|
||||
- async_get_cache
|
||||
- async_set_cache
|
||||
|
||||
This file is a wrapper around caching.py
|
||||
|
||||
This class is used to handle caching logic specific for LLM API requests (completion / embedding / text_completion / transcription etc)
|
||||
|
||||
It utilizes the (RedisCache, s3Cache, RedisSemanticCache, QdrantSemanticCache, InMemoryCache, DiskCache) based on what the user has setup
|
||||
|
||||
In each method it will call the appropriate method from caching.py
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import datetime
|
||||
import inspect
|
||||
import threading
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
AsyncGenerator,
|
||||
Callable,
|
||||
Dict,
|
||||
Generator,
|
||||
List,
|
||||
Optional,
|
||||
Tuple,
|
||||
Union,
|
||||
)
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
import litellm
|
||||
from litellm._logging import print_verbose, verbose_logger
|
||||
from litellm.caching.caching import S3Cache
|
||||
from litellm.litellm_core_utils.logging_utils import (
|
||||
_assemble_complete_response_from_streaming_chunks,
|
||||
)
|
||||
from litellm.types.rerank import RerankResponse
|
||||
from litellm.types.utils import (
|
||||
CallTypes,
|
||||
Embedding,
|
||||
EmbeddingResponse,
|
||||
ModelResponse,
|
||||
TextCompletionResponse,
|
||||
TranscriptionResponse,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
|
||||
from litellm.utils import CustomStreamWrapper
|
||||
else:
|
||||
LiteLLMLoggingObj = Any
|
||||
CustomStreamWrapper = Any
|
||||
|
||||
|
||||
class CachingHandlerResponse(BaseModel):
|
||||
"""
|
||||
This is the response object for the caching handler. We need to separate embedding cached responses and (completion / text_completion / transcription) cached responses
|
||||
|
||||
For embeddings there can be a cache hit for some of the inputs in the list and a cache miss for others
|
||||
"""
|
||||
|
||||
cached_result: Optional[Any] = None
|
||||
final_embedding_cached_response: Optional[EmbeddingResponse] = None
|
||||
embedding_all_elements_cache_hit: bool = False # this is set to True when all elements in the list have a cache hit in the embedding cache, if true return the final_embedding_cached_response no need to make an API call
|
||||
|
||||
|
||||
class LLMCachingHandler:
|
||||
def __init__(
|
||||
self,
|
||||
original_function: Callable,
|
||||
request_kwargs: Dict[str, Any],
|
||||
start_time: datetime.datetime,
|
||||
):
|
||||
self.async_streaming_chunks: List[ModelResponse] = []
|
||||
self.sync_streaming_chunks: List[ModelResponse] = []
|
||||
self.request_kwargs = request_kwargs
|
||||
self.original_function = original_function
|
||||
self.start_time = start_time
|
||||
pass
|
||||
|
||||
async def _async_get_cache(
|
||||
self,
|
||||
model: str,
|
||||
original_function: Callable,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
start_time: datetime.datetime,
|
||||
call_type: str,
|
||||
kwargs: Dict[str, Any],
|
||||
args: Optional[Tuple[Any, ...]] = None,
|
||||
) -> CachingHandlerResponse:
|
||||
"""
|
||||
Internal method to get from the cache.
|
||||
Handles different call types (embeddings, chat/completions, text_completion, transcription)
|
||||
and accordingly returns the cached response
|
||||
|
||||
Args:
|
||||
model: str:
|
||||
original_function: Callable:
|
||||
logging_obj: LiteLLMLoggingObj:
|
||||
start_time: datetime.datetime:
|
||||
call_type: str:
|
||||
kwargs: Dict[str, Any]:
|
||||
args: Optional[Tuple[Any, ...]] = None:
|
||||
|
||||
|
||||
Returns:
|
||||
CachingHandlerResponse:
|
||||
Raises:
|
||||
None
|
||||
"""
|
||||
from litellm.utils import CustomStreamWrapper
|
||||
|
||||
args = args or ()
|
||||
|
||||
final_embedding_cached_response: Optional[EmbeddingResponse] = None
|
||||
embedding_all_elements_cache_hit: bool = False
|
||||
cached_result: Optional[Any] = None
|
||||
if (
|
||||
(kwargs.get("caching", None) is None and litellm.cache is not None)
|
||||
or kwargs.get("caching", False) is True
|
||||
) and (
|
||||
kwargs.get("cache", {}).get("no-cache", False) is not True
|
||||
): # allow users to control returning cached responses from the completion function
|
||||
if litellm.cache is not None and self._is_call_type_supported_by_cache(
|
||||
original_function=original_function
|
||||
):
|
||||
verbose_logger.debug("Checking Cache")
|
||||
cached_result = await self._retrieve_from_cache(
|
||||
call_type=call_type,
|
||||
kwargs=kwargs,
|
||||
args=args,
|
||||
)
|
||||
|
||||
if cached_result is not None and not isinstance(cached_result, list):
|
||||
verbose_logger.debug("Cache Hit!")
|
||||
cache_hit = True
|
||||
end_time = datetime.datetime.now()
|
||||
model, _, _, _ = litellm.get_llm_provider(
|
||||
model=model,
|
||||
custom_llm_provider=kwargs.get("custom_llm_provider", None),
|
||||
api_base=kwargs.get("api_base", None),
|
||||
api_key=kwargs.get("api_key", None),
|
||||
)
|
||||
self._update_litellm_logging_obj_environment(
|
||||
logging_obj=logging_obj,
|
||||
model=model,
|
||||
kwargs=kwargs,
|
||||
cached_result=cached_result,
|
||||
is_async=True,
|
||||
)
|
||||
|
||||
call_type = original_function.__name__
|
||||
|
||||
cached_result = self._convert_cached_result_to_model_response(
|
||||
cached_result=cached_result,
|
||||
call_type=call_type,
|
||||
kwargs=kwargs,
|
||||
logging_obj=logging_obj,
|
||||
model=model,
|
||||
custom_llm_provider=kwargs.get("custom_llm_provider", None),
|
||||
args=args,
|
||||
)
|
||||
if kwargs.get("stream", False) is False:
|
||||
# LOG SUCCESS
|
||||
self._async_log_cache_hit_on_callbacks(
|
||||
logging_obj=logging_obj,
|
||||
cached_result=cached_result,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
cache_hit=cache_hit,
|
||||
)
|
||||
cache_key = litellm.cache._get_preset_cache_key_from_kwargs(
|
||||
**kwargs
|
||||
)
|
||||
if (
|
||||
isinstance(cached_result, BaseModel)
|
||||
or isinstance(cached_result, CustomStreamWrapper)
|
||||
) and hasattr(cached_result, "_hidden_params"):
|
||||
cached_result._hidden_params["cache_key"] = cache_key # type: ignore
|
||||
return CachingHandlerResponse(cached_result=cached_result)
|
||||
elif (
|
||||
call_type == CallTypes.aembedding.value
|
||||
and cached_result is not None
|
||||
and isinstance(cached_result, list)
|
||||
and litellm.cache is not None
|
||||
and not isinstance(
|
||||
litellm.cache.cache, S3Cache
|
||||
) # s3 doesn't support bulk writing. Exclude.
|
||||
):
|
||||
(
|
||||
final_embedding_cached_response,
|
||||
embedding_all_elements_cache_hit,
|
||||
) = self._process_async_embedding_cached_response(
|
||||
final_embedding_cached_response=final_embedding_cached_response,
|
||||
cached_result=cached_result,
|
||||
kwargs=kwargs,
|
||||
logging_obj=logging_obj,
|
||||
start_time=start_time,
|
||||
model=model,
|
||||
)
|
||||
return CachingHandlerResponse(
|
||||
final_embedding_cached_response=final_embedding_cached_response,
|
||||
embedding_all_elements_cache_hit=embedding_all_elements_cache_hit,
|
||||
)
|
||||
verbose_logger.debug(f"CACHE RESULT: {cached_result}")
|
||||
return CachingHandlerResponse(
|
||||
cached_result=cached_result,
|
||||
final_embedding_cached_response=final_embedding_cached_response,
|
||||
)
|
||||
|
||||
def _sync_get_cache(
|
||||
self,
|
||||
model: str,
|
||||
original_function: Callable,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
start_time: datetime.datetime,
|
||||
call_type: str,
|
||||
kwargs: Dict[str, Any],
|
||||
args: Optional[Tuple[Any, ...]] = None,
|
||||
) -> CachingHandlerResponse:
|
||||
from litellm.utils import CustomStreamWrapper
|
||||
|
||||
args = args or ()
|
||||
new_kwargs = kwargs.copy()
|
||||
new_kwargs.update(
|
||||
convert_args_to_kwargs(
|
||||
self.original_function,
|
||||
args,
|
||||
)
|
||||
)
|
||||
cached_result: Optional[Any] = None
|
||||
if litellm.cache is not None and self._is_call_type_supported_by_cache(
|
||||
original_function=original_function
|
||||
):
|
||||
print_verbose("Checking Cache")
|
||||
cached_result = litellm.cache.get_cache(**new_kwargs)
|
||||
if cached_result is not None:
|
||||
if "detail" in cached_result:
|
||||
# implies an error occurred
|
||||
pass
|
||||
else:
|
||||
call_type = original_function.__name__
|
||||
cached_result = self._convert_cached_result_to_model_response(
|
||||
cached_result=cached_result,
|
||||
call_type=call_type,
|
||||
kwargs=kwargs,
|
||||
logging_obj=logging_obj,
|
||||
model=model,
|
||||
custom_llm_provider=kwargs.get("custom_llm_provider", None),
|
||||
args=args,
|
||||
)
|
||||
|
||||
# LOG SUCCESS
|
||||
cache_hit = True
|
||||
end_time = datetime.datetime.now()
|
||||
(
|
||||
model,
|
||||
custom_llm_provider,
|
||||
dynamic_api_key,
|
||||
api_base,
|
||||
) = litellm.get_llm_provider(
|
||||
model=model or "",
|
||||
custom_llm_provider=kwargs.get("custom_llm_provider", None),
|
||||
api_base=kwargs.get("api_base", None),
|
||||
api_key=kwargs.get("api_key", None),
|
||||
)
|
||||
self._update_litellm_logging_obj_environment(
|
||||
logging_obj=logging_obj,
|
||||
model=model,
|
||||
kwargs=kwargs,
|
||||
cached_result=cached_result,
|
||||
is_async=False,
|
||||
)
|
||||
|
||||
threading.Thread(
|
||||
target=logging_obj.success_handler,
|
||||
args=(cached_result, start_time, end_time, cache_hit),
|
||||
).start()
|
||||
cache_key = litellm.cache._get_preset_cache_key_from_kwargs(
|
||||
**kwargs
|
||||
)
|
||||
if (
|
||||
isinstance(cached_result, BaseModel)
|
||||
or isinstance(cached_result, CustomStreamWrapper)
|
||||
) and hasattr(cached_result, "_hidden_params"):
|
||||
cached_result._hidden_params["cache_key"] = cache_key # type: ignore
|
||||
return CachingHandlerResponse(cached_result=cached_result)
|
||||
return CachingHandlerResponse(cached_result=cached_result)
|
||||
|
||||
def _process_async_embedding_cached_response(
|
||||
self,
|
||||
final_embedding_cached_response: Optional[EmbeddingResponse],
|
||||
cached_result: List[Optional[Dict[str, Any]]],
|
||||
kwargs: Dict[str, Any],
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
start_time: datetime.datetime,
|
||||
model: str,
|
||||
) -> Tuple[Optional[EmbeddingResponse], bool]:
|
||||
"""
|
||||
Returns the final embedding cached response and a boolean indicating if all elements in the list have a cache hit
|
||||
|
||||
For embedding responses, there can be a cache hit for some of the inputs in the list and a cache miss for others
|
||||
This function processes the cached embedding responses and returns the final embedding cached response and a boolean indicating if all elements in the list have a cache hit
|
||||
|
||||
Args:
|
||||
final_embedding_cached_response: Optional[EmbeddingResponse]:
|
||||
cached_result: List[Optional[Dict[str, Any]]]:
|
||||
kwargs: Dict[str, Any]:
|
||||
logging_obj: LiteLLMLoggingObj:
|
||||
start_time: datetime.datetime:
|
||||
model: str:
|
||||
|
||||
Returns:
|
||||
Tuple[Optional[EmbeddingResponse], bool]:
|
||||
Returns the final embedding cached response and a boolean indicating if all elements in the list have a cache hit
|
||||
|
||||
|
||||
"""
|
||||
embedding_all_elements_cache_hit: bool = False
|
||||
remaining_list = []
|
||||
non_null_list = []
|
||||
for idx, cr in enumerate(cached_result):
|
||||
if cr is None:
|
||||
remaining_list.append(kwargs["input"][idx])
|
||||
else:
|
||||
non_null_list.append((idx, cr))
|
||||
original_kwargs_input = kwargs["input"]
|
||||
kwargs["input"] = remaining_list
|
||||
if len(non_null_list) > 0:
|
||||
print_verbose(f"EMBEDDING CACHE HIT! - {len(non_null_list)}")
|
||||
final_embedding_cached_response = EmbeddingResponse(
|
||||
model=kwargs.get("model"),
|
||||
data=[None] * len(original_kwargs_input),
|
||||
)
|
||||
final_embedding_cached_response._hidden_params["cache_hit"] = True
|
||||
|
||||
for val in non_null_list:
|
||||
idx, cr = val # (idx, cr) tuple
|
||||
if cr is not None:
|
||||
final_embedding_cached_response.data[idx] = Embedding(
|
||||
embedding=cr["embedding"],
|
||||
index=idx,
|
||||
object="embedding",
|
||||
)
|
||||
if len(remaining_list) == 0:
|
||||
# LOG SUCCESS
|
||||
cache_hit = True
|
||||
embedding_all_elements_cache_hit = True
|
||||
end_time = datetime.datetime.now()
|
||||
(
|
||||
model,
|
||||
custom_llm_provider,
|
||||
dynamic_api_key,
|
||||
api_base,
|
||||
) = litellm.get_llm_provider(
|
||||
model=model,
|
||||
custom_llm_provider=kwargs.get("custom_llm_provider", None),
|
||||
api_base=kwargs.get("api_base", None),
|
||||
api_key=kwargs.get("api_key", None),
|
||||
)
|
||||
|
||||
self._update_litellm_logging_obj_environment(
|
||||
logging_obj=logging_obj,
|
||||
model=model,
|
||||
kwargs=kwargs,
|
||||
cached_result=final_embedding_cached_response,
|
||||
is_async=True,
|
||||
is_embedding=True,
|
||||
)
|
||||
self._async_log_cache_hit_on_callbacks(
|
||||
logging_obj=logging_obj,
|
||||
cached_result=final_embedding_cached_response,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
cache_hit=cache_hit,
|
||||
)
|
||||
return final_embedding_cached_response, embedding_all_elements_cache_hit
|
||||
return final_embedding_cached_response, embedding_all_elements_cache_hit
|
||||
|
||||
def _combine_cached_embedding_response_with_api_result(
|
||||
self,
|
||||
_caching_handler_response: CachingHandlerResponse,
|
||||
embedding_response: EmbeddingResponse,
|
||||
start_time: datetime.datetime,
|
||||
end_time: datetime.datetime,
|
||||
) -> EmbeddingResponse:
|
||||
"""
|
||||
Combines the cached embedding response with the API EmbeddingResponse
|
||||
|
||||
For caching there can be a cache hit for some of the inputs in the list and a cache miss for others
|
||||
This function combines the cached embedding response with the API EmbeddingResponse
|
||||
|
||||
Args:
|
||||
caching_handler_response: CachingHandlerResponse:
|
||||
embedding_response: EmbeddingResponse:
|
||||
|
||||
Returns:
|
||||
EmbeddingResponse:
|
||||
"""
|
||||
if _caching_handler_response.final_embedding_cached_response is None:
|
||||
return embedding_response
|
||||
|
||||
idx = 0
|
||||
final_data_list = []
|
||||
for item in _caching_handler_response.final_embedding_cached_response.data:
|
||||
if item is None and embedding_response.data is not None:
|
||||
final_data_list.append(embedding_response.data[idx])
|
||||
idx += 1
|
||||
else:
|
||||
final_data_list.append(item)
|
||||
|
||||
_caching_handler_response.final_embedding_cached_response.data = final_data_list
|
||||
_caching_handler_response.final_embedding_cached_response._hidden_params[
|
||||
"cache_hit"
|
||||
] = True
|
||||
_caching_handler_response.final_embedding_cached_response._response_ms = (
|
||||
end_time - start_time
|
||||
).total_seconds() * 1000
|
||||
return _caching_handler_response.final_embedding_cached_response
|
||||
|
||||
def _async_log_cache_hit_on_callbacks(
|
||||
self,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
cached_result: Any,
|
||||
start_time: datetime.datetime,
|
||||
end_time: datetime.datetime,
|
||||
cache_hit: bool,
|
||||
):
|
||||
"""
|
||||
Helper function to log the success of a cached result on callbacks
|
||||
|
||||
Args:
|
||||
logging_obj (LiteLLMLoggingObj): The logging object.
|
||||
cached_result: The cached result.
|
||||
start_time (datetime): The start time of the operation.
|
||||
end_time (datetime): The end time of the operation.
|
||||
cache_hit (bool): Whether it was a cache hit.
|
||||
"""
|
||||
asyncio.create_task(
|
||||
logging_obj.async_success_handler(
|
||||
cached_result, start_time, end_time, cache_hit
|
||||
)
|
||||
)
|
||||
threading.Thread(
|
||||
target=logging_obj.success_handler,
|
||||
args=(cached_result, start_time, end_time, cache_hit),
|
||||
).start()
|
||||
|
||||
async def _retrieve_from_cache(
|
||||
self, call_type: str, kwargs: Dict[str, Any], args: Tuple[Any, ...]
|
||||
) -> Optional[Any]:
|
||||
"""
|
||||
Internal method to
|
||||
- get cache key
|
||||
- check what type of cache is used - Redis, RedisSemantic, Qdrant, S3
|
||||
- async get cache value
|
||||
- return the cached value
|
||||
|
||||
Args:
|
||||
call_type: str:
|
||||
kwargs: Dict[str, Any]:
|
||||
args: Optional[Tuple[Any, ...]] = None:
|
||||
|
||||
Returns:
|
||||
Optional[Any]:
|
||||
Raises:
|
||||
None
|
||||
"""
|
||||
if litellm.cache is None:
|
||||
return None
|
||||
|
||||
new_kwargs = kwargs.copy()
|
||||
new_kwargs.update(
|
||||
convert_args_to_kwargs(
|
||||
self.original_function,
|
||||
args,
|
||||
)
|
||||
)
|
||||
cached_result: Optional[Any] = None
|
||||
if call_type == CallTypes.aembedding.value and isinstance(
|
||||
new_kwargs["input"], list
|
||||
):
|
||||
tasks = []
|
||||
for idx, i in enumerate(new_kwargs["input"]):
|
||||
preset_cache_key = litellm.cache.get_cache_key(
|
||||
**{**new_kwargs, "input": i}
|
||||
)
|
||||
tasks.append(litellm.cache.async_get_cache(cache_key=preset_cache_key))
|
||||
cached_result = await asyncio.gather(*tasks)
|
||||
## check if cached result is None ##
|
||||
if cached_result is not None and isinstance(cached_result, list):
|
||||
# set cached_result to None if all elements are None
|
||||
if all(result is None for result in cached_result):
|
||||
cached_result = None
|
||||
else:
|
||||
if litellm.cache._supports_async() is True:
|
||||
cached_result = await litellm.cache.async_get_cache(**new_kwargs)
|
||||
else: # for s3 caching. [NOT RECOMMENDED IN PROD - this will slow down responses since boto3 is sync]
|
||||
cached_result = litellm.cache.get_cache(**new_kwargs)
|
||||
return cached_result
|
||||
|
||||
def _convert_cached_result_to_model_response(
|
||||
self,
|
||||
cached_result: Any,
|
||||
call_type: str,
|
||||
kwargs: Dict[str, Any],
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
model: str,
|
||||
args: Tuple[Any, ...],
|
||||
custom_llm_provider: Optional[str] = None,
|
||||
) -> Optional[
|
||||
Union[
|
||||
ModelResponse,
|
||||
TextCompletionResponse,
|
||||
EmbeddingResponse,
|
||||
RerankResponse,
|
||||
TranscriptionResponse,
|
||||
CustomStreamWrapper,
|
||||
]
|
||||
]:
|
||||
"""
|
||||
Internal method to process the cached result
|
||||
|
||||
Checks the call type and converts the cached result to the appropriate model response object
|
||||
example if call type is text_completion -> returns TextCompletionResponse object
|
||||
|
||||
Args:
|
||||
cached_result: Any:
|
||||
call_type: str:
|
||||
kwargs: Dict[str, Any]:
|
||||
logging_obj: LiteLLMLoggingObj:
|
||||
model: str:
|
||||
custom_llm_provider: Optional[str] = None:
|
||||
args: Optional[Tuple[Any, ...]] = None:
|
||||
|
||||
Returns:
|
||||
Optional[Any]:
|
||||
"""
|
||||
from litellm.utils import convert_to_model_response_object
|
||||
|
||||
if (
|
||||
call_type == CallTypes.acompletion.value
|
||||
or call_type == CallTypes.completion.value
|
||||
) and isinstance(cached_result, dict):
|
||||
if kwargs.get("stream", False) is True:
|
||||
cached_result = self._convert_cached_stream_response(
|
||||
cached_result=cached_result,
|
||||
call_type=call_type,
|
||||
logging_obj=logging_obj,
|
||||
model=model,
|
||||
)
|
||||
else:
|
||||
cached_result = convert_to_model_response_object(
|
||||
response_object=cached_result,
|
||||
model_response_object=ModelResponse(),
|
||||
)
|
||||
if (
|
||||
call_type == CallTypes.atext_completion.value
|
||||
or call_type == CallTypes.text_completion.value
|
||||
) and isinstance(cached_result, dict):
|
||||
if kwargs.get("stream", False) is True:
|
||||
cached_result = self._convert_cached_stream_response(
|
||||
cached_result=cached_result,
|
||||
call_type=call_type,
|
||||
logging_obj=logging_obj,
|
||||
model=model,
|
||||
)
|
||||
else:
|
||||
cached_result = TextCompletionResponse(**cached_result)
|
||||
elif (
|
||||
call_type == CallTypes.aembedding.value
|
||||
or call_type == CallTypes.embedding.value
|
||||
) and isinstance(cached_result, dict):
|
||||
cached_result = convert_to_model_response_object(
|
||||
response_object=cached_result,
|
||||
model_response_object=EmbeddingResponse(),
|
||||
response_type="embedding",
|
||||
)
|
||||
|
||||
elif (
|
||||
call_type == CallTypes.arerank.value or call_type == CallTypes.rerank.value
|
||||
) and isinstance(cached_result, dict):
|
||||
cached_result = convert_to_model_response_object(
|
||||
response_object=cached_result,
|
||||
model_response_object=None,
|
||||
response_type="rerank",
|
||||
)
|
||||
elif (
|
||||
call_type == CallTypes.atranscription.value
|
||||
or call_type == CallTypes.transcription.value
|
||||
) and isinstance(cached_result, dict):
|
||||
hidden_params = {
|
||||
"model": "whisper-1",
|
||||
"custom_llm_provider": custom_llm_provider,
|
||||
"cache_hit": True,
|
||||
}
|
||||
cached_result = convert_to_model_response_object(
|
||||
response_object=cached_result,
|
||||
model_response_object=TranscriptionResponse(),
|
||||
response_type="audio_transcription",
|
||||
hidden_params=hidden_params,
|
||||
)
|
||||
|
||||
if (
|
||||
hasattr(cached_result, "_hidden_params")
|
||||
and cached_result._hidden_params is not None
|
||||
and isinstance(cached_result._hidden_params, dict)
|
||||
):
|
||||
cached_result._hidden_params["cache_hit"] = True
|
||||
return cached_result
|
||||
|
||||
def _convert_cached_stream_response(
|
||||
self,
|
||||
cached_result: Any,
|
||||
call_type: str,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
model: str,
|
||||
) -> CustomStreamWrapper:
|
||||
from litellm.utils import (
|
||||
CustomStreamWrapper,
|
||||
convert_to_streaming_response,
|
||||
convert_to_streaming_response_async,
|
||||
)
|
||||
|
||||
_stream_cached_result: Union[AsyncGenerator, Generator]
|
||||
if (
|
||||
call_type == CallTypes.acompletion.value
|
||||
or call_type == CallTypes.atext_completion.value
|
||||
):
|
||||
_stream_cached_result = convert_to_streaming_response_async(
|
||||
response_object=cached_result,
|
||||
)
|
||||
else:
|
||||
_stream_cached_result = convert_to_streaming_response(
|
||||
response_object=cached_result,
|
||||
)
|
||||
return CustomStreamWrapper(
|
||||
completion_stream=_stream_cached_result,
|
||||
model=model,
|
||||
custom_llm_provider="cached_response",
|
||||
logging_obj=logging_obj,
|
||||
)
|
||||
|
||||
async def async_set_cache(
|
||||
self,
|
||||
result: Any,
|
||||
original_function: Callable,
|
||||
kwargs: Dict[str, Any],
|
||||
args: Optional[Tuple[Any, ...]] = None,
|
||||
):
|
||||
"""
|
||||
Internal method to check the type of the result & cache used and adds the result to the cache accordingly
|
||||
|
||||
Args:
|
||||
result: Any:
|
||||
original_function: Callable:
|
||||
kwargs: Dict[str, Any]:
|
||||
args: Optional[Tuple[Any, ...]] = None:
|
||||
|
||||
Returns:
|
||||
None
|
||||
Raises:
|
||||
None
|
||||
"""
|
||||
if litellm.cache is None:
|
||||
return
|
||||
|
||||
new_kwargs = kwargs.copy()
|
||||
new_kwargs.update(
|
||||
convert_args_to_kwargs(
|
||||
original_function,
|
||||
args,
|
||||
)
|
||||
)
|
||||
# [OPTIONAL] ADD TO CACHE
|
||||
if self._should_store_result_in_cache(
|
||||
original_function=original_function, kwargs=new_kwargs
|
||||
):
|
||||
if (
|
||||
isinstance(result, litellm.ModelResponse)
|
||||
or isinstance(result, litellm.EmbeddingResponse)
|
||||
or isinstance(result, TranscriptionResponse)
|
||||
or isinstance(result, RerankResponse)
|
||||
):
|
||||
if (
|
||||
isinstance(result, EmbeddingResponse)
|
||||
and isinstance(new_kwargs["input"], list)
|
||||
and litellm.cache is not None
|
||||
and not isinstance(
|
||||
litellm.cache.cache, S3Cache
|
||||
) # s3 doesn't support bulk writing. Exclude.
|
||||
):
|
||||
asyncio.create_task(
|
||||
litellm.cache.async_add_cache_pipeline(result, **new_kwargs)
|
||||
)
|
||||
elif isinstance(litellm.cache.cache, S3Cache):
|
||||
threading.Thread(
|
||||
target=litellm.cache.add_cache,
|
||||
args=(result,),
|
||||
kwargs=new_kwargs,
|
||||
).start()
|
||||
else:
|
||||
asyncio.create_task(
|
||||
litellm.cache.async_add_cache(
|
||||
result.model_dump_json(), **new_kwargs
|
||||
)
|
||||
)
|
||||
else:
|
||||
asyncio.create_task(litellm.cache.async_add_cache(result, **new_kwargs))
|
||||
|
||||
def sync_set_cache(
|
||||
self,
|
||||
result: Any,
|
||||
kwargs: Dict[str, Any],
|
||||
args: Optional[Tuple[Any, ...]] = None,
|
||||
):
|
||||
"""
|
||||
Sync internal method to add the result to the cache
|
||||
"""
|
||||
|
||||
new_kwargs = kwargs.copy()
|
||||
new_kwargs.update(
|
||||
convert_args_to_kwargs(
|
||||
self.original_function,
|
||||
args,
|
||||
)
|
||||
)
|
||||
if litellm.cache is None:
|
||||
return
|
||||
|
||||
if self._should_store_result_in_cache(
|
||||
original_function=self.original_function, kwargs=new_kwargs
|
||||
):
|
||||
litellm.cache.add_cache(result, **new_kwargs)
|
||||
|
||||
return
|
||||
|
||||
def _should_store_result_in_cache(
|
||||
self, original_function: Callable, kwargs: Dict[str, Any]
|
||||
) -> bool:
|
||||
"""
|
||||
Helper function to determine if the result should be stored in the cache.
|
||||
|
||||
Returns:
|
||||
bool: True if the result should be stored in the cache, False otherwise.
|
||||
"""
|
||||
return (
|
||||
(litellm.cache is not None)
|
||||
and litellm.cache.supported_call_types is not None
|
||||
and (str(original_function.__name__) in litellm.cache.supported_call_types)
|
||||
and (kwargs.get("cache", {}).get("no-store", False) is not True)
|
||||
)
|
||||
|
||||
def _is_call_type_supported_by_cache(
|
||||
self,
|
||||
original_function: Callable,
|
||||
) -> bool:
|
||||
"""
|
||||
Helper function to determine if the call type is supported by the cache.
|
||||
|
||||
call types are acompletion, aembedding, atext_completion, atranscription, arerank
|
||||
|
||||
Defined on `litellm.types.utils.CallTypes`
|
||||
|
||||
Returns:
|
||||
bool: True if the call type is supported by the cache, False otherwise.
|
||||
"""
|
||||
if (
|
||||
litellm.cache is not None
|
||||
and litellm.cache.supported_call_types is not None
|
||||
and str(original_function.__name__) in litellm.cache.supported_call_types
|
||||
):
|
||||
return True
|
||||
return False
|
||||
|
||||
async def _add_streaming_response_to_cache(self, processed_chunk: ModelResponse):
|
||||
"""
|
||||
Internal method to add the streaming response to the cache
|
||||
|
||||
|
||||
- If 'streaming_chunk' has a 'finish_reason' then assemble a litellm.ModelResponse object
|
||||
- Else append the chunk to self.async_streaming_chunks
|
||||
|
||||
"""
|
||||
|
||||
complete_streaming_response: Optional[
|
||||
Union[ModelResponse, TextCompletionResponse]
|
||||
] = _assemble_complete_response_from_streaming_chunks(
|
||||
result=processed_chunk,
|
||||
start_time=self.start_time,
|
||||
end_time=datetime.datetime.now(),
|
||||
request_kwargs=self.request_kwargs,
|
||||
streaming_chunks=self.async_streaming_chunks,
|
||||
is_async=True,
|
||||
)
|
||||
# if a complete_streaming_response is assembled, add it to the cache
|
||||
if complete_streaming_response is not None:
|
||||
await self.async_set_cache(
|
||||
result=complete_streaming_response,
|
||||
original_function=self.original_function,
|
||||
kwargs=self.request_kwargs,
|
||||
)
|
||||
|
||||
def _sync_add_streaming_response_to_cache(self, processed_chunk: ModelResponse):
|
||||
"""
|
||||
Sync internal method to add the streaming response to the cache
|
||||
"""
|
||||
complete_streaming_response: Optional[
|
||||
Union[ModelResponse, TextCompletionResponse]
|
||||
] = _assemble_complete_response_from_streaming_chunks(
|
||||
result=processed_chunk,
|
||||
start_time=self.start_time,
|
||||
end_time=datetime.datetime.now(),
|
||||
request_kwargs=self.request_kwargs,
|
||||
streaming_chunks=self.sync_streaming_chunks,
|
||||
is_async=False,
|
||||
)
|
||||
|
||||
# if a complete_streaming_response is assembled, add it to the cache
|
||||
if complete_streaming_response is not None:
|
||||
self.sync_set_cache(
|
||||
result=complete_streaming_response,
|
||||
kwargs=self.request_kwargs,
|
||||
)
|
||||
|
||||
def _update_litellm_logging_obj_environment(
|
||||
self,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
model: str,
|
||||
kwargs: Dict[str, Any],
|
||||
cached_result: Any,
|
||||
is_async: bool,
|
||||
is_embedding: bool = False,
|
||||
):
|
||||
"""
|
||||
Helper function to update the LiteLLMLoggingObj environment variables.
|
||||
|
||||
Args:
|
||||
logging_obj (LiteLLMLoggingObj): The logging object to update.
|
||||
model (str): The model being used.
|
||||
kwargs (Dict[str, Any]): The keyword arguments from the original function call.
|
||||
cached_result (Any): The cached result to log.
|
||||
is_async (bool): Whether the call is asynchronous or not.
|
||||
is_embedding (bool): Whether the call is for embeddings or not.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
litellm_params = {
|
||||
"logger_fn": kwargs.get("logger_fn", None),
|
||||
"acompletion": is_async,
|
||||
"api_base": kwargs.get("api_base", ""),
|
||||
"metadata": kwargs.get("metadata", {}),
|
||||
"model_info": kwargs.get("model_info", {}),
|
||||
"proxy_server_request": kwargs.get("proxy_server_request", None),
|
||||
"stream_response": kwargs.get("stream_response", {}),
|
||||
}
|
||||
|
||||
if litellm.cache is not None:
|
||||
litellm_params[
|
||||
"preset_cache_key"
|
||||
] = litellm.cache._get_preset_cache_key_from_kwargs(**kwargs)
|
||||
else:
|
||||
litellm_params["preset_cache_key"] = None
|
||||
|
||||
logging_obj.update_environment_variables(
|
||||
model=model,
|
||||
user=kwargs.get("user", None),
|
||||
optional_params={},
|
||||
litellm_params=litellm_params,
|
||||
input=(
|
||||
kwargs.get("messages", "")
|
||||
if not is_embedding
|
||||
else kwargs.get("input", "")
|
||||
),
|
||||
api_key=kwargs.get("api_key", None),
|
||||
original_response=str(cached_result),
|
||||
additional_args=None,
|
||||
stream=kwargs.get("stream", False),
|
||||
)
|
||||
|
||||
|
||||
def convert_args_to_kwargs(
|
||||
original_function: Callable,
|
||||
args: Optional[Tuple[Any, ...]] = None,
|
||||
) -> Dict[str, Any]:
|
||||
# Get the signature of the original function
|
||||
signature = inspect.signature(original_function)
|
||||
|
||||
# Get parameter names in the order they appear in the original function
|
||||
param_names = list(signature.parameters.keys())
|
||||
|
||||
# Create a mapping of positional arguments to parameter names
|
||||
args_to_kwargs = {}
|
||||
if args:
|
||||
for index, arg in enumerate(args):
|
||||
if index < len(param_names):
|
||||
param_name = param_names[index]
|
||||
args_to_kwargs[param_name] = arg
|
||||
|
||||
return args_to_kwargs
|
||||
@@ -0,0 +1,88 @@
|
||||
import json
|
||||
from typing import TYPE_CHECKING, Any, Optional, Union
|
||||
|
||||
from .base_cache import BaseCache
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from opentelemetry.trace import Span as _Span
|
||||
|
||||
Span = Union[_Span, Any]
|
||||
else:
|
||||
Span = Any
|
||||
|
||||
|
||||
class DiskCache(BaseCache):
|
||||
def __init__(self, disk_cache_dir: Optional[str] = None):
|
||||
import diskcache as dc
|
||||
|
||||
# if users don't provider one, use the default litellm cache
|
||||
if disk_cache_dir is None:
|
||||
self.disk_cache = dc.Cache(".litellm_cache")
|
||||
else:
|
||||
self.disk_cache = dc.Cache(disk_cache_dir)
|
||||
|
||||
def set_cache(self, key, value, **kwargs):
|
||||
if "ttl" in kwargs:
|
||||
self.disk_cache.set(key, value, expire=kwargs["ttl"])
|
||||
else:
|
||||
self.disk_cache.set(key, value)
|
||||
|
||||
async def async_set_cache(self, key, value, **kwargs):
|
||||
self.set_cache(key=key, value=value, **kwargs)
|
||||
|
||||
async def async_set_cache_pipeline(self, cache_list, **kwargs):
|
||||
for cache_key, cache_value in cache_list:
|
||||
if "ttl" in kwargs:
|
||||
self.set_cache(key=cache_key, value=cache_value, ttl=kwargs["ttl"])
|
||||
else:
|
||||
self.set_cache(key=cache_key, value=cache_value)
|
||||
|
||||
def get_cache(self, key, **kwargs):
|
||||
original_cached_response = self.disk_cache.get(key)
|
||||
if original_cached_response:
|
||||
try:
|
||||
cached_response = json.loads(original_cached_response) # type: ignore
|
||||
except Exception:
|
||||
cached_response = original_cached_response
|
||||
return cached_response
|
||||
return None
|
||||
|
||||
def batch_get_cache(self, keys: list, **kwargs):
|
||||
return_val = []
|
||||
for k in keys:
|
||||
val = self.get_cache(key=k, **kwargs)
|
||||
return_val.append(val)
|
||||
return return_val
|
||||
|
||||
def increment_cache(self, key, value: int, **kwargs) -> int:
|
||||
# get the value
|
||||
init_value = self.get_cache(key=key) or 0
|
||||
value = init_value + value # type: ignore
|
||||
self.set_cache(key, value, **kwargs)
|
||||
return value
|
||||
|
||||
async def async_get_cache(self, key, **kwargs):
|
||||
return self.get_cache(key=key, **kwargs)
|
||||
|
||||
async def async_batch_get_cache(self, keys: list, **kwargs):
|
||||
return_val = []
|
||||
for k in keys:
|
||||
val = self.get_cache(key=k, **kwargs)
|
||||
return_val.append(val)
|
||||
return return_val
|
||||
|
||||
async def async_increment(self, key, value: int, **kwargs) -> int:
|
||||
# get the value
|
||||
init_value = await self.async_get_cache(key=key) or 0
|
||||
value = init_value + value # type: ignore
|
||||
await self.async_set_cache(key, value, **kwargs)
|
||||
return value
|
||||
|
||||
def flush_cache(self):
|
||||
self.disk_cache.clear()
|
||||
|
||||
async def disconnect(self):
|
||||
pass
|
||||
|
||||
def delete_cache(self, key):
|
||||
self.disk_cache.pop(key)
|
||||
434
.venv/lib/python3.10/site-packages/litellm/caching/dual_cache.py
Normal file
434
.venv/lib/python3.10/site-packages/litellm/caching/dual_cache.py
Normal file
@@ -0,0 +1,434 @@
|
||||
"""
|
||||
Dual Cache implementation - Class to update both Redis and an in-memory cache simultaneously.
|
||||
|
||||
Has 4 primary methods:
|
||||
- set_cache
|
||||
- get_cache
|
||||
- async_set_cache
|
||||
- async_get_cache
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
import traceback
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import TYPE_CHECKING, Any, List, Optional, Union
|
||||
|
||||
import litellm
|
||||
from litellm._logging import print_verbose, verbose_logger
|
||||
|
||||
from .base_cache import BaseCache
|
||||
from .in_memory_cache import InMemoryCache
|
||||
from .redis_cache import RedisCache
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from opentelemetry.trace import Span as _Span
|
||||
|
||||
Span = Union[_Span, Any]
|
||||
else:
|
||||
Span = Any
|
||||
|
||||
from collections import OrderedDict
|
||||
|
||||
|
||||
class LimitedSizeOrderedDict(OrderedDict):
|
||||
def __init__(self, *args, max_size=100, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.max_size = max_size
|
||||
|
||||
def __setitem__(self, key, value):
|
||||
# If inserting a new key exceeds max size, remove the oldest item
|
||||
if len(self) >= self.max_size:
|
||||
self.popitem(last=False)
|
||||
super().__setitem__(key, value)
|
||||
|
||||
|
||||
class DualCache(BaseCache):
|
||||
"""
|
||||
DualCache is a cache implementation that updates both Redis and an in-memory cache simultaneously.
|
||||
When data is updated or inserted, it is written to both the in-memory cache + Redis.
|
||||
This ensures that even if Redis hasn't been updated yet, the in-memory cache reflects the most recent data.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_memory_cache: Optional[InMemoryCache] = None,
|
||||
redis_cache: Optional[RedisCache] = None,
|
||||
default_in_memory_ttl: Optional[float] = None,
|
||||
default_redis_ttl: Optional[float] = None,
|
||||
default_redis_batch_cache_expiry: Optional[float] = None,
|
||||
default_max_redis_batch_cache_size: int = 100,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
# If in_memory_cache is not provided, use the default InMemoryCache
|
||||
self.in_memory_cache = in_memory_cache or InMemoryCache()
|
||||
# If redis_cache is not provided, use the default RedisCache
|
||||
self.redis_cache = redis_cache
|
||||
self.last_redis_batch_access_time = LimitedSizeOrderedDict(
|
||||
max_size=default_max_redis_batch_cache_size
|
||||
)
|
||||
self.redis_batch_cache_expiry = (
|
||||
default_redis_batch_cache_expiry
|
||||
or litellm.default_redis_batch_cache_expiry
|
||||
or 10
|
||||
)
|
||||
self.default_in_memory_ttl = (
|
||||
default_in_memory_ttl or litellm.default_in_memory_ttl
|
||||
)
|
||||
self.default_redis_ttl = default_redis_ttl or litellm.default_redis_ttl
|
||||
|
||||
def update_cache_ttl(
|
||||
self, default_in_memory_ttl: Optional[float], default_redis_ttl: Optional[float]
|
||||
):
|
||||
if default_in_memory_ttl is not None:
|
||||
self.default_in_memory_ttl = default_in_memory_ttl
|
||||
|
||||
if default_redis_ttl is not None:
|
||||
self.default_redis_ttl = default_redis_ttl
|
||||
|
||||
def set_cache(self, key, value, local_only: bool = False, **kwargs):
|
||||
# Update both Redis and in-memory cache
|
||||
try:
|
||||
if self.in_memory_cache is not None:
|
||||
if "ttl" not in kwargs and self.default_in_memory_ttl is not None:
|
||||
kwargs["ttl"] = self.default_in_memory_ttl
|
||||
|
||||
self.in_memory_cache.set_cache(key, value, **kwargs)
|
||||
|
||||
if self.redis_cache is not None and local_only is False:
|
||||
self.redis_cache.set_cache(key, value, **kwargs)
|
||||
except Exception as e:
|
||||
print_verbose(e)
|
||||
|
||||
def increment_cache(
|
||||
self, key, value: int, local_only: bool = False, **kwargs
|
||||
) -> int:
|
||||
"""
|
||||
Key - the key in cache
|
||||
|
||||
Value - int - the value you want to increment by
|
||||
|
||||
Returns - int - the incremented value
|
||||
"""
|
||||
try:
|
||||
result: int = value
|
||||
if self.in_memory_cache is not None:
|
||||
result = self.in_memory_cache.increment_cache(key, value, **kwargs)
|
||||
|
||||
if self.redis_cache is not None and local_only is False:
|
||||
result = self.redis_cache.increment_cache(key, value, **kwargs)
|
||||
|
||||
return result
|
||||
except Exception as e:
|
||||
verbose_logger.error(f"LiteLLM Cache: Excepton async add_cache: {str(e)}")
|
||||
raise e
|
||||
|
||||
def get_cache(
|
||||
self,
|
||||
key,
|
||||
parent_otel_span: Optional[Span] = None,
|
||||
local_only: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
# Try to fetch from in-memory cache first
|
||||
try:
|
||||
result = None
|
||||
if self.in_memory_cache is not None:
|
||||
in_memory_result = self.in_memory_cache.get_cache(key, **kwargs)
|
||||
|
||||
if in_memory_result is not None:
|
||||
result = in_memory_result
|
||||
|
||||
if result is None and self.redis_cache is not None and local_only is False:
|
||||
# If not found in in-memory cache, try fetching from Redis
|
||||
redis_result = self.redis_cache.get_cache(
|
||||
key, parent_otel_span=parent_otel_span
|
||||
)
|
||||
|
||||
if redis_result is not None:
|
||||
# Update in-memory cache with the value from Redis
|
||||
self.in_memory_cache.set_cache(key, redis_result, **kwargs)
|
||||
|
||||
result = redis_result
|
||||
|
||||
print_verbose(f"get cache: cache result: {result}")
|
||||
return result
|
||||
except Exception:
|
||||
verbose_logger.error(traceback.format_exc())
|
||||
|
||||
def batch_get_cache(
|
||||
self,
|
||||
keys: list,
|
||||
parent_otel_span: Optional[Span] = None,
|
||||
local_only: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
received_args = locals()
|
||||
received_args.pop("self")
|
||||
|
||||
def run_in_new_loop():
|
||||
"""Run the coroutine in a new event loop within this thread."""
|
||||
new_loop = asyncio.new_event_loop()
|
||||
try:
|
||||
asyncio.set_event_loop(new_loop)
|
||||
return new_loop.run_until_complete(
|
||||
self.async_batch_get_cache(**received_args)
|
||||
)
|
||||
finally:
|
||||
new_loop.close()
|
||||
asyncio.set_event_loop(None)
|
||||
|
||||
try:
|
||||
# First, try to get the current event loop
|
||||
_ = asyncio.get_running_loop()
|
||||
# If we're already in an event loop, run in a separate thread
|
||||
# to avoid nested event loop issues
|
||||
with ThreadPoolExecutor(max_workers=1) as executor:
|
||||
future = executor.submit(run_in_new_loop)
|
||||
return future.result()
|
||||
|
||||
except RuntimeError:
|
||||
# No running event loop, we can safely run in this thread
|
||||
return run_in_new_loop()
|
||||
|
||||
async def async_get_cache(
|
||||
self,
|
||||
key,
|
||||
parent_otel_span: Optional[Span] = None,
|
||||
local_only: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
# Try to fetch from in-memory cache first
|
||||
try:
|
||||
print_verbose(
|
||||
f"async get cache: cache key: {key}; local_only: {local_only}"
|
||||
)
|
||||
result = None
|
||||
if self.in_memory_cache is not None:
|
||||
in_memory_result = await self.in_memory_cache.async_get_cache(
|
||||
key, **kwargs
|
||||
)
|
||||
|
||||
print_verbose(f"in_memory_result: {in_memory_result}")
|
||||
if in_memory_result is not None:
|
||||
result = in_memory_result
|
||||
|
||||
if result is None and self.redis_cache is not None and local_only is False:
|
||||
# If not found in in-memory cache, try fetching from Redis
|
||||
redis_result = await self.redis_cache.async_get_cache(
|
||||
key, parent_otel_span=parent_otel_span
|
||||
)
|
||||
|
||||
if redis_result is not None:
|
||||
# Update in-memory cache with the value from Redis
|
||||
await self.in_memory_cache.async_set_cache(
|
||||
key, redis_result, **kwargs
|
||||
)
|
||||
|
||||
result = redis_result
|
||||
|
||||
print_verbose(f"get cache: cache result: {result}")
|
||||
return result
|
||||
except Exception:
|
||||
verbose_logger.error(traceback.format_exc())
|
||||
|
||||
def get_redis_batch_keys(
|
||||
self,
|
||||
current_time: float,
|
||||
keys: List[str],
|
||||
result: List[Any],
|
||||
) -> List[str]:
|
||||
sublist_keys = []
|
||||
for key, value in zip(keys, result):
|
||||
if value is None:
|
||||
if (
|
||||
key not in self.last_redis_batch_access_time
|
||||
or current_time - self.last_redis_batch_access_time[key]
|
||||
>= self.redis_batch_cache_expiry
|
||||
):
|
||||
sublist_keys.append(key)
|
||||
return sublist_keys
|
||||
|
||||
async def async_batch_get_cache(
|
||||
self,
|
||||
keys: list,
|
||||
parent_otel_span: Optional[Span] = None,
|
||||
local_only: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
try:
|
||||
result = [None for _ in range(len(keys))]
|
||||
if self.in_memory_cache is not None:
|
||||
in_memory_result = await self.in_memory_cache.async_batch_get_cache(
|
||||
keys, **kwargs
|
||||
)
|
||||
|
||||
if in_memory_result is not None:
|
||||
result = in_memory_result
|
||||
|
||||
if None in result and self.redis_cache is not None and local_only is False:
|
||||
"""
|
||||
- for the none values in the result
|
||||
- check the redis cache
|
||||
"""
|
||||
current_time = time.time()
|
||||
sublist_keys = self.get_redis_batch_keys(current_time, keys, result)
|
||||
|
||||
# Only hit Redis if the last access time was more than 5 seconds ago
|
||||
if len(sublist_keys) > 0:
|
||||
# If not found in in-memory cache, try fetching from Redis
|
||||
redis_result = await self.redis_cache.async_batch_get_cache(
|
||||
sublist_keys, parent_otel_span=parent_otel_span
|
||||
)
|
||||
|
||||
if redis_result is not None:
|
||||
# Update in-memory cache with the value from Redis
|
||||
for key, value in redis_result.items():
|
||||
if value is not None:
|
||||
await self.in_memory_cache.async_set_cache(
|
||||
key, redis_result[key], **kwargs
|
||||
)
|
||||
# Update the last access time for each key fetched from Redis
|
||||
self.last_redis_batch_access_time[key] = current_time
|
||||
|
||||
for key, value in redis_result.items():
|
||||
index = keys.index(key)
|
||||
result[index] = value
|
||||
|
||||
return result
|
||||
except Exception:
|
||||
verbose_logger.error(traceback.format_exc())
|
||||
|
||||
async def async_set_cache(self, key, value, local_only: bool = False, **kwargs):
|
||||
print_verbose(
|
||||
f"async set cache: cache key: {key}; local_only: {local_only}; value: {value}"
|
||||
)
|
||||
try:
|
||||
if self.in_memory_cache is not None:
|
||||
await self.in_memory_cache.async_set_cache(key, value, **kwargs)
|
||||
|
||||
if self.redis_cache is not None and local_only is False:
|
||||
await self.redis_cache.async_set_cache(key, value, **kwargs)
|
||||
except Exception as e:
|
||||
verbose_logger.exception(
|
||||
f"LiteLLM Cache: Excepton async add_cache: {str(e)}"
|
||||
)
|
||||
|
||||
# async_batch_set_cache
|
||||
async def async_set_cache_pipeline(
|
||||
self, cache_list: list, local_only: bool = False, **kwargs
|
||||
):
|
||||
"""
|
||||
Batch write values to the cache
|
||||
"""
|
||||
print_verbose(
|
||||
f"async batch set cache: cache keys: {cache_list}; local_only: {local_only}"
|
||||
)
|
||||
try:
|
||||
if self.in_memory_cache is not None:
|
||||
await self.in_memory_cache.async_set_cache_pipeline(
|
||||
cache_list=cache_list, **kwargs
|
||||
)
|
||||
|
||||
if self.redis_cache is not None and local_only is False:
|
||||
await self.redis_cache.async_set_cache_pipeline(
|
||||
cache_list=cache_list, ttl=kwargs.pop("ttl", None), **kwargs
|
||||
)
|
||||
except Exception as e:
|
||||
verbose_logger.exception(
|
||||
f"LiteLLM Cache: Excepton async add_cache: {str(e)}"
|
||||
)
|
||||
|
||||
async def async_increment_cache(
|
||||
self,
|
||||
key,
|
||||
value: float,
|
||||
parent_otel_span: Optional[Span] = None,
|
||||
local_only: bool = False,
|
||||
**kwargs,
|
||||
) -> float:
|
||||
"""
|
||||
Key - the key in cache
|
||||
|
||||
Value - float - the value you want to increment by
|
||||
|
||||
Returns - float - the incremented value
|
||||
"""
|
||||
try:
|
||||
result: float = value
|
||||
if self.in_memory_cache is not None:
|
||||
result = await self.in_memory_cache.async_increment(
|
||||
key, value, **kwargs
|
||||
)
|
||||
|
||||
if self.redis_cache is not None and local_only is False:
|
||||
result = await self.redis_cache.async_increment(
|
||||
key,
|
||||
value,
|
||||
parent_otel_span=parent_otel_span,
|
||||
ttl=kwargs.get("ttl", None),
|
||||
)
|
||||
|
||||
return result
|
||||
except Exception as e:
|
||||
raise e # don't log if exception is raised
|
||||
|
||||
async def async_set_cache_sadd(
|
||||
self, key, value: List, local_only: bool = False, **kwargs
|
||||
) -> None:
|
||||
"""
|
||||
Add value to a set
|
||||
|
||||
Key - the key in cache
|
||||
|
||||
Value - str - the value you want to add to the set
|
||||
|
||||
Returns - None
|
||||
"""
|
||||
try:
|
||||
if self.in_memory_cache is not None:
|
||||
_ = await self.in_memory_cache.async_set_cache_sadd(
|
||||
key, value, ttl=kwargs.get("ttl", None)
|
||||
)
|
||||
|
||||
if self.redis_cache is not None and local_only is False:
|
||||
_ = await self.redis_cache.async_set_cache_sadd(
|
||||
key, value, ttl=kwargs.get("ttl", None)
|
||||
)
|
||||
|
||||
return None
|
||||
except Exception as e:
|
||||
raise e # don't log, if exception is raised
|
||||
|
||||
def flush_cache(self):
|
||||
if self.in_memory_cache is not None:
|
||||
self.in_memory_cache.flush_cache()
|
||||
if self.redis_cache is not None:
|
||||
self.redis_cache.flush_cache()
|
||||
|
||||
def delete_cache(self, key):
|
||||
"""
|
||||
Delete a key from the cache
|
||||
"""
|
||||
if self.in_memory_cache is not None:
|
||||
self.in_memory_cache.delete_cache(key)
|
||||
if self.redis_cache is not None:
|
||||
self.redis_cache.delete_cache(key)
|
||||
|
||||
async def async_delete_cache(self, key: str):
|
||||
"""
|
||||
Delete a key from the cache
|
||||
"""
|
||||
if self.in_memory_cache is not None:
|
||||
self.in_memory_cache.delete_cache(key)
|
||||
if self.redis_cache is not None:
|
||||
await self.redis_cache.async_delete_cache(key)
|
||||
|
||||
async def async_get_ttl(self, key: str) -> Optional[int]:
|
||||
"""
|
||||
Get the remaining TTL of a key in in-memory cache or redis
|
||||
"""
|
||||
ttl = await self.in_memory_cache.async_get_ttl(key)
|
||||
if ttl is None and self.redis_cache is not None:
|
||||
ttl = await self.redis_cache.async_get_ttl(key)
|
||||
return ttl
|
||||
@@ -0,0 +1,204 @@
|
||||
"""
|
||||
In-Memory Cache implementation
|
||||
|
||||
Has 4 methods:
|
||||
- set_cache
|
||||
- get_cache
|
||||
- async_set_cache
|
||||
- async_get_cache
|
||||
"""
|
||||
|
||||
import json
|
||||
import sys
|
||||
import time
|
||||
from typing import Any, List, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from litellm.constants import MAX_SIZE_PER_ITEM_IN_MEMORY_CACHE_IN_KB
|
||||
|
||||
from .base_cache import BaseCache
|
||||
|
||||
|
||||
class InMemoryCache(BaseCache):
|
||||
def __init__(
|
||||
self,
|
||||
max_size_in_memory: Optional[int] = 200,
|
||||
default_ttl: Optional[
|
||||
int
|
||||
] = 600, # default ttl is 10 minutes. At maximum litellm rate limiting logic requires objects to be in memory for 1 minute
|
||||
max_size_per_item: Optional[int] = 1024, # 1MB = 1024KB
|
||||
):
|
||||
"""
|
||||
max_size_in_memory [int]: Maximum number of items in cache. done to prevent memory leaks. Use 200 items as a default
|
||||
"""
|
||||
self.max_size_in_memory = (
|
||||
max_size_in_memory or 200
|
||||
) # set an upper bound of 200 items in-memory
|
||||
self.default_ttl = default_ttl or 600
|
||||
self.max_size_per_item = (
|
||||
max_size_per_item or MAX_SIZE_PER_ITEM_IN_MEMORY_CACHE_IN_KB
|
||||
) # 1MB = 1024KB
|
||||
|
||||
# in-memory cache
|
||||
self.cache_dict: dict = {}
|
||||
self.ttl_dict: dict = {}
|
||||
|
||||
def check_value_size(self, value: Any):
|
||||
"""
|
||||
Check if value size exceeds max_size_per_item (1MB)
|
||||
Returns True if value size is acceptable, False otherwise
|
||||
"""
|
||||
try:
|
||||
# Fast path for common primitive types that are typically small
|
||||
if (
|
||||
isinstance(value, (bool, int, float, str))
|
||||
and len(str(value))
|
||||
< self.max_size_per_item * MAX_SIZE_PER_ITEM_IN_MEMORY_CACHE_IN_KB
|
||||
): # Conservative estimate
|
||||
return True
|
||||
|
||||
# Direct size check for bytes objects
|
||||
if isinstance(value, bytes):
|
||||
return sys.getsizeof(value) / 1024 <= self.max_size_per_item
|
||||
|
||||
# Handle special types without full conversion when possible
|
||||
if hasattr(value, "__sizeof__"): # Use __sizeof__ if available
|
||||
size = value.__sizeof__() / 1024
|
||||
return size <= self.max_size_per_item
|
||||
|
||||
# Fallback for complex types
|
||||
if isinstance(value, BaseModel) and hasattr(
|
||||
value, "model_dump"
|
||||
): # Pydantic v2
|
||||
value = value.model_dump()
|
||||
elif hasattr(value, "isoformat"): # datetime objects
|
||||
return True # datetime strings are always small
|
||||
|
||||
# Only convert to JSON if absolutely necessary
|
||||
if not isinstance(value, (str, bytes)):
|
||||
value = json.dumps(value, default=str)
|
||||
|
||||
return sys.getsizeof(value) / 1024 <= self.max_size_per_item
|
||||
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
def evict_cache(self):
|
||||
"""
|
||||
Eviction policy:
|
||||
- check if any items in ttl_dict are expired -> remove them from ttl_dict and cache_dict
|
||||
|
||||
|
||||
This guarantees the following:
|
||||
- 1. When item ttl not set: At minimumm each item will remain in memory for 5 minutes
|
||||
- 2. When ttl is set: the item will remain in memory for at least that amount of time
|
||||
- 3. the size of in-memory cache is bounded
|
||||
|
||||
"""
|
||||
for key in list(self.ttl_dict.keys()):
|
||||
if time.time() > self.ttl_dict[key]:
|
||||
self.cache_dict.pop(key, None)
|
||||
self.ttl_dict.pop(key, None)
|
||||
|
||||
# de-reference the removed item
|
||||
# https://www.geeksforgeeks.org/diagnosing-and-fixing-memory-leaks-in-python/
|
||||
# One of the most common causes of memory leaks in Python is the retention of objects that are no longer being used.
|
||||
# This can occur when an object is referenced by another object, but the reference is never removed.
|
||||
|
||||
def set_cache(self, key, value, **kwargs):
|
||||
if len(self.cache_dict) >= self.max_size_in_memory:
|
||||
# only evict when cache is full
|
||||
self.evict_cache()
|
||||
if not self.check_value_size(value):
|
||||
return
|
||||
|
||||
self.cache_dict[key] = value
|
||||
if "ttl" in kwargs and kwargs["ttl"] is not None:
|
||||
self.ttl_dict[key] = time.time() + kwargs["ttl"]
|
||||
else:
|
||||
self.ttl_dict[key] = time.time() + self.default_ttl
|
||||
|
||||
async def async_set_cache(self, key, value, **kwargs):
|
||||
self.set_cache(key=key, value=value, **kwargs)
|
||||
|
||||
async def async_set_cache_pipeline(self, cache_list, ttl=None, **kwargs):
|
||||
for cache_key, cache_value in cache_list:
|
||||
if ttl is not None:
|
||||
self.set_cache(key=cache_key, value=cache_value, ttl=ttl)
|
||||
else:
|
||||
self.set_cache(key=cache_key, value=cache_value)
|
||||
|
||||
async def async_set_cache_sadd(self, key, value: List, ttl: Optional[float]):
|
||||
"""
|
||||
Add value to set
|
||||
"""
|
||||
# get the value
|
||||
init_value = self.get_cache(key=key) or set()
|
||||
for val in value:
|
||||
init_value.add(val)
|
||||
self.set_cache(key, init_value, ttl=ttl)
|
||||
return value
|
||||
|
||||
def get_cache(self, key, **kwargs):
|
||||
if key in self.cache_dict:
|
||||
if key in self.ttl_dict:
|
||||
if time.time() > self.ttl_dict[key]:
|
||||
self.cache_dict.pop(key, None)
|
||||
return None
|
||||
original_cached_response = self.cache_dict[key]
|
||||
try:
|
||||
cached_response = json.loads(original_cached_response)
|
||||
except Exception:
|
||||
cached_response = original_cached_response
|
||||
return cached_response
|
||||
return None
|
||||
|
||||
def batch_get_cache(self, keys: list, **kwargs):
|
||||
return_val = []
|
||||
for k in keys:
|
||||
val = self.get_cache(key=k, **kwargs)
|
||||
return_val.append(val)
|
||||
return return_val
|
||||
|
||||
def increment_cache(self, key, value: int, **kwargs) -> int:
|
||||
# get the value
|
||||
init_value = self.get_cache(key=key) or 0
|
||||
value = init_value + value
|
||||
self.set_cache(key, value, **kwargs)
|
||||
return value
|
||||
|
||||
async def async_get_cache(self, key, **kwargs):
|
||||
return self.get_cache(key=key, **kwargs)
|
||||
|
||||
async def async_batch_get_cache(self, keys: list, **kwargs):
|
||||
return_val = []
|
||||
for k in keys:
|
||||
val = self.get_cache(key=k, **kwargs)
|
||||
return_val.append(val)
|
||||
return return_val
|
||||
|
||||
async def async_increment(self, key, value: float, **kwargs) -> float:
|
||||
# get the value
|
||||
init_value = await self.async_get_cache(key=key) or 0
|
||||
value = init_value + value
|
||||
await self.async_set_cache(key, value, **kwargs)
|
||||
|
||||
return value
|
||||
|
||||
def flush_cache(self):
|
||||
self.cache_dict.clear()
|
||||
self.ttl_dict.clear()
|
||||
|
||||
async def disconnect(self):
|
||||
pass
|
||||
|
||||
def delete_cache(self, key):
|
||||
self.cache_dict.pop(key, None)
|
||||
self.ttl_dict.pop(key, None)
|
||||
|
||||
async def async_get_ttl(self, key: str) -> Optional[int]:
|
||||
"""
|
||||
Get the remaining TTL of a key in in-memory cache
|
||||
"""
|
||||
return self.ttl_dict.get(key, None)
|
||||
@@ -0,0 +1,39 @@
|
||||
"""
|
||||
Add the event loop to the cache key, to prevent event loop closed errors.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
|
||||
from .in_memory_cache import InMemoryCache
|
||||
|
||||
|
||||
class LLMClientCache(InMemoryCache):
|
||||
def update_cache_key_with_event_loop(self, key):
|
||||
"""
|
||||
Add the event loop to the cache key, to prevent event loop closed errors.
|
||||
If none, use the key as is.
|
||||
"""
|
||||
try:
|
||||
event_loop = asyncio.get_event_loop()
|
||||
stringified_event_loop = str(id(event_loop))
|
||||
return f"{key}-{stringified_event_loop}"
|
||||
except Exception: # handle no current event loop
|
||||
return key
|
||||
|
||||
def set_cache(self, key, value, **kwargs):
|
||||
key = self.update_cache_key_with_event_loop(key)
|
||||
return super().set_cache(key, value, **kwargs)
|
||||
|
||||
async def async_set_cache(self, key, value, **kwargs):
|
||||
key = self.update_cache_key_with_event_loop(key)
|
||||
return await super().async_set_cache(key, value, **kwargs)
|
||||
|
||||
def get_cache(self, key, **kwargs):
|
||||
key = self.update_cache_key_with_event_loop(key)
|
||||
|
||||
return super().get_cache(key, **kwargs)
|
||||
|
||||
async def async_get_cache(self, key, **kwargs):
|
||||
key = self.update_cache_key_with_event_loop(key)
|
||||
|
||||
return await super().async_get_cache(key, **kwargs)
|
||||
@@ -0,0 +1,442 @@
|
||||
"""
|
||||
Qdrant Semantic Cache implementation
|
||||
|
||||
Has 4 methods:
|
||||
- set_cache
|
||||
- get_cache
|
||||
- async_set_cache
|
||||
- async_get_cache
|
||||
"""
|
||||
|
||||
import ast
|
||||
import asyncio
|
||||
import json
|
||||
from typing import Any, cast
|
||||
|
||||
import litellm
|
||||
from litellm._logging import print_verbose
|
||||
from litellm.constants import QDRANT_SCALAR_QUANTILE, QDRANT_VECTOR_SIZE
|
||||
from litellm.types.utils import EmbeddingResponse
|
||||
|
||||
from .base_cache import BaseCache
|
||||
|
||||
|
||||
class QdrantSemanticCache(BaseCache):
|
||||
def __init__( # noqa: PLR0915
|
||||
self,
|
||||
qdrant_api_base=None,
|
||||
qdrant_api_key=None,
|
||||
collection_name=None,
|
||||
similarity_threshold=None,
|
||||
quantization_config=None,
|
||||
embedding_model="text-embedding-ada-002",
|
||||
host_type=None,
|
||||
):
|
||||
import os
|
||||
|
||||
from litellm.llms.custom_httpx.http_handler import (
|
||||
_get_httpx_client,
|
||||
get_async_httpx_client,
|
||||
httpxSpecialProvider,
|
||||
)
|
||||
from litellm.secret_managers.main import get_secret_str
|
||||
|
||||
if collection_name is None:
|
||||
raise Exception("collection_name must be provided, passed None")
|
||||
|
||||
self.collection_name = collection_name
|
||||
print_verbose(
|
||||
f"qdrant semantic-cache initializing COLLECTION - {self.collection_name}"
|
||||
)
|
||||
|
||||
if similarity_threshold is None:
|
||||
raise Exception("similarity_threshold must be provided, passed None")
|
||||
self.similarity_threshold = similarity_threshold
|
||||
self.embedding_model = embedding_model
|
||||
headers = {}
|
||||
|
||||
# check if defined as os.environ/ variable
|
||||
if qdrant_api_base:
|
||||
if isinstance(qdrant_api_base, str) and qdrant_api_base.startswith(
|
||||
"os.environ/"
|
||||
):
|
||||
qdrant_api_base = get_secret_str(qdrant_api_base)
|
||||
if qdrant_api_key:
|
||||
if isinstance(qdrant_api_key, str) and qdrant_api_key.startswith(
|
||||
"os.environ/"
|
||||
):
|
||||
qdrant_api_key = get_secret_str(qdrant_api_key)
|
||||
|
||||
qdrant_api_base = (
|
||||
qdrant_api_base or os.getenv("QDRANT_URL") or os.getenv("QDRANT_API_BASE")
|
||||
)
|
||||
qdrant_api_key = qdrant_api_key or os.getenv("QDRANT_API_KEY")
|
||||
headers = {"Content-Type": "application/json"}
|
||||
if qdrant_api_key:
|
||||
headers["api-key"] = qdrant_api_key
|
||||
|
||||
if qdrant_api_base is None:
|
||||
raise ValueError("Qdrant url must be provided")
|
||||
|
||||
self.qdrant_api_base = qdrant_api_base
|
||||
self.qdrant_api_key = qdrant_api_key
|
||||
print_verbose(f"qdrant semantic-cache qdrant_api_base: {self.qdrant_api_base}")
|
||||
|
||||
self.headers = headers
|
||||
|
||||
self.sync_client = _get_httpx_client()
|
||||
self.async_client = get_async_httpx_client(
|
||||
llm_provider=httpxSpecialProvider.Caching
|
||||
)
|
||||
|
||||
if quantization_config is None:
|
||||
print_verbose(
|
||||
"Quantization config is not provided. Default binary quantization will be used."
|
||||
)
|
||||
collection_exists = self.sync_client.get(
|
||||
url=f"{self.qdrant_api_base}/collections/{self.collection_name}/exists",
|
||||
headers=self.headers,
|
||||
)
|
||||
if collection_exists.status_code != 200:
|
||||
raise ValueError(
|
||||
f"Error from qdrant checking if /collections exist {collection_exists.text}"
|
||||
)
|
||||
|
||||
if collection_exists.json()["result"]["exists"]:
|
||||
collection_details = self.sync_client.get(
|
||||
url=f"{self.qdrant_api_base}/collections/{self.collection_name}",
|
||||
headers=self.headers,
|
||||
)
|
||||
self.collection_info = collection_details.json()
|
||||
print_verbose(
|
||||
f"Collection already exists.\nCollection details:{self.collection_info}"
|
||||
)
|
||||
else:
|
||||
if quantization_config is None or quantization_config == "binary":
|
||||
quantization_params = {
|
||||
"binary": {
|
||||
"always_ram": False,
|
||||
}
|
||||
}
|
||||
elif quantization_config == "scalar":
|
||||
quantization_params = {
|
||||
"scalar": {
|
||||
"type": "int8",
|
||||
"quantile": QDRANT_SCALAR_QUANTILE,
|
||||
"always_ram": False,
|
||||
}
|
||||
}
|
||||
elif quantization_config == "product":
|
||||
quantization_params = {
|
||||
"product": {"compression": "x16", "always_ram": False}
|
||||
}
|
||||
else:
|
||||
raise Exception(
|
||||
"Quantization config must be one of 'scalar', 'binary' or 'product'"
|
||||
)
|
||||
|
||||
new_collection_status = self.sync_client.put(
|
||||
url=f"{self.qdrant_api_base}/collections/{self.collection_name}",
|
||||
json={
|
||||
"vectors": {"size": QDRANT_VECTOR_SIZE, "distance": "Cosine"},
|
||||
"quantization_config": quantization_params,
|
||||
},
|
||||
headers=self.headers,
|
||||
)
|
||||
if new_collection_status.json()["result"]:
|
||||
collection_details = self.sync_client.get(
|
||||
url=f"{self.qdrant_api_base}/collections/{self.collection_name}",
|
||||
headers=self.headers,
|
||||
)
|
||||
self.collection_info = collection_details.json()
|
||||
print_verbose(
|
||||
f"New collection created.\nCollection details:{self.collection_info}"
|
||||
)
|
||||
else:
|
||||
raise Exception("Error while creating new collection")
|
||||
|
||||
def _get_cache_logic(self, cached_response: Any):
|
||||
if cached_response is None:
|
||||
return cached_response
|
||||
try:
|
||||
cached_response = json.loads(
|
||||
cached_response
|
||||
) # Convert string to dictionary
|
||||
except Exception:
|
||||
cached_response = ast.literal_eval(cached_response)
|
||||
return cached_response
|
||||
|
||||
def set_cache(self, key, value, **kwargs):
|
||||
print_verbose(f"qdrant semantic-cache set_cache, kwargs: {kwargs}")
|
||||
import uuid
|
||||
|
||||
# get the prompt
|
||||
messages = kwargs["messages"]
|
||||
prompt = ""
|
||||
for message in messages:
|
||||
prompt += message["content"]
|
||||
|
||||
# create an embedding for prompt
|
||||
embedding_response = cast(
|
||||
EmbeddingResponse,
|
||||
litellm.embedding(
|
||||
model=self.embedding_model,
|
||||
input=prompt,
|
||||
cache={"no-store": True, "no-cache": True},
|
||||
),
|
||||
)
|
||||
|
||||
# get the embedding
|
||||
embedding = embedding_response["data"][0]["embedding"]
|
||||
|
||||
value = str(value)
|
||||
assert isinstance(value, str)
|
||||
|
||||
data = {
|
||||
"points": [
|
||||
{
|
||||
"id": str(uuid.uuid4()),
|
||||
"vector": embedding,
|
||||
"payload": {
|
||||
"text": prompt,
|
||||
"response": value,
|
||||
},
|
||||
},
|
||||
]
|
||||
}
|
||||
self.sync_client.put(
|
||||
url=f"{self.qdrant_api_base}/collections/{self.collection_name}/points",
|
||||
headers=self.headers,
|
||||
json=data,
|
||||
)
|
||||
return
|
||||
|
||||
def get_cache(self, key, **kwargs):
|
||||
print_verbose(f"sync qdrant semantic-cache get_cache, kwargs: {kwargs}")
|
||||
|
||||
# get the messages
|
||||
messages = kwargs["messages"]
|
||||
prompt = ""
|
||||
for message in messages:
|
||||
prompt += message["content"]
|
||||
|
||||
# convert to embedding
|
||||
embedding_response = cast(
|
||||
EmbeddingResponse,
|
||||
litellm.embedding(
|
||||
model=self.embedding_model,
|
||||
input=prompt,
|
||||
cache={"no-store": True, "no-cache": True},
|
||||
),
|
||||
)
|
||||
|
||||
# get the embedding
|
||||
embedding = embedding_response["data"][0]["embedding"]
|
||||
|
||||
data = {
|
||||
"vector": embedding,
|
||||
"params": {
|
||||
"quantization": {
|
||||
"ignore": False,
|
||||
"rescore": True,
|
||||
"oversampling": 3.0,
|
||||
}
|
||||
},
|
||||
"limit": 1,
|
||||
"with_payload": True,
|
||||
}
|
||||
|
||||
search_response = self.sync_client.post(
|
||||
url=f"{self.qdrant_api_base}/collections/{self.collection_name}/points/search",
|
||||
headers=self.headers,
|
||||
json=data,
|
||||
)
|
||||
results = search_response.json()["result"]
|
||||
|
||||
if results is None:
|
||||
return None
|
||||
if isinstance(results, list):
|
||||
if len(results) == 0:
|
||||
return None
|
||||
|
||||
similarity = results[0]["score"]
|
||||
cached_prompt = results[0]["payload"]["text"]
|
||||
|
||||
# check similarity, if more than self.similarity_threshold, return results
|
||||
print_verbose(
|
||||
f"semantic cache: similarity threshold: {self.similarity_threshold}, similarity: {similarity}, prompt: {prompt}, closest_cached_prompt: {cached_prompt}"
|
||||
)
|
||||
if similarity >= self.similarity_threshold:
|
||||
# cache hit !
|
||||
cached_value = results[0]["payload"]["response"]
|
||||
print_verbose(
|
||||
f"got a cache hit, similarity: {similarity}, Current prompt: {prompt}, cached_prompt: {cached_prompt}"
|
||||
)
|
||||
return self._get_cache_logic(cached_response=cached_value)
|
||||
else:
|
||||
# cache miss !
|
||||
return None
|
||||
pass
|
||||
|
||||
async def async_set_cache(self, key, value, **kwargs):
|
||||
import uuid
|
||||
|
||||
from litellm.proxy.proxy_server import llm_model_list, llm_router
|
||||
|
||||
print_verbose(f"async qdrant semantic-cache set_cache, kwargs: {kwargs}")
|
||||
|
||||
# get the prompt
|
||||
messages = kwargs["messages"]
|
||||
prompt = ""
|
||||
for message in messages:
|
||||
prompt += message["content"]
|
||||
# create an embedding for prompt
|
||||
router_model_names = (
|
||||
[m["model_name"] for m in llm_model_list]
|
||||
if llm_model_list is not None
|
||||
else []
|
||||
)
|
||||
if llm_router is not None and self.embedding_model in router_model_names:
|
||||
user_api_key = kwargs.get("metadata", {}).get("user_api_key", "")
|
||||
embedding_response = await llm_router.aembedding(
|
||||
model=self.embedding_model,
|
||||
input=prompt,
|
||||
cache={"no-store": True, "no-cache": True},
|
||||
metadata={
|
||||
"user_api_key": user_api_key,
|
||||
"semantic-cache-embedding": True,
|
||||
"trace_id": kwargs.get("metadata", {}).get("trace_id", None),
|
||||
},
|
||||
)
|
||||
else:
|
||||
# convert to embedding
|
||||
embedding_response = await litellm.aembedding(
|
||||
model=self.embedding_model,
|
||||
input=prompt,
|
||||
cache={"no-store": True, "no-cache": True},
|
||||
)
|
||||
|
||||
# get the embedding
|
||||
embedding = embedding_response["data"][0]["embedding"]
|
||||
|
||||
value = str(value)
|
||||
assert isinstance(value, str)
|
||||
|
||||
data = {
|
||||
"points": [
|
||||
{
|
||||
"id": str(uuid.uuid4()),
|
||||
"vector": embedding,
|
||||
"payload": {
|
||||
"text": prompt,
|
||||
"response": value,
|
||||
},
|
||||
},
|
||||
]
|
||||
}
|
||||
|
||||
await self.async_client.put(
|
||||
url=f"{self.qdrant_api_base}/collections/{self.collection_name}/points",
|
||||
headers=self.headers,
|
||||
json=data,
|
||||
)
|
||||
return
|
||||
|
||||
async def async_get_cache(self, key, **kwargs):
|
||||
print_verbose(f"async qdrant semantic-cache get_cache, kwargs: {kwargs}")
|
||||
from litellm.proxy.proxy_server import llm_model_list, llm_router
|
||||
|
||||
# get the messages
|
||||
messages = kwargs["messages"]
|
||||
prompt = ""
|
||||
for message in messages:
|
||||
prompt += message["content"]
|
||||
|
||||
router_model_names = (
|
||||
[m["model_name"] for m in llm_model_list]
|
||||
if llm_model_list is not None
|
||||
else []
|
||||
)
|
||||
if llm_router is not None and self.embedding_model in router_model_names:
|
||||
user_api_key = kwargs.get("metadata", {}).get("user_api_key", "")
|
||||
embedding_response = await llm_router.aembedding(
|
||||
model=self.embedding_model,
|
||||
input=prompt,
|
||||
cache={"no-store": True, "no-cache": True},
|
||||
metadata={
|
||||
"user_api_key": user_api_key,
|
||||
"semantic-cache-embedding": True,
|
||||
"trace_id": kwargs.get("metadata", {}).get("trace_id", None),
|
||||
},
|
||||
)
|
||||
else:
|
||||
# convert to embedding
|
||||
embedding_response = await litellm.aembedding(
|
||||
model=self.embedding_model,
|
||||
input=prompt,
|
||||
cache={"no-store": True, "no-cache": True},
|
||||
)
|
||||
|
||||
# get the embedding
|
||||
embedding = embedding_response["data"][0]["embedding"]
|
||||
|
||||
data = {
|
||||
"vector": embedding,
|
||||
"params": {
|
||||
"quantization": {
|
||||
"ignore": False,
|
||||
"rescore": True,
|
||||
"oversampling": 3.0,
|
||||
}
|
||||
},
|
||||
"limit": 1,
|
||||
"with_payload": True,
|
||||
}
|
||||
|
||||
search_response = await self.async_client.post(
|
||||
url=f"{self.qdrant_api_base}/collections/{self.collection_name}/points/search",
|
||||
headers=self.headers,
|
||||
json=data,
|
||||
)
|
||||
|
||||
results = search_response.json()["result"]
|
||||
|
||||
if results is None:
|
||||
kwargs.setdefault("metadata", {})["semantic-similarity"] = 0.0
|
||||
return None
|
||||
if isinstance(results, list):
|
||||
if len(results) == 0:
|
||||
kwargs.setdefault("metadata", {})["semantic-similarity"] = 0.0
|
||||
return None
|
||||
|
||||
similarity = results[0]["score"]
|
||||
cached_prompt = results[0]["payload"]["text"]
|
||||
|
||||
# check similarity, if more than self.similarity_threshold, return results
|
||||
print_verbose(
|
||||
f"semantic cache: similarity threshold: {self.similarity_threshold}, similarity: {similarity}, prompt: {prompt}, closest_cached_prompt: {cached_prompt}"
|
||||
)
|
||||
|
||||
# update kwargs["metadata"] with similarity, don't rewrite the original metadata
|
||||
kwargs.setdefault("metadata", {})["semantic-similarity"] = similarity
|
||||
|
||||
if similarity >= self.similarity_threshold:
|
||||
# cache hit !
|
||||
cached_value = results[0]["payload"]["response"]
|
||||
print_verbose(
|
||||
f"got a cache hit, similarity: {similarity}, Current prompt: {prompt}, cached_prompt: {cached_prompt}"
|
||||
)
|
||||
return self._get_cache_logic(cached_response=cached_value)
|
||||
else:
|
||||
# cache miss !
|
||||
return None
|
||||
pass
|
||||
|
||||
async def _collection_info(self):
|
||||
return self.collection_info
|
||||
|
||||
async def async_set_cache_pipeline(self, cache_list, **kwargs):
|
||||
tasks = []
|
||||
for val in cache_list:
|
||||
tasks.append(self.async_set_cache(val[0], val[1], **kwargs))
|
||||
await asyncio.gather(*tasks)
|
||||
1159
.venv/lib/python3.10/site-packages/litellm/caching/redis_cache.py
Normal file
1159
.venv/lib/python3.10/site-packages/litellm/caching/redis_cache.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,59 @@
|
||||
"""
|
||||
Redis Cluster Cache implementation
|
||||
|
||||
Key differences:
|
||||
- RedisClient NEEDs to be re-used across requests, adds 3000ms latency if it's re-created
|
||||
"""
|
||||
|
||||
from typing import TYPE_CHECKING, Any, List, Optional, Union
|
||||
|
||||
from litellm.caching.redis_cache import RedisCache
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from opentelemetry.trace import Span as _Span
|
||||
from redis.asyncio import Redis, RedisCluster
|
||||
from redis.asyncio.client import Pipeline
|
||||
|
||||
pipeline = Pipeline
|
||||
async_redis_client = Redis
|
||||
Span = Union[_Span, Any]
|
||||
else:
|
||||
pipeline = Any
|
||||
async_redis_client = Any
|
||||
Span = Any
|
||||
|
||||
|
||||
class RedisClusterCache(RedisCache):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.redis_async_redis_cluster_client: Optional[RedisCluster] = None
|
||||
self.redis_sync_redis_cluster_client: Optional[RedisCluster] = None
|
||||
|
||||
def init_async_client(self):
|
||||
from redis.asyncio import RedisCluster
|
||||
|
||||
from .._redis import get_redis_async_client
|
||||
|
||||
if self.redis_async_redis_cluster_client:
|
||||
return self.redis_async_redis_cluster_client
|
||||
|
||||
_redis_client = get_redis_async_client(
|
||||
connection_pool=self.async_redis_conn_pool, **self.redis_kwargs
|
||||
)
|
||||
if isinstance(_redis_client, RedisCluster):
|
||||
self.redis_async_redis_cluster_client = _redis_client
|
||||
|
||||
return _redis_client
|
||||
|
||||
def _run_redis_mget_operation(self, keys: List[str]) -> List[Any]:
|
||||
"""
|
||||
Overrides `_run_redis_mget_operation` in redis_cache.py
|
||||
"""
|
||||
return self.redis_client.mget_nonatomic(keys=keys) # type: ignore
|
||||
|
||||
async def _async_run_redis_mget_operation(self, keys: List[str]) -> List[Any]:
|
||||
"""
|
||||
Overrides `_async_run_redis_mget_operation` in redis_cache.py
|
||||
"""
|
||||
async_redis_cluster_client = self.init_async_client()
|
||||
return await async_redis_cluster_client.mget_nonatomic(keys=keys) # type: ignore
|
||||
@@ -0,0 +1,450 @@
|
||||
"""
|
||||
Redis Semantic Cache implementation for LiteLLM
|
||||
|
||||
The RedisSemanticCache provides semantic caching functionality using Redis as a backend.
|
||||
This cache stores responses based on the semantic similarity of prompts rather than
|
||||
exact matching, allowing for more flexible caching of LLM responses.
|
||||
|
||||
This implementation uses RedisVL's SemanticCache to find semantically similar prompts
|
||||
and their cached responses.
|
||||
"""
|
||||
|
||||
import ast
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
from typing import Any, Dict, List, Optional, Tuple, cast
|
||||
|
||||
import litellm
|
||||
from litellm._logging import print_verbose
|
||||
from litellm.litellm_core_utils.prompt_templates.common_utils import (
|
||||
get_str_from_messages,
|
||||
)
|
||||
from litellm.types.utils import EmbeddingResponse
|
||||
|
||||
from .base_cache import BaseCache
|
||||
|
||||
|
||||
class RedisSemanticCache(BaseCache):
|
||||
"""
|
||||
Redis-backed semantic cache for LLM responses.
|
||||
|
||||
This cache uses vector similarity to find semantically similar prompts that have been
|
||||
previously sent to the LLM, allowing for cache hits even when prompts are not identical
|
||||
but carry similar meaning.
|
||||
"""
|
||||
|
||||
DEFAULT_REDIS_INDEX_NAME: str = "litellm_semantic_cache_index"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
host: Optional[str] = None,
|
||||
port: Optional[str] = None,
|
||||
password: Optional[str] = None,
|
||||
redis_url: Optional[str] = None,
|
||||
similarity_threshold: Optional[float] = None,
|
||||
embedding_model: str = "text-embedding-ada-002",
|
||||
index_name: Optional[str] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Initialize the Redis Semantic Cache.
|
||||
|
||||
Args:
|
||||
host: Redis host address
|
||||
port: Redis port
|
||||
password: Redis password
|
||||
redis_url: Full Redis URL (alternative to separate host/port/password)
|
||||
similarity_threshold: Threshold for semantic similarity (0.0 to 1.0)
|
||||
where 1.0 requires exact matches and 0.0 accepts any match
|
||||
embedding_model: Model to use for generating embeddings
|
||||
index_name: Name for the Redis index
|
||||
ttl: Default time-to-live for cache entries in seconds
|
||||
**kwargs: Additional arguments passed to the Redis client
|
||||
|
||||
Raises:
|
||||
Exception: If similarity_threshold is not provided or required Redis
|
||||
connection information is missing
|
||||
"""
|
||||
from redisvl.extensions.llmcache import SemanticCache
|
||||
from redisvl.utils.vectorize import CustomTextVectorizer
|
||||
|
||||
if index_name is None:
|
||||
index_name = self.DEFAULT_REDIS_INDEX_NAME
|
||||
|
||||
print_verbose(f"Redis semantic-cache initializing index - {index_name}")
|
||||
|
||||
# Validate similarity threshold
|
||||
if similarity_threshold is None:
|
||||
raise ValueError("similarity_threshold must be provided, passed None")
|
||||
|
||||
# Store configuration
|
||||
self.similarity_threshold = similarity_threshold
|
||||
|
||||
# Convert similarity threshold [0,1] to distance threshold [0,2]
|
||||
# For cosine distance: 0 = most similar, 2 = least similar
|
||||
# While similarity: 1 = most similar, 0 = least similar
|
||||
self.distance_threshold = 1 - similarity_threshold
|
||||
self.embedding_model = embedding_model
|
||||
|
||||
# Set up Redis connection
|
||||
if redis_url is None:
|
||||
try:
|
||||
# Attempt to use provided parameters or fallback to environment variables
|
||||
host = host or os.environ["REDIS_HOST"]
|
||||
port = port or os.environ["REDIS_PORT"]
|
||||
password = password or os.environ["REDIS_PASSWORD"]
|
||||
except KeyError as e:
|
||||
# Raise a more informative exception if any of the required keys are missing
|
||||
missing_var = e.args[0]
|
||||
raise ValueError(
|
||||
f"Missing required Redis configuration: {missing_var}. "
|
||||
f"Provide {missing_var} or redis_url."
|
||||
) from e
|
||||
|
||||
redis_url = f"redis://:{password}@{host}:{port}"
|
||||
|
||||
print_verbose(f"Redis semantic-cache redis_url: {redis_url}")
|
||||
|
||||
# Initialize the Redis vectorizer and cache
|
||||
cache_vectorizer = CustomTextVectorizer(self._get_embedding)
|
||||
|
||||
self.llmcache = SemanticCache(
|
||||
name=index_name,
|
||||
redis_url=redis_url,
|
||||
vectorizer=cache_vectorizer,
|
||||
distance_threshold=self.distance_threshold,
|
||||
overwrite=False,
|
||||
)
|
||||
|
||||
def _get_ttl(self, **kwargs) -> Optional[int]:
|
||||
"""
|
||||
Get the TTL (time-to-live) value for cache entries.
|
||||
|
||||
Args:
|
||||
**kwargs: Keyword arguments that may contain a custom TTL
|
||||
|
||||
Returns:
|
||||
Optional[int]: The TTL value in seconds, or None if no TTL should be applied
|
||||
"""
|
||||
ttl = kwargs.get("ttl")
|
||||
if ttl is not None:
|
||||
ttl = int(ttl)
|
||||
return ttl
|
||||
|
||||
def _get_embedding(self, prompt: str) -> List[float]:
|
||||
"""
|
||||
Generate an embedding vector for the given prompt using the configured embedding model.
|
||||
|
||||
Args:
|
||||
prompt: The text to generate an embedding for
|
||||
|
||||
Returns:
|
||||
List[float]: The embedding vector
|
||||
"""
|
||||
# Create an embedding from prompt
|
||||
embedding_response = cast(
|
||||
EmbeddingResponse,
|
||||
litellm.embedding(
|
||||
model=self.embedding_model,
|
||||
input=prompt,
|
||||
cache={"no-store": True, "no-cache": True},
|
||||
),
|
||||
)
|
||||
embedding = embedding_response["data"][0]["embedding"]
|
||||
return embedding
|
||||
|
||||
def _get_cache_logic(self, cached_response: Any) -> Any:
|
||||
"""
|
||||
Process the cached response to prepare it for use.
|
||||
|
||||
Args:
|
||||
cached_response: The raw cached response
|
||||
|
||||
Returns:
|
||||
The processed cache response, or None if input was None
|
||||
"""
|
||||
if cached_response is None:
|
||||
return cached_response
|
||||
|
||||
# Convert bytes to string if needed
|
||||
if isinstance(cached_response, bytes):
|
||||
cached_response = cached_response.decode("utf-8")
|
||||
|
||||
# Convert string representation to Python object
|
||||
try:
|
||||
cached_response = json.loads(cached_response)
|
||||
except json.JSONDecodeError:
|
||||
try:
|
||||
cached_response = ast.literal_eval(cached_response)
|
||||
except (ValueError, SyntaxError) as e:
|
||||
print_verbose(f"Error parsing cached response: {str(e)}")
|
||||
return None
|
||||
|
||||
return cached_response
|
||||
|
||||
def set_cache(self, key: str, value: Any, **kwargs) -> None:
|
||||
"""
|
||||
Store a value in the semantic cache.
|
||||
|
||||
Args:
|
||||
key: The cache key (not directly used in semantic caching)
|
||||
value: The response value to cache
|
||||
**kwargs: Additional arguments including 'messages' for the prompt
|
||||
and optional 'ttl' for time-to-live
|
||||
"""
|
||||
print_verbose(f"Redis semantic-cache set_cache, kwargs: {kwargs}")
|
||||
|
||||
value_str: Optional[str] = None
|
||||
try:
|
||||
# Extract the prompt from messages
|
||||
messages = kwargs.get("messages", [])
|
||||
if not messages:
|
||||
print_verbose("No messages provided for semantic caching")
|
||||
return
|
||||
|
||||
prompt = get_str_from_messages(messages)
|
||||
value_str = str(value)
|
||||
|
||||
# Get TTL and store in Redis semantic cache
|
||||
ttl = self._get_ttl(**kwargs)
|
||||
if ttl is not None:
|
||||
self.llmcache.store(prompt, value_str, ttl=int(ttl))
|
||||
else:
|
||||
self.llmcache.store(prompt, value_str)
|
||||
except Exception as e:
|
||||
print_verbose(
|
||||
f"Error setting {value_str or value} in the Redis semantic cache: {str(e)}"
|
||||
)
|
||||
|
||||
def get_cache(self, key: str, **kwargs) -> Any:
|
||||
"""
|
||||
Retrieve a semantically similar cached response.
|
||||
|
||||
Args:
|
||||
key: The cache key (not directly used in semantic caching)
|
||||
**kwargs: Additional arguments including 'messages' for the prompt
|
||||
|
||||
Returns:
|
||||
The cached response if a semantically similar prompt is found, else None
|
||||
"""
|
||||
print_verbose(f"Redis semantic-cache get_cache, kwargs: {kwargs}")
|
||||
|
||||
try:
|
||||
# Extract the prompt from messages
|
||||
messages = kwargs.get("messages", [])
|
||||
if not messages:
|
||||
print_verbose("No messages provided for semantic cache lookup")
|
||||
return None
|
||||
|
||||
prompt = get_str_from_messages(messages)
|
||||
# Check the cache for semantically similar prompts
|
||||
results = self.llmcache.check(prompt=prompt)
|
||||
|
||||
# Return None if no similar prompts found
|
||||
if not results:
|
||||
return None
|
||||
|
||||
# Process the best matching result
|
||||
cache_hit = results[0]
|
||||
vector_distance = float(cache_hit["vector_distance"])
|
||||
|
||||
# Convert vector distance back to similarity score
|
||||
# For cosine distance: 0 = most similar, 2 = least similar
|
||||
# While similarity: 1 = most similar, 0 = least similar
|
||||
similarity = 1 - vector_distance
|
||||
|
||||
cached_prompt = cache_hit["prompt"]
|
||||
cached_response = cache_hit["response"]
|
||||
|
||||
print_verbose(
|
||||
f"Cache hit: similarity threshold: {self.similarity_threshold}, "
|
||||
f"actual similarity: {similarity}, "
|
||||
f"current prompt: {prompt}, "
|
||||
f"cached prompt: {cached_prompt}"
|
||||
)
|
||||
|
||||
return self._get_cache_logic(cached_response=cached_response)
|
||||
except Exception as e:
|
||||
print_verbose(f"Error retrieving from Redis semantic cache: {str(e)}")
|
||||
|
||||
async def _get_async_embedding(self, prompt: str, **kwargs) -> List[float]:
|
||||
"""
|
||||
Asynchronously generate an embedding for the given prompt.
|
||||
|
||||
Args:
|
||||
prompt: The text to generate an embedding for
|
||||
**kwargs: Additional arguments that may contain metadata
|
||||
|
||||
Returns:
|
||||
List[float]: The embedding vector
|
||||
"""
|
||||
from litellm.proxy.proxy_server import llm_model_list, llm_router
|
||||
|
||||
# Route the embedding request through the proxy if appropriate
|
||||
router_model_names = (
|
||||
[m["model_name"] for m in llm_model_list]
|
||||
if llm_model_list is not None
|
||||
else []
|
||||
)
|
||||
|
||||
try:
|
||||
if llm_router is not None and self.embedding_model in router_model_names:
|
||||
# Use the router for embedding generation
|
||||
user_api_key = kwargs.get("metadata", {}).get("user_api_key", "")
|
||||
embedding_response = await llm_router.aembedding(
|
||||
model=self.embedding_model,
|
||||
input=prompt,
|
||||
cache={"no-store": True, "no-cache": True},
|
||||
metadata={
|
||||
"user_api_key": user_api_key,
|
||||
"semantic-cache-embedding": True,
|
||||
"trace_id": kwargs.get("metadata", {}).get("trace_id", None),
|
||||
},
|
||||
)
|
||||
else:
|
||||
# Generate embedding directly
|
||||
embedding_response = await litellm.aembedding(
|
||||
model=self.embedding_model,
|
||||
input=prompt,
|
||||
cache={"no-store": True, "no-cache": True},
|
||||
)
|
||||
|
||||
# Extract and return the embedding vector
|
||||
return embedding_response["data"][0]["embedding"]
|
||||
except Exception as e:
|
||||
print_verbose(f"Error generating async embedding: {str(e)}")
|
||||
raise ValueError(f"Failed to generate embedding: {str(e)}") from e
|
||||
|
||||
async def async_set_cache(self, key: str, value: Any, **kwargs) -> None:
|
||||
"""
|
||||
Asynchronously store a value in the semantic cache.
|
||||
|
||||
Args:
|
||||
key: The cache key (not directly used in semantic caching)
|
||||
value: The response value to cache
|
||||
**kwargs: Additional arguments including 'messages' for the prompt
|
||||
and optional 'ttl' for time-to-live
|
||||
"""
|
||||
print_verbose(f"Async Redis semantic-cache set_cache, kwargs: {kwargs}")
|
||||
|
||||
try:
|
||||
# Extract the prompt from messages
|
||||
messages = kwargs.get("messages", [])
|
||||
if not messages:
|
||||
print_verbose("No messages provided for semantic caching")
|
||||
return
|
||||
|
||||
prompt = get_str_from_messages(messages)
|
||||
value_str = str(value)
|
||||
|
||||
# Generate embedding for the value (response) to cache
|
||||
prompt_embedding = await self._get_async_embedding(prompt, **kwargs)
|
||||
|
||||
# Get TTL and store in Redis semantic cache
|
||||
ttl = self._get_ttl(**kwargs)
|
||||
if ttl is not None:
|
||||
await self.llmcache.astore(
|
||||
prompt,
|
||||
value_str,
|
||||
vector=prompt_embedding, # Pass through custom embedding
|
||||
ttl=ttl,
|
||||
)
|
||||
else:
|
||||
await self.llmcache.astore(
|
||||
prompt,
|
||||
value_str,
|
||||
vector=prompt_embedding, # Pass through custom embedding
|
||||
)
|
||||
except Exception as e:
|
||||
print_verbose(f"Error in async_set_cache: {str(e)}")
|
||||
|
||||
async def async_get_cache(self, key: str, **kwargs) -> Any:
|
||||
"""
|
||||
Asynchronously retrieve a semantically similar cached response.
|
||||
|
||||
Args:
|
||||
key: The cache key (not directly used in semantic caching)
|
||||
**kwargs: Additional arguments including 'messages' for the prompt
|
||||
|
||||
Returns:
|
||||
The cached response if a semantically similar prompt is found, else None
|
||||
"""
|
||||
print_verbose(f"Async Redis semantic-cache get_cache, kwargs: {kwargs}")
|
||||
|
||||
try:
|
||||
# Extract the prompt from messages
|
||||
messages = kwargs.get("messages", [])
|
||||
if not messages:
|
||||
print_verbose("No messages provided for semantic cache lookup")
|
||||
kwargs.setdefault("metadata", {})["semantic-similarity"] = 0.0
|
||||
return None
|
||||
|
||||
prompt = get_str_from_messages(messages)
|
||||
|
||||
# Generate embedding for the prompt
|
||||
prompt_embedding = await self._get_async_embedding(prompt, **kwargs)
|
||||
|
||||
# Check the cache for semantically similar prompts
|
||||
results = await self.llmcache.acheck(prompt=prompt, vector=prompt_embedding)
|
||||
|
||||
# handle results / cache hit
|
||||
if not results:
|
||||
kwargs.setdefault("metadata", {})[
|
||||
"semantic-similarity"
|
||||
] = 0.0 # TODO why here but not above??
|
||||
return None
|
||||
|
||||
cache_hit = results[0]
|
||||
vector_distance = float(cache_hit["vector_distance"])
|
||||
|
||||
# Convert vector distance back to similarity
|
||||
# For cosine distance: 0 = most similar, 2 = least similar
|
||||
# While similarity: 1 = most similar, 0 = least similar
|
||||
similarity = 1 - vector_distance
|
||||
|
||||
cached_prompt = cache_hit["prompt"]
|
||||
cached_response = cache_hit["response"]
|
||||
|
||||
# update kwargs["metadata"] with similarity, don't rewrite the original metadata
|
||||
kwargs.setdefault("metadata", {})["semantic-similarity"] = similarity
|
||||
|
||||
print_verbose(
|
||||
f"Cache hit: similarity threshold: {self.similarity_threshold}, "
|
||||
f"actual similarity: {similarity}, "
|
||||
f"current prompt: {prompt}, "
|
||||
f"cached prompt: {cached_prompt}"
|
||||
)
|
||||
|
||||
return self._get_cache_logic(cached_response=cached_response)
|
||||
except Exception as e:
|
||||
print_verbose(f"Error in async_get_cache: {str(e)}")
|
||||
kwargs.setdefault("metadata", {})["semantic-similarity"] = 0.0
|
||||
|
||||
async def _index_info(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Get information about the Redis index.
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: Information about the Redis index
|
||||
"""
|
||||
aindex = await self.llmcache._get_async_index()
|
||||
return await aindex.info()
|
||||
|
||||
async def async_set_cache_pipeline(
|
||||
self, cache_list: List[Tuple[str, Any]], **kwargs
|
||||
) -> None:
|
||||
"""
|
||||
Asynchronously store multiple values in the semantic cache.
|
||||
|
||||
Args:
|
||||
cache_list: List of (key, value) tuples to cache
|
||||
**kwargs: Additional arguments
|
||||
"""
|
||||
try:
|
||||
tasks = []
|
||||
for val in cache_list:
|
||||
tasks.append(self.async_set_cache(val[0], val[1], **kwargs))
|
||||
await asyncio.gather(*tasks)
|
||||
except Exception as e:
|
||||
print_verbose(f"Error in async_set_cache_pipeline: {str(e)}")
|
||||
159
.venv/lib/python3.10/site-packages/litellm/caching/s3_cache.py
Normal file
159
.venv/lib/python3.10/site-packages/litellm/caching/s3_cache.py
Normal file
@@ -0,0 +1,159 @@
|
||||
"""
|
||||
S3 Cache implementation
|
||||
WARNING: DO NOT USE THIS IN PRODUCTION - This is not ASYNC
|
||||
|
||||
Has 4 methods:
|
||||
- set_cache
|
||||
- get_cache
|
||||
- async_set_cache
|
||||
- async_get_cache
|
||||
"""
|
||||
|
||||
import ast
|
||||
import asyncio
|
||||
import json
|
||||
from typing import Optional
|
||||
|
||||
from litellm._logging import print_verbose, verbose_logger
|
||||
|
||||
from .base_cache import BaseCache
|
||||
|
||||
|
||||
class S3Cache(BaseCache):
|
||||
def __init__(
|
||||
self,
|
||||
s3_bucket_name,
|
||||
s3_region_name=None,
|
||||
s3_api_version=None,
|
||||
s3_use_ssl: Optional[bool] = True,
|
||||
s3_verify=None,
|
||||
s3_endpoint_url=None,
|
||||
s3_aws_access_key_id=None,
|
||||
s3_aws_secret_access_key=None,
|
||||
s3_aws_session_token=None,
|
||||
s3_config=None,
|
||||
s3_path=None,
|
||||
**kwargs,
|
||||
):
|
||||
import boto3
|
||||
|
||||
self.bucket_name = s3_bucket_name
|
||||
self.key_prefix = s3_path.rstrip("/") + "/" if s3_path else ""
|
||||
# Create an S3 client with custom endpoint URL
|
||||
|
||||
self.s3_client = boto3.client(
|
||||
"s3",
|
||||
region_name=s3_region_name,
|
||||
endpoint_url=s3_endpoint_url,
|
||||
api_version=s3_api_version,
|
||||
use_ssl=s3_use_ssl,
|
||||
verify=s3_verify,
|
||||
aws_access_key_id=s3_aws_access_key_id,
|
||||
aws_secret_access_key=s3_aws_secret_access_key,
|
||||
aws_session_token=s3_aws_session_token,
|
||||
config=s3_config,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def set_cache(self, key, value, **kwargs):
|
||||
try:
|
||||
print_verbose(f"LiteLLM SET Cache - S3. Key={key}. Value={value}")
|
||||
ttl = kwargs.get("ttl", None)
|
||||
# Convert value to JSON before storing in S3
|
||||
serialized_value = json.dumps(value)
|
||||
key = self.key_prefix + key
|
||||
|
||||
if ttl is not None:
|
||||
cache_control = f"immutable, max-age={ttl}, s-maxage={ttl}"
|
||||
import datetime
|
||||
|
||||
# Calculate expiration time
|
||||
expiration_time = datetime.datetime.now() + ttl
|
||||
|
||||
# Upload the data to S3 with the calculated expiration time
|
||||
self.s3_client.put_object(
|
||||
Bucket=self.bucket_name,
|
||||
Key=key,
|
||||
Body=serialized_value,
|
||||
Expires=expiration_time,
|
||||
CacheControl=cache_control,
|
||||
ContentType="application/json",
|
||||
ContentLanguage="en",
|
||||
ContentDisposition=f'inline; filename="{key}.json"',
|
||||
)
|
||||
else:
|
||||
cache_control = "immutable, max-age=31536000, s-maxage=31536000"
|
||||
# Upload the data to S3 without specifying Expires
|
||||
self.s3_client.put_object(
|
||||
Bucket=self.bucket_name,
|
||||
Key=key,
|
||||
Body=serialized_value,
|
||||
CacheControl=cache_control,
|
||||
ContentType="application/json",
|
||||
ContentLanguage="en",
|
||||
ContentDisposition=f'inline; filename="{key}.json"',
|
||||
)
|
||||
except Exception as e:
|
||||
# NON blocking - notify users S3 is throwing an exception
|
||||
print_verbose(f"S3 Caching: set_cache() - Got exception from S3: {e}")
|
||||
|
||||
async def async_set_cache(self, key, value, **kwargs):
|
||||
self.set_cache(key=key, value=value, **kwargs)
|
||||
|
||||
def get_cache(self, key, **kwargs):
|
||||
import botocore
|
||||
|
||||
try:
|
||||
key = self.key_prefix + key
|
||||
|
||||
print_verbose(f"Get S3 Cache: key: {key}")
|
||||
# Download the data from S3
|
||||
cached_response = self.s3_client.get_object(
|
||||
Bucket=self.bucket_name, Key=key
|
||||
)
|
||||
|
||||
if cached_response is not None:
|
||||
# cached_response is in `b{} convert it to ModelResponse
|
||||
cached_response = (
|
||||
cached_response["Body"].read().decode("utf-8")
|
||||
) # Convert bytes to string
|
||||
try:
|
||||
cached_response = json.loads(
|
||||
cached_response
|
||||
) # Convert string to dictionary
|
||||
except Exception:
|
||||
cached_response = ast.literal_eval(cached_response)
|
||||
if not isinstance(cached_response, dict):
|
||||
cached_response = dict(cached_response)
|
||||
verbose_logger.debug(
|
||||
f"Got S3 Cache: key: {key}, cached_response {cached_response}. Type Response {type(cached_response)}"
|
||||
)
|
||||
|
||||
return cached_response
|
||||
except botocore.exceptions.ClientError as e: # type: ignore
|
||||
if e.response["Error"]["Code"] == "NoSuchKey":
|
||||
verbose_logger.debug(
|
||||
f"S3 Cache: The specified key '{key}' does not exist in the S3 bucket."
|
||||
)
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
# NON blocking - notify users S3 is throwing an exception
|
||||
verbose_logger.error(
|
||||
f"S3 Caching: get_cache() - Got exception from S3: {e}"
|
||||
)
|
||||
|
||||
async def async_get_cache(self, key, **kwargs):
|
||||
return self.get_cache(key=key, **kwargs)
|
||||
|
||||
def flush_cache(self):
|
||||
pass
|
||||
|
||||
async def disconnect(self):
|
||||
pass
|
||||
|
||||
async def async_set_cache_pipeline(self, cache_list, **kwargs):
|
||||
tasks = []
|
||||
for val in cache_list:
|
||||
tasks.append(self.async_set_cache(val[0], val[1], **kwargs))
|
||||
await asyncio.gather(*tasks)
|
||||
539
.venv/lib/python3.10/site-packages/litellm/constants.py
Normal file
539
.venv/lib/python3.10/site-packages/litellm/constants.py
Normal file
@@ -0,0 +1,539 @@
|
||||
from typing import List, Literal
|
||||
|
||||
ROUTER_MAX_FALLBACKS = 5
|
||||
DEFAULT_BATCH_SIZE = 512
|
||||
DEFAULT_FLUSH_INTERVAL_SECONDS = 5
|
||||
DEFAULT_MAX_RETRIES = 2
|
||||
DEFAULT_MAX_RECURSE_DEPTH = 10
|
||||
DEFAULT_FAILURE_THRESHOLD_PERCENT = (
|
||||
0.5 # default cooldown a deployment if 50% of requests fail in a given minute
|
||||
)
|
||||
DEFAULT_MAX_TOKENS = 4096
|
||||
DEFAULT_ALLOWED_FAILS = 3
|
||||
DEFAULT_REDIS_SYNC_INTERVAL = 1
|
||||
DEFAULT_COOLDOWN_TIME_SECONDS = 5
|
||||
DEFAULT_REPLICATE_POLLING_RETRIES = 5
|
||||
DEFAULT_REPLICATE_POLLING_DELAY_SECONDS = 1
|
||||
DEFAULT_IMAGE_TOKEN_COUNT = 250
|
||||
DEFAULT_IMAGE_WIDTH = 300
|
||||
DEFAULT_IMAGE_HEIGHT = 300
|
||||
DEFAULT_MAX_TOKENS = 256 # used when providers need a default
|
||||
MAX_SIZE_PER_ITEM_IN_MEMORY_CACHE_IN_KB = 1024 # 1MB = 1024KB
|
||||
SINGLE_DEPLOYMENT_TRAFFIC_FAILURE_THRESHOLD = 1000 # Minimum number of requests to consider "reasonable traffic". Used for single-deployment cooldown logic.
|
||||
|
||||
DEFAULT_REASONING_EFFORT_LOW_THINKING_BUDGET = 1024
|
||||
DEFAULT_REASONING_EFFORT_MEDIUM_THINKING_BUDGET = 2048
|
||||
DEFAULT_REASONING_EFFORT_HIGH_THINKING_BUDGET = 4096
|
||||
|
||||
########## Networking constants ##############################################################
|
||||
_DEFAULT_TTL_FOR_HTTPX_CLIENTS = 3600 # 1 hour, re-use the same httpx client for 1 hour
|
||||
|
||||
########### v2 Architecture constants for managing writing updates to the database ###########
|
||||
REDIS_UPDATE_BUFFER_KEY = "litellm_spend_update_buffer"
|
||||
REDIS_DAILY_SPEND_UPDATE_BUFFER_KEY = "litellm_daily_spend_update_buffer"
|
||||
REDIS_DAILY_TEAM_SPEND_UPDATE_BUFFER_KEY = "litellm_daily_team_spend_update_buffer"
|
||||
REDIS_DAILY_TAG_SPEND_UPDATE_BUFFER_KEY = "litellm_daily_tag_spend_update_buffer"
|
||||
MAX_REDIS_BUFFER_DEQUEUE_COUNT = 100
|
||||
MAX_SIZE_IN_MEMORY_QUEUE = 10000
|
||||
MAX_IN_MEMORY_QUEUE_FLUSH_COUNT = 1000
|
||||
###############################################################################################
|
||||
MINIMUM_PROMPT_CACHE_TOKEN_COUNT = (
|
||||
1024 # minimum number of tokens to cache a prompt by Anthropic
|
||||
)
|
||||
DEFAULT_TRIM_RATIO = 0.75 # default ratio of tokens to trim from the end of a prompt
|
||||
HOURS_IN_A_DAY = 24
|
||||
DAYS_IN_A_WEEK = 7
|
||||
DAYS_IN_A_MONTH = 28
|
||||
DAYS_IN_A_YEAR = 365
|
||||
REPLICATE_MODEL_NAME_WITH_ID_LENGTH = 64
|
||||
#### TOKEN COUNTING ####
|
||||
FUNCTION_DEFINITION_TOKEN_COUNT = 9
|
||||
SYSTEM_MESSAGE_TOKEN_COUNT = 4
|
||||
TOOL_CHOICE_OBJECT_TOKEN_COUNT = 4
|
||||
DEFAULT_MOCK_RESPONSE_PROMPT_TOKEN_COUNT = 10
|
||||
DEFAULT_MOCK_RESPONSE_COMPLETION_TOKEN_COUNT = 20
|
||||
MAX_SHORT_SIDE_FOR_IMAGE_HIGH_RES = 768
|
||||
MAX_LONG_SIDE_FOR_IMAGE_HIGH_RES = 2000
|
||||
MAX_TILE_WIDTH = 512
|
||||
MAX_TILE_HEIGHT = 512
|
||||
OPENAI_FILE_SEARCH_COST_PER_1K_CALLS = 2.5 / 1000
|
||||
MIN_NON_ZERO_TEMPERATURE = 0.0001
|
||||
#### RELIABILITY ####
|
||||
REPEATED_STREAMING_CHUNK_LIMIT = 100 # catch if model starts looping the same chunk while streaming. Uses high default to prevent false positives.
|
||||
DEFAULT_MAX_LRU_CACHE_SIZE = 16
|
||||
INITIAL_RETRY_DELAY = 0.5
|
||||
MAX_RETRY_DELAY = 8.0
|
||||
JITTER = 0.75
|
||||
DEFAULT_IN_MEMORY_TTL = 5 # default time to live for the in-memory cache
|
||||
DEFAULT_POLLING_INTERVAL = 0.03 # default polling interval for the scheduler
|
||||
AZURE_OPERATION_POLLING_TIMEOUT = 120
|
||||
REDIS_SOCKET_TIMEOUT = 0.1
|
||||
REDIS_CONNECTION_POOL_TIMEOUT = 5
|
||||
NON_LLM_CONNECTION_TIMEOUT = 15 # timeout for adjacent services (e.g. jwt auth)
|
||||
MAX_EXCEPTION_MESSAGE_LENGTH = 2000
|
||||
BEDROCK_MAX_POLICY_SIZE = 75
|
||||
REPLICATE_POLLING_DELAY_SECONDS = 0.5
|
||||
DEFAULT_ANTHROPIC_CHAT_MAX_TOKENS = 4096
|
||||
TOGETHER_AI_4_B = 4
|
||||
TOGETHER_AI_8_B = 8
|
||||
TOGETHER_AI_21_B = 21
|
||||
TOGETHER_AI_41_B = 41
|
||||
TOGETHER_AI_80_B = 80
|
||||
TOGETHER_AI_110_B = 110
|
||||
TOGETHER_AI_EMBEDDING_150_M = 150
|
||||
TOGETHER_AI_EMBEDDING_350_M = 350
|
||||
QDRANT_SCALAR_QUANTILE = 0.99
|
||||
QDRANT_VECTOR_SIZE = 1536
|
||||
CACHED_STREAMING_CHUNK_DELAY = 0.02
|
||||
MAX_SIZE_PER_ITEM_IN_MEMORY_CACHE_IN_KB = 512
|
||||
DEFAULT_MAX_TOKENS_FOR_TRITON = 2000
|
||||
#### Networking settings ####
|
||||
request_timeout: float = 6000 # time in seconds
|
||||
STREAM_SSE_DONE_STRING: str = "[DONE]"
|
||||
### SPEND TRACKING ###
|
||||
DEFAULT_REPLICATE_GPU_PRICE_PER_SECOND = 0.001400 # price per second for a100 80GB
|
||||
FIREWORKS_AI_56_B_MOE = 56
|
||||
FIREWORKS_AI_176_B_MOE = 176
|
||||
FIREWORKS_AI_16_B = 16
|
||||
FIREWORKS_AI_80_B = 80
|
||||
|
||||
LITELLM_CHAT_PROVIDERS = [
|
||||
"openai",
|
||||
"openai_like",
|
||||
"xai",
|
||||
"custom_openai",
|
||||
"text-completion-openai",
|
||||
"cohere",
|
||||
"cohere_chat",
|
||||
"clarifai",
|
||||
"anthropic",
|
||||
"anthropic_text",
|
||||
"replicate",
|
||||
"huggingface",
|
||||
"together_ai",
|
||||
"openrouter",
|
||||
"vertex_ai",
|
||||
"vertex_ai_beta",
|
||||
"gemini",
|
||||
"ai21",
|
||||
"baseten",
|
||||
"azure",
|
||||
"azure_text",
|
||||
"azure_ai",
|
||||
"sagemaker",
|
||||
"sagemaker_chat",
|
||||
"bedrock",
|
||||
"vllm",
|
||||
"nlp_cloud",
|
||||
"petals",
|
||||
"oobabooga",
|
||||
"ollama",
|
||||
"ollama_chat",
|
||||
"deepinfra",
|
||||
"perplexity",
|
||||
"mistral",
|
||||
"groq",
|
||||
"nvidia_nim",
|
||||
"cerebras",
|
||||
"ai21_chat",
|
||||
"volcengine",
|
||||
"codestral",
|
||||
"text-completion-codestral",
|
||||
"deepseek",
|
||||
"sambanova",
|
||||
"maritalk",
|
||||
"cloudflare",
|
||||
"fireworks_ai",
|
||||
"friendliai",
|
||||
"watsonx",
|
||||
"watsonx_text",
|
||||
"triton",
|
||||
"predibase",
|
||||
"databricks",
|
||||
"empower",
|
||||
"github",
|
||||
"custom",
|
||||
"litellm_proxy",
|
||||
"hosted_vllm",
|
||||
"lm_studio",
|
||||
"galadriel",
|
||||
]
|
||||
|
||||
|
||||
OPENAI_CHAT_COMPLETION_PARAMS = [
|
||||
"functions",
|
||||
"function_call",
|
||||
"temperature",
|
||||
"temperature",
|
||||
"top_p",
|
||||
"n",
|
||||
"stream",
|
||||
"stream_options",
|
||||
"stop",
|
||||
"max_completion_tokens",
|
||||
"modalities",
|
||||
"prediction",
|
||||
"audio",
|
||||
"max_tokens",
|
||||
"presence_penalty",
|
||||
"frequency_penalty",
|
||||
"logit_bias",
|
||||
"user",
|
||||
"request_timeout",
|
||||
"api_base",
|
||||
"api_version",
|
||||
"api_key",
|
||||
"deployment_id",
|
||||
"organization",
|
||||
"base_url",
|
||||
"default_headers",
|
||||
"timeout",
|
||||
"response_format",
|
||||
"seed",
|
||||
"tools",
|
||||
"tool_choice",
|
||||
"max_retries",
|
||||
"parallel_tool_calls",
|
||||
"logprobs",
|
||||
"top_logprobs",
|
||||
"reasoning_effort",
|
||||
"extra_headers",
|
||||
"thinking",
|
||||
]
|
||||
|
||||
openai_compatible_endpoints: List = [
|
||||
"api.perplexity.ai",
|
||||
"api.endpoints.anyscale.com/v1",
|
||||
"api.deepinfra.com/v1/openai",
|
||||
"api.mistral.ai/v1",
|
||||
"codestral.mistral.ai/v1/chat/completions",
|
||||
"codestral.mistral.ai/v1/fim/completions",
|
||||
"api.groq.com/openai/v1",
|
||||
"https://integrate.api.nvidia.com/v1",
|
||||
"api.deepseek.com/v1",
|
||||
"api.together.xyz/v1",
|
||||
"app.empower.dev/api/v1",
|
||||
"https://api.friendli.ai/serverless/v1",
|
||||
"api.sambanova.ai/v1",
|
||||
"api.x.ai/v1",
|
||||
"api.galadriel.ai/v1",
|
||||
]
|
||||
|
||||
|
||||
openai_compatible_providers: List = [
|
||||
"anyscale",
|
||||
"mistral",
|
||||
"groq",
|
||||
"nvidia_nim",
|
||||
"cerebras",
|
||||
"sambanova",
|
||||
"ai21_chat",
|
||||
"ai21",
|
||||
"volcengine",
|
||||
"codestral",
|
||||
"deepseek",
|
||||
"deepinfra",
|
||||
"perplexity",
|
||||
"xinference",
|
||||
"xai",
|
||||
"together_ai",
|
||||
"fireworks_ai",
|
||||
"empower",
|
||||
"friendliai",
|
||||
"azure_ai",
|
||||
"github",
|
||||
"litellm_proxy",
|
||||
"hosted_vllm",
|
||||
"lm_studio",
|
||||
"galadriel",
|
||||
]
|
||||
openai_text_completion_compatible_providers: List = (
|
||||
[ # providers that support `/v1/completions`
|
||||
"together_ai",
|
||||
"fireworks_ai",
|
||||
"hosted_vllm",
|
||||
]
|
||||
)
|
||||
_openai_like_providers: List = [
|
||||
"predibase",
|
||||
"databricks",
|
||||
"watsonx",
|
||||
] # private helper. similar to openai but require some custom auth / endpoint handling, so can't use the openai sdk
|
||||
# well supported replicate llms
|
||||
replicate_models: List = [
|
||||
# llama replicate supported LLMs
|
||||
"replicate/llama-2-70b-chat:2796ee9483c3fd7aa2e171d38f4ca12251a30609463dcfd4cd76703f22e96cdf",
|
||||
"a16z-infra/llama-2-13b-chat:2a7f981751ec7fdf87b5b91ad4db53683a98082e9ff7bfd12c8cd5ea85980a52",
|
||||
"meta/codellama-13b:1c914d844307b0588599b8393480a3ba917b660c7e9dfae681542b5325f228db",
|
||||
# Vicuna
|
||||
"replicate/vicuna-13b:6282abe6a492de4145d7bb601023762212f9ddbbe78278bd6771c8b3b2f2a13b",
|
||||
"joehoover/instructblip-vicuna13b:c4c54e3c8c97cd50c2d2fec9be3b6065563ccf7d43787fb99f84151b867178fe",
|
||||
# Flan T-5
|
||||
"daanelson/flan-t5-large:ce962b3f6792a57074a601d3979db5839697add2e4e02696b3ced4c022d4767f",
|
||||
# Others
|
||||
"replicate/dolly-v2-12b:ef0e1aefc61f8e096ebe4db6b2bacc297daf2ef6899f0f7e001ec445893500e5",
|
||||
"replit/replit-code-v1-3b:b84f4c074b807211cd75e3e8b1589b6399052125b4c27106e43d47189e8415ad",
|
||||
]
|
||||
|
||||
clarifai_models: List = [
|
||||
"clarifai/meta.Llama-3.Llama-3-8B-Instruct",
|
||||
"clarifai/gcp.generate.gemma-1_1-7b-it",
|
||||
"clarifai/mistralai.completion.mixtral-8x22B",
|
||||
"clarifai/cohere.generate.command-r-plus",
|
||||
"clarifai/databricks.drbx.dbrx-instruct",
|
||||
"clarifai/mistralai.completion.mistral-large",
|
||||
"clarifai/mistralai.completion.mistral-medium",
|
||||
"clarifai/mistralai.completion.mistral-small",
|
||||
"clarifai/mistralai.completion.mixtral-8x7B-Instruct-v0_1",
|
||||
"clarifai/gcp.generate.gemma-2b-it",
|
||||
"clarifai/gcp.generate.gemma-7b-it",
|
||||
"clarifai/deci.decilm.deciLM-7B-instruct",
|
||||
"clarifai/mistralai.completion.mistral-7B-Instruct",
|
||||
"clarifai/gcp.generate.gemini-pro",
|
||||
"clarifai/anthropic.completion.claude-v1",
|
||||
"clarifai/anthropic.completion.claude-instant-1_2",
|
||||
"clarifai/anthropic.completion.claude-instant",
|
||||
"clarifai/anthropic.completion.claude-v2",
|
||||
"clarifai/anthropic.completion.claude-2_1",
|
||||
"clarifai/meta.Llama-2.codeLlama-70b-Python",
|
||||
"clarifai/meta.Llama-2.codeLlama-70b-Instruct",
|
||||
"clarifai/openai.completion.gpt-3_5-turbo-instruct",
|
||||
"clarifai/meta.Llama-2.llama2-7b-chat",
|
||||
"clarifai/meta.Llama-2.llama2-13b-chat",
|
||||
"clarifai/meta.Llama-2.llama2-70b-chat",
|
||||
"clarifai/openai.chat-completion.gpt-4-turbo",
|
||||
"clarifai/microsoft.text-generation.phi-2",
|
||||
"clarifai/meta.Llama-2.llama2-7b-chat-vllm",
|
||||
"clarifai/upstage.solar.solar-10_7b-instruct",
|
||||
"clarifai/openchat.openchat.openchat-3_5-1210",
|
||||
"clarifai/togethercomputer.stripedHyena.stripedHyena-Nous-7B",
|
||||
"clarifai/gcp.generate.text-bison",
|
||||
"clarifai/meta.Llama-2.llamaGuard-7b",
|
||||
"clarifai/fblgit.una-cybertron.una-cybertron-7b-v2",
|
||||
"clarifai/openai.chat-completion.GPT-4",
|
||||
"clarifai/openai.chat-completion.GPT-3_5-turbo",
|
||||
"clarifai/ai21.complete.Jurassic2-Grande",
|
||||
"clarifai/ai21.complete.Jurassic2-Grande-Instruct",
|
||||
"clarifai/ai21.complete.Jurassic2-Jumbo-Instruct",
|
||||
"clarifai/ai21.complete.Jurassic2-Jumbo",
|
||||
"clarifai/ai21.complete.Jurassic2-Large",
|
||||
"clarifai/cohere.generate.cohere-generate-command",
|
||||
"clarifai/wizardlm.generate.wizardCoder-Python-34B",
|
||||
"clarifai/wizardlm.generate.wizardLM-70B",
|
||||
"clarifai/tiiuae.falcon.falcon-40b-instruct",
|
||||
"clarifai/togethercomputer.RedPajama.RedPajama-INCITE-7B-Chat",
|
||||
"clarifai/gcp.generate.code-gecko",
|
||||
"clarifai/gcp.generate.code-bison",
|
||||
"clarifai/mistralai.completion.mistral-7B-OpenOrca",
|
||||
"clarifai/mistralai.completion.openHermes-2-mistral-7B",
|
||||
"clarifai/wizardlm.generate.wizardLM-13B",
|
||||
"clarifai/huggingface-research.zephyr.zephyr-7B-alpha",
|
||||
"clarifai/wizardlm.generate.wizardCoder-15B",
|
||||
"clarifai/microsoft.text-generation.phi-1_5",
|
||||
"clarifai/databricks.Dolly-v2.dolly-v2-12b",
|
||||
"clarifai/bigcode.code.StarCoder",
|
||||
"clarifai/salesforce.xgen.xgen-7b-8k-instruct",
|
||||
"clarifai/mosaicml.mpt.mpt-7b-instruct",
|
||||
"clarifai/anthropic.completion.claude-3-opus",
|
||||
"clarifai/anthropic.completion.claude-3-sonnet",
|
||||
"clarifai/gcp.generate.gemini-1_5-pro",
|
||||
"clarifai/gcp.generate.imagen-2",
|
||||
"clarifai/salesforce.blip.general-english-image-caption-blip-2",
|
||||
]
|
||||
|
||||
|
||||
huggingface_models: List = [
|
||||
"meta-llama/Llama-2-7b-hf",
|
||||
"meta-llama/Llama-2-7b-chat-hf",
|
||||
"meta-llama/Llama-2-13b-hf",
|
||||
"meta-llama/Llama-2-13b-chat-hf",
|
||||
"meta-llama/Llama-2-70b-hf",
|
||||
"meta-llama/Llama-2-70b-chat-hf",
|
||||
"meta-llama/Llama-2-7b",
|
||||
"meta-llama/Llama-2-7b-chat",
|
||||
"meta-llama/Llama-2-13b",
|
||||
"meta-llama/Llama-2-13b-chat",
|
||||
"meta-llama/Llama-2-70b",
|
||||
"meta-llama/Llama-2-70b-chat",
|
||||
] # these have been tested on extensively. But by default all text2text-generation and text-generation models are supported by liteLLM. - https://docs.litellm.ai/docs/providers
|
||||
empower_models = [
|
||||
"empower/empower-functions",
|
||||
"empower/empower-functions-small",
|
||||
]
|
||||
|
||||
together_ai_models: List = [
|
||||
# llama llms - chat
|
||||
"togethercomputer/llama-2-70b-chat",
|
||||
# llama llms - language / instruct
|
||||
"togethercomputer/llama-2-70b",
|
||||
"togethercomputer/LLaMA-2-7B-32K",
|
||||
"togethercomputer/Llama-2-7B-32K-Instruct",
|
||||
"togethercomputer/llama-2-7b",
|
||||
# falcon llms
|
||||
"togethercomputer/falcon-40b-instruct",
|
||||
"togethercomputer/falcon-7b-instruct",
|
||||
# alpaca
|
||||
"togethercomputer/alpaca-7b",
|
||||
# chat llms
|
||||
"HuggingFaceH4/starchat-alpha",
|
||||
# code llms
|
||||
"togethercomputer/CodeLlama-34b",
|
||||
"togethercomputer/CodeLlama-34b-Instruct",
|
||||
"togethercomputer/CodeLlama-34b-Python",
|
||||
"defog/sqlcoder",
|
||||
"NumbersStation/nsql-llama-2-7B",
|
||||
"WizardLM/WizardCoder-15B-V1.0",
|
||||
"WizardLM/WizardCoder-Python-34B-V1.0",
|
||||
# language llms
|
||||
"NousResearch/Nous-Hermes-Llama2-13b",
|
||||
"Austism/chronos-hermes-13b",
|
||||
"upstage/SOLAR-0-70b-16bit",
|
||||
"WizardLM/WizardLM-70B-V1.0",
|
||||
] # supports all together ai models, just pass in the model id e.g. completion(model="together_computer/replit_code_3b",...)
|
||||
|
||||
|
||||
baseten_models: List = [
|
||||
"qvv0xeq",
|
||||
"q841o8w",
|
||||
"31dxrj3",
|
||||
] # FALCON 7B # WizardLM # Mosaic ML
|
||||
|
||||
BEDROCK_INVOKE_PROVIDERS_LITERAL = Literal[
|
||||
"cohere",
|
||||
"anthropic",
|
||||
"mistral",
|
||||
"amazon",
|
||||
"meta",
|
||||
"llama",
|
||||
"ai21",
|
||||
"nova",
|
||||
"deepseek_r1",
|
||||
]
|
||||
|
||||
open_ai_embedding_models: List = ["text-embedding-ada-002"]
|
||||
cohere_embedding_models: List = [
|
||||
"embed-english-v3.0",
|
||||
"embed-english-light-v3.0",
|
||||
"embed-multilingual-v3.0",
|
||||
"embed-english-v2.0",
|
||||
"embed-english-light-v2.0",
|
||||
"embed-multilingual-v2.0",
|
||||
]
|
||||
bedrock_embedding_models: List = [
|
||||
"amazon.titan-embed-text-v1",
|
||||
"cohere.embed-english-v3",
|
||||
"cohere.embed-multilingual-v3",
|
||||
]
|
||||
|
||||
known_tokenizer_config = {
|
||||
"mistralai/Mistral-7B-Instruct-v0.1": {
|
||||
"tokenizer": {
|
||||
"chat_template": "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token + ' ' }}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}",
|
||||
"bos_token": "<s>",
|
||||
"eos_token": "</s>",
|
||||
},
|
||||
"status": "success",
|
||||
},
|
||||
"meta-llama/Meta-Llama-3-8B-Instruct": {
|
||||
"tokenizer": {
|
||||
"chat_template": "{% set loop_messages = messages %}{% for message in loop_messages %}{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' %}{% if loop.index0 == 0 %}{% set content = bos_token + content %}{% endif %}{{ content }}{% endfor %}{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}",
|
||||
"bos_token": "<|begin_of_text|>",
|
||||
"eos_token": "",
|
||||
},
|
||||
"status": "success",
|
||||
},
|
||||
"deepseek-r1/deepseek-r1-7b-instruct": {
|
||||
"tokenizer": {
|
||||
"add_bos_token": True,
|
||||
"add_eos_token": False,
|
||||
"bos_token": {
|
||||
"__type": "AddedToken",
|
||||
"content": "<|begin▁of▁sentence|>",
|
||||
"lstrip": False,
|
||||
"normalized": True,
|
||||
"rstrip": False,
|
||||
"single_word": False,
|
||||
},
|
||||
"clean_up_tokenization_spaces": False,
|
||||
"eos_token": {
|
||||
"__type": "AddedToken",
|
||||
"content": "<|end▁of▁sentence|>",
|
||||
"lstrip": False,
|
||||
"normalized": True,
|
||||
"rstrip": False,
|
||||
"single_word": False,
|
||||
},
|
||||
"legacy": True,
|
||||
"model_max_length": 16384,
|
||||
"pad_token": {
|
||||
"__type": "AddedToken",
|
||||
"content": "<|end▁of▁sentence|>",
|
||||
"lstrip": False,
|
||||
"normalized": True,
|
||||
"rstrip": False,
|
||||
"single_word": False,
|
||||
},
|
||||
"sp_model_kwargs": {},
|
||||
"unk_token": None,
|
||||
"tokenizer_class": "LlamaTokenizerFast",
|
||||
"chat_template": "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% set ns = namespace(is_first=false, is_tool=false, is_output_first=true, system_prompt='') %}{%- for message in messages %}{%- if message['role'] == 'system' %}{% set ns.system_prompt = message['content'] %}{%- endif %}{%- endfor %}{{bos_token}}{{ns.system_prompt}}{%- for message in messages %}{%- if message['role'] == 'user' %}{%- set ns.is_tool = false -%}{{'<|User|>' + message['content']}}{%- endif %}{%- if message['role'] == 'assistant' and message['content'] is none %}{%- set ns.is_tool = false -%}{%- for tool in message['tool_calls']%}{%- if not ns.is_first %}{{'<|Assistant|><|tool▁calls▁begin|><|tool▁call▁begin|>' + tool['type'] + '<|tool▁sep|>' + tool['function']['name'] + '\\n' + '```json' + '\\n' + tool['function']['arguments'] + '\\n' + '```' + '<|tool▁call▁end|>'}}{%- set ns.is_first = true -%}{%- else %}{{'\\n' + '<|tool▁call▁begin|>' + tool['type'] + '<|tool▁sep|>' + tool['function']['name'] + '\\n' + '```json' + '\\n' + tool['function']['arguments'] + '\\n' + '```' + '<|tool▁call▁end|>'}}{{'<|tool▁calls▁end|><|end▁of▁sentence|>'}}{%- endif %}{%- endfor %}{%- endif %}{%- if message['role'] == 'assistant' and message['content'] is not none %}{%- if ns.is_tool %}{{'<|tool▁outputs▁end|>' + message['content'] + '<|end▁of▁sentence|>'}}{%- set ns.is_tool = false -%}{%- else %}{% set content = message['content'] %}{% if '</think>' in content %}{% set content = content.split('</think>')[-1] %}{% endif %}{{'<|Assistant|>' + content + '<|end▁of▁sentence|>'}}{%- endif %}{%- endif %}{%- if message['role'] == 'tool' %}{%- set ns.is_tool = true -%}{%- if ns.is_output_first %}{{'<|tool▁outputs▁begin|><|tool▁output▁begin|>' + message['content'] + '<|tool▁output▁end|>'}}{%- set ns.is_output_first = false %}{%- else %}{{'\\n<|tool▁output▁begin|>' + message['content'] + '<|tool▁output▁end|>'}}{%- endif %}{%- endif %}{%- endfor -%}{% if ns.is_tool %}{{'<|tool▁outputs▁end|>'}}{% endif %}{% if add_generation_prompt and not ns.is_tool %}{{'<|Assistant|><think>\\n'}}{% endif %}",
|
||||
},
|
||||
"status": "success",
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
OPENAI_FINISH_REASONS = ["stop", "length", "function_call", "content_filter", "null"]
|
||||
HUMANLOOP_PROMPT_CACHE_TTL_SECONDS = 60 # 1 minute
|
||||
RESPONSE_FORMAT_TOOL_NAME = "json_tool_call" # default tool name used when converting response format to tool call
|
||||
|
||||
########################### Logging Callback Constants ###########################
|
||||
AZURE_STORAGE_MSFT_VERSION = "2019-07-07"
|
||||
PROMETHEUS_BUDGET_METRICS_REFRESH_INTERVAL_MINUTES = 5
|
||||
MCP_TOOL_NAME_PREFIX = "mcp_tool"
|
||||
|
||||
########################### LiteLLM Proxy Specific Constants ###########################
|
||||
########################################################################################
|
||||
MAX_SPENDLOG_ROWS_TO_QUERY = (
|
||||
1_000_000 # if spendLogs has more than 1M rows, do not query the DB
|
||||
)
|
||||
DEFAULT_SOFT_BUDGET = (
|
||||
50.0 # by default all litellm proxy keys have a soft budget of 50.0
|
||||
)
|
||||
# makes it clear this is a rate limit error for a litellm virtual key
|
||||
RATE_LIMIT_ERROR_MESSAGE_FOR_VIRTUAL_KEY = "LiteLLM Virtual Key user_api_key_hash"
|
||||
|
||||
# pass through route constansts
|
||||
BEDROCK_AGENT_RUNTIME_PASS_THROUGH_ROUTES = [
|
||||
"agents/",
|
||||
"knowledgebases/",
|
||||
"flows/",
|
||||
"retrieveAndGenerate/",
|
||||
"rerank/",
|
||||
"generateQuery/",
|
||||
"optimize-prompt/",
|
||||
]
|
||||
|
||||
BATCH_STATUS_POLL_INTERVAL_SECONDS = 3600 # 1 hour
|
||||
BATCH_STATUS_POLL_MAX_ATTEMPTS = 24 # for 24 hours
|
||||
|
||||
HEALTH_CHECK_TIMEOUT_SECONDS = 60 # 60 seconds
|
||||
|
||||
UI_SESSION_TOKEN_TEAM_ID = "litellm-dashboard"
|
||||
LITELLM_PROXY_ADMIN_NAME = "default_user_id"
|
||||
|
||||
########################### DB CRON JOB NAMES ###########################
|
||||
DB_SPEND_UPDATE_JOB_NAME = "db_spend_update_job"
|
||||
PROMETHEUS_EMIT_BUDGET_METRICS_JOB_NAME = "prometheus_emit_budget_metrics_job"
|
||||
DEFAULT_CRON_JOB_LOCK_TTL_SECONDS = 60 # 1 minute
|
||||
PROXY_BUDGET_RESCHEDULER_MIN_TIME = 597
|
||||
PROXY_BUDGET_RESCHEDULER_MAX_TIME = 605
|
||||
PROXY_BATCH_WRITE_AT = 10 # in seconds
|
||||
DEFAULT_HEALTH_CHECK_INTERVAL = 300 # 5 minutes
|
||||
PROMETHEUS_FALLBACK_STATS_SEND_TIME_HOURS = 9
|
||||
DEFAULT_MODEL_CREATED_AT_TIME = 1677610602 # returns on `/models` endpoint
|
||||
DEFAULT_SLACK_ALERTING_THRESHOLD = 300
|
||||
MAX_TEAM_LIST_LIMIT = 20
|
||||
DEFAULT_PROMPT_INJECTION_SIMILARITY_THRESHOLD = 0.7
|
||||
LENGTH_OF_LITELLM_GENERATED_KEY = 16
|
||||
SECRET_MANAGER_REFRESH_INTERVAL = 86400
|
||||
5
.venv/lib/python3.10/site-packages/litellm/cost.json
Normal file
5
.venv/lib/python3.10/site-packages/litellm/cost.json
Normal file
@@ -0,0 +1,5 @@
|
||||
{
|
||||
"gpt-3.5-turbo-0613": 0.00015000000000000001,
|
||||
"claude-2": 0.00016454,
|
||||
"gpt-4-0613": 0.015408
|
||||
}
|
||||
1350
.venv/lib/python3.10/site-packages/litellm/cost_calculator.py
Normal file
1350
.venv/lib/python3.10/site-packages/litellm/cost_calculator.py
Normal file
File diff suppressed because it is too large
Load Diff
808
.venv/lib/python3.10/site-packages/litellm/exceptions.py
Normal file
808
.venv/lib/python3.10/site-packages/litellm/exceptions.py
Normal file
@@ -0,0 +1,808 @@
|
||||
# +-----------------------------------------------+
|
||||
# | |
|
||||
# | Give Feedback / Get Help |
|
||||
# | https://github.com/BerriAI/litellm/issues/new |
|
||||
# | |
|
||||
# +-----------------------------------------------+
|
||||
#
|
||||
# Thank you users! We ❤️ you! - Krrish & Ishaan
|
||||
|
||||
## LiteLLM versions of the OpenAI Exception Types
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import httpx
|
||||
import openai
|
||||
|
||||
from litellm.types.utils import LiteLLMCommonStrings
|
||||
|
||||
|
||||
class AuthenticationError(openai.AuthenticationError): # type: ignore
|
||||
def __init__(
|
||||
self,
|
||||
message,
|
||||
llm_provider,
|
||||
model,
|
||||
response: Optional[httpx.Response] = None,
|
||||
litellm_debug_info: Optional[str] = None,
|
||||
max_retries: Optional[int] = None,
|
||||
num_retries: Optional[int] = None,
|
||||
):
|
||||
self.status_code = 401
|
||||
self.message = "litellm.AuthenticationError: {}".format(message)
|
||||
self.llm_provider = llm_provider
|
||||
self.model = model
|
||||
self.litellm_debug_info = litellm_debug_info
|
||||
self.max_retries = max_retries
|
||||
self.num_retries = num_retries
|
||||
self.response = response or httpx.Response(
|
||||
status_code=self.status_code,
|
||||
request=httpx.Request(
|
||||
method="GET", url="https://litellm.ai"
|
||||
), # mock request object
|
||||
)
|
||||
super().__init__(
|
||||
self.message, response=self.response, body=None
|
||||
) # Call the base class constructor with the parameters it needs
|
||||
|
||||
def __str__(self):
|
||||
_message = self.message
|
||||
if self.num_retries:
|
||||
_message += f" LiteLLM Retried: {self.num_retries} times"
|
||||
if self.max_retries:
|
||||
_message += f", LiteLLM Max Retries: {self.max_retries}"
|
||||
return _message
|
||||
|
||||
def __repr__(self):
|
||||
_message = self.message
|
||||
if self.num_retries:
|
||||
_message += f" LiteLLM Retried: {self.num_retries} times"
|
||||
if self.max_retries:
|
||||
_message += f", LiteLLM Max Retries: {self.max_retries}"
|
||||
return _message
|
||||
|
||||
|
||||
# raise when invalid models passed, example gpt-8
|
||||
class NotFoundError(openai.NotFoundError): # type: ignore
|
||||
def __init__(
|
||||
self,
|
||||
message,
|
||||
model,
|
||||
llm_provider,
|
||||
response: Optional[httpx.Response] = None,
|
||||
litellm_debug_info: Optional[str] = None,
|
||||
max_retries: Optional[int] = None,
|
||||
num_retries: Optional[int] = None,
|
||||
):
|
||||
self.status_code = 404
|
||||
self.message = "litellm.NotFoundError: {}".format(message)
|
||||
self.model = model
|
||||
self.llm_provider = llm_provider
|
||||
self.litellm_debug_info = litellm_debug_info
|
||||
self.max_retries = max_retries
|
||||
self.num_retries = num_retries
|
||||
self.response = response or httpx.Response(
|
||||
status_code=self.status_code,
|
||||
request=httpx.Request(
|
||||
method="GET", url="https://litellm.ai"
|
||||
), # mock request object
|
||||
)
|
||||
super().__init__(
|
||||
self.message, response=self.response, body=None
|
||||
) # Call the base class constructor with the parameters it needs
|
||||
|
||||
def __str__(self):
|
||||
_message = self.message
|
||||
if self.num_retries:
|
||||
_message += f" LiteLLM Retried: {self.num_retries} times"
|
||||
if self.max_retries:
|
||||
_message += f", LiteLLM Max Retries: {self.max_retries}"
|
||||
return _message
|
||||
|
||||
def __repr__(self):
|
||||
_message = self.message
|
||||
if self.num_retries:
|
||||
_message += f" LiteLLM Retried: {self.num_retries} times"
|
||||
if self.max_retries:
|
||||
_message += f", LiteLLM Max Retries: {self.max_retries}"
|
||||
return _message
|
||||
|
||||
|
||||
class BadRequestError(openai.BadRequestError): # type: ignore
|
||||
def __init__(
|
||||
self,
|
||||
message,
|
||||
model,
|
||||
llm_provider,
|
||||
response: Optional[httpx.Response] = None,
|
||||
litellm_debug_info: Optional[str] = None,
|
||||
max_retries: Optional[int] = None,
|
||||
num_retries: Optional[int] = None,
|
||||
body: Optional[dict] = None,
|
||||
):
|
||||
self.status_code = 400
|
||||
self.message = "litellm.BadRequestError: {}".format(message)
|
||||
self.model = model
|
||||
self.llm_provider = llm_provider
|
||||
self.litellm_debug_info = litellm_debug_info
|
||||
response = httpx.Response(
|
||||
status_code=self.status_code,
|
||||
request=httpx.Request(
|
||||
method="GET", url="https://litellm.ai"
|
||||
), # mock request object
|
||||
)
|
||||
self.max_retries = max_retries
|
||||
self.num_retries = num_retries
|
||||
super().__init__(
|
||||
self.message, response=response, body=body
|
||||
) # Call the base class constructor with the parameters it needs
|
||||
|
||||
def __str__(self):
|
||||
_message = self.message
|
||||
if self.num_retries:
|
||||
_message += f" LiteLLM Retried: {self.num_retries} times"
|
||||
if self.max_retries:
|
||||
_message += f", LiteLLM Max Retries: {self.max_retries}"
|
||||
return _message
|
||||
|
||||
def __repr__(self):
|
||||
_message = self.message
|
||||
if self.num_retries:
|
||||
_message += f" LiteLLM Retried: {self.num_retries} times"
|
||||
if self.max_retries:
|
||||
_message += f", LiteLLM Max Retries: {self.max_retries}"
|
||||
return _message
|
||||
|
||||
|
||||
class UnprocessableEntityError(openai.UnprocessableEntityError): # type: ignore
|
||||
def __init__(
|
||||
self,
|
||||
message,
|
||||
model,
|
||||
llm_provider,
|
||||
response: httpx.Response,
|
||||
litellm_debug_info: Optional[str] = None,
|
||||
max_retries: Optional[int] = None,
|
||||
num_retries: Optional[int] = None,
|
||||
):
|
||||
self.status_code = 422
|
||||
self.message = "litellm.UnprocessableEntityError: {}".format(message)
|
||||
self.model = model
|
||||
self.llm_provider = llm_provider
|
||||
self.litellm_debug_info = litellm_debug_info
|
||||
self.max_retries = max_retries
|
||||
self.num_retries = num_retries
|
||||
super().__init__(
|
||||
self.message, response=response, body=None
|
||||
) # Call the base class constructor with the parameters it needs
|
||||
|
||||
def __str__(self):
|
||||
_message = self.message
|
||||
if self.num_retries:
|
||||
_message += f" LiteLLM Retried: {self.num_retries} times"
|
||||
if self.max_retries:
|
||||
_message += f", LiteLLM Max Retries: {self.max_retries}"
|
||||
return _message
|
||||
|
||||
def __repr__(self):
|
||||
_message = self.message
|
||||
if self.num_retries:
|
||||
_message += f" LiteLLM Retried: {self.num_retries} times"
|
||||
if self.max_retries:
|
||||
_message += f", LiteLLM Max Retries: {self.max_retries}"
|
||||
return _message
|
||||
|
||||
|
||||
class Timeout(openai.APITimeoutError): # type: ignore
|
||||
def __init__(
|
||||
self,
|
||||
message,
|
||||
model,
|
||||
llm_provider,
|
||||
litellm_debug_info: Optional[str] = None,
|
||||
max_retries: Optional[int] = None,
|
||||
num_retries: Optional[int] = None,
|
||||
headers: Optional[dict] = None,
|
||||
):
|
||||
request = httpx.Request(
|
||||
method="POST",
|
||||
url="https://api.openai.com/v1",
|
||||
)
|
||||
super().__init__(
|
||||
request=request
|
||||
) # Call the base class constructor with the parameters it needs
|
||||
self.status_code = 408
|
||||
self.message = "litellm.Timeout: {}".format(message)
|
||||
self.model = model
|
||||
self.llm_provider = llm_provider
|
||||
self.litellm_debug_info = litellm_debug_info
|
||||
self.max_retries = max_retries
|
||||
self.num_retries = num_retries
|
||||
self.headers = headers
|
||||
|
||||
# custom function to convert to str
|
||||
def __str__(self):
|
||||
_message = self.message
|
||||
if self.num_retries:
|
||||
_message += f" LiteLLM Retried: {self.num_retries} times"
|
||||
if self.max_retries:
|
||||
_message += f", LiteLLM Max Retries: {self.max_retries}"
|
||||
return _message
|
||||
|
||||
def __repr__(self):
|
||||
_message = self.message
|
||||
if self.num_retries:
|
||||
_message += f" LiteLLM Retried: {self.num_retries} times"
|
||||
if self.max_retries:
|
||||
_message += f", LiteLLM Max Retries: {self.max_retries}"
|
||||
return _message
|
||||
|
||||
|
||||
class PermissionDeniedError(openai.PermissionDeniedError): # type:ignore
|
||||
def __init__(
|
||||
self,
|
||||
message,
|
||||
llm_provider,
|
||||
model,
|
||||
response: httpx.Response,
|
||||
litellm_debug_info: Optional[str] = None,
|
||||
max_retries: Optional[int] = None,
|
||||
num_retries: Optional[int] = None,
|
||||
):
|
||||
self.status_code = 403
|
||||
self.message = "litellm.PermissionDeniedError: {}".format(message)
|
||||
self.llm_provider = llm_provider
|
||||
self.model = model
|
||||
self.litellm_debug_info = litellm_debug_info
|
||||
self.max_retries = max_retries
|
||||
self.num_retries = num_retries
|
||||
super().__init__(
|
||||
self.message, response=response, body=None
|
||||
) # Call the base class constructor with the parameters it needs
|
||||
|
||||
def __str__(self):
|
||||
_message = self.message
|
||||
if self.num_retries:
|
||||
_message += f" LiteLLM Retried: {self.num_retries} times"
|
||||
if self.max_retries:
|
||||
_message += f", LiteLLM Max Retries: {self.max_retries}"
|
||||
return _message
|
||||
|
||||
def __repr__(self):
|
||||
_message = self.message
|
||||
if self.num_retries:
|
||||
_message += f" LiteLLM Retried: {self.num_retries} times"
|
||||
if self.max_retries:
|
||||
_message += f", LiteLLM Max Retries: {self.max_retries}"
|
||||
return _message
|
||||
|
||||
|
||||
class RateLimitError(openai.RateLimitError): # type: ignore
|
||||
def __init__(
|
||||
self,
|
||||
message,
|
||||
llm_provider,
|
||||
model,
|
||||
response: Optional[httpx.Response] = None,
|
||||
litellm_debug_info: Optional[str] = None,
|
||||
max_retries: Optional[int] = None,
|
||||
num_retries: Optional[int] = None,
|
||||
):
|
||||
self.status_code = 429
|
||||
self.message = "litellm.RateLimitError: {}".format(message)
|
||||
self.llm_provider = llm_provider
|
||||
self.model = model
|
||||
self.litellm_debug_info = litellm_debug_info
|
||||
self.max_retries = max_retries
|
||||
self.num_retries = num_retries
|
||||
_response_headers = (
|
||||
getattr(response, "headers", None) if response is not None else None
|
||||
)
|
||||
self.response = httpx.Response(
|
||||
status_code=429,
|
||||
headers=_response_headers,
|
||||
request=httpx.Request(
|
||||
method="POST",
|
||||
url=" https://cloud.google.com/vertex-ai/",
|
||||
),
|
||||
)
|
||||
super().__init__(
|
||||
self.message, response=self.response, body=None
|
||||
) # Call the base class constructor with the parameters it needs
|
||||
self.code = "429"
|
||||
self.type = "throttling_error"
|
||||
|
||||
def __str__(self):
|
||||
_message = self.message
|
||||
if self.num_retries:
|
||||
_message += f" LiteLLM Retried: {self.num_retries} times"
|
||||
if self.max_retries:
|
||||
_message += f", LiteLLM Max Retries: {self.max_retries}"
|
||||
return _message
|
||||
|
||||
def __repr__(self):
|
||||
_message = self.message
|
||||
if self.num_retries:
|
||||
_message += f" LiteLLM Retried: {self.num_retries} times"
|
||||
if self.max_retries:
|
||||
_message += f", LiteLLM Max Retries: {self.max_retries}"
|
||||
return _message
|
||||
|
||||
|
||||
# sub class of rate limit error - meant to give more granularity for error handling context window exceeded errors
|
||||
class ContextWindowExceededError(BadRequestError): # type: ignore
|
||||
def __init__(
|
||||
self,
|
||||
message,
|
||||
model,
|
||||
llm_provider,
|
||||
response: Optional[httpx.Response] = None,
|
||||
litellm_debug_info: Optional[str] = None,
|
||||
):
|
||||
self.status_code = 400
|
||||
self.model = model
|
||||
self.llm_provider = llm_provider
|
||||
self.litellm_debug_info = litellm_debug_info
|
||||
request = httpx.Request(method="POST", url="https://api.openai.com/v1")
|
||||
self.response = httpx.Response(status_code=400, request=request)
|
||||
super().__init__(
|
||||
message=message,
|
||||
model=self.model, # type: ignore
|
||||
llm_provider=self.llm_provider, # type: ignore
|
||||
response=self.response,
|
||||
litellm_debug_info=self.litellm_debug_info,
|
||||
) # Call the base class constructor with the parameters it needs
|
||||
|
||||
# set after, to make it clear the raised error is a context window exceeded error
|
||||
self.message = "litellm.ContextWindowExceededError: {}".format(self.message)
|
||||
|
||||
def __str__(self):
|
||||
_message = self.message
|
||||
if self.num_retries:
|
||||
_message += f" LiteLLM Retried: {self.num_retries} times"
|
||||
if self.max_retries:
|
||||
_message += f", LiteLLM Max Retries: {self.max_retries}"
|
||||
return _message
|
||||
|
||||
def __repr__(self):
|
||||
_message = self.message
|
||||
if self.num_retries:
|
||||
_message += f" LiteLLM Retried: {self.num_retries} times"
|
||||
if self.max_retries:
|
||||
_message += f", LiteLLM Max Retries: {self.max_retries}"
|
||||
return _message
|
||||
|
||||
|
||||
# sub class of bad request error - meant to help us catch guardrails-related errors on proxy.
|
||||
class RejectedRequestError(BadRequestError): # type: ignore
|
||||
def __init__(
|
||||
self,
|
||||
message,
|
||||
model,
|
||||
llm_provider,
|
||||
request_data: dict,
|
||||
litellm_debug_info: Optional[str] = None,
|
||||
):
|
||||
self.status_code = 400
|
||||
self.message = "litellm.RejectedRequestError: {}".format(message)
|
||||
self.model = model
|
||||
self.llm_provider = llm_provider
|
||||
self.litellm_debug_info = litellm_debug_info
|
||||
self.request_data = request_data
|
||||
request = httpx.Request(method="POST", url="https://api.openai.com/v1")
|
||||
response = httpx.Response(status_code=400, request=request)
|
||||
super().__init__(
|
||||
message=self.message,
|
||||
model=self.model, # type: ignore
|
||||
llm_provider=self.llm_provider, # type: ignore
|
||||
response=response,
|
||||
litellm_debug_info=self.litellm_debug_info,
|
||||
) # Call the base class constructor with the parameters it needs
|
||||
|
||||
def __str__(self):
|
||||
_message = self.message
|
||||
if self.num_retries:
|
||||
_message += f" LiteLLM Retried: {self.num_retries} times"
|
||||
if self.max_retries:
|
||||
_message += f", LiteLLM Max Retries: {self.max_retries}"
|
||||
return _message
|
||||
|
||||
def __repr__(self):
|
||||
_message = self.message
|
||||
if self.num_retries:
|
||||
_message += f" LiteLLM Retried: {self.num_retries} times"
|
||||
if self.max_retries:
|
||||
_message += f", LiteLLM Max Retries: {self.max_retries}"
|
||||
return _message
|
||||
|
||||
|
||||
class ContentPolicyViolationError(BadRequestError): # type: ignore
|
||||
# Error code: 400 - {'error': {'code': 'content_policy_violation', 'message': 'Your request was rejected as a result of our safety system. Image descriptions generated from your prompt may contain text that is not allowed by our safety system. If you believe this was done in error, your request may succeed if retried, or by adjusting your prompt.', 'param': None, 'type': 'invalid_request_error'}}
|
||||
def __init__(
|
||||
self,
|
||||
message,
|
||||
model,
|
||||
llm_provider,
|
||||
response: Optional[httpx.Response] = None,
|
||||
litellm_debug_info: Optional[str] = None,
|
||||
):
|
||||
self.status_code = 400
|
||||
self.message = "litellm.ContentPolicyViolationError: {}".format(message)
|
||||
self.model = model
|
||||
self.llm_provider = llm_provider
|
||||
self.litellm_debug_info = litellm_debug_info
|
||||
request = httpx.Request(method="POST", url="https://api.openai.com/v1")
|
||||
self.response = httpx.Response(status_code=400, request=request)
|
||||
super().__init__(
|
||||
message=self.message,
|
||||
model=self.model, # type: ignore
|
||||
llm_provider=self.llm_provider, # type: ignore
|
||||
response=self.response,
|
||||
litellm_debug_info=self.litellm_debug_info,
|
||||
) # Call the base class constructor with the parameters it needs
|
||||
|
||||
def __str__(self):
|
||||
_message = self.message
|
||||
if self.num_retries:
|
||||
_message += f" LiteLLM Retried: {self.num_retries} times"
|
||||
if self.max_retries:
|
||||
_message += f", LiteLLM Max Retries: {self.max_retries}"
|
||||
return _message
|
||||
|
||||
def __repr__(self):
|
||||
_message = self.message
|
||||
if self.num_retries:
|
||||
_message += f" LiteLLM Retried: {self.num_retries} times"
|
||||
if self.max_retries:
|
||||
_message += f", LiteLLM Max Retries: {self.max_retries}"
|
||||
return _message
|
||||
|
||||
|
||||
class ServiceUnavailableError(openai.APIStatusError): # type: ignore
|
||||
def __init__(
|
||||
self,
|
||||
message,
|
||||
llm_provider,
|
||||
model,
|
||||
response: Optional[httpx.Response] = None,
|
||||
litellm_debug_info: Optional[str] = None,
|
||||
max_retries: Optional[int] = None,
|
||||
num_retries: Optional[int] = None,
|
||||
):
|
||||
self.status_code = 503
|
||||
self.message = "litellm.ServiceUnavailableError: {}".format(message)
|
||||
self.llm_provider = llm_provider
|
||||
self.model = model
|
||||
self.litellm_debug_info = litellm_debug_info
|
||||
self.max_retries = max_retries
|
||||
self.num_retries = num_retries
|
||||
self.response = httpx.Response(
|
||||
status_code=self.status_code,
|
||||
request=httpx.Request(
|
||||
method="POST",
|
||||
url=" https://cloud.google.com/vertex-ai/",
|
||||
),
|
||||
)
|
||||
super().__init__(
|
||||
self.message, response=self.response, body=None
|
||||
) # Call the base class constructor with the parameters it needs
|
||||
|
||||
def __str__(self):
|
||||
_message = self.message
|
||||
if self.num_retries:
|
||||
_message += f" LiteLLM Retried: {self.num_retries} times"
|
||||
if self.max_retries:
|
||||
_message += f", LiteLLM Max Retries: {self.max_retries}"
|
||||
return _message
|
||||
|
||||
def __repr__(self):
|
||||
_message = self.message
|
||||
if self.num_retries:
|
||||
_message += f" LiteLLM Retried: {self.num_retries} times"
|
||||
if self.max_retries:
|
||||
_message += f", LiteLLM Max Retries: {self.max_retries}"
|
||||
return _message
|
||||
|
||||
|
||||
class InternalServerError(openai.InternalServerError): # type: ignore
|
||||
def __init__(
|
||||
self,
|
||||
message,
|
||||
llm_provider,
|
||||
model,
|
||||
response: Optional[httpx.Response] = None,
|
||||
litellm_debug_info: Optional[str] = None,
|
||||
max_retries: Optional[int] = None,
|
||||
num_retries: Optional[int] = None,
|
||||
):
|
||||
self.status_code = 500
|
||||
self.message = "litellm.InternalServerError: {}".format(message)
|
||||
self.llm_provider = llm_provider
|
||||
self.model = model
|
||||
self.litellm_debug_info = litellm_debug_info
|
||||
self.max_retries = max_retries
|
||||
self.num_retries = num_retries
|
||||
self.response = httpx.Response(
|
||||
status_code=self.status_code,
|
||||
request=httpx.Request(
|
||||
method="POST",
|
||||
url=" https://cloud.google.com/vertex-ai/",
|
||||
),
|
||||
)
|
||||
super().__init__(
|
||||
self.message, response=self.response, body=None
|
||||
) # Call the base class constructor with the parameters it needs
|
||||
|
||||
def __str__(self):
|
||||
_message = self.message
|
||||
if self.num_retries:
|
||||
_message += f" LiteLLM Retried: {self.num_retries} times"
|
||||
if self.max_retries:
|
||||
_message += f", LiteLLM Max Retries: {self.max_retries}"
|
||||
return _message
|
||||
|
||||
def __repr__(self):
|
||||
_message = self.message
|
||||
if self.num_retries:
|
||||
_message += f" LiteLLM Retried: {self.num_retries} times"
|
||||
if self.max_retries:
|
||||
_message += f", LiteLLM Max Retries: {self.max_retries}"
|
||||
return _message
|
||||
|
||||
|
||||
# raise this when the API returns an invalid response object - https://github.com/openai/openai-python/blob/1be14ee34a0f8e42d3f9aa5451aa4cb161f1781f/openai/api_requestor.py#L401
|
||||
class APIError(openai.APIError): # type: ignore
|
||||
def __init__(
|
||||
self,
|
||||
status_code: int,
|
||||
message,
|
||||
llm_provider,
|
||||
model,
|
||||
request: Optional[httpx.Request] = None,
|
||||
litellm_debug_info: Optional[str] = None,
|
||||
max_retries: Optional[int] = None,
|
||||
num_retries: Optional[int] = None,
|
||||
):
|
||||
self.status_code = status_code
|
||||
self.message = "litellm.APIError: {}".format(message)
|
||||
self.llm_provider = llm_provider
|
||||
self.model = model
|
||||
self.litellm_debug_info = litellm_debug_info
|
||||
self.max_retries = max_retries
|
||||
self.num_retries = num_retries
|
||||
if request is None:
|
||||
request = httpx.Request(method="POST", url="https://api.openai.com/v1")
|
||||
super().__init__(self.message, request=request, body=None) # type: ignore
|
||||
|
||||
def __str__(self):
|
||||
_message = self.message
|
||||
if self.num_retries:
|
||||
_message += f" LiteLLM Retried: {self.num_retries} times"
|
||||
if self.max_retries:
|
||||
_message += f", LiteLLM Max Retries: {self.max_retries}"
|
||||
return _message
|
||||
|
||||
def __repr__(self):
|
||||
_message = self.message
|
||||
if self.num_retries:
|
||||
_message += f" LiteLLM Retried: {self.num_retries} times"
|
||||
if self.max_retries:
|
||||
_message += f", LiteLLM Max Retries: {self.max_retries}"
|
||||
return _message
|
||||
|
||||
|
||||
# raised if an invalid request (not get, delete, put, post) is made
|
||||
class APIConnectionError(openai.APIConnectionError): # type: ignore
|
||||
def __init__(
|
||||
self,
|
||||
message,
|
||||
llm_provider,
|
||||
model,
|
||||
request: Optional[httpx.Request] = None,
|
||||
litellm_debug_info: Optional[str] = None,
|
||||
max_retries: Optional[int] = None,
|
||||
num_retries: Optional[int] = None,
|
||||
):
|
||||
self.message = "litellm.APIConnectionError: {}".format(message)
|
||||
self.llm_provider = llm_provider
|
||||
self.model = model
|
||||
self.status_code = 500
|
||||
self.litellm_debug_info = litellm_debug_info
|
||||
self.request = httpx.Request(method="POST", url="https://api.openai.com/v1")
|
||||
self.max_retries = max_retries
|
||||
self.num_retries = num_retries
|
||||
super().__init__(message=self.message, request=self.request)
|
||||
|
||||
def __str__(self):
|
||||
_message = self.message
|
||||
if self.num_retries:
|
||||
_message += f" LiteLLM Retried: {self.num_retries} times"
|
||||
if self.max_retries:
|
||||
_message += f", LiteLLM Max Retries: {self.max_retries}"
|
||||
return _message
|
||||
|
||||
def __repr__(self):
|
||||
_message = self.message
|
||||
if self.num_retries:
|
||||
_message += f" LiteLLM Retried: {self.num_retries} times"
|
||||
if self.max_retries:
|
||||
_message += f", LiteLLM Max Retries: {self.max_retries}"
|
||||
return _message
|
||||
|
||||
|
||||
# raised if an invalid request (not get, delete, put, post) is made
|
||||
class APIResponseValidationError(openai.APIResponseValidationError): # type: ignore
|
||||
def __init__(
|
||||
self,
|
||||
message,
|
||||
llm_provider,
|
||||
model,
|
||||
litellm_debug_info: Optional[str] = None,
|
||||
max_retries: Optional[int] = None,
|
||||
num_retries: Optional[int] = None,
|
||||
):
|
||||
self.message = "litellm.APIResponseValidationError: {}".format(message)
|
||||
self.llm_provider = llm_provider
|
||||
self.model = model
|
||||
request = httpx.Request(method="POST", url="https://api.openai.com/v1")
|
||||
response = httpx.Response(status_code=500, request=request)
|
||||
self.litellm_debug_info = litellm_debug_info
|
||||
self.max_retries = max_retries
|
||||
self.num_retries = num_retries
|
||||
super().__init__(response=response, body=None, message=message)
|
||||
|
||||
def __str__(self):
|
||||
_message = self.message
|
||||
if self.num_retries:
|
||||
_message += f" LiteLLM Retried: {self.num_retries} times"
|
||||
if self.max_retries:
|
||||
_message += f", LiteLLM Max Retries: {self.max_retries}"
|
||||
return _message
|
||||
|
||||
def __repr__(self):
|
||||
_message = self.message
|
||||
if self.num_retries:
|
||||
_message += f" LiteLLM Retried: {self.num_retries} times"
|
||||
if self.max_retries:
|
||||
_message += f", LiteLLM Max Retries: {self.max_retries}"
|
||||
return _message
|
||||
|
||||
|
||||
class JSONSchemaValidationError(APIResponseValidationError):
|
||||
def __init__(
|
||||
self, model: str, llm_provider: str, raw_response: str, schema: str
|
||||
) -> None:
|
||||
self.raw_response = raw_response
|
||||
self.schema = schema
|
||||
self.model = model
|
||||
message = "litellm.JSONSchemaValidationError: model={}, returned an invalid response={}, for schema={}.\nAccess raw response with `e.raw_response`".format(
|
||||
model, raw_response, schema
|
||||
)
|
||||
self.message = message
|
||||
super().__init__(model=model, message=message, llm_provider=llm_provider)
|
||||
|
||||
|
||||
class OpenAIError(openai.OpenAIError): # type: ignore
|
||||
def __init__(self, original_exception=None):
|
||||
super().__init__()
|
||||
self.llm_provider = "openai"
|
||||
|
||||
|
||||
class UnsupportedParamsError(BadRequestError):
|
||||
def __init__(
|
||||
self,
|
||||
message,
|
||||
llm_provider: Optional[str] = None,
|
||||
model: Optional[str] = None,
|
||||
status_code: int = 400,
|
||||
response: Optional[httpx.Response] = None,
|
||||
litellm_debug_info: Optional[str] = None,
|
||||
max_retries: Optional[int] = None,
|
||||
num_retries: Optional[int] = None,
|
||||
):
|
||||
self.status_code = 400
|
||||
self.message = "litellm.UnsupportedParamsError: {}".format(message)
|
||||
self.model = model
|
||||
self.llm_provider = llm_provider
|
||||
self.litellm_debug_info = litellm_debug_info
|
||||
response = response or httpx.Response(
|
||||
status_code=self.status_code,
|
||||
request=httpx.Request(
|
||||
method="GET", url="https://litellm.ai"
|
||||
), # mock request object
|
||||
)
|
||||
self.max_retries = max_retries
|
||||
self.num_retries = num_retries
|
||||
|
||||
|
||||
LITELLM_EXCEPTION_TYPES = [
|
||||
AuthenticationError,
|
||||
NotFoundError,
|
||||
BadRequestError,
|
||||
UnprocessableEntityError,
|
||||
UnsupportedParamsError,
|
||||
Timeout,
|
||||
PermissionDeniedError,
|
||||
RateLimitError,
|
||||
ContextWindowExceededError,
|
||||
RejectedRequestError,
|
||||
ContentPolicyViolationError,
|
||||
InternalServerError,
|
||||
ServiceUnavailableError,
|
||||
APIError,
|
||||
APIConnectionError,
|
||||
APIResponseValidationError,
|
||||
OpenAIError,
|
||||
InternalServerError,
|
||||
JSONSchemaValidationError,
|
||||
]
|
||||
|
||||
|
||||
class BudgetExceededError(Exception):
|
||||
def __init__(
|
||||
self, current_cost: float, max_budget: float, message: Optional[str] = None
|
||||
):
|
||||
self.current_cost = current_cost
|
||||
self.max_budget = max_budget
|
||||
message = (
|
||||
message
|
||||
or f"Budget has been exceeded! Current cost: {current_cost}, Max budget: {max_budget}"
|
||||
)
|
||||
self.message = message
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
## DEPRECATED ##
|
||||
class InvalidRequestError(openai.BadRequestError): # type: ignore
|
||||
def __init__(self, message, model, llm_provider):
|
||||
self.status_code = 400
|
||||
self.message = message
|
||||
self.model = model
|
||||
self.llm_provider = llm_provider
|
||||
self.response = httpx.Response(
|
||||
status_code=400,
|
||||
request=httpx.Request(
|
||||
method="GET", url="https://litellm.ai"
|
||||
), # mock request object
|
||||
)
|
||||
super().__init__(
|
||||
message=self.message, response=self.response, body=None
|
||||
) # Call the base class constructor with the parameters it needs
|
||||
|
||||
|
||||
class MockException(openai.APIError):
|
||||
# used for testing
|
||||
def __init__(
|
||||
self,
|
||||
status_code: int,
|
||||
message,
|
||||
llm_provider,
|
||||
model,
|
||||
request: Optional[httpx.Request] = None,
|
||||
litellm_debug_info: Optional[str] = None,
|
||||
max_retries: Optional[int] = None,
|
||||
num_retries: Optional[int] = None,
|
||||
):
|
||||
self.status_code = status_code
|
||||
self.message = "litellm.MockException: {}".format(message)
|
||||
self.llm_provider = llm_provider
|
||||
self.model = model
|
||||
self.litellm_debug_info = litellm_debug_info
|
||||
self.max_retries = max_retries
|
||||
self.num_retries = num_retries
|
||||
if request is None:
|
||||
request = httpx.Request(method="POST", url="https://api.openai.com/v1")
|
||||
super().__init__(self.message, request=request, body=None) # type: ignore
|
||||
|
||||
|
||||
class LiteLLMUnknownProvider(BadRequestError):
|
||||
def __init__(self, model: str, custom_llm_provider: Optional[str] = None):
|
||||
self.message = LiteLLMCommonStrings.llm_provider_not_provided.value.format(
|
||||
model=model, custom_llm_provider=custom_llm_provider
|
||||
)
|
||||
super().__init__(
|
||||
self.message, model=model, llm_provider=custom_llm_provider, response=None
|
||||
)
|
||||
|
||||
def __str__(self):
|
||||
return self.message
|
||||
@@ -0,0 +1,6 @@
|
||||
# LiteLLM MCP Client
|
||||
|
||||
LiteLLM MCP Client is a client that allows you to use MCP tools with LiteLLM.
|
||||
|
||||
|
||||
|
||||
@@ -0,0 +1,3 @@
|
||||
from .tools import call_openai_tool, load_mcp_tools
|
||||
|
||||
__all__ = ["load_mcp_tools", "call_openai_tool"]
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,111 @@
|
||||
import json
|
||||
from typing import Dict, List, Literal, Union
|
||||
|
||||
from mcp import ClientSession
|
||||
from mcp.types import CallToolRequestParams as MCPCallToolRequestParams
|
||||
from mcp.types import CallToolResult as MCPCallToolResult
|
||||
from mcp.types import Tool as MCPTool
|
||||
from openai.types.chat import ChatCompletionToolParam
|
||||
from openai.types.shared_params.function_definition import FunctionDefinition
|
||||
|
||||
from litellm.types.utils import ChatCompletionMessageToolCall
|
||||
|
||||
|
||||
########################################################
|
||||
# List MCP Tool functions
|
||||
########################################################
|
||||
def transform_mcp_tool_to_openai_tool(mcp_tool: MCPTool) -> ChatCompletionToolParam:
|
||||
"""Convert an MCP tool to an OpenAI tool."""
|
||||
return ChatCompletionToolParam(
|
||||
type="function",
|
||||
function=FunctionDefinition(
|
||||
name=mcp_tool.name,
|
||||
description=mcp_tool.description or "",
|
||||
parameters=mcp_tool.inputSchema,
|
||||
strict=False,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
async def load_mcp_tools(
|
||||
session: ClientSession, format: Literal["mcp", "openai"] = "mcp"
|
||||
) -> Union[List[MCPTool], List[ChatCompletionToolParam]]:
|
||||
"""
|
||||
Load all available MCP tools
|
||||
|
||||
Args:
|
||||
session: The MCP session to use
|
||||
format: The format to convert the tools to
|
||||
By default, the tools are returned in MCP format.
|
||||
|
||||
If format is set to "openai", the tools are converted to OpenAI API compatible tools.
|
||||
"""
|
||||
tools = await session.list_tools()
|
||||
if format == "openai":
|
||||
return [
|
||||
transform_mcp_tool_to_openai_tool(mcp_tool=tool) for tool in tools.tools
|
||||
]
|
||||
return tools.tools
|
||||
|
||||
|
||||
########################################################
|
||||
# Call MCP Tool functions
|
||||
########################################################
|
||||
|
||||
|
||||
async def call_mcp_tool(
|
||||
session: ClientSession,
|
||||
call_tool_request_params: MCPCallToolRequestParams,
|
||||
) -> MCPCallToolResult:
|
||||
"""Call an MCP tool."""
|
||||
tool_result = await session.call_tool(
|
||||
name=call_tool_request_params.name,
|
||||
arguments=call_tool_request_params.arguments,
|
||||
)
|
||||
return tool_result
|
||||
|
||||
|
||||
def _get_function_arguments(function: FunctionDefinition) -> dict:
|
||||
"""Helper to safely get and parse function arguments."""
|
||||
arguments = function.get("arguments", {})
|
||||
if isinstance(arguments, str):
|
||||
try:
|
||||
arguments = json.loads(arguments)
|
||||
except json.JSONDecodeError:
|
||||
arguments = {}
|
||||
return arguments if isinstance(arguments, dict) else {}
|
||||
|
||||
|
||||
def transform_openai_tool_call_request_to_mcp_tool_call_request(
|
||||
openai_tool: Union[ChatCompletionMessageToolCall, Dict],
|
||||
) -> MCPCallToolRequestParams:
|
||||
"""Convert an OpenAI ChatCompletionMessageToolCall to an MCP CallToolRequestParams."""
|
||||
function = openai_tool["function"]
|
||||
return MCPCallToolRequestParams(
|
||||
name=function["name"],
|
||||
arguments=_get_function_arguments(function),
|
||||
)
|
||||
|
||||
|
||||
async def call_openai_tool(
|
||||
session: ClientSession,
|
||||
openai_tool: ChatCompletionMessageToolCall,
|
||||
) -> MCPCallToolResult:
|
||||
"""
|
||||
Call an OpenAI tool using MCP client.
|
||||
|
||||
Args:
|
||||
session: The MCP session to use
|
||||
openai_tool: The OpenAI tool to call. You can get this from the `choices[0].message.tool_calls[0]` of the response from the OpenAI API.
|
||||
Returns:
|
||||
The result of the MCP tool call.
|
||||
"""
|
||||
mcp_tool_call_request_params = (
|
||||
transform_openai_tool_call_request_to_mcp_tool_call_request(
|
||||
openai_tool=openai_tool,
|
||||
)
|
||||
)
|
||||
return await call_mcp_tool(
|
||||
session=session,
|
||||
call_tool_request_params=mcp_tool_call_request_params,
|
||||
)
|
||||
Binary file not shown.
886
.venv/lib/python3.10/site-packages/litellm/files/main.py
Normal file
886
.venv/lib/python3.10/site-packages/litellm/files/main.py
Normal file
@@ -0,0 +1,886 @@
|
||||
"""
|
||||
Main File for Files API implementation
|
||||
|
||||
https://platform.openai.com/docs/api-reference/files
|
||||
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import contextvars
|
||||
import os
|
||||
from functools import partial
|
||||
from typing import Any, Coroutine, Dict, Literal, Optional, Union, cast
|
||||
|
||||
import httpx
|
||||
|
||||
import litellm
|
||||
from litellm import get_secret_str
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
|
||||
from litellm.llms.azure.files.handler import AzureOpenAIFilesAPI
|
||||
from litellm.llms.custom_httpx.llm_http_handler import BaseLLMHTTPHandler
|
||||
from litellm.llms.openai.openai import FileDeleted, FileObject, OpenAIFilesAPI
|
||||
from litellm.llms.vertex_ai.files.handler import VertexAIFilesHandler
|
||||
from litellm.types.llms.openai import (
|
||||
CreateFileRequest,
|
||||
FileContentRequest,
|
||||
FileTypes,
|
||||
HttpxBinaryResponseContent,
|
||||
OpenAIFileObject,
|
||||
)
|
||||
from litellm.types.router import *
|
||||
from litellm.types.utils import LlmProviders
|
||||
from litellm.utils import (
|
||||
ProviderConfigManager,
|
||||
client,
|
||||
get_litellm_params,
|
||||
supports_httpx_timeout,
|
||||
)
|
||||
|
||||
base_llm_http_handler = BaseLLMHTTPHandler()
|
||||
|
||||
####### ENVIRONMENT VARIABLES ###################
|
||||
openai_files_instance = OpenAIFilesAPI()
|
||||
azure_files_instance = AzureOpenAIFilesAPI()
|
||||
vertex_ai_files_instance = VertexAIFilesHandler()
|
||||
#################################################
|
||||
|
||||
|
||||
@client
|
||||
async def acreate_file(
|
||||
file: FileTypes,
|
||||
purpose: Literal["assistants", "batch", "fine-tune"],
|
||||
custom_llm_provider: Literal["openai", "azure", "vertex_ai"] = "openai",
|
||||
extra_headers: Optional[Dict[str, str]] = None,
|
||||
extra_body: Optional[Dict[str, str]] = None,
|
||||
**kwargs,
|
||||
) -> OpenAIFileObject:
|
||||
"""
|
||||
Async: Files are used to upload documents that can be used with features like Assistants, Fine-tuning, and Batch API.
|
||||
|
||||
LiteLLM Equivalent of POST: POST https://api.openai.com/v1/files
|
||||
"""
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
kwargs["acreate_file"] = True
|
||||
|
||||
call_args = {
|
||||
"file": file,
|
||||
"purpose": purpose,
|
||||
"custom_llm_provider": custom_llm_provider,
|
||||
"extra_headers": extra_headers,
|
||||
"extra_body": extra_body,
|
||||
**kwargs,
|
||||
}
|
||||
|
||||
# Use a partial function to pass your keyword arguments
|
||||
func = partial(create_file, **call_args)
|
||||
|
||||
# Add the context to the function
|
||||
ctx = contextvars.copy_context()
|
||||
func_with_context = partial(ctx.run, func)
|
||||
init_response = await loop.run_in_executor(None, func_with_context)
|
||||
if asyncio.iscoroutine(init_response):
|
||||
response = await init_response
|
||||
else:
|
||||
response = init_response # type: ignore
|
||||
|
||||
return response
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
|
||||
@client
|
||||
def create_file(
|
||||
file: FileTypes,
|
||||
purpose: Literal["assistants", "batch", "fine-tune"],
|
||||
custom_llm_provider: Optional[Literal["openai", "azure", "vertex_ai"]] = None,
|
||||
extra_headers: Optional[Dict[str, str]] = None,
|
||||
extra_body: Optional[Dict[str, str]] = None,
|
||||
**kwargs,
|
||||
) -> Union[OpenAIFileObject, Coroutine[Any, Any, OpenAIFileObject]]:
|
||||
"""
|
||||
Files are used to upload documents that can be used with features like Assistants, Fine-tuning, and Batch API.
|
||||
|
||||
LiteLLM Equivalent of POST: POST https://api.openai.com/v1/files
|
||||
|
||||
Specify either provider_list or custom_llm_provider.
|
||||
"""
|
||||
try:
|
||||
_is_async = kwargs.pop("acreate_file", False) is True
|
||||
optional_params = GenericLiteLLMParams(**kwargs)
|
||||
litellm_params_dict = get_litellm_params(**kwargs)
|
||||
logging_obj = cast(
|
||||
Optional[LiteLLMLoggingObj], kwargs.get("litellm_logging_obj")
|
||||
)
|
||||
if logging_obj is None:
|
||||
raise ValueError("logging_obj is required")
|
||||
client = kwargs.get("client")
|
||||
|
||||
### TIMEOUT LOGIC ###
|
||||
timeout = optional_params.timeout or kwargs.get("request_timeout", 600) or 600
|
||||
# set timeout for 10 minutes by default
|
||||
|
||||
if (
|
||||
timeout is not None
|
||||
and isinstance(timeout, httpx.Timeout)
|
||||
and supports_httpx_timeout(cast(str, custom_llm_provider)) is False
|
||||
):
|
||||
read_timeout = timeout.read or 600
|
||||
timeout = read_timeout # default 10 min timeout
|
||||
elif timeout is not None and not isinstance(timeout, httpx.Timeout):
|
||||
timeout = float(timeout) # type: ignore
|
||||
elif timeout is None:
|
||||
timeout = 600.0
|
||||
|
||||
_create_file_request = CreateFileRequest(
|
||||
file=file,
|
||||
purpose=purpose,
|
||||
extra_headers=extra_headers,
|
||||
extra_body=extra_body,
|
||||
)
|
||||
|
||||
provider_config = ProviderConfigManager.get_provider_files_config(
|
||||
model="",
|
||||
provider=LlmProviders(custom_llm_provider),
|
||||
)
|
||||
if provider_config is not None:
|
||||
response = base_llm_http_handler.create_file(
|
||||
provider_config=provider_config,
|
||||
litellm_params=litellm_params_dict,
|
||||
create_file_data=_create_file_request,
|
||||
headers=extra_headers or {},
|
||||
api_base=optional_params.api_base,
|
||||
api_key=optional_params.api_key,
|
||||
logging_obj=logging_obj,
|
||||
_is_async=_is_async,
|
||||
client=client
|
||||
if client is not None
|
||||
and isinstance(client, (HTTPHandler, AsyncHTTPHandler))
|
||||
else None,
|
||||
timeout=timeout,
|
||||
)
|
||||
elif custom_llm_provider == "openai":
|
||||
# for deepinfra/perplexity/anyscale/groq we check in get_llm_provider and pass in the api base from there
|
||||
api_base = (
|
||||
optional_params.api_base
|
||||
or litellm.api_base
|
||||
or os.getenv("OPENAI_API_BASE")
|
||||
or "https://api.openai.com/v1"
|
||||
)
|
||||
organization = (
|
||||
optional_params.organization
|
||||
or litellm.organization
|
||||
or os.getenv("OPENAI_ORGANIZATION", None)
|
||||
or None # default - https://github.com/openai/openai-python/blob/284c1799070c723c6a553337134148a7ab088dd8/openai/util.py#L105
|
||||
)
|
||||
# set API KEY
|
||||
api_key = (
|
||||
optional_params.api_key
|
||||
or litellm.api_key # for deepinfra/perplexity/anyscale we check in get_llm_provider and pass in the api key from there
|
||||
or litellm.openai_key
|
||||
or os.getenv("OPENAI_API_KEY")
|
||||
)
|
||||
|
||||
response = openai_files_instance.create_file(
|
||||
_is_async=_is_async,
|
||||
api_base=api_base,
|
||||
api_key=api_key,
|
||||
timeout=timeout,
|
||||
max_retries=optional_params.max_retries,
|
||||
organization=organization,
|
||||
create_file_data=_create_file_request,
|
||||
)
|
||||
elif custom_llm_provider == "azure":
|
||||
api_base = optional_params.api_base or litellm.api_base or get_secret_str("AZURE_API_BASE") # type: ignore
|
||||
api_version = (
|
||||
optional_params.api_version
|
||||
or litellm.api_version
|
||||
or get_secret_str("AZURE_API_VERSION")
|
||||
) # type: ignore
|
||||
|
||||
api_key = (
|
||||
optional_params.api_key
|
||||
or litellm.api_key
|
||||
or litellm.azure_key
|
||||
or get_secret_str("AZURE_OPENAI_API_KEY")
|
||||
or get_secret_str("AZURE_API_KEY")
|
||||
) # type: ignore
|
||||
|
||||
extra_body = optional_params.get("extra_body", {})
|
||||
if extra_body is not None:
|
||||
extra_body.pop("azure_ad_token", None)
|
||||
else:
|
||||
get_secret_str("AZURE_AD_TOKEN") # type: ignore
|
||||
|
||||
response = azure_files_instance.create_file(
|
||||
_is_async=_is_async,
|
||||
api_base=api_base,
|
||||
api_key=api_key,
|
||||
api_version=api_version,
|
||||
timeout=timeout,
|
||||
max_retries=optional_params.max_retries,
|
||||
create_file_data=_create_file_request,
|
||||
litellm_params=litellm_params_dict,
|
||||
)
|
||||
elif custom_llm_provider == "vertex_ai":
|
||||
api_base = optional_params.api_base or ""
|
||||
vertex_ai_project = (
|
||||
optional_params.vertex_project
|
||||
or litellm.vertex_project
|
||||
or get_secret_str("VERTEXAI_PROJECT")
|
||||
)
|
||||
vertex_ai_location = (
|
||||
optional_params.vertex_location
|
||||
or litellm.vertex_location
|
||||
or get_secret_str("VERTEXAI_LOCATION")
|
||||
)
|
||||
vertex_credentials = optional_params.vertex_credentials or get_secret_str(
|
||||
"VERTEXAI_CREDENTIALS"
|
||||
)
|
||||
|
||||
response = vertex_ai_files_instance.create_file(
|
||||
_is_async=_is_async,
|
||||
api_base=api_base,
|
||||
vertex_project=vertex_ai_project,
|
||||
vertex_location=vertex_ai_location,
|
||||
vertex_credentials=vertex_credentials,
|
||||
timeout=timeout,
|
||||
max_retries=optional_params.max_retries,
|
||||
create_file_data=_create_file_request,
|
||||
)
|
||||
else:
|
||||
raise litellm.exceptions.BadRequestError(
|
||||
message="LiteLLM doesn't support {} for 'create_file'. Only ['openai', 'azure', 'vertex_ai'] are supported.".format(
|
||||
custom_llm_provider
|
||||
),
|
||||
model="n/a",
|
||||
llm_provider=custom_llm_provider,
|
||||
response=httpx.Response(
|
||||
status_code=400,
|
||||
content="Unsupported provider",
|
||||
request=httpx.Request(method="create_file", url="https://github.com/BerriAI/litellm"), # type: ignore
|
||||
),
|
||||
)
|
||||
return response
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
|
||||
async def afile_retrieve(
|
||||
file_id: str,
|
||||
custom_llm_provider: Literal["openai", "azure"] = "openai",
|
||||
extra_headers: Optional[Dict[str, str]] = None,
|
||||
extra_body: Optional[Dict[str, str]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Async: Get file contents
|
||||
|
||||
LiteLLM Equivalent of GET https://api.openai.com/v1/files
|
||||
"""
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
kwargs["is_async"] = True
|
||||
|
||||
# Use a partial function to pass your keyword arguments
|
||||
func = partial(
|
||||
file_retrieve,
|
||||
file_id,
|
||||
custom_llm_provider,
|
||||
extra_headers,
|
||||
extra_body,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# Add the context to the function
|
||||
ctx = contextvars.copy_context()
|
||||
func_with_context = partial(ctx.run, func)
|
||||
init_response = await loop.run_in_executor(None, func_with_context)
|
||||
if asyncio.iscoroutine(init_response):
|
||||
response = await init_response
|
||||
else:
|
||||
response = init_response
|
||||
|
||||
return response
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
|
||||
def file_retrieve(
|
||||
file_id: str,
|
||||
custom_llm_provider: Literal["openai", "azure"] = "openai",
|
||||
extra_headers: Optional[Dict[str, str]] = None,
|
||||
extra_body: Optional[Dict[str, str]] = None,
|
||||
**kwargs,
|
||||
) -> FileObject:
|
||||
"""
|
||||
Returns the contents of the specified file.
|
||||
|
||||
LiteLLM Equivalent of POST: POST https://api.openai.com/v1/files
|
||||
"""
|
||||
try:
|
||||
optional_params = GenericLiteLLMParams(**kwargs)
|
||||
### TIMEOUT LOGIC ###
|
||||
timeout = optional_params.timeout or kwargs.get("request_timeout", 600) or 600
|
||||
# set timeout for 10 minutes by default
|
||||
|
||||
if (
|
||||
timeout is not None
|
||||
and isinstance(timeout, httpx.Timeout)
|
||||
and supports_httpx_timeout(custom_llm_provider) is False
|
||||
):
|
||||
read_timeout = timeout.read or 600
|
||||
timeout = read_timeout # default 10 min timeout
|
||||
elif timeout is not None and not isinstance(timeout, httpx.Timeout):
|
||||
timeout = float(timeout) # type: ignore
|
||||
elif timeout is None:
|
||||
timeout = 600.0
|
||||
|
||||
_is_async = kwargs.pop("is_async", False) is True
|
||||
|
||||
if custom_llm_provider == "openai":
|
||||
# for deepinfra/perplexity/anyscale/groq we check in get_llm_provider and pass in the api base from there
|
||||
api_base = (
|
||||
optional_params.api_base
|
||||
or litellm.api_base
|
||||
or os.getenv("OPENAI_API_BASE")
|
||||
or "https://api.openai.com/v1"
|
||||
)
|
||||
organization = (
|
||||
optional_params.organization
|
||||
or litellm.organization
|
||||
or os.getenv("OPENAI_ORGANIZATION", None)
|
||||
or None # default - https://github.com/openai/openai-python/blob/284c1799070c723c6a553337134148a7ab088dd8/openai/util.py#L105
|
||||
)
|
||||
# set API KEY
|
||||
api_key = (
|
||||
optional_params.api_key
|
||||
or litellm.api_key # for deepinfra/perplexity/anyscale we check in get_llm_provider and pass in the api key from there
|
||||
or litellm.openai_key
|
||||
or os.getenv("OPENAI_API_KEY")
|
||||
)
|
||||
|
||||
response = openai_files_instance.retrieve_file(
|
||||
file_id=file_id,
|
||||
_is_async=_is_async,
|
||||
api_base=api_base,
|
||||
api_key=api_key,
|
||||
timeout=timeout,
|
||||
max_retries=optional_params.max_retries,
|
||||
organization=organization,
|
||||
)
|
||||
elif custom_llm_provider == "azure":
|
||||
api_base = optional_params.api_base or litellm.api_base or get_secret_str("AZURE_API_BASE") # type: ignore
|
||||
api_version = (
|
||||
optional_params.api_version
|
||||
or litellm.api_version
|
||||
or get_secret_str("AZURE_API_VERSION")
|
||||
) # type: ignore
|
||||
|
||||
api_key = (
|
||||
optional_params.api_key
|
||||
or litellm.api_key
|
||||
or litellm.azure_key
|
||||
or get_secret_str("AZURE_OPENAI_API_KEY")
|
||||
or get_secret_str("AZURE_API_KEY")
|
||||
) # type: ignore
|
||||
|
||||
extra_body = optional_params.get("extra_body", {})
|
||||
if extra_body is not None:
|
||||
extra_body.pop("azure_ad_token", None)
|
||||
else:
|
||||
get_secret_str("AZURE_AD_TOKEN") # type: ignore
|
||||
|
||||
response = azure_files_instance.retrieve_file(
|
||||
_is_async=_is_async,
|
||||
api_base=api_base,
|
||||
api_key=api_key,
|
||||
api_version=api_version,
|
||||
timeout=timeout,
|
||||
max_retries=optional_params.max_retries,
|
||||
file_id=file_id,
|
||||
)
|
||||
else:
|
||||
raise litellm.exceptions.BadRequestError(
|
||||
message="LiteLLM doesn't support {} for 'file_retrieve'. Only 'openai' and 'azure' are supported.".format(
|
||||
custom_llm_provider
|
||||
),
|
||||
model="n/a",
|
||||
llm_provider=custom_llm_provider,
|
||||
response=httpx.Response(
|
||||
status_code=400,
|
||||
content="Unsupported provider",
|
||||
request=httpx.Request(method="create_thread", url="https://github.com/BerriAI/litellm"), # type: ignore
|
||||
),
|
||||
)
|
||||
return cast(FileObject, response)
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
|
||||
# Delete file
|
||||
async def afile_delete(
|
||||
file_id: str,
|
||||
custom_llm_provider: Literal["openai", "azure"] = "openai",
|
||||
extra_headers: Optional[Dict[str, str]] = None,
|
||||
extra_body: Optional[Dict[str, str]] = None,
|
||||
**kwargs,
|
||||
) -> Coroutine[Any, Any, FileObject]:
|
||||
"""
|
||||
Async: Delete file
|
||||
|
||||
LiteLLM Equivalent of DELETE https://api.openai.com/v1/files
|
||||
"""
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
kwargs["is_async"] = True
|
||||
|
||||
# Use a partial function to pass your keyword arguments
|
||||
func = partial(
|
||||
file_delete,
|
||||
file_id,
|
||||
custom_llm_provider,
|
||||
extra_headers,
|
||||
extra_body,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# Add the context to the function
|
||||
ctx = contextvars.copy_context()
|
||||
func_with_context = partial(ctx.run, func)
|
||||
init_response = await loop.run_in_executor(None, func_with_context)
|
||||
if asyncio.iscoroutine(init_response):
|
||||
response = await init_response
|
||||
else:
|
||||
response = init_response # type: ignore
|
||||
|
||||
return cast(FileDeleted, response) # type: ignore
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
|
||||
def file_delete(
|
||||
file_id: str,
|
||||
custom_llm_provider: Literal["openai", "azure"] = "openai",
|
||||
extra_headers: Optional[Dict[str, str]] = None,
|
||||
extra_body: Optional[Dict[str, str]] = None,
|
||||
**kwargs,
|
||||
) -> FileDeleted:
|
||||
"""
|
||||
Delete file
|
||||
|
||||
LiteLLM Equivalent of DELETE https://api.openai.com/v1/files
|
||||
"""
|
||||
try:
|
||||
optional_params = GenericLiteLLMParams(**kwargs)
|
||||
litellm_params_dict = get_litellm_params(**kwargs)
|
||||
### TIMEOUT LOGIC ###
|
||||
timeout = optional_params.timeout or kwargs.get("request_timeout", 600) or 600
|
||||
# set timeout for 10 minutes by default
|
||||
client = kwargs.get("client")
|
||||
|
||||
if (
|
||||
timeout is not None
|
||||
and isinstance(timeout, httpx.Timeout)
|
||||
and supports_httpx_timeout(custom_llm_provider) is False
|
||||
):
|
||||
read_timeout = timeout.read or 600
|
||||
timeout = read_timeout # default 10 min timeout
|
||||
elif timeout is not None and not isinstance(timeout, httpx.Timeout):
|
||||
timeout = float(timeout) # type: ignore
|
||||
elif timeout is None:
|
||||
timeout = 600.0
|
||||
_is_async = kwargs.pop("is_async", False) is True
|
||||
if custom_llm_provider == "openai":
|
||||
# for deepinfra/perplexity/anyscale/groq we check in get_llm_provider and pass in the api base from there
|
||||
api_base = (
|
||||
optional_params.api_base
|
||||
or litellm.api_base
|
||||
or os.getenv("OPENAI_API_BASE")
|
||||
or "https://api.openai.com/v1"
|
||||
)
|
||||
organization = (
|
||||
optional_params.organization
|
||||
or litellm.organization
|
||||
or os.getenv("OPENAI_ORGANIZATION", None)
|
||||
or None # default - https://github.com/openai/openai-python/blob/284c1799070c723c6a553337134148a7ab088dd8/openai/util.py#L105
|
||||
)
|
||||
# set API KEY
|
||||
api_key = (
|
||||
optional_params.api_key
|
||||
or litellm.api_key # for deepinfra/perplexity/anyscale we check in get_llm_provider and pass in the api key from there
|
||||
or litellm.openai_key
|
||||
or os.getenv("OPENAI_API_KEY")
|
||||
)
|
||||
response = openai_files_instance.delete_file(
|
||||
file_id=file_id,
|
||||
_is_async=_is_async,
|
||||
api_base=api_base,
|
||||
api_key=api_key,
|
||||
timeout=timeout,
|
||||
max_retries=optional_params.max_retries,
|
||||
organization=organization,
|
||||
)
|
||||
elif custom_llm_provider == "azure":
|
||||
api_base = optional_params.api_base or litellm.api_base or get_secret_str("AZURE_API_BASE") # type: ignore
|
||||
api_version = (
|
||||
optional_params.api_version
|
||||
or litellm.api_version
|
||||
or get_secret_str("AZURE_API_VERSION")
|
||||
) # type: ignore
|
||||
|
||||
api_key = (
|
||||
optional_params.api_key
|
||||
or litellm.api_key
|
||||
or litellm.azure_key
|
||||
or get_secret_str("AZURE_OPENAI_API_KEY")
|
||||
or get_secret_str("AZURE_API_KEY")
|
||||
) # type: ignore
|
||||
|
||||
extra_body = optional_params.get("extra_body", {})
|
||||
if extra_body is not None:
|
||||
extra_body.pop("azure_ad_token", None)
|
||||
else:
|
||||
get_secret_str("AZURE_AD_TOKEN") # type: ignore
|
||||
|
||||
response = azure_files_instance.delete_file(
|
||||
_is_async=_is_async,
|
||||
api_base=api_base,
|
||||
api_key=api_key,
|
||||
api_version=api_version,
|
||||
timeout=timeout,
|
||||
max_retries=optional_params.max_retries,
|
||||
file_id=file_id,
|
||||
client=client,
|
||||
litellm_params=litellm_params_dict,
|
||||
)
|
||||
else:
|
||||
raise litellm.exceptions.BadRequestError(
|
||||
message="LiteLLM doesn't support {} for 'create_batch'. Only 'openai' is supported.".format(
|
||||
custom_llm_provider
|
||||
),
|
||||
model="n/a",
|
||||
llm_provider=custom_llm_provider,
|
||||
response=httpx.Response(
|
||||
status_code=400,
|
||||
content="Unsupported provider",
|
||||
request=httpx.Request(method="create_thread", url="https://github.com/BerriAI/litellm"), # type: ignore
|
||||
),
|
||||
)
|
||||
return cast(FileDeleted, response)
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
|
||||
# List files
|
||||
async def afile_list(
|
||||
custom_llm_provider: Literal["openai", "azure"] = "openai",
|
||||
purpose: Optional[str] = None,
|
||||
extra_headers: Optional[Dict[str, str]] = None,
|
||||
extra_body: Optional[Dict[str, str]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Async: List files
|
||||
|
||||
LiteLLM Equivalent of GET https://api.openai.com/v1/files
|
||||
"""
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
kwargs["is_async"] = True
|
||||
|
||||
# Use a partial function to pass your keyword arguments
|
||||
func = partial(
|
||||
file_list,
|
||||
custom_llm_provider,
|
||||
purpose,
|
||||
extra_headers,
|
||||
extra_body,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# Add the context to the function
|
||||
ctx = contextvars.copy_context()
|
||||
func_with_context = partial(ctx.run, func)
|
||||
init_response = await loop.run_in_executor(None, func_with_context)
|
||||
if asyncio.iscoroutine(init_response):
|
||||
response = await init_response
|
||||
else:
|
||||
response = init_response # type: ignore
|
||||
|
||||
return response
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
|
||||
def file_list(
|
||||
custom_llm_provider: Literal["openai", "azure"] = "openai",
|
||||
purpose: Optional[str] = None,
|
||||
extra_headers: Optional[Dict[str, str]] = None,
|
||||
extra_body: Optional[Dict[str, str]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
List files
|
||||
|
||||
LiteLLM Equivalent of GET https://api.openai.com/v1/files
|
||||
"""
|
||||
try:
|
||||
optional_params = GenericLiteLLMParams(**kwargs)
|
||||
### TIMEOUT LOGIC ###
|
||||
timeout = optional_params.timeout or kwargs.get("request_timeout", 600) or 600
|
||||
# set timeout for 10 minutes by default
|
||||
|
||||
if (
|
||||
timeout is not None
|
||||
and isinstance(timeout, httpx.Timeout)
|
||||
and supports_httpx_timeout(custom_llm_provider) is False
|
||||
):
|
||||
read_timeout = timeout.read or 600
|
||||
timeout = read_timeout # default 10 min timeout
|
||||
elif timeout is not None and not isinstance(timeout, httpx.Timeout):
|
||||
timeout = float(timeout) # type: ignore
|
||||
elif timeout is None:
|
||||
timeout = 600.0
|
||||
|
||||
_is_async = kwargs.pop("is_async", False) is True
|
||||
if custom_llm_provider == "openai":
|
||||
# for deepinfra/perplexity/anyscale/groq we check in get_llm_provider and pass in the api base from there
|
||||
api_base = (
|
||||
optional_params.api_base
|
||||
or litellm.api_base
|
||||
or os.getenv("OPENAI_API_BASE")
|
||||
or "https://api.openai.com/v1"
|
||||
)
|
||||
organization = (
|
||||
optional_params.organization
|
||||
or litellm.organization
|
||||
or os.getenv("OPENAI_ORGANIZATION", None)
|
||||
or None # default - https://github.com/openai/openai-python/blob/284c1799070c723c6a553337134148a7ab088dd8/openai/util.py#L105
|
||||
)
|
||||
# set API KEY
|
||||
api_key = (
|
||||
optional_params.api_key
|
||||
or litellm.api_key # for deepinfra/perplexity/anyscale we check in get_llm_provider and pass in the api key from there
|
||||
or litellm.openai_key
|
||||
or os.getenv("OPENAI_API_KEY")
|
||||
)
|
||||
|
||||
response = openai_files_instance.list_files(
|
||||
purpose=purpose,
|
||||
_is_async=_is_async,
|
||||
api_base=api_base,
|
||||
api_key=api_key,
|
||||
timeout=timeout,
|
||||
max_retries=optional_params.max_retries,
|
||||
organization=organization,
|
||||
)
|
||||
elif custom_llm_provider == "azure":
|
||||
api_base = optional_params.api_base or litellm.api_base or get_secret_str("AZURE_API_BASE") # type: ignore
|
||||
api_version = (
|
||||
optional_params.api_version
|
||||
or litellm.api_version
|
||||
or get_secret_str("AZURE_API_VERSION")
|
||||
) # type: ignore
|
||||
|
||||
api_key = (
|
||||
optional_params.api_key
|
||||
or litellm.api_key
|
||||
or litellm.azure_key
|
||||
or get_secret_str("AZURE_OPENAI_API_KEY")
|
||||
or get_secret_str("AZURE_API_KEY")
|
||||
) # type: ignore
|
||||
|
||||
extra_body = optional_params.get("extra_body", {})
|
||||
if extra_body is not None:
|
||||
extra_body.pop("azure_ad_token", None)
|
||||
else:
|
||||
get_secret_str("AZURE_AD_TOKEN") # type: ignore
|
||||
|
||||
response = azure_files_instance.list_files(
|
||||
_is_async=_is_async,
|
||||
api_base=api_base,
|
||||
api_key=api_key,
|
||||
api_version=api_version,
|
||||
timeout=timeout,
|
||||
max_retries=optional_params.max_retries,
|
||||
purpose=purpose,
|
||||
)
|
||||
else:
|
||||
raise litellm.exceptions.BadRequestError(
|
||||
message="LiteLLM doesn't support {} for 'file_list'. Only 'openai' and 'azure' are supported.".format(
|
||||
custom_llm_provider
|
||||
),
|
||||
model="n/a",
|
||||
llm_provider=custom_llm_provider,
|
||||
response=httpx.Response(
|
||||
status_code=400,
|
||||
content="Unsupported provider",
|
||||
request=httpx.Request(method="file_list", url="https://github.com/BerriAI/litellm"), # type: ignore
|
||||
),
|
||||
)
|
||||
return response
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
|
||||
async def afile_content(
|
||||
file_id: str,
|
||||
custom_llm_provider: Literal["openai", "azure"] = "openai",
|
||||
extra_headers: Optional[Dict[str, str]] = None,
|
||||
extra_body: Optional[Dict[str, str]] = None,
|
||||
**kwargs,
|
||||
) -> HttpxBinaryResponseContent:
|
||||
"""
|
||||
Async: Get file contents
|
||||
|
||||
LiteLLM Equivalent of GET https://api.openai.com/v1/files
|
||||
"""
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
kwargs["afile_content"] = True
|
||||
|
||||
# Use a partial function to pass your keyword arguments
|
||||
func = partial(
|
||||
file_content,
|
||||
file_id,
|
||||
custom_llm_provider,
|
||||
extra_headers,
|
||||
extra_body,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# Add the context to the function
|
||||
ctx = contextvars.copy_context()
|
||||
func_with_context = partial(ctx.run, func)
|
||||
init_response = await loop.run_in_executor(None, func_with_context)
|
||||
if asyncio.iscoroutine(init_response):
|
||||
response = await init_response
|
||||
else:
|
||||
response = init_response # type: ignore
|
||||
|
||||
return response
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
|
||||
def file_content(
|
||||
file_id: str,
|
||||
custom_llm_provider: Literal["openai", "azure"] = "openai",
|
||||
extra_headers: Optional[Dict[str, str]] = None,
|
||||
extra_body: Optional[Dict[str, str]] = None,
|
||||
**kwargs,
|
||||
) -> Union[HttpxBinaryResponseContent, Coroutine[Any, Any, HttpxBinaryResponseContent]]:
|
||||
"""
|
||||
Returns the contents of the specified file.
|
||||
|
||||
LiteLLM Equivalent of POST: POST https://api.openai.com/v1/files
|
||||
"""
|
||||
try:
|
||||
optional_params = GenericLiteLLMParams(**kwargs)
|
||||
litellm_params_dict = get_litellm_params(**kwargs)
|
||||
### TIMEOUT LOGIC ###
|
||||
timeout = optional_params.timeout or kwargs.get("request_timeout", 600) or 600
|
||||
client = kwargs.get("client")
|
||||
# set timeout for 10 minutes by default
|
||||
|
||||
if (
|
||||
timeout is not None
|
||||
and isinstance(timeout, httpx.Timeout)
|
||||
and supports_httpx_timeout(custom_llm_provider) is False
|
||||
):
|
||||
read_timeout = timeout.read or 600
|
||||
timeout = read_timeout # default 10 min timeout
|
||||
elif timeout is not None and not isinstance(timeout, httpx.Timeout):
|
||||
timeout = float(timeout) # type: ignore
|
||||
elif timeout is None:
|
||||
timeout = 600.0
|
||||
|
||||
_file_content_request = FileContentRequest(
|
||||
file_id=file_id,
|
||||
extra_headers=extra_headers,
|
||||
extra_body=extra_body,
|
||||
)
|
||||
|
||||
_is_async = kwargs.pop("afile_content", False) is True
|
||||
|
||||
if custom_llm_provider == "openai":
|
||||
# for deepinfra/perplexity/anyscale/groq we check in get_llm_provider and pass in the api base from there
|
||||
api_base = (
|
||||
optional_params.api_base
|
||||
or litellm.api_base
|
||||
or os.getenv("OPENAI_API_BASE")
|
||||
or "https://api.openai.com/v1"
|
||||
)
|
||||
organization = (
|
||||
optional_params.organization
|
||||
or litellm.organization
|
||||
or os.getenv("OPENAI_ORGANIZATION", None)
|
||||
or None # default - https://github.com/openai/openai-python/blob/284c1799070c723c6a553337134148a7ab088dd8/openai/util.py#L105
|
||||
)
|
||||
# set API KEY
|
||||
api_key = (
|
||||
optional_params.api_key
|
||||
or litellm.api_key # for deepinfra/perplexity/anyscale we check in get_llm_provider and pass in the api key from there
|
||||
or litellm.openai_key
|
||||
or os.getenv("OPENAI_API_KEY")
|
||||
)
|
||||
|
||||
response = openai_files_instance.file_content(
|
||||
_is_async=_is_async,
|
||||
file_content_request=_file_content_request,
|
||||
api_base=api_base,
|
||||
api_key=api_key,
|
||||
timeout=timeout,
|
||||
max_retries=optional_params.max_retries,
|
||||
organization=organization,
|
||||
)
|
||||
elif custom_llm_provider == "azure":
|
||||
api_base = optional_params.api_base or litellm.api_base or get_secret_str("AZURE_API_BASE") # type: ignore
|
||||
api_version = (
|
||||
optional_params.api_version
|
||||
or litellm.api_version
|
||||
or get_secret_str("AZURE_API_VERSION")
|
||||
) # type: ignore
|
||||
|
||||
api_key = (
|
||||
optional_params.api_key
|
||||
or litellm.api_key
|
||||
or litellm.azure_key
|
||||
or get_secret_str("AZURE_OPENAI_API_KEY")
|
||||
or get_secret_str("AZURE_API_KEY")
|
||||
) # type: ignore
|
||||
|
||||
extra_body = optional_params.get("extra_body", {})
|
||||
if extra_body is not None:
|
||||
extra_body.pop("azure_ad_token", None)
|
||||
else:
|
||||
get_secret_str("AZURE_AD_TOKEN") # type: ignore
|
||||
|
||||
response = azure_files_instance.file_content(
|
||||
_is_async=_is_async,
|
||||
api_base=api_base,
|
||||
api_key=api_key,
|
||||
api_version=api_version,
|
||||
timeout=timeout,
|
||||
max_retries=optional_params.max_retries,
|
||||
file_content_request=_file_content_request,
|
||||
client=client,
|
||||
litellm_params=litellm_params_dict,
|
||||
)
|
||||
else:
|
||||
raise litellm.exceptions.BadRequestError(
|
||||
message="LiteLLM doesn't support {} for 'custom_llm_provider'. Supported providers are 'openai', 'azure', 'vertex_ai'.".format(
|
||||
custom_llm_provider
|
||||
),
|
||||
model="n/a",
|
||||
llm_provider=custom_llm_provider,
|
||||
response=httpx.Response(
|
||||
status_code=400,
|
||||
content="Unsupported provider",
|
||||
request=httpx.Request(method="create_thread", url="https://github.com/BerriAI/litellm"), # type: ignore
|
||||
),
|
||||
)
|
||||
return response
|
||||
except Exception as e:
|
||||
raise e
|
||||
Binary file not shown.
757
.venv/lib/python3.10/site-packages/litellm/fine_tuning/main.py
Normal file
757
.venv/lib/python3.10/site-packages/litellm/fine_tuning/main.py
Normal file
@@ -0,0 +1,757 @@
|
||||
"""
|
||||
Main File for Fine Tuning API implementation
|
||||
|
||||
https://platform.openai.com/docs/api-reference/fine-tuning
|
||||
|
||||
- fine_tuning.jobs.create()
|
||||
- fine_tuning.jobs.list()
|
||||
- client.fine_tuning.jobs.list_events()
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import contextvars
|
||||
import os
|
||||
from functools import partial
|
||||
from typing import Any, Coroutine, Dict, Literal, Optional, Union
|
||||
|
||||
import httpx
|
||||
|
||||
import litellm
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.llms.azure.fine_tuning.handler import AzureOpenAIFineTuningAPI
|
||||
from litellm.llms.openai.fine_tuning.handler import OpenAIFineTuningAPI
|
||||
from litellm.llms.vertex_ai.fine_tuning.handler import VertexFineTuningAPI
|
||||
from litellm.secret_managers.main import get_secret_str
|
||||
from litellm.types.llms.openai import (
|
||||
FineTuningJob,
|
||||
FineTuningJobCreate,
|
||||
Hyperparameters,
|
||||
)
|
||||
from litellm.types.router import *
|
||||
from litellm.utils import client, supports_httpx_timeout
|
||||
|
||||
####### ENVIRONMENT VARIABLES ###################
|
||||
openai_fine_tuning_apis_instance = OpenAIFineTuningAPI()
|
||||
azure_fine_tuning_apis_instance = AzureOpenAIFineTuningAPI()
|
||||
vertex_fine_tuning_apis_instance = VertexFineTuningAPI()
|
||||
#################################################
|
||||
|
||||
|
||||
@client
|
||||
async def acreate_fine_tuning_job(
|
||||
model: str,
|
||||
training_file: str,
|
||||
hyperparameters: Optional[dict] = {},
|
||||
suffix: Optional[str] = None,
|
||||
validation_file: Optional[str] = None,
|
||||
integrations: Optional[List[str]] = None,
|
||||
seed: Optional[int] = None,
|
||||
custom_llm_provider: Literal["openai", "azure", "vertex_ai"] = "openai",
|
||||
extra_headers: Optional[Dict[str, str]] = None,
|
||||
extra_body: Optional[Dict[str, str]] = None,
|
||||
**kwargs,
|
||||
) -> FineTuningJob:
|
||||
"""
|
||||
Async: Creates and executes a batch from an uploaded file of request
|
||||
|
||||
"""
|
||||
verbose_logger.debug(
|
||||
"inside acreate_fine_tuning_job model=%s and kwargs=%s", model, kwargs
|
||||
)
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
kwargs["acreate_fine_tuning_job"] = True
|
||||
|
||||
# Use a partial function to pass your keyword arguments
|
||||
func = partial(
|
||||
create_fine_tuning_job,
|
||||
model,
|
||||
training_file,
|
||||
hyperparameters,
|
||||
suffix,
|
||||
validation_file,
|
||||
integrations,
|
||||
seed,
|
||||
custom_llm_provider,
|
||||
extra_headers,
|
||||
extra_body,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# Add the context to the function
|
||||
ctx = contextvars.copy_context()
|
||||
func_with_context = partial(ctx.run, func)
|
||||
init_response = await loop.run_in_executor(None, func_with_context)
|
||||
if asyncio.iscoroutine(init_response):
|
||||
response = await init_response
|
||||
else:
|
||||
response = init_response # type: ignore
|
||||
return response
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
|
||||
@client
|
||||
def create_fine_tuning_job(
|
||||
model: str,
|
||||
training_file: str,
|
||||
hyperparameters: Optional[dict] = {},
|
||||
suffix: Optional[str] = None,
|
||||
validation_file: Optional[str] = None,
|
||||
integrations: Optional[List[str]] = None,
|
||||
seed: Optional[int] = None,
|
||||
custom_llm_provider: Literal["openai", "azure", "vertex_ai"] = "openai",
|
||||
extra_headers: Optional[Dict[str, str]] = None,
|
||||
extra_body: Optional[Dict[str, str]] = None,
|
||||
**kwargs,
|
||||
) -> Union[FineTuningJob, Coroutine[Any, Any, FineTuningJob]]:
|
||||
"""
|
||||
Creates a fine-tuning job which begins the process of creating a new model from a given dataset.
|
||||
|
||||
Response includes details of the enqueued job including job status and the name of the fine-tuned models once complete
|
||||
|
||||
"""
|
||||
try:
|
||||
_is_async = kwargs.pop("acreate_fine_tuning_job", False) is True
|
||||
optional_params = GenericLiteLLMParams(**kwargs)
|
||||
|
||||
# handle hyperparameters
|
||||
hyperparameters = hyperparameters or {} # original hyperparameters
|
||||
_oai_hyperparameters: Hyperparameters = Hyperparameters(
|
||||
**hyperparameters
|
||||
) # Typed Hyperparameters for OpenAI Spec
|
||||
### TIMEOUT LOGIC ###
|
||||
timeout = optional_params.timeout or kwargs.get("request_timeout", 600) or 600
|
||||
# set timeout for 10 minutes by default
|
||||
|
||||
if (
|
||||
timeout is not None
|
||||
and isinstance(timeout, httpx.Timeout)
|
||||
and supports_httpx_timeout(custom_llm_provider) is False
|
||||
):
|
||||
read_timeout = timeout.read or 600
|
||||
timeout = read_timeout # default 10 min timeout
|
||||
elif timeout is not None and not isinstance(timeout, httpx.Timeout):
|
||||
timeout = float(timeout) # type: ignore
|
||||
elif timeout is None:
|
||||
timeout = 600.0
|
||||
|
||||
# OpenAI
|
||||
if custom_llm_provider == "openai":
|
||||
# for deepinfra/perplexity/anyscale/groq we check in get_llm_provider and pass in the api base from there
|
||||
api_base = (
|
||||
optional_params.api_base
|
||||
or litellm.api_base
|
||||
or os.getenv("OPENAI_API_BASE")
|
||||
or "https://api.openai.com/v1"
|
||||
)
|
||||
organization = (
|
||||
optional_params.organization
|
||||
or litellm.organization
|
||||
or os.getenv("OPENAI_ORGANIZATION", None)
|
||||
or None # default - https://github.com/openai/openai-python/blob/284c1799070c723c6a553337134148a7ab088dd8/openai/util.py#L105
|
||||
)
|
||||
# set API KEY
|
||||
api_key = (
|
||||
optional_params.api_key
|
||||
or litellm.api_key # for deepinfra/perplexity/anyscale we check in get_llm_provider and pass in the api key from there
|
||||
or litellm.openai_key
|
||||
or os.getenv("OPENAI_API_KEY")
|
||||
)
|
||||
|
||||
create_fine_tuning_job_data = FineTuningJobCreate(
|
||||
model=model,
|
||||
training_file=training_file,
|
||||
hyperparameters=_oai_hyperparameters,
|
||||
suffix=suffix,
|
||||
validation_file=validation_file,
|
||||
integrations=integrations,
|
||||
seed=seed,
|
||||
)
|
||||
|
||||
create_fine_tuning_job_data_dict = create_fine_tuning_job_data.model_dump(
|
||||
exclude_none=True
|
||||
)
|
||||
|
||||
response = openai_fine_tuning_apis_instance.create_fine_tuning_job(
|
||||
api_base=api_base,
|
||||
api_key=api_key,
|
||||
api_version=optional_params.api_version,
|
||||
organization=organization,
|
||||
create_fine_tuning_job_data=create_fine_tuning_job_data_dict,
|
||||
timeout=timeout,
|
||||
max_retries=optional_params.max_retries,
|
||||
_is_async=_is_async,
|
||||
client=kwargs.get(
|
||||
"client", None
|
||||
), # note, when we add this to `GenericLiteLLMParams` it impacts a lot of other tests + linting
|
||||
)
|
||||
# Azure OpenAI
|
||||
elif custom_llm_provider == "azure":
|
||||
api_base = optional_params.api_base or litellm.api_base or get_secret_str("AZURE_API_BASE") # type: ignore
|
||||
|
||||
api_version = (
|
||||
optional_params.api_version
|
||||
or litellm.api_version
|
||||
or get_secret_str("AZURE_API_VERSION")
|
||||
) # type: ignore
|
||||
|
||||
api_key = (
|
||||
optional_params.api_key
|
||||
or litellm.api_key
|
||||
or litellm.azure_key
|
||||
or get_secret_str("AZURE_OPENAI_API_KEY")
|
||||
or get_secret_str("AZURE_API_KEY")
|
||||
) # type: ignore
|
||||
|
||||
extra_body = optional_params.get("extra_body", {})
|
||||
if extra_body is not None:
|
||||
extra_body.pop("azure_ad_token", None)
|
||||
else:
|
||||
get_secret_str("AZURE_AD_TOKEN") # type: ignore
|
||||
create_fine_tuning_job_data = FineTuningJobCreate(
|
||||
model=model,
|
||||
training_file=training_file,
|
||||
hyperparameters=_oai_hyperparameters,
|
||||
suffix=suffix,
|
||||
validation_file=validation_file,
|
||||
integrations=integrations,
|
||||
seed=seed,
|
||||
)
|
||||
|
||||
create_fine_tuning_job_data_dict = create_fine_tuning_job_data.model_dump(
|
||||
exclude_none=True
|
||||
)
|
||||
|
||||
response = azure_fine_tuning_apis_instance.create_fine_tuning_job(
|
||||
api_base=api_base,
|
||||
api_key=api_key,
|
||||
api_version=api_version,
|
||||
create_fine_tuning_job_data=create_fine_tuning_job_data_dict,
|
||||
timeout=timeout,
|
||||
max_retries=optional_params.max_retries,
|
||||
_is_async=_is_async,
|
||||
organization=optional_params.organization,
|
||||
)
|
||||
elif custom_llm_provider == "vertex_ai":
|
||||
api_base = optional_params.api_base or ""
|
||||
vertex_ai_project = (
|
||||
optional_params.vertex_project
|
||||
or litellm.vertex_project
|
||||
or get_secret_str("VERTEXAI_PROJECT")
|
||||
)
|
||||
vertex_ai_location = (
|
||||
optional_params.vertex_location
|
||||
or litellm.vertex_location
|
||||
or get_secret_str("VERTEXAI_LOCATION")
|
||||
)
|
||||
vertex_credentials = optional_params.vertex_credentials or get_secret_str(
|
||||
"VERTEXAI_CREDENTIALS"
|
||||
)
|
||||
create_fine_tuning_job_data = FineTuningJobCreate(
|
||||
model=model,
|
||||
training_file=training_file,
|
||||
hyperparameters=_oai_hyperparameters,
|
||||
suffix=suffix,
|
||||
validation_file=validation_file,
|
||||
integrations=integrations,
|
||||
seed=seed,
|
||||
)
|
||||
response = vertex_fine_tuning_apis_instance.create_fine_tuning_job(
|
||||
_is_async=_is_async,
|
||||
create_fine_tuning_job_data=create_fine_tuning_job_data,
|
||||
vertex_credentials=vertex_credentials,
|
||||
vertex_project=vertex_ai_project,
|
||||
vertex_location=vertex_ai_location,
|
||||
timeout=timeout,
|
||||
api_base=api_base,
|
||||
kwargs=kwargs,
|
||||
original_hyperparameters=hyperparameters,
|
||||
)
|
||||
else:
|
||||
raise litellm.exceptions.BadRequestError(
|
||||
message="LiteLLM doesn't support {} for 'create_batch'. Only 'openai' is supported.".format(
|
||||
custom_llm_provider
|
||||
),
|
||||
model="n/a",
|
||||
llm_provider=custom_llm_provider,
|
||||
response=httpx.Response(
|
||||
status_code=400,
|
||||
content="Unsupported provider",
|
||||
request=httpx.Request(method="create_thread", url="https://github.com/BerriAI/litellm"), # type: ignore
|
||||
),
|
||||
)
|
||||
return response
|
||||
except Exception as e:
|
||||
verbose_logger.error("got exception in create_fine_tuning_job=%s", str(e))
|
||||
raise e
|
||||
|
||||
|
||||
async def acancel_fine_tuning_job(
|
||||
fine_tuning_job_id: str,
|
||||
custom_llm_provider: Literal["openai", "azure", "vertex_ai"] = "openai",
|
||||
extra_headers: Optional[Dict[str, str]] = None,
|
||||
extra_body: Optional[Dict[str, str]] = None,
|
||||
**kwargs,
|
||||
) -> FineTuningJob:
|
||||
"""
|
||||
Async: Immediately cancel a fine-tune job.
|
||||
"""
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
kwargs["acancel_fine_tuning_job"] = True
|
||||
|
||||
# Use a partial function to pass your keyword arguments
|
||||
func = partial(
|
||||
cancel_fine_tuning_job,
|
||||
fine_tuning_job_id,
|
||||
custom_llm_provider,
|
||||
extra_headers,
|
||||
extra_body,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# Add the context to the function
|
||||
ctx = contextvars.copy_context()
|
||||
func_with_context = partial(ctx.run, func)
|
||||
init_response = await loop.run_in_executor(None, func_with_context)
|
||||
if asyncio.iscoroutine(init_response):
|
||||
response = await init_response
|
||||
else:
|
||||
response = init_response # type: ignore
|
||||
return response
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
|
||||
def cancel_fine_tuning_job(
|
||||
fine_tuning_job_id: str,
|
||||
custom_llm_provider: Literal["openai", "azure", "vertex_ai"] = "openai",
|
||||
extra_headers: Optional[Dict[str, str]] = None,
|
||||
extra_body: Optional[Dict[str, str]] = None,
|
||||
**kwargs,
|
||||
) -> Union[FineTuningJob, Coroutine[Any, Any, FineTuningJob]]:
|
||||
"""
|
||||
Immediately cancel a fine-tune job.
|
||||
|
||||
Response includes details of the enqueued job including job status and the name of the fine-tuned models once complete
|
||||
|
||||
"""
|
||||
try:
|
||||
optional_params = GenericLiteLLMParams(**kwargs)
|
||||
### TIMEOUT LOGIC ###
|
||||
timeout = optional_params.timeout or kwargs.get("request_timeout", 600) or 600
|
||||
# set timeout for 10 minutes by default
|
||||
|
||||
if (
|
||||
timeout is not None
|
||||
and isinstance(timeout, httpx.Timeout)
|
||||
and supports_httpx_timeout(custom_llm_provider) is False
|
||||
):
|
||||
read_timeout = timeout.read or 600
|
||||
timeout = read_timeout # default 10 min timeout
|
||||
elif timeout is not None and not isinstance(timeout, httpx.Timeout):
|
||||
timeout = float(timeout) # type: ignore
|
||||
elif timeout is None:
|
||||
timeout = 600.0
|
||||
|
||||
_is_async = kwargs.pop("acancel_fine_tuning_job", False) is True
|
||||
|
||||
# OpenAI
|
||||
if custom_llm_provider == "openai":
|
||||
# for deepinfra/perplexity/anyscale/groq we check in get_llm_provider and pass in the api base from there
|
||||
api_base = (
|
||||
optional_params.api_base
|
||||
or litellm.api_base
|
||||
or os.getenv("OPENAI_API_BASE")
|
||||
or "https://api.openai.com/v1"
|
||||
)
|
||||
organization = (
|
||||
optional_params.organization
|
||||
or litellm.organization
|
||||
or os.getenv("OPENAI_ORGANIZATION", None)
|
||||
or None # default - https://github.com/openai/openai-python/blob/284c1799070c723c6a553337134148a7ab088dd8/openai/util.py#L105
|
||||
)
|
||||
# set API KEY
|
||||
api_key = (
|
||||
optional_params.api_key
|
||||
or litellm.api_key # for deepinfra/perplexity/anyscale we check in get_llm_provider and pass in the api key from there
|
||||
or litellm.openai_key
|
||||
or os.getenv("OPENAI_API_KEY")
|
||||
)
|
||||
|
||||
response = openai_fine_tuning_apis_instance.cancel_fine_tuning_job(
|
||||
api_base=api_base,
|
||||
api_key=api_key,
|
||||
api_version=optional_params.api_version,
|
||||
organization=organization,
|
||||
fine_tuning_job_id=fine_tuning_job_id,
|
||||
timeout=timeout,
|
||||
max_retries=optional_params.max_retries,
|
||||
_is_async=_is_async,
|
||||
client=kwargs.get("client", None),
|
||||
)
|
||||
# Azure OpenAI
|
||||
elif custom_llm_provider == "azure":
|
||||
api_base = optional_params.api_base or litellm.api_base or get_secret("AZURE_API_BASE") # type: ignore
|
||||
|
||||
api_version = (
|
||||
optional_params.api_version
|
||||
or litellm.api_version
|
||||
or get_secret_str("AZURE_API_VERSION")
|
||||
) # type: ignore
|
||||
|
||||
api_key = (
|
||||
optional_params.api_key
|
||||
or litellm.api_key
|
||||
or litellm.azure_key
|
||||
or get_secret_str("AZURE_OPENAI_API_KEY")
|
||||
or get_secret_str("AZURE_API_KEY")
|
||||
) # type: ignore
|
||||
|
||||
extra_body = optional_params.get("extra_body", {})
|
||||
if extra_body is not None:
|
||||
extra_body.pop("azure_ad_token", None)
|
||||
else:
|
||||
get_secret_str("AZURE_AD_TOKEN") # type: ignore
|
||||
|
||||
response = azure_fine_tuning_apis_instance.cancel_fine_tuning_job(
|
||||
api_base=api_base,
|
||||
api_key=api_key,
|
||||
api_version=api_version,
|
||||
fine_tuning_job_id=fine_tuning_job_id,
|
||||
timeout=timeout,
|
||||
max_retries=optional_params.max_retries,
|
||||
_is_async=_is_async,
|
||||
organization=optional_params.organization,
|
||||
)
|
||||
else:
|
||||
raise litellm.exceptions.BadRequestError(
|
||||
message="LiteLLM doesn't support {} for 'create_batch'. Only 'openai' is supported.".format(
|
||||
custom_llm_provider
|
||||
),
|
||||
model="n/a",
|
||||
llm_provider=custom_llm_provider,
|
||||
response=httpx.Response(
|
||||
status_code=400,
|
||||
content="Unsupported provider",
|
||||
request=httpx.Request(method="create_thread", url="https://github.com/BerriAI/litellm"), # type: ignore
|
||||
),
|
||||
)
|
||||
return response
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
|
||||
async def alist_fine_tuning_jobs(
|
||||
after: Optional[str] = None,
|
||||
limit: Optional[int] = None,
|
||||
custom_llm_provider: Literal["openai", "azure", "vertex_ai"] = "openai",
|
||||
extra_headers: Optional[Dict[str, str]] = None,
|
||||
extra_body: Optional[Dict[str, str]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Async: List your organization's fine-tuning jobs
|
||||
"""
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
kwargs["alist_fine_tuning_jobs"] = True
|
||||
|
||||
# Use a partial function to pass your keyword arguments
|
||||
func = partial(
|
||||
list_fine_tuning_jobs,
|
||||
after,
|
||||
limit,
|
||||
custom_llm_provider,
|
||||
extra_headers,
|
||||
extra_body,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# Add the context to the function
|
||||
ctx = contextvars.copy_context()
|
||||
func_with_context = partial(ctx.run, func)
|
||||
init_response = await loop.run_in_executor(None, func_with_context)
|
||||
if asyncio.iscoroutine(init_response):
|
||||
response = await init_response
|
||||
else:
|
||||
response = init_response # type: ignore
|
||||
return response
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
|
||||
def list_fine_tuning_jobs(
|
||||
after: Optional[str] = None,
|
||||
limit: Optional[int] = None,
|
||||
custom_llm_provider: Literal["openai", "azure", "vertex_ai"] = "openai",
|
||||
extra_headers: Optional[Dict[str, str]] = None,
|
||||
extra_body: Optional[Dict[str, str]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
List your organization's fine-tuning jobs
|
||||
|
||||
Params:
|
||||
|
||||
- after: Optional[str] = None, Identifier for the last job from the previous pagination request.
|
||||
- limit: Optional[int] = None, Number of fine-tuning jobs to retrieve. Defaults to 20
|
||||
"""
|
||||
try:
|
||||
optional_params = GenericLiteLLMParams(**kwargs)
|
||||
### TIMEOUT LOGIC ###
|
||||
timeout = optional_params.timeout or kwargs.get("request_timeout", 600) or 600
|
||||
# set timeout for 10 minutes by default
|
||||
|
||||
if (
|
||||
timeout is not None
|
||||
and isinstance(timeout, httpx.Timeout)
|
||||
and supports_httpx_timeout(custom_llm_provider) is False
|
||||
):
|
||||
read_timeout = timeout.read or 600
|
||||
timeout = read_timeout # default 10 min timeout
|
||||
elif timeout is not None and not isinstance(timeout, httpx.Timeout):
|
||||
timeout = float(timeout) # type: ignore
|
||||
elif timeout is None:
|
||||
timeout = 600.0
|
||||
|
||||
_is_async = kwargs.pop("alist_fine_tuning_jobs", False) is True
|
||||
|
||||
# OpenAI
|
||||
if custom_llm_provider == "openai":
|
||||
# for deepinfra/perplexity/anyscale/groq we check in get_llm_provider and pass in the api base from there
|
||||
api_base = (
|
||||
optional_params.api_base
|
||||
or litellm.api_base
|
||||
or os.getenv("OPENAI_API_BASE")
|
||||
or "https://api.openai.com/v1"
|
||||
)
|
||||
organization = (
|
||||
optional_params.organization
|
||||
or litellm.organization
|
||||
or os.getenv("OPENAI_ORGANIZATION", None)
|
||||
or None # default - https://github.com/openai/openai-python/blob/284c1799070c723c6a553337134148a7ab088dd8/openai/util.py#L105
|
||||
)
|
||||
# set API KEY
|
||||
api_key = (
|
||||
optional_params.api_key
|
||||
or litellm.api_key # for deepinfra/perplexity/anyscale we check in get_llm_provider and pass in the api key from there
|
||||
or litellm.openai_key
|
||||
or os.getenv("OPENAI_API_KEY")
|
||||
)
|
||||
|
||||
response = openai_fine_tuning_apis_instance.list_fine_tuning_jobs(
|
||||
api_base=api_base,
|
||||
api_key=api_key,
|
||||
api_version=optional_params.api_version,
|
||||
organization=organization,
|
||||
after=after,
|
||||
limit=limit,
|
||||
timeout=timeout,
|
||||
max_retries=optional_params.max_retries,
|
||||
_is_async=_is_async,
|
||||
client=kwargs.get("client", None),
|
||||
)
|
||||
# Azure OpenAI
|
||||
elif custom_llm_provider == "azure":
|
||||
api_base = optional_params.api_base or litellm.api_base or get_secret_str("AZURE_API_BASE") # type: ignore
|
||||
|
||||
api_version = (
|
||||
optional_params.api_version
|
||||
or litellm.api_version
|
||||
or get_secret_str("AZURE_API_VERSION")
|
||||
) # type: ignore
|
||||
|
||||
api_key = (
|
||||
optional_params.api_key
|
||||
or litellm.api_key
|
||||
or litellm.azure_key
|
||||
or get_secret_str("AZURE_OPENAI_API_KEY")
|
||||
or get_secret_str("AZURE_API_KEY")
|
||||
) # type: ignore
|
||||
|
||||
extra_body = optional_params.get("extra_body", {})
|
||||
if extra_body is not None:
|
||||
extra_body.pop("azure_ad_token", None)
|
||||
else:
|
||||
get_secret("AZURE_AD_TOKEN") # type: ignore
|
||||
|
||||
response = azure_fine_tuning_apis_instance.list_fine_tuning_jobs(
|
||||
api_base=api_base,
|
||||
api_key=api_key,
|
||||
api_version=api_version,
|
||||
after=after,
|
||||
limit=limit,
|
||||
timeout=timeout,
|
||||
max_retries=optional_params.max_retries,
|
||||
_is_async=_is_async,
|
||||
organization=optional_params.organization,
|
||||
)
|
||||
else:
|
||||
raise litellm.exceptions.BadRequestError(
|
||||
message="LiteLLM doesn't support {} for 'create_batch'. Only 'openai' is supported.".format(
|
||||
custom_llm_provider
|
||||
),
|
||||
model="n/a",
|
||||
llm_provider=custom_llm_provider,
|
||||
response=httpx.Response(
|
||||
status_code=400,
|
||||
content="Unsupported provider",
|
||||
request=httpx.Request(method="create_thread", url="https://github.com/BerriAI/litellm"), # type: ignore
|
||||
),
|
||||
)
|
||||
return response
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
|
||||
async def aretrieve_fine_tuning_job(
|
||||
fine_tuning_job_id: str,
|
||||
custom_llm_provider: Literal["openai", "azure", "vertex_ai"] = "openai",
|
||||
extra_headers: Optional[Dict[str, str]] = None,
|
||||
extra_body: Optional[Dict[str, str]] = None,
|
||||
**kwargs,
|
||||
) -> FineTuningJob:
|
||||
"""
|
||||
Async: Get info about a fine-tuning job.
|
||||
"""
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
kwargs["aretrieve_fine_tuning_job"] = True
|
||||
|
||||
# Use a partial function to pass your keyword arguments
|
||||
func = partial(
|
||||
retrieve_fine_tuning_job,
|
||||
fine_tuning_job_id,
|
||||
custom_llm_provider,
|
||||
extra_headers,
|
||||
extra_body,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# Add the context to the function
|
||||
ctx = contextvars.copy_context()
|
||||
func_with_context = partial(ctx.run, func)
|
||||
init_response = await loop.run_in_executor(None, func_with_context)
|
||||
if asyncio.iscoroutine(init_response):
|
||||
response = await init_response
|
||||
else:
|
||||
response = init_response # type: ignore
|
||||
return response
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
|
||||
def retrieve_fine_tuning_job(
|
||||
fine_tuning_job_id: str,
|
||||
custom_llm_provider: Literal["openai", "azure", "vertex_ai"] = "openai",
|
||||
extra_headers: Optional[Dict[str, str]] = None,
|
||||
extra_body: Optional[Dict[str, str]] = None,
|
||||
**kwargs,
|
||||
) -> Union[FineTuningJob, Coroutine[Any, Any, FineTuningJob]]:
|
||||
"""
|
||||
Get info about a fine-tuning job.
|
||||
"""
|
||||
try:
|
||||
optional_params = GenericLiteLLMParams(**kwargs)
|
||||
### TIMEOUT LOGIC ###
|
||||
timeout = optional_params.timeout or kwargs.get("request_timeout", 600) or 600
|
||||
# set timeout for 10 minutes by default
|
||||
|
||||
if (
|
||||
timeout is not None
|
||||
and isinstance(timeout, httpx.Timeout)
|
||||
and supports_httpx_timeout(custom_llm_provider) is False
|
||||
):
|
||||
read_timeout = timeout.read or 600
|
||||
timeout = read_timeout # default 10 min timeout
|
||||
elif timeout is not None and not isinstance(timeout, httpx.Timeout):
|
||||
timeout = float(timeout) # type: ignore
|
||||
elif timeout is None:
|
||||
timeout = 600.0
|
||||
|
||||
_is_async = kwargs.pop("aretrieve_fine_tuning_job", False) is True
|
||||
|
||||
# OpenAI
|
||||
if custom_llm_provider == "openai":
|
||||
api_base = (
|
||||
optional_params.api_base
|
||||
or litellm.api_base
|
||||
or os.getenv("OPENAI_API_BASE")
|
||||
or "https://api.openai.com/v1"
|
||||
)
|
||||
organization = (
|
||||
optional_params.organization
|
||||
or litellm.organization
|
||||
or os.getenv("OPENAI_ORGANIZATION", None)
|
||||
or None
|
||||
)
|
||||
api_key = (
|
||||
optional_params.api_key
|
||||
or litellm.api_key
|
||||
or litellm.openai_key
|
||||
or os.getenv("OPENAI_API_KEY")
|
||||
)
|
||||
|
||||
response = openai_fine_tuning_apis_instance.retrieve_fine_tuning_job(
|
||||
api_base=api_base,
|
||||
api_key=api_key,
|
||||
api_version=optional_params.api_version,
|
||||
organization=organization,
|
||||
fine_tuning_job_id=fine_tuning_job_id,
|
||||
timeout=timeout,
|
||||
max_retries=optional_params.max_retries,
|
||||
_is_async=_is_async,
|
||||
client=kwargs.get("client", None),
|
||||
)
|
||||
# Azure OpenAI
|
||||
elif custom_llm_provider == "azure":
|
||||
api_base = optional_params.api_base or litellm.api_base or get_secret_str("AZURE_API_BASE") # type: ignore
|
||||
|
||||
api_version = (
|
||||
optional_params.api_version
|
||||
or litellm.api_version
|
||||
or get_secret_str("AZURE_API_VERSION")
|
||||
) # type: ignore
|
||||
|
||||
api_key = (
|
||||
optional_params.api_key
|
||||
or litellm.api_key
|
||||
or litellm.azure_key
|
||||
or get_secret_str("AZURE_OPENAI_API_KEY")
|
||||
or get_secret_str("AZURE_API_KEY")
|
||||
) # type: ignore
|
||||
|
||||
extra_body = optional_params.get("extra_body", {})
|
||||
if extra_body is not None:
|
||||
extra_body.pop("azure_ad_token", None)
|
||||
else:
|
||||
get_secret_str("AZURE_AD_TOKEN") # type: ignore
|
||||
|
||||
response = azure_fine_tuning_apis_instance.retrieve_fine_tuning_job(
|
||||
api_base=api_base,
|
||||
api_key=api_key,
|
||||
api_version=api_version,
|
||||
fine_tuning_job_id=fine_tuning_job_id,
|
||||
timeout=timeout,
|
||||
max_retries=optional_params.max_retries,
|
||||
_is_async=_is_async,
|
||||
organization=optional_params.organization,
|
||||
)
|
||||
else:
|
||||
raise litellm.exceptions.BadRequestError(
|
||||
message="LiteLLM doesn't support {} for 'retrieve_fine_tuning_job'. Only 'openai' and 'azure' are supported.".format(
|
||||
custom_llm_provider
|
||||
),
|
||||
model="n/a",
|
||||
llm_provider=custom_llm_provider,
|
||||
response=httpx.Response(
|
||||
status_code=400,
|
||||
content="Unsupported provider",
|
||||
request=httpx.Request(method="retrieve_fine_tuning_job", url="https://github.com/BerriAI/litellm"), # type: ignore
|
||||
),
|
||||
)
|
||||
return response
|
||||
except Exception as e:
|
||||
raise e
|
||||
@@ -0,0 +1,5 @@
|
||||
# Integrations
|
||||
|
||||
This folder contains logging integrations for litellm
|
||||
|
||||
eg. logging to Datadog, Langfuse, Prometheus, s3, GCS Bucket, etc.
|
||||
@@ -0,0 +1,13 @@
|
||||
# Slack Alerting on LiteLLM Gateway
|
||||
|
||||
This folder contains the Slack Alerting integration for LiteLLM Gateway.
|
||||
|
||||
## Folder Structure
|
||||
|
||||
- `slack_alerting.py`: This is the main file that handles sending different types of alerts
|
||||
- `batching_handler.py`: Handles Batching + sending Httpx Post requests to slack. Slack alerts are sent every 10s or when events are greater than X events. Done to ensure litellm has good performance under high traffic
|
||||
- `types.py`: This file contains the AlertType enum which is used to define the different types of alerts that can be sent to Slack.
|
||||
- `utils.py`: This file contains common utils used specifically for slack alerting
|
||||
|
||||
## Further Reading
|
||||
- [Doc setting up Alerting on LiteLLM Proxy (Gateway)](https://docs.litellm.ai/docs/proxy/alerting)
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,81 @@
|
||||
"""
|
||||
Handles Batching + sending Httpx Post requests to slack
|
||||
|
||||
Slack alerts are sent every 10s or when events are greater than X events
|
||||
|
||||
see custom_batch_logger.py for more details / defaults
|
||||
"""
|
||||
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .slack_alerting import SlackAlerting as _SlackAlerting
|
||||
|
||||
SlackAlertingType = _SlackAlerting
|
||||
else:
|
||||
SlackAlertingType = Any
|
||||
|
||||
|
||||
def squash_payloads(queue):
|
||||
squashed = {}
|
||||
if len(queue) == 0:
|
||||
return squashed
|
||||
if len(queue) == 1:
|
||||
return {"key": {"item": queue[0], "count": 1}}
|
||||
|
||||
for item in queue:
|
||||
url = item["url"]
|
||||
alert_type = item["alert_type"]
|
||||
_key = (url, alert_type)
|
||||
|
||||
if _key in squashed:
|
||||
squashed[_key]["count"] += 1
|
||||
# Merge the payloads
|
||||
|
||||
else:
|
||||
squashed[_key] = {"item": item, "count": 1}
|
||||
|
||||
return squashed
|
||||
|
||||
|
||||
def _print_alerting_payload_warning(
|
||||
payload: dict, slackAlertingInstance: SlackAlertingType
|
||||
):
|
||||
"""
|
||||
Print the payload to the console when
|
||||
slackAlertingInstance.alerting_args.log_to_console is True
|
||||
|
||||
Relevant issue: https://github.com/BerriAI/litellm/issues/7372
|
||||
"""
|
||||
if slackAlertingInstance.alerting_args.log_to_console is True:
|
||||
verbose_proxy_logger.warning(payload)
|
||||
|
||||
|
||||
async def send_to_webhook(slackAlertingInstance: SlackAlertingType, item, count):
|
||||
"""
|
||||
Send a single slack alert to the webhook
|
||||
"""
|
||||
import json
|
||||
|
||||
payload = item.get("payload", {})
|
||||
try:
|
||||
if count > 1:
|
||||
payload["text"] = f"[Num Alerts: {count}]\n\n{payload['text']}"
|
||||
|
||||
response = await slackAlertingInstance.async_http_handler.post(
|
||||
url=item["url"],
|
||||
headers=item["headers"],
|
||||
data=json.dumps(payload),
|
||||
)
|
||||
if response.status_code != 200:
|
||||
verbose_proxy_logger.debug(
|
||||
f"Error sending slack alert to url={item['url']}. Error={response.text}"
|
||||
)
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.debug(f"Error sending slack alert: {str(e)}")
|
||||
finally:
|
||||
_print_alerting_payload_warning(
|
||||
payload, slackAlertingInstance=slackAlertingInstance
|
||||
)
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,92 @@
|
||||
"""
|
||||
Utils used for slack alerting
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
|
||||
|
||||
from litellm.proxy._types import AlertType
|
||||
from litellm.secret_managers.main import get_secret
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as _Logging
|
||||
|
||||
Logging = _Logging
|
||||
else:
|
||||
Logging = Any
|
||||
|
||||
|
||||
def process_slack_alerting_variables(
|
||||
alert_to_webhook_url: Optional[Dict[AlertType, Union[List[str], str]]]
|
||||
) -> Optional[Dict[AlertType, Union[List[str], str]]]:
|
||||
"""
|
||||
process alert_to_webhook_url
|
||||
- check if any urls are set as os.environ/SLACK_WEBHOOK_URL_1 read env var and set the correct value
|
||||
"""
|
||||
if alert_to_webhook_url is None:
|
||||
return None
|
||||
|
||||
for alert_type, webhook_urls in alert_to_webhook_url.items():
|
||||
if isinstance(webhook_urls, list):
|
||||
_webhook_values: List[str] = []
|
||||
for webhook_url in webhook_urls:
|
||||
if "os.environ/" in webhook_url:
|
||||
_env_value = get_secret(secret_name=webhook_url)
|
||||
if not isinstance(_env_value, str):
|
||||
raise ValueError(
|
||||
f"Invalid webhook url value for: {webhook_url}. Got type={type(_env_value)}"
|
||||
)
|
||||
_webhook_values.append(_env_value)
|
||||
else:
|
||||
_webhook_values.append(webhook_url)
|
||||
|
||||
alert_to_webhook_url[alert_type] = _webhook_values
|
||||
else:
|
||||
_webhook_value_str: str = webhook_urls
|
||||
if "os.environ/" in webhook_urls:
|
||||
_env_value = get_secret(secret_name=webhook_urls)
|
||||
if not isinstance(_env_value, str):
|
||||
raise ValueError(
|
||||
f"Invalid webhook url value for: {webhook_urls}. Got type={type(_env_value)}"
|
||||
)
|
||||
_webhook_value_str = _env_value
|
||||
else:
|
||||
_webhook_value_str = webhook_urls
|
||||
|
||||
alert_to_webhook_url[alert_type] = _webhook_value_str
|
||||
|
||||
return alert_to_webhook_url
|
||||
|
||||
|
||||
async def _add_langfuse_trace_id_to_alert(
|
||||
request_data: Optional[dict] = None,
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
Returns langfuse trace url
|
||||
|
||||
- check:
|
||||
-> existing_trace_id
|
||||
-> trace_id
|
||||
-> litellm_call_id
|
||||
"""
|
||||
# do nothing for now
|
||||
if (
|
||||
request_data is not None
|
||||
and request_data.get("litellm_logging_obj", None) is not None
|
||||
):
|
||||
trace_id: Optional[str] = None
|
||||
litellm_logging_obj: Logging = request_data["litellm_logging_obj"]
|
||||
|
||||
for _ in range(3):
|
||||
trace_id = litellm_logging_obj._get_trace_id(service_name="langfuse")
|
||||
if trace_id is not None:
|
||||
break
|
||||
await asyncio.sleep(3) # wait 3s before retrying for trace id
|
||||
|
||||
_langfuse_object = litellm_logging_obj._get_callback_object(
|
||||
service_name="langfuse"
|
||||
)
|
||||
if _langfuse_object is not None:
|
||||
base_url = _langfuse_object.Langfuse.base_url
|
||||
return f"{base_url}/trace/{trace_id}"
|
||||
return None
|
||||
@@ -0,0 +1 @@
|
||||
from . import *
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user