structure saas with tools

This commit is contained in:
Davidson Gomes
2025-04-25 15:30:54 -03:00
commit 1aef473937
16434 changed files with 6584257 additions and 0 deletions

View File

@@ -0,0 +1,40 @@
# Caching on LiteLLM
LiteLLM supports multiple caching mechanisms. This allows users to choose the most suitable caching solution for their use case.
The following caching mechanisms are supported:
1. **RedisCache**
2. **RedisSemanticCache**
3. **QdrantSemanticCache**
4. **InMemoryCache**
5. **DiskCache**
6. **S3Cache**
7. **DualCache** (updates both Redis and an in-memory cache simultaneously)
## Folder Structure
```
litellm/caching/
├── base_cache.py
├── caching.py
├── caching_handler.py
├── disk_cache.py
├── dual_cache.py
├── in_memory_cache.py
├── qdrant_semantic_cache.py
├── redis_cache.py
├── redis_semantic_cache.py
├── s3_cache.py
```
## Documentation
- [Caching on LiteLLM Gateway](https://docs.litellm.ai/docs/proxy/caching)
- [Caching on LiteLLM Python](https://docs.litellm.ai/docs/caching/all_caches)

View File

@@ -0,0 +1,9 @@
from .caching import Cache, LiteLLMCacheType
from .disk_cache import DiskCache
from .dual_cache import DualCache
from .in_memory_cache import InMemoryCache
from .qdrant_semantic_cache import QdrantSemanticCache
from .redis_cache import RedisCache
from .redis_cluster_cache import RedisClusterCache
from .redis_semantic_cache import RedisSemanticCache
from .s3_cache import S3Cache

View File

@@ -0,0 +1,30 @@
from functools import lru_cache
from typing import Callable, Optional, TypeVar
T = TypeVar("T")
def lru_cache_wrapper(
maxsize: Optional[int] = None,
) -> Callable[[Callable[..., T]], Callable[..., T]]:
"""
Wrapper for lru_cache that caches success and exceptions
"""
def decorator(f: Callable[..., T]) -> Callable[..., T]:
@lru_cache(maxsize=maxsize)
def wrapper(*args, **kwargs):
try:
return ("success", f(*args, **kwargs))
except Exception as e:
return ("error", e)
def wrapped(*args, **kwargs):
result = wrapper(*args, **kwargs)
if result[0] == "error":
raise result[1]
return result[1]
return wrapped
return decorator

View File

@@ -0,0 +1,55 @@
"""
Base Cache implementation. All cache implementations should inherit from this class.
Has 4 methods:
- set_cache
- get_cache
- async_set_cache
- async_get_cache
"""
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any, Optional, Union
if TYPE_CHECKING:
from opentelemetry.trace import Span as _Span
Span = Union[_Span, Any]
else:
Span = Any
class BaseCache(ABC):
def __init__(self, default_ttl: int = 60):
self.default_ttl = default_ttl
def get_ttl(self, **kwargs) -> Optional[int]:
kwargs_ttl: Optional[int] = kwargs.get("ttl")
if kwargs_ttl is not None:
try:
return int(kwargs_ttl)
except ValueError:
return self.default_ttl
return self.default_ttl
def set_cache(self, key, value, **kwargs):
raise NotImplementedError
async def async_set_cache(self, key, value, **kwargs):
raise NotImplementedError
@abstractmethod
async def async_set_cache_pipeline(self, cache_list, **kwargs):
pass
def get_cache(self, key, **kwargs):
raise NotImplementedError
async def async_get_cache(self, key, **kwargs):
raise NotImplementedError
async def batch_cache_write(self, key, value, **kwargs):
raise NotImplementedError
async def disconnect(self):
raise NotImplementedError

View File

@@ -0,0 +1,798 @@
# +-----------------------------------------------+
# | |
# | Give Feedback / Get Help |
# | https://github.com/BerriAI/litellm/issues/new |
# | |
# +-----------------------------------------------+
#
# Thank you users! We ❤️ you! - Krrish & Ishaan
import ast
import hashlib
import json
import time
import traceback
from enum import Enum
from typing import Any, Dict, List, Optional, Union
from pydantic import BaseModel
import litellm
from litellm._logging import verbose_logger
from litellm.constants import CACHED_STREAMING_CHUNK_DELAY
from litellm.litellm_core_utils.model_param_helper import ModelParamHelper
from litellm.types.caching import *
from litellm.types.utils import all_litellm_params
from .base_cache import BaseCache
from .disk_cache import DiskCache
from .dual_cache import DualCache # noqa
from .in_memory_cache import InMemoryCache
from .qdrant_semantic_cache import QdrantSemanticCache
from .redis_cache import RedisCache
from .redis_cluster_cache import RedisClusterCache
from .redis_semantic_cache import RedisSemanticCache
from .s3_cache import S3Cache
def print_verbose(print_statement):
try:
verbose_logger.debug(print_statement)
if litellm.set_verbose:
print(print_statement) # noqa
except Exception:
pass
class CacheMode(str, Enum):
default_on = "default_on"
default_off = "default_off"
#### LiteLLM.Completion / Embedding Cache ####
class Cache:
def __init__(
self,
type: Optional[LiteLLMCacheType] = LiteLLMCacheType.LOCAL,
mode: Optional[
CacheMode
] = CacheMode.default_on, # when default_on cache is always on, when default_off cache is opt in
host: Optional[str] = None,
port: Optional[str] = None,
password: Optional[str] = None,
namespace: Optional[str] = None,
ttl: Optional[float] = None,
default_in_memory_ttl: Optional[float] = None,
default_in_redis_ttl: Optional[float] = None,
similarity_threshold: Optional[float] = None,
supported_call_types: Optional[List[CachingSupportedCallTypes]] = [
"completion",
"acompletion",
"embedding",
"aembedding",
"atranscription",
"transcription",
"atext_completion",
"text_completion",
"arerank",
"rerank",
],
# s3 Bucket, boto3 configuration
s3_bucket_name: Optional[str] = None,
s3_region_name: Optional[str] = None,
s3_api_version: Optional[str] = None,
s3_use_ssl: Optional[bool] = True,
s3_verify: Optional[Union[bool, str]] = None,
s3_endpoint_url: Optional[str] = None,
s3_aws_access_key_id: Optional[str] = None,
s3_aws_secret_access_key: Optional[str] = None,
s3_aws_session_token: Optional[str] = None,
s3_config: Optional[Any] = None,
s3_path: Optional[str] = None,
redis_semantic_cache_embedding_model: str = "text-embedding-ada-002",
redis_semantic_cache_index_name: Optional[str] = None,
redis_flush_size: Optional[int] = None,
redis_startup_nodes: Optional[List] = None,
disk_cache_dir: Optional[str] = None,
qdrant_api_base: Optional[str] = None,
qdrant_api_key: Optional[str] = None,
qdrant_collection_name: Optional[str] = None,
qdrant_quantization_config: Optional[str] = None,
qdrant_semantic_cache_embedding_model: str = "text-embedding-ada-002",
**kwargs,
):
"""
Initializes the cache based on the given type.
Args:
type (str, optional): The type of cache to initialize. Can be "local", "redis", "redis-semantic", "qdrant-semantic", "s3" or "disk". Defaults to "local".
# Redis Cache Args
host (str, optional): The host address for the Redis cache. Required if type is "redis".
port (int, optional): The port number for the Redis cache. Required if type is "redis".
password (str, optional): The password for the Redis cache. Required if type is "redis".
namespace (str, optional): The namespace for the Redis cache. Required if type is "redis".
ttl (float, optional): The ttl for the Redis cache
redis_flush_size (int, optional): The number of keys to flush at a time. Defaults to 1000. Only used if batch redis set caching is used.
redis_startup_nodes (list, optional): The list of startup nodes for the Redis cache. Defaults to None.
# Qdrant Cache Args
qdrant_api_base (str, optional): The url for your qdrant cluster. Required if type is "qdrant-semantic".
qdrant_api_key (str, optional): The api_key for the local or cloud qdrant cluster.
qdrant_collection_name (str, optional): The name for your qdrant collection. Required if type is "qdrant-semantic".
similarity_threshold (float, optional): The similarity threshold for semantic-caching, Required if type is "redis-semantic" or "qdrant-semantic".
# Disk Cache Args
disk_cache_dir (str, optional): The directory for the disk cache. Defaults to None.
# S3 Cache Args
s3_bucket_name (str, optional): The bucket name for the s3 cache. Defaults to None.
s3_region_name (str, optional): The region name for the s3 cache. Defaults to None.
s3_api_version (str, optional): The api version for the s3 cache. Defaults to None.
s3_use_ssl (bool, optional): The use ssl for the s3 cache. Defaults to True.
s3_verify (bool, optional): The verify for the s3 cache. Defaults to None.
s3_endpoint_url (str, optional): The endpoint url for the s3 cache. Defaults to None.
s3_aws_access_key_id (str, optional): The aws access key id for the s3 cache. Defaults to None.
s3_aws_secret_access_key (str, optional): The aws secret access key for the s3 cache. Defaults to None.
s3_aws_session_token (str, optional): The aws session token for the s3 cache. Defaults to None.
s3_config (dict, optional): The config for the s3 cache. Defaults to None.
# Common Cache Args
supported_call_types (list, optional): List of call types to cache for. Defaults to cache == on for all call types.
**kwargs: Additional keyword arguments for redis.Redis() cache
Raises:
ValueError: If an invalid cache type is provided.
Returns:
None. Cache is set as a litellm param
"""
if type == LiteLLMCacheType.REDIS:
if redis_startup_nodes:
self.cache: BaseCache = RedisClusterCache(
host=host,
port=port,
password=password,
redis_flush_size=redis_flush_size,
startup_nodes=redis_startup_nodes,
**kwargs,
)
else:
self.cache = RedisCache(
host=host,
port=port,
password=password,
redis_flush_size=redis_flush_size,
**kwargs,
)
elif type == LiteLLMCacheType.REDIS_SEMANTIC:
self.cache = RedisSemanticCache(
host=host,
port=port,
password=password,
similarity_threshold=similarity_threshold,
embedding_model=redis_semantic_cache_embedding_model,
index_name=redis_semantic_cache_index_name,
**kwargs,
)
elif type == LiteLLMCacheType.QDRANT_SEMANTIC:
self.cache = QdrantSemanticCache(
qdrant_api_base=qdrant_api_base,
qdrant_api_key=qdrant_api_key,
collection_name=qdrant_collection_name,
similarity_threshold=similarity_threshold,
quantization_config=qdrant_quantization_config,
embedding_model=qdrant_semantic_cache_embedding_model,
)
elif type == LiteLLMCacheType.LOCAL:
self.cache = InMemoryCache()
elif type == LiteLLMCacheType.S3:
self.cache = S3Cache(
s3_bucket_name=s3_bucket_name,
s3_region_name=s3_region_name,
s3_api_version=s3_api_version,
s3_use_ssl=s3_use_ssl,
s3_verify=s3_verify,
s3_endpoint_url=s3_endpoint_url,
s3_aws_access_key_id=s3_aws_access_key_id,
s3_aws_secret_access_key=s3_aws_secret_access_key,
s3_aws_session_token=s3_aws_session_token,
s3_config=s3_config,
s3_path=s3_path,
**kwargs,
)
elif type == LiteLLMCacheType.DISK:
self.cache = DiskCache(disk_cache_dir=disk_cache_dir)
if "cache" not in litellm.input_callback:
litellm.input_callback.append("cache")
if "cache" not in litellm.success_callback:
litellm.logging_callback_manager.add_litellm_success_callback("cache")
if "cache" not in litellm._async_success_callback:
litellm.logging_callback_manager.add_litellm_async_success_callback("cache")
self.supported_call_types = supported_call_types # default to ["completion", "acompletion", "embedding", "aembedding"]
self.type = type
self.namespace = namespace
self.redis_flush_size = redis_flush_size
self.ttl = ttl
self.mode: CacheMode = mode or CacheMode.default_on
if self.type == LiteLLMCacheType.LOCAL and default_in_memory_ttl is not None:
self.ttl = default_in_memory_ttl
if (
self.type == LiteLLMCacheType.REDIS
or self.type == LiteLLMCacheType.REDIS_SEMANTIC
) and default_in_redis_ttl is not None:
self.ttl = default_in_redis_ttl
if self.namespace is not None and isinstance(self.cache, RedisCache):
self.cache.namespace = self.namespace
def get_cache_key(self, **kwargs) -> str:
"""
Get the cache key for the given arguments.
Args:
**kwargs: kwargs to litellm.completion() or embedding()
Returns:
str: The cache key generated from the arguments, or None if no cache key could be generated.
"""
cache_key = ""
# verbose_logger.debug("\nGetting Cache key. Kwargs: %s", kwargs)
preset_cache_key = self._get_preset_cache_key_from_kwargs(**kwargs)
if preset_cache_key is not None:
verbose_logger.debug("\nReturning preset cache key: %s", preset_cache_key)
return preset_cache_key
combined_kwargs = ModelParamHelper._get_all_llm_api_params()
litellm_param_kwargs = all_litellm_params
for param in kwargs:
if param in combined_kwargs:
param_value: Optional[str] = self._get_param_value(param, kwargs)
if param_value is not None:
cache_key += f"{str(param)}: {str(param_value)}"
elif (
param not in litellm_param_kwargs
): # check if user passed in optional param - e.g. top_k
if (
litellm.enable_caching_on_provider_specific_optional_params is True
): # feature flagged for now
if kwargs[param] is None:
continue # ignore None params
param_value = kwargs[param]
cache_key += f"{str(param)}: {str(param_value)}"
verbose_logger.debug("\nCreated cache key: %s", cache_key)
hashed_cache_key = Cache._get_hashed_cache_key(cache_key)
hashed_cache_key = self._add_namespace_to_cache_key(hashed_cache_key, **kwargs)
self._set_preset_cache_key_in_kwargs(
preset_cache_key=hashed_cache_key, **kwargs
)
return hashed_cache_key
def _get_param_value(
self,
param: str,
kwargs: dict,
) -> Optional[str]:
"""
Get the value for the given param from kwargs
"""
if param == "model":
return self._get_model_param_value(kwargs)
elif param == "file":
return self._get_file_param_value(kwargs)
return kwargs[param]
def _get_model_param_value(self, kwargs: dict) -> str:
"""
Handles getting the value for the 'model' param from kwargs
1. If caching groups are set, then return the caching group as the model https://docs.litellm.ai/docs/routing#caching-across-model-groups
2. Else if a model_group is set, then return the model_group as the model. This is used for all requests sent through the litellm.Router()
3. Else use the `model` passed in kwargs
"""
metadata: Dict = kwargs.get("metadata", {}) or {}
litellm_params: Dict = kwargs.get("litellm_params", {}) or {}
metadata_in_litellm_params: Dict = litellm_params.get("metadata", {}) or {}
model_group: Optional[str] = metadata.get(
"model_group"
) or metadata_in_litellm_params.get("model_group")
caching_group = self._get_caching_group(metadata, model_group)
return caching_group or model_group or kwargs["model"]
def _get_caching_group(
self, metadata: dict, model_group: Optional[str]
) -> Optional[str]:
caching_groups: Optional[List] = metadata.get("caching_groups", [])
if caching_groups:
for group in caching_groups:
if model_group in group:
return str(group)
return None
def _get_file_param_value(self, kwargs: dict) -> str:
"""
Handles getting the value for the 'file' param from kwargs. Used for `transcription` requests
"""
file = kwargs.get("file")
metadata = kwargs.get("metadata", {})
litellm_params = kwargs.get("litellm_params", {})
return (
metadata.get("file_checksum")
or getattr(file, "name", None)
or metadata.get("file_name")
or litellm_params.get("file_name")
)
def _get_preset_cache_key_from_kwargs(self, **kwargs) -> Optional[str]:
"""
Get the preset cache key from kwargs["litellm_params"]
We use _get_preset_cache_keys for two reasons
1. optional params like max_tokens, get transformed for bedrock -> max_new_tokens
2. avoid doing duplicate / repeated work
"""
if kwargs:
if "litellm_params" in kwargs:
return kwargs["litellm_params"].get("preset_cache_key", None)
return None
def _set_preset_cache_key_in_kwargs(self, preset_cache_key: str, **kwargs) -> None:
"""
Set the calculated cache key in kwargs
This is used to avoid doing duplicate / repeated work
Placed in kwargs["litellm_params"]
"""
if kwargs:
if "litellm_params" in kwargs:
kwargs["litellm_params"]["preset_cache_key"] = preset_cache_key
@staticmethod
def _get_hashed_cache_key(cache_key: str) -> str:
"""
Get the hashed cache key for the given cache key.
Use hashlib to create a sha256 hash of the cache key
Args:
cache_key (str): The cache key to hash.
Returns:
str: The hashed cache key.
"""
hash_object = hashlib.sha256(cache_key.encode())
# Hexadecimal representation of the hash
hash_hex = hash_object.hexdigest()
verbose_logger.debug("Hashed cache key (SHA-256): %s", hash_hex)
return hash_hex
def _add_namespace_to_cache_key(self, hash_hex: str, **kwargs) -> str:
"""
If a redis namespace is provided, add it to the cache key
Args:
hash_hex (str): The hashed cache key.
**kwargs: Additional keyword arguments.
Returns:
str: The final hashed cache key with the redis namespace.
"""
dynamic_cache_control: DynamicCacheControl = kwargs.get("cache", {})
namespace = (
dynamic_cache_control.get("namespace")
or kwargs.get("metadata", {}).get("redis_namespace")
or self.namespace
)
if namespace:
hash_hex = f"{namespace}:{hash_hex}"
verbose_logger.debug("Final hashed key: %s", hash_hex)
return hash_hex
def generate_streaming_content(self, content):
chunk_size = 5 # Adjust the chunk size as needed
for i in range(0, len(content), chunk_size):
yield {
"choices": [
{
"delta": {
"role": "assistant",
"content": content[i : i + chunk_size],
}
}
]
}
time.sleep(CACHED_STREAMING_CHUNK_DELAY)
def _get_cache_logic(
self,
cached_result: Optional[Any],
max_age: Optional[float],
):
"""
Common get cache logic across sync + async implementations
"""
# Check if a timestamp was stored with the cached response
if (
cached_result is not None
and isinstance(cached_result, dict)
and "timestamp" in cached_result
):
timestamp = cached_result["timestamp"]
current_time = time.time()
# Calculate age of the cached response
response_age = current_time - timestamp
# Check if the cached response is older than the max-age
if max_age is not None and response_age > max_age:
return None # Cached response is too old
# If the response is fresh, or there's no max-age requirement, return the cached response
# cached_response is in `b{} convert it to ModelResponse
cached_response = cached_result.get("response")
try:
if isinstance(cached_response, dict):
pass
else:
cached_response = json.loads(
cached_response # type: ignore
) # Convert string to dictionary
except Exception:
cached_response = ast.literal_eval(cached_response) # type: ignore
return cached_response
return cached_result
def get_cache(self, **kwargs):
"""
Retrieves the cached result for the given arguments.
Args:
*args: args to litellm.completion() or embedding()
**kwargs: kwargs to litellm.completion() or embedding()
Returns:
The cached result if it exists, otherwise None.
"""
try: # never block execution
if self.should_use_cache(**kwargs) is not True:
return
messages = kwargs.get("messages", [])
if "cache_key" in kwargs:
cache_key = kwargs["cache_key"]
else:
cache_key = self.get_cache_key(**kwargs)
if cache_key is not None:
cache_control_args: DynamicCacheControl = kwargs.get("cache", {})
max_age = (
cache_control_args.get("s-maxage")
or cache_control_args.get("s-max-age")
or float("inf")
)
cached_result = self.cache.get_cache(cache_key, messages=messages)
cached_result = self.cache.get_cache(cache_key, messages=messages)
return self._get_cache_logic(
cached_result=cached_result, max_age=max_age
)
except Exception:
print_verbose(f"An exception occurred: {traceback.format_exc()}")
return None
async def async_get_cache(self, **kwargs):
"""
Async get cache implementation.
Used for embedding calls in async wrapper
"""
try: # never block execution
if self.should_use_cache(**kwargs) is not True:
return
kwargs.get("messages", [])
if "cache_key" in kwargs:
cache_key = kwargs["cache_key"]
else:
cache_key = self.get_cache_key(**kwargs)
if cache_key is not None:
cache_control_args = kwargs.get("cache", {})
max_age = cache_control_args.get(
"s-max-age", cache_control_args.get("s-maxage", float("inf"))
)
cached_result = await self.cache.async_get_cache(cache_key, **kwargs)
return self._get_cache_logic(
cached_result=cached_result, max_age=max_age
)
except Exception:
print_verbose(f"An exception occurred: {traceback.format_exc()}")
return None
def _add_cache_logic(self, result, **kwargs):
"""
Common implementation across sync + async add_cache functions
"""
try:
if "cache_key" in kwargs:
cache_key = kwargs["cache_key"]
else:
cache_key = self.get_cache_key(**kwargs)
if cache_key is not None:
if isinstance(result, BaseModel):
result = result.model_dump_json()
## DEFAULT TTL ##
if self.ttl is not None:
kwargs["ttl"] = self.ttl
## Get Cache-Controls ##
_cache_kwargs = kwargs.get("cache", None)
if isinstance(_cache_kwargs, dict):
for k, v in _cache_kwargs.items():
if k == "ttl":
kwargs["ttl"] = v
cached_data = {"timestamp": time.time(), "response": result}
return cache_key, cached_data, kwargs
else:
raise Exception("cache key is None")
except Exception as e:
raise e
def add_cache(self, result, **kwargs):
"""
Adds a result to the cache.
Args:
*args: args to litellm.completion() or embedding()
**kwargs: kwargs to litellm.completion() or embedding()
Returns:
None
"""
try:
if self.should_use_cache(**kwargs) is not True:
return
cache_key, cached_data, kwargs = self._add_cache_logic(
result=result, **kwargs
)
self.cache.set_cache(cache_key, cached_data, **kwargs)
except Exception as e:
verbose_logger.exception(f"LiteLLM Cache: Excepton add_cache: {str(e)}")
async def async_add_cache(self, result, **kwargs):
"""
Async implementation of add_cache
"""
try:
if self.should_use_cache(**kwargs) is not True:
return
if self.type == "redis" and self.redis_flush_size is not None:
# high traffic - fill in results in memory and then flush
await self.batch_cache_write(result, **kwargs)
else:
cache_key, cached_data, kwargs = self._add_cache_logic(
result=result, **kwargs
)
await self.cache.async_set_cache(cache_key, cached_data, **kwargs)
except Exception as e:
verbose_logger.exception(f"LiteLLM Cache: Excepton add_cache: {str(e)}")
async def async_add_cache_pipeline(self, result, **kwargs):
"""
Async implementation of add_cache for Embedding calls
Does a bulk write, to prevent using too many clients
"""
try:
if self.should_use_cache(**kwargs) is not True:
return
# set default ttl if not set
if self.ttl is not None:
kwargs["ttl"] = self.ttl
cache_list = []
for idx, i in enumerate(kwargs["input"]):
preset_cache_key = self.get_cache_key(**{**kwargs, "input": i})
kwargs["cache_key"] = preset_cache_key
embedding_response = result.data[idx]
cache_key, cached_data, kwargs = self._add_cache_logic(
result=embedding_response,
**kwargs,
)
cache_list.append((cache_key, cached_data))
await self.cache.async_set_cache_pipeline(cache_list=cache_list, **kwargs)
# if async_set_cache_pipeline:
# await async_set_cache_pipeline(cache_list=cache_list, **kwargs)
# else:
# tasks = []
# for val in cache_list:
# tasks.append(self.cache.async_set_cache(val[0], val[1], **kwargs))
# await asyncio.gather(*tasks)
except Exception as e:
verbose_logger.exception(f"LiteLLM Cache: Excepton add_cache: {str(e)}")
def should_use_cache(self, **kwargs):
"""
Returns true if we should use the cache for LLM API calls
If cache is default_on then this is True
If cache is default_off then this is only true when user has opted in to use cache
"""
if self.mode == CacheMode.default_on:
return True
# when mode == default_off -> Cache is opt in only
_cache = kwargs.get("cache", None)
verbose_logger.debug("should_use_cache: kwargs: %s; _cache: %s", kwargs, _cache)
if _cache and isinstance(_cache, dict):
if _cache.get("use-cache", False) is True:
return True
return False
async def batch_cache_write(self, result, **kwargs):
cache_key, cached_data, kwargs = self._add_cache_logic(result=result, **kwargs)
await self.cache.batch_cache_write(cache_key, cached_data, **kwargs)
async def ping(self):
cache_ping = getattr(self.cache, "ping")
if cache_ping:
return await cache_ping()
return None
async def delete_cache_keys(self, keys):
cache_delete_cache_keys = getattr(self.cache, "delete_cache_keys")
if cache_delete_cache_keys:
return await cache_delete_cache_keys(keys)
return None
async def disconnect(self):
if hasattr(self.cache, "disconnect"):
await self.cache.disconnect()
def _supports_async(self) -> bool:
"""
Internal method to check if the cache type supports async get/set operations
Only S3 Cache Does NOT support async operations
"""
if self.type and self.type == LiteLLMCacheType.S3:
return False
return True
def enable_cache(
type: Optional[LiteLLMCacheType] = LiteLLMCacheType.LOCAL,
host: Optional[str] = None,
port: Optional[str] = None,
password: Optional[str] = None,
supported_call_types: Optional[List[CachingSupportedCallTypes]] = [
"completion",
"acompletion",
"embedding",
"aembedding",
"atranscription",
"transcription",
"atext_completion",
"text_completion",
"arerank",
"rerank",
],
**kwargs,
):
"""
Enable cache with the specified configuration.
Args:
type (Optional[Literal["local", "redis", "s3", "disk"]]): The type of cache to enable. Defaults to "local".
host (Optional[str]): The host address of the cache server. Defaults to None.
port (Optional[str]): The port number of the cache server. Defaults to None.
password (Optional[str]): The password for the cache server. Defaults to None.
supported_call_types (Optional[List[Literal["completion", "acompletion", "embedding", "aembedding"]]]):
The supported call types for the cache. Defaults to ["completion", "acompletion", "embedding", "aembedding"].
**kwargs: Additional keyword arguments.
Returns:
None
Raises:
None
"""
print_verbose("LiteLLM: Enabling Cache")
if "cache" not in litellm.input_callback:
litellm.input_callback.append("cache")
if "cache" not in litellm.success_callback:
litellm.logging_callback_manager.add_litellm_success_callback("cache")
if "cache" not in litellm._async_success_callback:
litellm.logging_callback_manager.add_litellm_async_success_callback("cache")
if litellm.cache is None:
litellm.cache = Cache(
type=type,
host=host,
port=port,
password=password,
supported_call_types=supported_call_types,
**kwargs,
)
print_verbose(f"LiteLLM: Cache enabled, litellm.cache={litellm.cache}")
print_verbose(f"LiteLLM Cache: {vars(litellm.cache)}")
def update_cache(
type: Optional[LiteLLMCacheType] = LiteLLMCacheType.LOCAL,
host: Optional[str] = None,
port: Optional[str] = None,
password: Optional[str] = None,
supported_call_types: Optional[List[CachingSupportedCallTypes]] = [
"completion",
"acompletion",
"embedding",
"aembedding",
"atranscription",
"transcription",
"atext_completion",
"text_completion",
"arerank",
"rerank",
],
**kwargs,
):
"""
Update the cache for LiteLLM.
Args:
type (Optional[Literal["local", "redis", "s3", "disk"]]): The type of cache. Defaults to "local".
host (Optional[str]): The host of the cache. Defaults to None.
port (Optional[str]): The port of the cache. Defaults to None.
password (Optional[str]): The password for the cache. Defaults to None.
supported_call_types (Optional[List[Literal["completion", "acompletion", "embedding", "aembedding"]]]):
The supported call types for the cache. Defaults to ["completion", "acompletion", "embedding", "aembedding"].
**kwargs: Additional keyword arguments for the cache.
Returns:
None
"""
print_verbose("LiteLLM: Updating Cache")
litellm.cache = Cache(
type=type,
host=host,
port=port,
password=password,
supported_call_types=supported_call_types,
**kwargs,
)
print_verbose(f"LiteLLM: Cache Updated, litellm.cache={litellm.cache}")
print_verbose(f"LiteLLM Cache: {vars(litellm.cache)}")
def disable_cache():
"""
Disable the cache used by LiteLLM.
This function disables the cache used by the LiteLLM module. It removes the cache-related callbacks from the input_callback, success_callback, and _async_success_callback lists. It also sets the litellm.cache attribute to None.
Parameters:
None
Returns:
None
"""
from contextlib import suppress
print_verbose("LiteLLM: Disabling Cache")
with suppress(ValueError):
litellm.input_callback.remove("cache")
litellm.success_callback.remove("cache")
litellm._async_success_callback.remove("cache")
litellm.cache = None
print_verbose(f"LiteLLM: Cache disabled, litellm.cache={litellm.cache}")

View File

@@ -0,0 +1,906 @@
"""
This contains LLMCachingHandler
This exposes two methods:
- async_get_cache
- async_set_cache
This file is a wrapper around caching.py
This class is used to handle caching logic specific for LLM API requests (completion / embedding / text_completion / transcription etc)
It utilizes the (RedisCache, s3Cache, RedisSemanticCache, QdrantSemanticCache, InMemoryCache, DiskCache) based on what the user has setup
In each method it will call the appropriate method from caching.py
"""
import asyncio
import datetime
import inspect
import threading
from typing import (
TYPE_CHECKING,
Any,
AsyncGenerator,
Callable,
Dict,
Generator,
List,
Optional,
Tuple,
Union,
)
from pydantic import BaseModel
import litellm
from litellm._logging import print_verbose, verbose_logger
from litellm.caching.caching import S3Cache
from litellm.litellm_core_utils.logging_utils import (
_assemble_complete_response_from_streaming_chunks,
)
from litellm.types.rerank import RerankResponse
from litellm.types.utils import (
CallTypes,
Embedding,
EmbeddingResponse,
ModelResponse,
TextCompletionResponse,
TranscriptionResponse,
)
if TYPE_CHECKING:
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
from litellm.utils import CustomStreamWrapper
else:
LiteLLMLoggingObj = Any
CustomStreamWrapper = Any
class CachingHandlerResponse(BaseModel):
"""
This is the response object for the caching handler. We need to separate embedding cached responses and (completion / text_completion / transcription) cached responses
For embeddings there can be a cache hit for some of the inputs in the list and a cache miss for others
"""
cached_result: Optional[Any] = None
final_embedding_cached_response: Optional[EmbeddingResponse] = None
embedding_all_elements_cache_hit: bool = False # this is set to True when all elements in the list have a cache hit in the embedding cache, if true return the final_embedding_cached_response no need to make an API call
class LLMCachingHandler:
def __init__(
self,
original_function: Callable,
request_kwargs: Dict[str, Any],
start_time: datetime.datetime,
):
self.async_streaming_chunks: List[ModelResponse] = []
self.sync_streaming_chunks: List[ModelResponse] = []
self.request_kwargs = request_kwargs
self.original_function = original_function
self.start_time = start_time
pass
async def _async_get_cache(
self,
model: str,
original_function: Callable,
logging_obj: LiteLLMLoggingObj,
start_time: datetime.datetime,
call_type: str,
kwargs: Dict[str, Any],
args: Optional[Tuple[Any, ...]] = None,
) -> CachingHandlerResponse:
"""
Internal method to get from the cache.
Handles different call types (embeddings, chat/completions, text_completion, transcription)
and accordingly returns the cached response
Args:
model: str:
original_function: Callable:
logging_obj: LiteLLMLoggingObj:
start_time: datetime.datetime:
call_type: str:
kwargs: Dict[str, Any]:
args: Optional[Tuple[Any, ...]] = None:
Returns:
CachingHandlerResponse:
Raises:
None
"""
from litellm.utils import CustomStreamWrapper
args = args or ()
final_embedding_cached_response: Optional[EmbeddingResponse] = None
embedding_all_elements_cache_hit: bool = False
cached_result: Optional[Any] = None
if (
(kwargs.get("caching", None) is None and litellm.cache is not None)
or kwargs.get("caching", False) is True
) and (
kwargs.get("cache", {}).get("no-cache", False) is not True
): # allow users to control returning cached responses from the completion function
if litellm.cache is not None and self._is_call_type_supported_by_cache(
original_function=original_function
):
verbose_logger.debug("Checking Cache")
cached_result = await self._retrieve_from_cache(
call_type=call_type,
kwargs=kwargs,
args=args,
)
if cached_result is not None and not isinstance(cached_result, list):
verbose_logger.debug("Cache Hit!")
cache_hit = True
end_time = datetime.datetime.now()
model, _, _, _ = litellm.get_llm_provider(
model=model,
custom_llm_provider=kwargs.get("custom_llm_provider", None),
api_base=kwargs.get("api_base", None),
api_key=kwargs.get("api_key", None),
)
self._update_litellm_logging_obj_environment(
logging_obj=logging_obj,
model=model,
kwargs=kwargs,
cached_result=cached_result,
is_async=True,
)
call_type = original_function.__name__
cached_result = self._convert_cached_result_to_model_response(
cached_result=cached_result,
call_type=call_type,
kwargs=kwargs,
logging_obj=logging_obj,
model=model,
custom_llm_provider=kwargs.get("custom_llm_provider", None),
args=args,
)
if kwargs.get("stream", False) is False:
# LOG SUCCESS
self._async_log_cache_hit_on_callbacks(
logging_obj=logging_obj,
cached_result=cached_result,
start_time=start_time,
end_time=end_time,
cache_hit=cache_hit,
)
cache_key = litellm.cache._get_preset_cache_key_from_kwargs(
**kwargs
)
if (
isinstance(cached_result, BaseModel)
or isinstance(cached_result, CustomStreamWrapper)
) and hasattr(cached_result, "_hidden_params"):
cached_result._hidden_params["cache_key"] = cache_key # type: ignore
return CachingHandlerResponse(cached_result=cached_result)
elif (
call_type == CallTypes.aembedding.value
and cached_result is not None
and isinstance(cached_result, list)
and litellm.cache is not None
and not isinstance(
litellm.cache.cache, S3Cache
) # s3 doesn't support bulk writing. Exclude.
):
(
final_embedding_cached_response,
embedding_all_elements_cache_hit,
) = self._process_async_embedding_cached_response(
final_embedding_cached_response=final_embedding_cached_response,
cached_result=cached_result,
kwargs=kwargs,
logging_obj=logging_obj,
start_time=start_time,
model=model,
)
return CachingHandlerResponse(
final_embedding_cached_response=final_embedding_cached_response,
embedding_all_elements_cache_hit=embedding_all_elements_cache_hit,
)
verbose_logger.debug(f"CACHE RESULT: {cached_result}")
return CachingHandlerResponse(
cached_result=cached_result,
final_embedding_cached_response=final_embedding_cached_response,
)
def _sync_get_cache(
self,
model: str,
original_function: Callable,
logging_obj: LiteLLMLoggingObj,
start_time: datetime.datetime,
call_type: str,
kwargs: Dict[str, Any],
args: Optional[Tuple[Any, ...]] = None,
) -> CachingHandlerResponse:
from litellm.utils import CustomStreamWrapper
args = args or ()
new_kwargs = kwargs.copy()
new_kwargs.update(
convert_args_to_kwargs(
self.original_function,
args,
)
)
cached_result: Optional[Any] = None
if litellm.cache is not None and self._is_call_type_supported_by_cache(
original_function=original_function
):
print_verbose("Checking Cache")
cached_result = litellm.cache.get_cache(**new_kwargs)
if cached_result is not None:
if "detail" in cached_result:
# implies an error occurred
pass
else:
call_type = original_function.__name__
cached_result = self._convert_cached_result_to_model_response(
cached_result=cached_result,
call_type=call_type,
kwargs=kwargs,
logging_obj=logging_obj,
model=model,
custom_llm_provider=kwargs.get("custom_llm_provider", None),
args=args,
)
# LOG SUCCESS
cache_hit = True
end_time = datetime.datetime.now()
(
model,
custom_llm_provider,
dynamic_api_key,
api_base,
) = litellm.get_llm_provider(
model=model or "",
custom_llm_provider=kwargs.get("custom_llm_provider", None),
api_base=kwargs.get("api_base", None),
api_key=kwargs.get("api_key", None),
)
self._update_litellm_logging_obj_environment(
logging_obj=logging_obj,
model=model,
kwargs=kwargs,
cached_result=cached_result,
is_async=False,
)
threading.Thread(
target=logging_obj.success_handler,
args=(cached_result, start_time, end_time, cache_hit),
).start()
cache_key = litellm.cache._get_preset_cache_key_from_kwargs(
**kwargs
)
if (
isinstance(cached_result, BaseModel)
or isinstance(cached_result, CustomStreamWrapper)
) and hasattr(cached_result, "_hidden_params"):
cached_result._hidden_params["cache_key"] = cache_key # type: ignore
return CachingHandlerResponse(cached_result=cached_result)
return CachingHandlerResponse(cached_result=cached_result)
def _process_async_embedding_cached_response(
self,
final_embedding_cached_response: Optional[EmbeddingResponse],
cached_result: List[Optional[Dict[str, Any]]],
kwargs: Dict[str, Any],
logging_obj: LiteLLMLoggingObj,
start_time: datetime.datetime,
model: str,
) -> Tuple[Optional[EmbeddingResponse], bool]:
"""
Returns the final embedding cached response and a boolean indicating if all elements in the list have a cache hit
For embedding responses, there can be a cache hit for some of the inputs in the list and a cache miss for others
This function processes the cached embedding responses and returns the final embedding cached response and a boolean indicating if all elements in the list have a cache hit
Args:
final_embedding_cached_response: Optional[EmbeddingResponse]:
cached_result: List[Optional[Dict[str, Any]]]:
kwargs: Dict[str, Any]:
logging_obj: LiteLLMLoggingObj:
start_time: datetime.datetime:
model: str:
Returns:
Tuple[Optional[EmbeddingResponse], bool]:
Returns the final embedding cached response and a boolean indicating if all elements in the list have a cache hit
"""
embedding_all_elements_cache_hit: bool = False
remaining_list = []
non_null_list = []
for idx, cr in enumerate(cached_result):
if cr is None:
remaining_list.append(kwargs["input"][idx])
else:
non_null_list.append((idx, cr))
original_kwargs_input = kwargs["input"]
kwargs["input"] = remaining_list
if len(non_null_list) > 0:
print_verbose(f"EMBEDDING CACHE HIT! - {len(non_null_list)}")
final_embedding_cached_response = EmbeddingResponse(
model=kwargs.get("model"),
data=[None] * len(original_kwargs_input),
)
final_embedding_cached_response._hidden_params["cache_hit"] = True
for val in non_null_list:
idx, cr = val # (idx, cr) tuple
if cr is not None:
final_embedding_cached_response.data[idx] = Embedding(
embedding=cr["embedding"],
index=idx,
object="embedding",
)
if len(remaining_list) == 0:
# LOG SUCCESS
cache_hit = True
embedding_all_elements_cache_hit = True
end_time = datetime.datetime.now()
(
model,
custom_llm_provider,
dynamic_api_key,
api_base,
) = litellm.get_llm_provider(
model=model,
custom_llm_provider=kwargs.get("custom_llm_provider", None),
api_base=kwargs.get("api_base", None),
api_key=kwargs.get("api_key", None),
)
self._update_litellm_logging_obj_environment(
logging_obj=logging_obj,
model=model,
kwargs=kwargs,
cached_result=final_embedding_cached_response,
is_async=True,
is_embedding=True,
)
self._async_log_cache_hit_on_callbacks(
logging_obj=logging_obj,
cached_result=final_embedding_cached_response,
start_time=start_time,
end_time=end_time,
cache_hit=cache_hit,
)
return final_embedding_cached_response, embedding_all_elements_cache_hit
return final_embedding_cached_response, embedding_all_elements_cache_hit
def _combine_cached_embedding_response_with_api_result(
self,
_caching_handler_response: CachingHandlerResponse,
embedding_response: EmbeddingResponse,
start_time: datetime.datetime,
end_time: datetime.datetime,
) -> EmbeddingResponse:
"""
Combines the cached embedding response with the API EmbeddingResponse
For caching there can be a cache hit for some of the inputs in the list and a cache miss for others
This function combines the cached embedding response with the API EmbeddingResponse
Args:
caching_handler_response: CachingHandlerResponse:
embedding_response: EmbeddingResponse:
Returns:
EmbeddingResponse:
"""
if _caching_handler_response.final_embedding_cached_response is None:
return embedding_response
idx = 0
final_data_list = []
for item in _caching_handler_response.final_embedding_cached_response.data:
if item is None and embedding_response.data is not None:
final_data_list.append(embedding_response.data[idx])
idx += 1
else:
final_data_list.append(item)
_caching_handler_response.final_embedding_cached_response.data = final_data_list
_caching_handler_response.final_embedding_cached_response._hidden_params[
"cache_hit"
] = True
_caching_handler_response.final_embedding_cached_response._response_ms = (
end_time - start_time
).total_seconds() * 1000
return _caching_handler_response.final_embedding_cached_response
def _async_log_cache_hit_on_callbacks(
self,
logging_obj: LiteLLMLoggingObj,
cached_result: Any,
start_time: datetime.datetime,
end_time: datetime.datetime,
cache_hit: bool,
):
"""
Helper function to log the success of a cached result on callbacks
Args:
logging_obj (LiteLLMLoggingObj): The logging object.
cached_result: The cached result.
start_time (datetime): The start time of the operation.
end_time (datetime): The end time of the operation.
cache_hit (bool): Whether it was a cache hit.
"""
asyncio.create_task(
logging_obj.async_success_handler(
cached_result, start_time, end_time, cache_hit
)
)
threading.Thread(
target=logging_obj.success_handler,
args=(cached_result, start_time, end_time, cache_hit),
).start()
async def _retrieve_from_cache(
self, call_type: str, kwargs: Dict[str, Any], args: Tuple[Any, ...]
) -> Optional[Any]:
"""
Internal method to
- get cache key
- check what type of cache is used - Redis, RedisSemantic, Qdrant, S3
- async get cache value
- return the cached value
Args:
call_type: str:
kwargs: Dict[str, Any]:
args: Optional[Tuple[Any, ...]] = None:
Returns:
Optional[Any]:
Raises:
None
"""
if litellm.cache is None:
return None
new_kwargs = kwargs.copy()
new_kwargs.update(
convert_args_to_kwargs(
self.original_function,
args,
)
)
cached_result: Optional[Any] = None
if call_type == CallTypes.aembedding.value and isinstance(
new_kwargs["input"], list
):
tasks = []
for idx, i in enumerate(new_kwargs["input"]):
preset_cache_key = litellm.cache.get_cache_key(
**{**new_kwargs, "input": i}
)
tasks.append(litellm.cache.async_get_cache(cache_key=preset_cache_key))
cached_result = await asyncio.gather(*tasks)
## check if cached result is None ##
if cached_result is not None and isinstance(cached_result, list):
# set cached_result to None if all elements are None
if all(result is None for result in cached_result):
cached_result = None
else:
if litellm.cache._supports_async() is True:
cached_result = await litellm.cache.async_get_cache(**new_kwargs)
else: # for s3 caching. [NOT RECOMMENDED IN PROD - this will slow down responses since boto3 is sync]
cached_result = litellm.cache.get_cache(**new_kwargs)
return cached_result
def _convert_cached_result_to_model_response(
self,
cached_result: Any,
call_type: str,
kwargs: Dict[str, Any],
logging_obj: LiteLLMLoggingObj,
model: str,
args: Tuple[Any, ...],
custom_llm_provider: Optional[str] = None,
) -> Optional[
Union[
ModelResponse,
TextCompletionResponse,
EmbeddingResponse,
RerankResponse,
TranscriptionResponse,
CustomStreamWrapper,
]
]:
"""
Internal method to process the cached result
Checks the call type and converts the cached result to the appropriate model response object
example if call type is text_completion -> returns TextCompletionResponse object
Args:
cached_result: Any:
call_type: str:
kwargs: Dict[str, Any]:
logging_obj: LiteLLMLoggingObj:
model: str:
custom_llm_provider: Optional[str] = None:
args: Optional[Tuple[Any, ...]] = None:
Returns:
Optional[Any]:
"""
from litellm.utils import convert_to_model_response_object
if (
call_type == CallTypes.acompletion.value
or call_type == CallTypes.completion.value
) and isinstance(cached_result, dict):
if kwargs.get("stream", False) is True:
cached_result = self._convert_cached_stream_response(
cached_result=cached_result,
call_type=call_type,
logging_obj=logging_obj,
model=model,
)
else:
cached_result = convert_to_model_response_object(
response_object=cached_result,
model_response_object=ModelResponse(),
)
if (
call_type == CallTypes.atext_completion.value
or call_type == CallTypes.text_completion.value
) and isinstance(cached_result, dict):
if kwargs.get("stream", False) is True:
cached_result = self._convert_cached_stream_response(
cached_result=cached_result,
call_type=call_type,
logging_obj=logging_obj,
model=model,
)
else:
cached_result = TextCompletionResponse(**cached_result)
elif (
call_type == CallTypes.aembedding.value
or call_type == CallTypes.embedding.value
) and isinstance(cached_result, dict):
cached_result = convert_to_model_response_object(
response_object=cached_result,
model_response_object=EmbeddingResponse(),
response_type="embedding",
)
elif (
call_type == CallTypes.arerank.value or call_type == CallTypes.rerank.value
) and isinstance(cached_result, dict):
cached_result = convert_to_model_response_object(
response_object=cached_result,
model_response_object=None,
response_type="rerank",
)
elif (
call_type == CallTypes.atranscription.value
or call_type == CallTypes.transcription.value
) and isinstance(cached_result, dict):
hidden_params = {
"model": "whisper-1",
"custom_llm_provider": custom_llm_provider,
"cache_hit": True,
}
cached_result = convert_to_model_response_object(
response_object=cached_result,
model_response_object=TranscriptionResponse(),
response_type="audio_transcription",
hidden_params=hidden_params,
)
if (
hasattr(cached_result, "_hidden_params")
and cached_result._hidden_params is not None
and isinstance(cached_result._hidden_params, dict)
):
cached_result._hidden_params["cache_hit"] = True
return cached_result
def _convert_cached_stream_response(
self,
cached_result: Any,
call_type: str,
logging_obj: LiteLLMLoggingObj,
model: str,
) -> CustomStreamWrapper:
from litellm.utils import (
CustomStreamWrapper,
convert_to_streaming_response,
convert_to_streaming_response_async,
)
_stream_cached_result: Union[AsyncGenerator, Generator]
if (
call_type == CallTypes.acompletion.value
or call_type == CallTypes.atext_completion.value
):
_stream_cached_result = convert_to_streaming_response_async(
response_object=cached_result,
)
else:
_stream_cached_result = convert_to_streaming_response(
response_object=cached_result,
)
return CustomStreamWrapper(
completion_stream=_stream_cached_result,
model=model,
custom_llm_provider="cached_response",
logging_obj=logging_obj,
)
async def async_set_cache(
self,
result: Any,
original_function: Callable,
kwargs: Dict[str, Any],
args: Optional[Tuple[Any, ...]] = None,
):
"""
Internal method to check the type of the result & cache used and adds the result to the cache accordingly
Args:
result: Any:
original_function: Callable:
kwargs: Dict[str, Any]:
args: Optional[Tuple[Any, ...]] = None:
Returns:
None
Raises:
None
"""
if litellm.cache is None:
return
new_kwargs = kwargs.copy()
new_kwargs.update(
convert_args_to_kwargs(
original_function,
args,
)
)
# [OPTIONAL] ADD TO CACHE
if self._should_store_result_in_cache(
original_function=original_function, kwargs=new_kwargs
):
if (
isinstance(result, litellm.ModelResponse)
or isinstance(result, litellm.EmbeddingResponse)
or isinstance(result, TranscriptionResponse)
or isinstance(result, RerankResponse)
):
if (
isinstance(result, EmbeddingResponse)
and isinstance(new_kwargs["input"], list)
and litellm.cache is not None
and not isinstance(
litellm.cache.cache, S3Cache
) # s3 doesn't support bulk writing. Exclude.
):
asyncio.create_task(
litellm.cache.async_add_cache_pipeline(result, **new_kwargs)
)
elif isinstance(litellm.cache.cache, S3Cache):
threading.Thread(
target=litellm.cache.add_cache,
args=(result,),
kwargs=new_kwargs,
).start()
else:
asyncio.create_task(
litellm.cache.async_add_cache(
result.model_dump_json(), **new_kwargs
)
)
else:
asyncio.create_task(litellm.cache.async_add_cache(result, **new_kwargs))
def sync_set_cache(
self,
result: Any,
kwargs: Dict[str, Any],
args: Optional[Tuple[Any, ...]] = None,
):
"""
Sync internal method to add the result to the cache
"""
new_kwargs = kwargs.copy()
new_kwargs.update(
convert_args_to_kwargs(
self.original_function,
args,
)
)
if litellm.cache is None:
return
if self._should_store_result_in_cache(
original_function=self.original_function, kwargs=new_kwargs
):
litellm.cache.add_cache(result, **new_kwargs)
return
def _should_store_result_in_cache(
self, original_function: Callable, kwargs: Dict[str, Any]
) -> bool:
"""
Helper function to determine if the result should be stored in the cache.
Returns:
bool: True if the result should be stored in the cache, False otherwise.
"""
return (
(litellm.cache is not None)
and litellm.cache.supported_call_types is not None
and (str(original_function.__name__) in litellm.cache.supported_call_types)
and (kwargs.get("cache", {}).get("no-store", False) is not True)
)
def _is_call_type_supported_by_cache(
self,
original_function: Callable,
) -> bool:
"""
Helper function to determine if the call type is supported by the cache.
call types are acompletion, aembedding, atext_completion, atranscription, arerank
Defined on `litellm.types.utils.CallTypes`
Returns:
bool: True if the call type is supported by the cache, False otherwise.
"""
if (
litellm.cache is not None
and litellm.cache.supported_call_types is not None
and str(original_function.__name__) in litellm.cache.supported_call_types
):
return True
return False
async def _add_streaming_response_to_cache(self, processed_chunk: ModelResponse):
"""
Internal method to add the streaming response to the cache
- If 'streaming_chunk' has a 'finish_reason' then assemble a litellm.ModelResponse object
- Else append the chunk to self.async_streaming_chunks
"""
complete_streaming_response: Optional[
Union[ModelResponse, TextCompletionResponse]
] = _assemble_complete_response_from_streaming_chunks(
result=processed_chunk,
start_time=self.start_time,
end_time=datetime.datetime.now(),
request_kwargs=self.request_kwargs,
streaming_chunks=self.async_streaming_chunks,
is_async=True,
)
# if a complete_streaming_response is assembled, add it to the cache
if complete_streaming_response is not None:
await self.async_set_cache(
result=complete_streaming_response,
original_function=self.original_function,
kwargs=self.request_kwargs,
)
def _sync_add_streaming_response_to_cache(self, processed_chunk: ModelResponse):
"""
Sync internal method to add the streaming response to the cache
"""
complete_streaming_response: Optional[
Union[ModelResponse, TextCompletionResponse]
] = _assemble_complete_response_from_streaming_chunks(
result=processed_chunk,
start_time=self.start_time,
end_time=datetime.datetime.now(),
request_kwargs=self.request_kwargs,
streaming_chunks=self.sync_streaming_chunks,
is_async=False,
)
# if a complete_streaming_response is assembled, add it to the cache
if complete_streaming_response is not None:
self.sync_set_cache(
result=complete_streaming_response,
kwargs=self.request_kwargs,
)
def _update_litellm_logging_obj_environment(
self,
logging_obj: LiteLLMLoggingObj,
model: str,
kwargs: Dict[str, Any],
cached_result: Any,
is_async: bool,
is_embedding: bool = False,
):
"""
Helper function to update the LiteLLMLoggingObj environment variables.
Args:
logging_obj (LiteLLMLoggingObj): The logging object to update.
model (str): The model being used.
kwargs (Dict[str, Any]): The keyword arguments from the original function call.
cached_result (Any): The cached result to log.
is_async (bool): Whether the call is asynchronous or not.
is_embedding (bool): Whether the call is for embeddings or not.
Returns:
None
"""
litellm_params = {
"logger_fn": kwargs.get("logger_fn", None),
"acompletion": is_async,
"api_base": kwargs.get("api_base", ""),
"metadata": kwargs.get("metadata", {}),
"model_info": kwargs.get("model_info", {}),
"proxy_server_request": kwargs.get("proxy_server_request", None),
"stream_response": kwargs.get("stream_response", {}),
}
if litellm.cache is not None:
litellm_params[
"preset_cache_key"
] = litellm.cache._get_preset_cache_key_from_kwargs(**kwargs)
else:
litellm_params["preset_cache_key"] = None
logging_obj.update_environment_variables(
model=model,
user=kwargs.get("user", None),
optional_params={},
litellm_params=litellm_params,
input=(
kwargs.get("messages", "")
if not is_embedding
else kwargs.get("input", "")
),
api_key=kwargs.get("api_key", None),
original_response=str(cached_result),
additional_args=None,
stream=kwargs.get("stream", False),
)
def convert_args_to_kwargs(
original_function: Callable,
args: Optional[Tuple[Any, ...]] = None,
) -> Dict[str, Any]:
# Get the signature of the original function
signature = inspect.signature(original_function)
# Get parameter names in the order they appear in the original function
param_names = list(signature.parameters.keys())
# Create a mapping of positional arguments to parameter names
args_to_kwargs = {}
if args:
for index, arg in enumerate(args):
if index < len(param_names):
param_name = param_names[index]
args_to_kwargs[param_name] = arg
return args_to_kwargs

View File

@@ -0,0 +1,88 @@
import json
from typing import TYPE_CHECKING, Any, Optional, Union
from .base_cache import BaseCache
if TYPE_CHECKING:
from opentelemetry.trace import Span as _Span
Span = Union[_Span, Any]
else:
Span = Any
class DiskCache(BaseCache):
def __init__(self, disk_cache_dir: Optional[str] = None):
import diskcache as dc
# if users don't provider one, use the default litellm cache
if disk_cache_dir is None:
self.disk_cache = dc.Cache(".litellm_cache")
else:
self.disk_cache = dc.Cache(disk_cache_dir)
def set_cache(self, key, value, **kwargs):
if "ttl" in kwargs:
self.disk_cache.set(key, value, expire=kwargs["ttl"])
else:
self.disk_cache.set(key, value)
async def async_set_cache(self, key, value, **kwargs):
self.set_cache(key=key, value=value, **kwargs)
async def async_set_cache_pipeline(self, cache_list, **kwargs):
for cache_key, cache_value in cache_list:
if "ttl" in kwargs:
self.set_cache(key=cache_key, value=cache_value, ttl=kwargs["ttl"])
else:
self.set_cache(key=cache_key, value=cache_value)
def get_cache(self, key, **kwargs):
original_cached_response = self.disk_cache.get(key)
if original_cached_response:
try:
cached_response = json.loads(original_cached_response) # type: ignore
except Exception:
cached_response = original_cached_response
return cached_response
return None
def batch_get_cache(self, keys: list, **kwargs):
return_val = []
for k in keys:
val = self.get_cache(key=k, **kwargs)
return_val.append(val)
return return_val
def increment_cache(self, key, value: int, **kwargs) -> int:
# get the value
init_value = self.get_cache(key=key) or 0
value = init_value + value # type: ignore
self.set_cache(key, value, **kwargs)
return value
async def async_get_cache(self, key, **kwargs):
return self.get_cache(key=key, **kwargs)
async def async_batch_get_cache(self, keys: list, **kwargs):
return_val = []
for k in keys:
val = self.get_cache(key=k, **kwargs)
return_val.append(val)
return return_val
async def async_increment(self, key, value: int, **kwargs) -> int:
# get the value
init_value = await self.async_get_cache(key=key) or 0
value = init_value + value # type: ignore
await self.async_set_cache(key, value, **kwargs)
return value
def flush_cache(self):
self.disk_cache.clear()
async def disconnect(self):
pass
def delete_cache(self, key):
self.disk_cache.pop(key)

View File

@@ -0,0 +1,434 @@
"""
Dual Cache implementation - Class to update both Redis and an in-memory cache simultaneously.
Has 4 primary methods:
- set_cache
- get_cache
- async_set_cache
- async_get_cache
"""
import asyncio
import time
import traceback
from concurrent.futures import ThreadPoolExecutor
from typing import TYPE_CHECKING, Any, List, Optional, Union
import litellm
from litellm._logging import print_verbose, verbose_logger
from .base_cache import BaseCache
from .in_memory_cache import InMemoryCache
from .redis_cache import RedisCache
if TYPE_CHECKING:
from opentelemetry.trace import Span as _Span
Span = Union[_Span, Any]
else:
Span = Any
from collections import OrderedDict
class LimitedSizeOrderedDict(OrderedDict):
def __init__(self, *args, max_size=100, **kwargs):
super().__init__(*args, **kwargs)
self.max_size = max_size
def __setitem__(self, key, value):
# If inserting a new key exceeds max size, remove the oldest item
if len(self) >= self.max_size:
self.popitem(last=False)
super().__setitem__(key, value)
class DualCache(BaseCache):
"""
DualCache is a cache implementation that updates both Redis and an in-memory cache simultaneously.
When data is updated or inserted, it is written to both the in-memory cache + Redis.
This ensures that even if Redis hasn't been updated yet, the in-memory cache reflects the most recent data.
"""
def __init__(
self,
in_memory_cache: Optional[InMemoryCache] = None,
redis_cache: Optional[RedisCache] = None,
default_in_memory_ttl: Optional[float] = None,
default_redis_ttl: Optional[float] = None,
default_redis_batch_cache_expiry: Optional[float] = None,
default_max_redis_batch_cache_size: int = 100,
) -> None:
super().__init__()
# If in_memory_cache is not provided, use the default InMemoryCache
self.in_memory_cache = in_memory_cache or InMemoryCache()
# If redis_cache is not provided, use the default RedisCache
self.redis_cache = redis_cache
self.last_redis_batch_access_time = LimitedSizeOrderedDict(
max_size=default_max_redis_batch_cache_size
)
self.redis_batch_cache_expiry = (
default_redis_batch_cache_expiry
or litellm.default_redis_batch_cache_expiry
or 10
)
self.default_in_memory_ttl = (
default_in_memory_ttl or litellm.default_in_memory_ttl
)
self.default_redis_ttl = default_redis_ttl or litellm.default_redis_ttl
def update_cache_ttl(
self, default_in_memory_ttl: Optional[float], default_redis_ttl: Optional[float]
):
if default_in_memory_ttl is not None:
self.default_in_memory_ttl = default_in_memory_ttl
if default_redis_ttl is not None:
self.default_redis_ttl = default_redis_ttl
def set_cache(self, key, value, local_only: bool = False, **kwargs):
# Update both Redis and in-memory cache
try:
if self.in_memory_cache is not None:
if "ttl" not in kwargs and self.default_in_memory_ttl is not None:
kwargs["ttl"] = self.default_in_memory_ttl
self.in_memory_cache.set_cache(key, value, **kwargs)
if self.redis_cache is not None and local_only is False:
self.redis_cache.set_cache(key, value, **kwargs)
except Exception as e:
print_verbose(e)
def increment_cache(
self, key, value: int, local_only: bool = False, **kwargs
) -> int:
"""
Key - the key in cache
Value - int - the value you want to increment by
Returns - int - the incremented value
"""
try:
result: int = value
if self.in_memory_cache is not None:
result = self.in_memory_cache.increment_cache(key, value, **kwargs)
if self.redis_cache is not None and local_only is False:
result = self.redis_cache.increment_cache(key, value, **kwargs)
return result
except Exception as e:
verbose_logger.error(f"LiteLLM Cache: Excepton async add_cache: {str(e)}")
raise e
def get_cache(
self,
key,
parent_otel_span: Optional[Span] = None,
local_only: bool = False,
**kwargs,
):
# Try to fetch from in-memory cache first
try:
result = None
if self.in_memory_cache is not None:
in_memory_result = self.in_memory_cache.get_cache(key, **kwargs)
if in_memory_result is not None:
result = in_memory_result
if result is None and self.redis_cache is not None and local_only is False:
# If not found in in-memory cache, try fetching from Redis
redis_result = self.redis_cache.get_cache(
key, parent_otel_span=parent_otel_span
)
if redis_result is not None:
# Update in-memory cache with the value from Redis
self.in_memory_cache.set_cache(key, redis_result, **kwargs)
result = redis_result
print_verbose(f"get cache: cache result: {result}")
return result
except Exception:
verbose_logger.error(traceback.format_exc())
def batch_get_cache(
self,
keys: list,
parent_otel_span: Optional[Span] = None,
local_only: bool = False,
**kwargs,
):
received_args = locals()
received_args.pop("self")
def run_in_new_loop():
"""Run the coroutine in a new event loop within this thread."""
new_loop = asyncio.new_event_loop()
try:
asyncio.set_event_loop(new_loop)
return new_loop.run_until_complete(
self.async_batch_get_cache(**received_args)
)
finally:
new_loop.close()
asyncio.set_event_loop(None)
try:
# First, try to get the current event loop
_ = asyncio.get_running_loop()
# If we're already in an event loop, run in a separate thread
# to avoid nested event loop issues
with ThreadPoolExecutor(max_workers=1) as executor:
future = executor.submit(run_in_new_loop)
return future.result()
except RuntimeError:
# No running event loop, we can safely run in this thread
return run_in_new_loop()
async def async_get_cache(
self,
key,
parent_otel_span: Optional[Span] = None,
local_only: bool = False,
**kwargs,
):
# Try to fetch from in-memory cache first
try:
print_verbose(
f"async get cache: cache key: {key}; local_only: {local_only}"
)
result = None
if self.in_memory_cache is not None:
in_memory_result = await self.in_memory_cache.async_get_cache(
key, **kwargs
)
print_verbose(f"in_memory_result: {in_memory_result}")
if in_memory_result is not None:
result = in_memory_result
if result is None and self.redis_cache is not None and local_only is False:
# If not found in in-memory cache, try fetching from Redis
redis_result = await self.redis_cache.async_get_cache(
key, parent_otel_span=parent_otel_span
)
if redis_result is not None:
# Update in-memory cache with the value from Redis
await self.in_memory_cache.async_set_cache(
key, redis_result, **kwargs
)
result = redis_result
print_verbose(f"get cache: cache result: {result}")
return result
except Exception:
verbose_logger.error(traceback.format_exc())
def get_redis_batch_keys(
self,
current_time: float,
keys: List[str],
result: List[Any],
) -> List[str]:
sublist_keys = []
for key, value in zip(keys, result):
if value is None:
if (
key not in self.last_redis_batch_access_time
or current_time - self.last_redis_batch_access_time[key]
>= self.redis_batch_cache_expiry
):
sublist_keys.append(key)
return sublist_keys
async def async_batch_get_cache(
self,
keys: list,
parent_otel_span: Optional[Span] = None,
local_only: bool = False,
**kwargs,
):
try:
result = [None for _ in range(len(keys))]
if self.in_memory_cache is not None:
in_memory_result = await self.in_memory_cache.async_batch_get_cache(
keys, **kwargs
)
if in_memory_result is not None:
result = in_memory_result
if None in result and self.redis_cache is not None and local_only is False:
"""
- for the none values in the result
- check the redis cache
"""
current_time = time.time()
sublist_keys = self.get_redis_batch_keys(current_time, keys, result)
# Only hit Redis if the last access time was more than 5 seconds ago
if len(sublist_keys) > 0:
# If not found in in-memory cache, try fetching from Redis
redis_result = await self.redis_cache.async_batch_get_cache(
sublist_keys, parent_otel_span=parent_otel_span
)
if redis_result is not None:
# Update in-memory cache with the value from Redis
for key, value in redis_result.items():
if value is not None:
await self.in_memory_cache.async_set_cache(
key, redis_result[key], **kwargs
)
# Update the last access time for each key fetched from Redis
self.last_redis_batch_access_time[key] = current_time
for key, value in redis_result.items():
index = keys.index(key)
result[index] = value
return result
except Exception:
verbose_logger.error(traceback.format_exc())
async def async_set_cache(self, key, value, local_only: bool = False, **kwargs):
print_verbose(
f"async set cache: cache key: {key}; local_only: {local_only}; value: {value}"
)
try:
if self.in_memory_cache is not None:
await self.in_memory_cache.async_set_cache(key, value, **kwargs)
if self.redis_cache is not None and local_only is False:
await self.redis_cache.async_set_cache(key, value, **kwargs)
except Exception as e:
verbose_logger.exception(
f"LiteLLM Cache: Excepton async add_cache: {str(e)}"
)
# async_batch_set_cache
async def async_set_cache_pipeline(
self, cache_list: list, local_only: bool = False, **kwargs
):
"""
Batch write values to the cache
"""
print_verbose(
f"async batch set cache: cache keys: {cache_list}; local_only: {local_only}"
)
try:
if self.in_memory_cache is not None:
await self.in_memory_cache.async_set_cache_pipeline(
cache_list=cache_list, **kwargs
)
if self.redis_cache is not None and local_only is False:
await self.redis_cache.async_set_cache_pipeline(
cache_list=cache_list, ttl=kwargs.pop("ttl", None), **kwargs
)
except Exception as e:
verbose_logger.exception(
f"LiteLLM Cache: Excepton async add_cache: {str(e)}"
)
async def async_increment_cache(
self,
key,
value: float,
parent_otel_span: Optional[Span] = None,
local_only: bool = False,
**kwargs,
) -> float:
"""
Key - the key in cache
Value - float - the value you want to increment by
Returns - float - the incremented value
"""
try:
result: float = value
if self.in_memory_cache is not None:
result = await self.in_memory_cache.async_increment(
key, value, **kwargs
)
if self.redis_cache is not None and local_only is False:
result = await self.redis_cache.async_increment(
key,
value,
parent_otel_span=parent_otel_span,
ttl=kwargs.get("ttl", None),
)
return result
except Exception as e:
raise e # don't log if exception is raised
async def async_set_cache_sadd(
self, key, value: List, local_only: bool = False, **kwargs
) -> None:
"""
Add value to a set
Key - the key in cache
Value - str - the value you want to add to the set
Returns - None
"""
try:
if self.in_memory_cache is not None:
_ = await self.in_memory_cache.async_set_cache_sadd(
key, value, ttl=kwargs.get("ttl", None)
)
if self.redis_cache is not None and local_only is False:
_ = await self.redis_cache.async_set_cache_sadd(
key, value, ttl=kwargs.get("ttl", None)
)
return None
except Exception as e:
raise e # don't log, if exception is raised
def flush_cache(self):
if self.in_memory_cache is not None:
self.in_memory_cache.flush_cache()
if self.redis_cache is not None:
self.redis_cache.flush_cache()
def delete_cache(self, key):
"""
Delete a key from the cache
"""
if self.in_memory_cache is not None:
self.in_memory_cache.delete_cache(key)
if self.redis_cache is not None:
self.redis_cache.delete_cache(key)
async def async_delete_cache(self, key: str):
"""
Delete a key from the cache
"""
if self.in_memory_cache is not None:
self.in_memory_cache.delete_cache(key)
if self.redis_cache is not None:
await self.redis_cache.async_delete_cache(key)
async def async_get_ttl(self, key: str) -> Optional[int]:
"""
Get the remaining TTL of a key in in-memory cache or redis
"""
ttl = await self.in_memory_cache.async_get_ttl(key)
if ttl is None and self.redis_cache is not None:
ttl = await self.redis_cache.async_get_ttl(key)
return ttl

View File

@@ -0,0 +1,204 @@
"""
In-Memory Cache implementation
Has 4 methods:
- set_cache
- get_cache
- async_set_cache
- async_get_cache
"""
import json
import sys
import time
from typing import Any, List, Optional
from pydantic import BaseModel
from litellm.constants import MAX_SIZE_PER_ITEM_IN_MEMORY_CACHE_IN_KB
from .base_cache import BaseCache
class InMemoryCache(BaseCache):
def __init__(
self,
max_size_in_memory: Optional[int] = 200,
default_ttl: Optional[
int
] = 600, # default ttl is 10 minutes. At maximum litellm rate limiting logic requires objects to be in memory for 1 minute
max_size_per_item: Optional[int] = 1024, # 1MB = 1024KB
):
"""
max_size_in_memory [int]: Maximum number of items in cache. done to prevent memory leaks. Use 200 items as a default
"""
self.max_size_in_memory = (
max_size_in_memory or 200
) # set an upper bound of 200 items in-memory
self.default_ttl = default_ttl or 600
self.max_size_per_item = (
max_size_per_item or MAX_SIZE_PER_ITEM_IN_MEMORY_CACHE_IN_KB
) # 1MB = 1024KB
# in-memory cache
self.cache_dict: dict = {}
self.ttl_dict: dict = {}
def check_value_size(self, value: Any):
"""
Check if value size exceeds max_size_per_item (1MB)
Returns True if value size is acceptable, False otherwise
"""
try:
# Fast path for common primitive types that are typically small
if (
isinstance(value, (bool, int, float, str))
and len(str(value))
< self.max_size_per_item * MAX_SIZE_PER_ITEM_IN_MEMORY_CACHE_IN_KB
): # Conservative estimate
return True
# Direct size check for bytes objects
if isinstance(value, bytes):
return sys.getsizeof(value) / 1024 <= self.max_size_per_item
# Handle special types without full conversion when possible
if hasattr(value, "__sizeof__"): # Use __sizeof__ if available
size = value.__sizeof__() / 1024
return size <= self.max_size_per_item
# Fallback for complex types
if isinstance(value, BaseModel) and hasattr(
value, "model_dump"
): # Pydantic v2
value = value.model_dump()
elif hasattr(value, "isoformat"): # datetime objects
return True # datetime strings are always small
# Only convert to JSON if absolutely necessary
if not isinstance(value, (str, bytes)):
value = json.dumps(value, default=str)
return sys.getsizeof(value) / 1024 <= self.max_size_per_item
except Exception:
return False
def evict_cache(self):
"""
Eviction policy:
- check if any items in ttl_dict are expired -> remove them from ttl_dict and cache_dict
This guarantees the following:
- 1. When item ttl not set: At minimumm each item will remain in memory for 5 minutes
- 2. When ttl is set: the item will remain in memory for at least that amount of time
- 3. the size of in-memory cache is bounded
"""
for key in list(self.ttl_dict.keys()):
if time.time() > self.ttl_dict[key]:
self.cache_dict.pop(key, None)
self.ttl_dict.pop(key, None)
# de-reference the removed item
# https://www.geeksforgeeks.org/diagnosing-and-fixing-memory-leaks-in-python/
# One of the most common causes of memory leaks in Python is the retention of objects that are no longer being used.
# This can occur when an object is referenced by another object, but the reference is never removed.
def set_cache(self, key, value, **kwargs):
if len(self.cache_dict) >= self.max_size_in_memory:
# only evict when cache is full
self.evict_cache()
if not self.check_value_size(value):
return
self.cache_dict[key] = value
if "ttl" in kwargs and kwargs["ttl"] is not None:
self.ttl_dict[key] = time.time() + kwargs["ttl"]
else:
self.ttl_dict[key] = time.time() + self.default_ttl
async def async_set_cache(self, key, value, **kwargs):
self.set_cache(key=key, value=value, **kwargs)
async def async_set_cache_pipeline(self, cache_list, ttl=None, **kwargs):
for cache_key, cache_value in cache_list:
if ttl is not None:
self.set_cache(key=cache_key, value=cache_value, ttl=ttl)
else:
self.set_cache(key=cache_key, value=cache_value)
async def async_set_cache_sadd(self, key, value: List, ttl: Optional[float]):
"""
Add value to set
"""
# get the value
init_value = self.get_cache(key=key) or set()
for val in value:
init_value.add(val)
self.set_cache(key, init_value, ttl=ttl)
return value
def get_cache(self, key, **kwargs):
if key in self.cache_dict:
if key in self.ttl_dict:
if time.time() > self.ttl_dict[key]:
self.cache_dict.pop(key, None)
return None
original_cached_response = self.cache_dict[key]
try:
cached_response = json.loads(original_cached_response)
except Exception:
cached_response = original_cached_response
return cached_response
return None
def batch_get_cache(self, keys: list, **kwargs):
return_val = []
for k in keys:
val = self.get_cache(key=k, **kwargs)
return_val.append(val)
return return_val
def increment_cache(self, key, value: int, **kwargs) -> int:
# get the value
init_value = self.get_cache(key=key) or 0
value = init_value + value
self.set_cache(key, value, **kwargs)
return value
async def async_get_cache(self, key, **kwargs):
return self.get_cache(key=key, **kwargs)
async def async_batch_get_cache(self, keys: list, **kwargs):
return_val = []
for k in keys:
val = self.get_cache(key=k, **kwargs)
return_val.append(val)
return return_val
async def async_increment(self, key, value: float, **kwargs) -> float:
# get the value
init_value = await self.async_get_cache(key=key) or 0
value = init_value + value
await self.async_set_cache(key, value, **kwargs)
return value
def flush_cache(self):
self.cache_dict.clear()
self.ttl_dict.clear()
async def disconnect(self):
pass
def delete_cache(self, key):
self.cache_dict.pop(key, None)
self.ttl_dict.pop(key, None)
async def async_get_ttl(self, key: str) -> Optional[int]:
"""
Get the remaining TTL of a key in in-memory cache
"""
return self.ttl_dict.get(key, None)

View File

@@ -0,0 +1,39 @@
"""
Add the event loop to the cache key, to prevent event loop closed errors.
"""
import asyncio
from .in_memory_cache import InMemoryCache
class LLMClientCache(InMemoryCache):
def update_cache_key_with_event_loop(self, key):
"""
Add the event loop to the cache key, to prevent event loop closed errors.
If none, use the key as is.
"""
try:
event_loop = asyncio.get_event_loop()
stringified_event_loop = str(id(event_loop))
return f"{key}-{stringified_event_loop}"
except Exception: # handle no current event loop
return key
def set_cache(self, key, value, **kwargs):
key = self.update_cache_key_with_event_loop(key)
return super().set_cache(key, value, **kwargs)
async def async_set_cache(self, key, value, **kwargs):
key = self.update_cache_key_with_event_loop(key)
return await super().async_set_cache(key, value, **kwargs)
def get_cache(self, key, **kwargs):
key = self.update_cache_key_with_event_loop(key)
return super().get_cache(key, **kwargs)
async def async_get_cache(self, key, **kwargs):
key = self.update_cache_key_with_event_loop(key)
return await super().async_get_cache(key, **kwargs)

View File

@@ -0,0 +1,442 @@
"""
Qdrant Semantic Cache implementation
Has 4 methods:
- set_cache
- get_cache
- async_set_cache
- async_get_cache
"""
import ast
import asyncio
import json
from typing import Any, cast
import litellm
from litellm._logging import print_verbose
from litellm.constants import QDRANT_SCALAR_QUANTILE, QDRANT_VECTOR_SIZE
from litellm.types.utils import EmbeddingResponse
from .base_cache import BaseCache
class QdrantSemanticCache(BaseCache):
def __init__( # noqa: PLR0915
self,
qdrant_api_base=None,
qdrant_api_key=None,
collection_name=None,
similarity_threshold=None,
quantization_config=None,
embedding_model="text-embedding-ada-002",
host_type=None,
):
import os
from litellm.llms.custom_httpx.http_handler import (
_get_httpx_client,
get_async_httpx_client,
httpxSpecialProvider,
)
from litellm.secret_managers.main import get_secret_str
if collection_name is None:
raise Exception("collection_name must be provided, passed None")
self.collection_name = collection_name
print_verbose(
f"qdrant semantic-cache initializing COLLECTION - {self.collection_name}"
)
if similarity_threshold is None:
raise Exception("similarity_threshold must be provided, passed None")
self.similarity_threshold = similarity_threshold
self.embedding_model = embedding_model
headers = {}
# check if defined as os.environ/ variable
if qdrant_api_base:
if isinstance(qdrant_api_base, str) and qdrant_api_base.startswith(
"os.environ/"
):
qdrant_api_base = get_secret_str(qdrant_api_base)
if qdrant_api_key:
if isinstance(qdrant_api_key, str) and qdrant_api_key.startswith(
"os.environ/"
):
qdrant_api_key = get_secret_str(qdrant_api_key)
qdrant_api_base = (
qdrant_api_base or os.getenv("QDRANT_URL") or os.getenv("QDRANT_API_BASE")
)
qdrant_api_key = qdrant_api_key or os.getenv("QDRANT_API_KEY")
headers = {"Content-Type": "application/json"}
if qdrant_api_key:
headers["api-key"] = qdrant_api_key
if qdrant_api_base is None:
raise ValueError("Qdrant url must be provided")
self.qdrant_api_base = qdrant_api_base
self.qdrant_api_key = qdrant_api_key
print_verbose(f"qdrant semantic-cache qdrant_api_base: {self.qdrant_api_base}")
self.headers = headers
self.sync_client = _get_httpx_client()
self.async_client = get_async_httpx_client(
llm_provider=httpxSpecialProvider.Caching
)
if quantization_config is None:
print_verbose(
"Quantization config is not provided. Default binary quantization will be used."
)
collection_exists = self.sync_client.get(
url=f"{self.qdrant_api_base}/collections/{self.collection_name}/exists",
headers=self.headers,
)
if collection_exists.status_code != 200:
raise ValueError(
f"Error from qdrant checking if /collections exist {collection_exists.text}"
)
if collection_exists.json()["result"]["exists"]:
collection_details = self.sync_client.get(
url=f"{self.qdrant_api_base}/collections/{self.collection_name}",
headers=self.headers,
)
self.collection_info = collection_details.json()
print_verbose(
f"Collection already exists.\nCollection details:{self.collection_info}"
)
else:
if quantization_config is None or quantization_config == "binary":
quantization_params = {
"binary": {
"always_ram": False,
}
}
elif quantization_config == "scalar":
quantization_params = {
"scalar": {
"type": "int8",
"quantile": QDRANT_SCALAR_QUANTILE,
"always_ram": False,
}
}
elif quantization_config == "product":
quantization_params = {
"product": {"compression": "x16", "always_ram": False}
}
else:
raise Exception(
"Quantization config must be one of 'scalar', 'binary' or 'product'"
)
new_collection_status = self.sync_client.put(
url=f"{self.qdrant_api_base}/collections/{self.collection_name}",
json={
"vectors": {"size": QDRANT_VECTOR_SIZE, "distance": "Cosine"},
"quantization_config": quantization_params,
},
headers=self.headers,
)
if new_collection_status.json()["result"]:
collection_details = self.sync_client.get(
url=f"{self.qdrant_api_base}/collections/{self.collection_name}",
headers=self.headers,
)
self.collection_info = collection_details.json()
print_verbose(
f"New collection created.\nCollection details:{self.collection_info}"
)
else:
raise Exception("Error while creating new collection")
def _get_cache_logic(self, cached_response: Any):
if cached_response is None:
return cached_response
try:
cached_response = json.loads(
cached_response
) # Convert string to dictionary
except Exception:
cached_response = ast.literal_eval(cached_response)
return cached_response
def set_cache(self, key, value, **kwargs):
print_verbose(f"qdrant semantic-cache set_cache, kwargs: {kwargs}")
import uuid
# get the prompt
messages = kwargs["messages"]
prompt = ""
for message in messages:
prompt += message["content"]
# create an embedding for prompt
embedding_response = cast(
EmbeddingResponse,
litellm.embedding(
model=self.embedding_model,
input=prompt,
cache={"no-store": True, "no-cache": True},
),
)
# get the embedding
embedding = embedding_response["data"][0]["embedding"]
value = str(value)
assert isinstance(value, str)
data = {
"points": [
{
"id": str(uuid.uuid4()),
"vector": embedding,
"payload": {
"text": prompt,
"response": value,
},
},
]
}
self.sync_client.put(
url=f"{self.qdrant_api_base}/collections/{self.collection_name}/points",
headers=self.headers,
json=data,
)
return
def get_cache(self, key, **kwargs):
print_verbose(f"sync qdrant semantic-cache get_cache, kwargs: {kwargs}")
# get the messages
messages = kwargs["messages"]
prompt = ""
for message in messages:
prompt += message["content"]
# convert to embedding
embedding_response = cast(
EmbeddingResponse,
litellm.embedding(
model=self.embedding_model,
input=prompt,
cache={"no-store": True, "no-cache": True},
),
)
# get the embedding
embedding = embedding_response["data"][0]["embedding"]
data = {
"vector": embedding,
"params": {
"quantization": {
"ignore": False,
"rescore": True,
"oversampling": 3.0,
}
},
"limit": 1,
"with_payload": True,
}
search_response = self.sync_client.post(
url=f"{self.qdrant_api_base}/collections/{self.collection_name}/points/search",
headers=self.headers,
json=data,
)
results = search_response.json()["result"]
if results is None:
return None
if isinstance(results, list):
if len(results) == 0:
return None
similarity = results[0]["score"]
cached_prompt = results[0]["payload"]["text"]
# check similarity, if more than self.similarity_threshold, return results
print_verbose(
f"semantic cache: similarity threshold: {self.similarity_threshold}, similarity: {similarity}, prompt: {prompt}, closest_cached_prompt: {cached_prompt}"
)
if similarity >= self.similarity_threshold:
# cache hit !
cached_value = results[0]["payload"]["response"]
print_verbose(
f"got a cache hit, similarity: {similarity}, Current prompt: {prompt}, cached_prompt: {cached_prompt}"
)
return self._get_cache_logic(cached_response=cached_value)
else:
# cache miss !
return None
pass
async def async_set_cache(self, key, value, **kwargs):
import uuid
from litellm.proxy.proxy_server import llm_model_list, llm_router
print_verbose(f"async qdrant semantic-cache set_cache, kwargs: {kwargs}")
# get the prompt
messages = kwargs["messages"]
prompt = ""
for message in messages:
prompt += message["content"]
# create an embedding for prompt
router_model_names = (
[m["model_name"] for m in llm_model_list]
if llm_model_list is not None
else []
)
if llm_router is not None and self.embedding_model in router_model_names:
user_api_key = kwargs.get("metadata", {}).get("user_api_key", "")
embedding_response = await llm_router.aembedding(
model=self.embedding_model,
input=prompt,
cache={"no-store": True, "no-cache": True},
metadata={
"user_api_key": user_api_key,
"semantic-cache-embedding": True,
"trace_id": kwargs.get("metadata", {}).get("trace_id", None),
},
)
else:
# convert to embedding
embedding_response = await litellm.aembedding(
model=self.embedding_model,
input=prompt,
cache={"no-store": True, "no-cache": True},
)
# get the embedding
embedding = embedding_response["data"][0]["embedding"]
value = str(value)
assert isinstance(value, str)
data = {
"points": [
{
"id": str(uuid.uuid4()),
"vector": embedding,
"payload": {
"text": prompt,
"response": value,
},
},
]
}
await self.async_client.put(
url=f"{self.qdrant_api_base}/collections/{self.collection_name}/points",
headers=self.headers,
json=data,
)
return
async def async_get_cache(self, key, **kwargs):
print_verbose(f"async qdrant semantic-cache get_cache, kwargs: {kwargs}")
from litellm.proxy.proxy_server import llm_model_list, llm_router
# get the messages
messages = kwargs["messages"]
prompt = ""
for message in messages:
prompt += message["content"]
router_model_names = (
[m["model_name"] for m in llm_model_list]
if llm_model_list is not None
else []
)
if llm_router is not None and self.embedding_model in router_model_names:
user_api_key = kwargs.get("metadata", {}).get("user_api_key", "")
embedding_response = await llm_router.aembedding(
model=self.embedding_model,
input=prompt,
cache={"no-store": True, "no-cache": True},
metadata={
"user_api_key": user_api_key,
"semantic-cache-embedding": True,
"trace_id": kwargs.get("metadata", {}).get("trace_id", None),
},
)
else:
# convert to embedding
embedding_response = await litellm.aembedding(
model=self.embedding_model,
input=prompt,
cache={"no-store": True, "no-cache": True},
)
# get the embedding
embedding = embedding_response["data"][0]["embedding"]
data = {
"vector": embedding,
"params": {
"quantization": {
"ignore": False,
"rescore": True,
"oversampling": 3.0,
}
},
"limit": 1,
"with_payload": True,
}
search_response = await self.async_client.post(
url=f"{self.qdrant_api_base}/collections/{self.collection_name}/points/search",
headers=self.headers,
json=data,
)
results = search_response.json()["result"]
if results is None:
kwargs.setdefault("metadata", {})["semantic-similarity"] = 0.0
return None
if isinstance(results, list):
if len(results) == 0:
kwargs.setdefault("metadata", {})["semantic-similarity"] = 0.0
return None
similarity = results[0]["score"]
cached_prompt = results[0]["payload"]["text"]
# check similarity, if more than self.similarity_threshold, return results
print_verbose(
f"semantic cache: similarity threshold: {self.similarity_threshold}, similarity: {similarity}, prompt: {prompt}, closest_cached_prompt: {cached_prompt}"
)
# update kwargs["metadata"] with similarity, don't rewrite the original metadata
kwargs.setdefault("metadata", {})["semantic-similarity"] = similarity
if similarity >= self.similarity_threshold:
# cache hit !
cached_value = results[0]["payload"]["response"]
print_verbose(
f"got a cache hit, similarity: {similarity}, Current prompt: {prompt}, cached_prompt: {cached_prompt}"
)
return self._get_cache_logic(cached_response=cached_value)
else:
# cache miss !
return None
pass
async def _collection_info(self):
return self.collection_info
async def async_set_cache_pipeline(self, cache_list, **kwargs):
tasks = []
for val in cache_list:
tasks.append(self.async_set_cache(val[0], val[1], **kwargs))
await asyncio.gather(*tasks)

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,59 @@
"""
Redis Cluster Cache implementation
Key differences:
- RedisClient NEEDs to be re-used across requests, adds 3000ms latency if it's re-created
"""
from typing import TYPE_CHECKING, Any, List, Optional, Union
from litellm.caching.redis_cache import RedisCache
if TYPE_CHECKING:
from opentelemetry.trace import Span as _Span
from redis.asyncio import Redis, RedisCluster
from redis.asyncio.client import Pipeline
pipeline = Pipeline
async_redis_client = Redis
Span = Union[_Span, Any]
else:
pipeline = Any
async_redis_client = Any
Span = Any
class RedisClusterCache(RedisCache):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.redis_async_redis_cluster_client: Optional[RedisCluster] = None
self.redis_sync_redis_cluster_client: Optional[RedisCluster] = None
def init_async_client(self):
from redis.asyncio import RedisCluster
from .._redis import get_redis_async_client
if self.redis_async_redis_cluster_client:
return self.redis_async_redis_cluster_client
_redis_client = get_redis_async_client(
connection_pool=self.async_redis_conn_pool, **self.redis_kwargs
)
if isinstance(_redis_client, RedisCluster):
self.redis_async_redis_cluster_client = _redis_client
return _redis_client
def _run_redis_mget_operation(self, keys: List[str]) -> List[Any]:
"""
Overrides `_run_redis_mget_operation` in redis_cache.py
"""
return self.redis_client.mget_nonatomic(keys=keys) # type: ignore
async def _async_run_redis_mget_operation(self, keys: List[str]) -> List[Any]:
"""
Overrides `_async_run_redis_mget_operation` in redis_cache.py
"""
async_redis_cluster_client = self.init_async_client()
return await async_redis_cluster_client.mget_nonatomic(keys=keys) # type: ignore

View File

@@ -0,0 +1,450 @@
"""
Redis Semantic Cache implementation for LiteLLM
The RedisSemanticCache provides semantic caching functionality using Redis as a backend.
This cache stores responses based on the semantic similarity of prompts rather than
exact matching, allowing for more flexible caching of LLM responses.
This implementation uses RedisVL's SemanticCache to find semantically similar prompts
and their cached responses.
"""
import ast
import asyncio
import json
import os
from typing import Any, Dict, List, Optional, Tuple, cast
import litellm
from litellm._logging import print_verbose
from litellm.litellm_core_utils.prompt_templates.common_utils import (
get_str_from_messages,
)
from litellm.types.utils import EmbeddingResponse
from .base_cache import BaseCache
class RedisSemanticCache(BaseCache):
"""
Redis-backed semantic cache for LLM responses.
This cache uses vector similarity to find semantically similar prompts that have been
previously sent to the LLM, allowing for cache hits even when prompts are not identical
but carry similar meaning.
"""
DEFAULT_REDIS_INDEX_NAME: str = "litellm_semantic_cache_index"
def __init__(
self,
host: Optional[str] = None,
port: Optional[str] = None,
password: Optional[str] = None,
redis_url: Optional[str] = None,
similarity_threshold: Optional[float] = None,
embedding_model: str = "text-embedding-ada-002",
index_name: Optional[str] = None,
**kwargs,
):
"""
Initialize the Redis Semantic Cache.
Args:
host: Redis host address
port: Redis port
password: Redis password
redis_url: Full Redis URL (alternative to separate host/port/password)
similarity_threshold: Threshold for semantic similarity (0.0 to 1.0)
where 1.0 requires exact matches and 0.0 accepts any match
embedding_model: Model to use for generating embeddings
index_name: Name for the Redis index
ttl: Default time-to-live for cache entries in seconds
**kwargs: Additional arguments passed to the Redis client
Raises:
Exception: If similarity_threshold is not provided or required Redis
connection information is missing
"""
from redisvl.extensions.llmcache import SemanticCache
from redisvl.utils.vectorize import CustomTextVectorizer
if index_name is None:
index_name = self.DEFAULT_REDIS_INDEX_NAME
print_verbose(f"Redis semantic-cache initializing index - {index_name}")
# Validate similarity threshold
if similarity_threshold is None:
raise ValueError("similarity_threshold must be provided, passed None")
# Store configuration
self.similarity_threshold = similarity_threshold
# Convert similarity threshold [0,1] to distance threshold [0,2]
# For cosine distance: 0 = most similar, 2 = least similar
# While similarity: 1 = most similar, 0 = least similar
self.distance_threshold = 1 - similarity_threshold
self.embedding_model = embedding_model
# Set up Redis connection
if redis_url is None:
try:
# Attempt to use provided parameters or fallback to environment variables
host = host or os.environ["REDIS_HOST"]
port = port or os.environ["REDIS_PORT"]
password = password or os.environ["REDIS_PASSWORD"]
except KeyError as e:
# Raise a more informative exception if any of the required keys are missing
missing_var = e.args[0]
raise ValueError(
f"Missing required Redis configuration: {missing_var}. "
f"Provide {missing_var} or redis_url."
) from e
redis_url = f"redis://:{password}@{host}:{port}"
print_verbose(f"Redis semantic-cache redis_url: {redis_url}")
# Initialize the Redis vectorizer and cache
cache_vectorizer = CustomTextVectorizer(self._get_embedding)
self.llmcache = SemanticCache(
name=index_name,
redis_url=redis_url,
vectorizer=cache_vectorizer,
distance_threshold=self.distance_threshold,
overwrite=False,
)
def _get_ttl(self, **kwargs) -> Optional[int]:
"""
Get the TTL (time-to-live) value for cache entries.
Args:
**kwargs: Keyword arguments that may contain a custom TTL
Returns:
Optional[int]: The TTL value in seconds, or None if no TTL should be applied
"""
ttl = kwargs.get("ttl")
if ttl is not None:
ttl = int(ttl)
return ttl
def _get_embedding(self, prompt: str) -> List[float]:
"""
Generate an embedding vector for the given prompt using the configured embedding model.
Args:
prompt: The text to generate an embedding for
Returns:
List[float]: The embedding vector
"""
# Create an embedding from prompt
embedding_response = cast(
EmbeddingResponse,
litellm.embedding(
model=self.embedding_model,
input=prompt,
cache={"no-store": True, "no-cache": True},
),
)
embedding = embedding_response["data"][0]["embedding"]
return embedding
def _get_cache_logic(self, cached_response: Any) -> Any:
"""
Process the cached response to prepare it for use.
Args:
cached_response: The raw cached response
Returns:
The processed cache response, or None if input was None
"""
if cached_response is None:
return cached_response
# Convert bytes to string if needed
if isinstance(cached_response, bytes):
cached_response = cached_response.decode("utf-8")
# Convert string representation to Python object
try:
cached_response = json.loads(cached_response)
except json.JSONDecodeError:
try:
cached_response = ast.literal_eval(cached_response)
except (ValueError, SyntaxError) as e:
print_verbose(f"Error parsing cached response: {str(e)}")
return None
return cached_response
def set_cache(self, key: str, value: Any, **kwargs) -> None:
"""
Store a value in the semantic cache.
Args:
key: The cache key (not directly used in semantic caching)
value: The response value to cache
**kwargs: Additional arguments including 'messages' for the prompt
and optional 'ttl' for time-to-live
"""
print_verbose(f"Redis semantic-cache set_cache, kwargs: {kwargs}")
value_str: Optional[str] = None
try:
# Extract the prompt from messages
messages = kwargs.get("messages", [])
if not messages:
print_verbose("No messages provided for semantic caching")
return
prompt = get_str_from_messages(messages)
value_str = str(value)
# Get TTL and store in Redis semantic cache
ttl = self._get_ttl(**kwargs)
if ttl is not None:
self.llmcache.store(prompt, value_str, ttl=int(ttl))
else:
self.llmcache.store(prompt, value_str)
except Exception as e:
print_verbose(
f"Error setting {value_str or value} in the Redis semantic cache: {str(e)}"
)
def get_cache(self, key: str, **kwargs) -> Any:
"""
Retrieve a semantically similar cached response.
Args:
key: The cache key (not directly used in semantic caching)
**kwargs: Additional arguments including 'messages' for the prompt
Returns:
The cached response if a semantically similar prompt is found, else None
"""
print_verbose(f"Redis semantic-cache get_cache, kwargs: {kwargs}")
try:
# Extract the prompt from messages
messages = kwargs.get("messages", [])
if not messages:
print_verbose("No messages provided for semantic cache lookup")
return None
prompt = get_str_from_messages(messages)
# Check the cache for semantically similar prompts
results = self.llmcache.check(prompt=prompt)
# Return None if no similar prompts found
if not results:
return None
# Process the best matching result
cache_hit = results[0]
vector_distance = float(cache_hit["vector_distance"])
# Convert vector distance back to similarity score
# For cosine distance: 0 = most similar, 2 = least similar
# While similarity: 1 = most similar, 0 = least similar
similarity = 1 - vector_distance
cached_prompt = cache_hit["prompt"]
cached_response = cache_hit["response"]
print_verbose(
f"Cache hit: similarity threshold: {self.similarity_threshold}, "
f"actual similarity: {similarity}, "
f"current prompt: {prompt}, "
f"cached prompt: {cached_prompt}"
)
return self._get_cache_logic(cached_response=cached_response)
except Exception as e:
print_verbose(f"Error retrieving from Redis semantic cache: {str(e)}")
async def _get_async_embedding(self, prompt: str, **kwargs) -> List[float]:
"""
Asynchronously generate an embedding for the given prompt.
Args:
prompt: The text to generate an embedding for
**kwargs: Additional arguments that may contain metadata
Returns:
List[float]: The embedding vector
"""
from litellm.proxy.proxy_server import llm_model_list, llm_router
# Route the embedding request through the proxy if appropriate
router_model_names = (
[m["model_name"] for m in llm_model_list]
if llm_model_list is not None
else []
)
try:
if llm_router is not None and self.embedding_model in router_model_names:
# Use the router for embedding generation
user_api_key = kwargs.get("metadata", {}).get("user_api_key", "")
embedding_response = await llm_router.aembedding(
model=self.embedding_model,
input=prompt,
cache={"no-store": True, "no-cache": True},
metadata={
"user_api_key": user_api_key,
"semantic-cache-embedding": True,
"trace_id": kwargs.get("metadata", {}).get("trace_id", None),
},
)
else:
# Generate embedding directly
embedding_response = await litellm.aembedding(
model=self.embedding_model,
input=prompt,
cache={"no-store": True, "no-cache": True},
)
# Extract and return the embedding vector
return embedding_response["data"][0]["embedding"]
except Exception as e:
print_verbose(f"Error generating async embedding: {str(e)}")
raise ValueError(f"Failed to generate embedding: {str(e)}") from e
async def async_set_cache(self, key: str, value: Any, **kwargs) -> None:
"""
Asynchronously store a value in the semantic cache.
Args:
key: The cache key (not directly used in semantic caching)
value: The response value to cache
**kwargs: Additional arguments including 'messages' for the prompt
and optional 'ttl' for time-to-live
"""
print_verbose(f"Async Redis semantic-cache set_cache, kwargs: {kwargs}")
try:
# Extract the prompt from messages
messages = kwargs.get("messages", [])
if not messages:
print_verbose("No messages provided for semantic caching")
return
prompt = get_str_from_messages(messages)
value_str = str(value)
# Generate embedding for the value (response) to cache
prompt_embedding = await self._get_async_embedding(prompt, **kwargs)
# Get TTL and store in Redis semantic cache
ttl = self._get_ttl(**kwargs)
if ttl is not None:
await self.llmcache.astore(
prompt,
value_str,
vector=prompt_embedding, # Pass through custom embedding
ttl=ttl,
)
else:
await self.llmcache.astore(
prompt,
value_str,
vector=prompt_embedding, # Pass through custom embedding
)
except Exception as e:
print_verbose(f"Error in async_set_cache: {str(e)}")
async def async_get_cache(self, key: str, **kwargs) -> Any:
"""
Asynchronously retrieve a semantically similar cached response.
Args:
key: The cache key (not directly used in semantic caching)
**kwargs: Additional arguments including 'messages' for the prompt
Returns:
The cached response if a semantically similar prompt is found, else None
"""
print_verbose(f"Async Redis semantic-cache get_cache, kwargs: {kwargs}")
try:
# Extract the prompt from messages
messages = kwargs.get("messages", [])
if not messages:
print_verbose("No messages provided for semantic cache lookup")
kwargs.setdefault("metadata", {})["semantic-similarity"] = 0.0
return None
prompt = get_str_from_messages(messages)
# Generate embedding for the prompt
prompt_embedding = await self._get_async_embedding(prompt, **kwargs)
# Check the cache for semantically similar prompts
results = await self.llmcache.acheck(prompt=prompt, vector=prompt_embedding)
# handle results / cache hit
if not results:
kwargs.setdefault("metadata", {})[
"semantic-similarity"
] = 0.0 # TODO why here but not above??
return None
cache_hit = results[0]
vector_distance = float(cache_hit["vector_distance"])
# Convert vector distance back to similarity
# For cosine distance: 0 = most similar, 2 = least similar
# While similarity: 1 = most similar, 0 = least similar
similarity = 1 - vector_distance
cached_prompt = cache_hit["prompt"]
cached_response = cache_hit["response"]
# update kwargs["metadata"] with similarity, don't rewrite the original metadata
kwargs.setdefault("metadata", {})["semantic-similarity"] = similarity
print_verbose(
f"Cache hit: similarity threshold: {self.similarity_threshold}, "
f"actual similarity: {similarity}, "
f"current prompt: {prompt}, "
f"cached prompt: {cached_prompt}"
)
return self._get_cache_logic(cached_response=cached_response)
except Exception as e:
print_verbose(f"Error in async_get_cache: {str(e)}")
kwargs.setdefault("metadata", {})["semantic-similarity"] = 0.0
async def _index_info(self) -> Dict[str, Any]:
"""
Get information about the Redis index.
Returns:
Dict[str, Any]: Information about the Redis index
"""
aindex = await self.llmcache._get_async_index()
return await aindex.info()
async def async_set_cache_pipeline(
self, cache_list: List[Tuple[str, Any]], **kwargs
) -> None:
"""
Asynchronously store multiple values in the semantic cache.
Args:
cache_list: List of (key, value) tuples to cache
**kwargs: Additional arguments
"""
try:
tasks = []
for val in cache_list:
tasks.append(self.async_set_cache(val[0], val[1], **kwargs))
await asyncio.gather(*tasks)
except Exception as e:
print_verbose(f"Error in async_set_cache_pipeline: {str(e)}")

View File

@@ -0,0 +1,159 @@
"""
S3 Cache implementation
WARNING: DO NOT USE THIS IN PRODUCTION - This is not ASYNC
Has 4 methods:
- set_cache
- get_cache
- async_set_cache
- async_get_cache
"""
import ast
import asyncio
import json
from typing import Optional
from litellm._logging import print_verbose, verbose_logger
from .base_cache import BaseCache
class S3Cache(BaseCache):
def __init__(
self,
s3_bucket_name,
s3_region_name=None,
s3_api_version=None,
s3_use_ssl: Optional[bool] = True,
s3_verify=None,
s3_endpoint_url=None,
s3_aws_access_key_id=None,
s3_aws_secret_access_key=None,
s3_aws_session_token=None,
s3_config=None,
s3_path=None,
**kwargs,
):
import boto3
self.bucket_name = s3_bucket_name
self.key_prefix = s3_path.rstrip("/") + "/" if s3_path else ""
# Create an S3 client with custom endpoint URL
self.s3_client = boto3.client(
"s3",
region_name=s3_region_name,
endpoint_url=s3_endpoint_url,
api_version=s3_api_version,
use_ssl=s3_use_ssl,
verify=s3_verify,
aws_access_key_id=s3_aws_access_key_id,
aws_secret_access_key=s3_aws_secret_access_key,
aws_session_token=s3_aws_session_token,
config=s3_config,
**kwargs,
)
def set_cache(self, key, value, **kwargs):
try:
print_verbose(f"LiteLLM SET Cache - S3. Key={key}. Value={value}")
ttl = kwargs.get("ttl", None)
# Convert value to JSON before storing in S3
serialized_value = json.dumps(value)
key = self.key_prefix + key
if ttl is not None:
cache_control = f"immutable, max-age={ttl}, s-maxage={ttl}"
import datetime
# Calculate expiration time
expiration_time = datetime.datetime.now() + ttl
# Upload the data to S3 with the calculated expiration time
self.s3_client.put_object(
Bucket=self.bucket_name,
Key=key,
Body=serialized_value,
Expires=expiration_time,
CacheControl=cache_control,
ContentType="application/json",
ContentLanguage="en",
ContentDisposition=f'inline; filename="{key}.json"',
)
else:
cache_control = "immutable, max-age=31536000, s-maxage=31536000"
# Upload the data to S3 without specifying Expires
self.s3_client.put_object(
Bucket=self.bucket_name,
Key=key,
Body=serialized_value,
CacheControl=cache_control,
ContentType="application/json",
ContentLanguage="en",
ContentDisposition=f'inline; filename="{key}.json"',
)
except Exception as e:
# NON blocking - notify users S3 is throwing an exception
print_verbose(f"S3 Caching: set_cache() - Got exception from S3: {e}")
async def async_set_cache(self, key, value, **kwargs):
self.set_cache(key=key, value=value, **kwargs)
def get_cache(self, key, **kwargs):
import botocore
try:
key = self.key_prefix + key
print_verbose(f"Get S3 Cache: key: {key}")
# Download the data from S3
cached_response = self.s3_client.get_object(
Bucket=self.bucket_name, Key=key
)
if cached_response is not None:
# cached_response is in `b{} convert it to ModelResponse
cached_response = (
cached_response["Body"].read().decode("utf-8")
) # Convert bytes to string
try:
cached_response = json.loads(
cached_response
) # Convert string to dictionary
except Exception:
cached_response = ast.literal_eval(cached_response)
if not isinstance(cached_response, dict):
cached_response = dict(cached_response)
verbose_logger.debug(
f"Got S3 Cache: key: {key}, cached_response {cached_response}. Type Response {type(cached_response)}"
)
return cached_response
except botocore.exceptions.ClientError as e: # type: ignore
if e.response["Error"]["Code"] == "NoSuchKey":
verbose_logger.debug(
f"S3 Cache: The specified key '{key}' does not exist in the S3 bucket."
)
return None
except Exception as e:
# NON blocking - notify users S3 is throwing an exception
verbose_logger.error(
f"S3 Caching: get_cache() - Got exception from S3: {e}"
)
async def async_get_cache(self, key, **kwargs):
return self.get_cache(key=key, **kwargs)
def flush_cache(self):
pass
async def disconnect(self):
pass
async def async_set_cache_pipeline(self, cache_list, **kwargs):
tasks = []
for val in cache_list:
tasks.append(self.async_set_cache(val[0], val[1], **kwargs))
await asyncio.gather(*tasks)