structure saas with tools
This commit is contained in:
798
.venv/lib/python3.10/site-packages/litellm/caching/caching.py
Normal file
798
.venv/lib/python3.10/site-packages/litellm/caching/caching.py
Normal file
@@ -0,0 +1,798 @@
|
||||
# +-----------------------------------------------+
|
||||
# | |
|
||||
# | Give Feedback / Get Help |
|
||||
# | https://github.com/BerriAI/litellm/issues/new |
|
||||
# | |
|
||||
# +-----------------------------------------------+
|
||||
#
|
||||
# Thank you users! We ❤️ you! - Krrish & Ishaan
|
||||
|
||||
import ast
|
||||
import hashlib
|
||||
import json
|
||||
import time
|
||||
import traceback
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
import litellm
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.constants import CACHED_STREAMING_CHUNK_DELAY
|
||||
from litellm.litellm_core_utils.model_param_helper import ModelParamHelper
|
||||
from litellm.types.caching import *
|
||||
from litellm.types.utils import all_litellm_params
|
||||
|
||||
from .base_cache import BaseCache
|
||||
from .disk_cache import DiskCache
|
||||
from .dual_cache import DualCache # noqa
|
||||
from .in_memory_cache import InMemoryCache
|
||||
from .qdrant_semantic_cache import QdrantSemanticCache
|
||||
from .redis_cache import RedisCache
|
||||
from .redis_cluster_cache import RedisClusterCache
|
||||
from .redis_semantic_cache import RedisSemanticCache
|
||||
from .s3_cache import S3Cache
|
||||
|
||||
|
||||
def print_verbose(print_statement):
|
||||
try:
|
||||
verbose_logger.debug(print_statement)
|
||||
if litellm.set_verbose:
|
||||
print(print_statement) # noqa
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
class CacheMode(str, Enum):
|
||||
default_on = "default_on"
|
||||
default_off = "default_off"
|
||||
|
||||
|
||||
#### LiteLLM.Completion / Embedding Cache ####
|
||||
class Cache:
|
||||
def __init__(
|
||||
self,
|
||||
type: Optional[LiteLLMCacheType] = LiteLLMCacheType.LOCAL,
|
||||
mode: Optional[
|
||||
CacheMode
|
||||
] = CacheMode.default_on, # when default_on cache is always on, when default_off cache is opt in
|
||||
host: Optional[str] = None,
|
||||
port: Optional[str] = None,
|
||||
password: Optional[str] = None,
|
||||
namespace: Optional[str] = None,
|
||||
ttl: Optional[float] = None,
|
||||
default_in_memory_ttl: Optional[float] = None,
|
||||
default_in_redis_ttl: Optional[float] = None,
|
||||
similarity_threshold: Optional[float] = None,
|
||||
supported_call_types: Optional[List[CachingSupportedCallTypes]] = [
|
||||
"completion",
|
||||
"acompletion",
|
||||
"embedding",
|
||||
"aembedding",
|
||||
"atranscription",
|
||||
"transcription",
|
||||
"atext_completion",
|
||||
"text_completion",
|
||||
"arerank",
|
||||
"rerank",
|
||||
],
|
||||
# s3 Bucket, boto3 configuration
|
||||
s3_bucket_name: Optional[str] = None,
|
||||
s3_region_name: Optional[str] = None,
|
||||
s3_api_version: Optional[str] = None,
|
||||
s3_use_ssl: Optional[bool] = True,
|
||||
s3_verify: Optional[Union[bool, str]] = None,
|
||||
s3_endpoint_url: Optional[str] = None,
|
||||
s3_aws_access_key_id: Optional[str] = None,
|
||||
s3_aws_secret_access_key: Optional[str] = None,
|
||||
s3_aws_session_token: Optional[str] = None,
|
||||
s3_config: Optional[Any] = None,
|
||||
s3_path: Optional[str] = None,
|
||||
redis_semantic_cache_embedding_model: str = "text-embedding-ada-002",
|
||||
redis_semantic_cache_index_name: Optional[str] = None,
|
||||
redis_flush_size: Optional[int] = None,
|
||||
redis_startup_nodes: Optional[List] = None,
|
||||
disk_cache_dir: Optional[str] = None,
|
||||
qdrant_api_base: Optional[str] = None,
|
||||
qdrant_api_key: Optional[str] = None,
|
||||
qdrant_collection_name: Optional[str] = None,
|
||||
qdrant_quantization_config: Optional[str] = None,
|
||||
qdrant_semantic_cache_embedding_model: str = "text-embedding-ada-002",
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Initializes the cache based on the given type.
|
||||
|
||||
Args:
|
||||
type (str, optional): The type of cache to initialize. Can be "local", "redis", "redis-semantic", "qdrant-semantic", "s3" or "disk". Defaults to "local".
|
||||
|
||||
# Redis Cache Args
|
||||
host (str, optional): The host address for the Redis cache. Required if type is "redis".
|
||||
port (int, optional): The port number for the Redis cache. Required if type is "redis".
|
||||
password (str, optional): The password for the Redis cache. Required if type is "redis".
|
||||
namespace (str, optional): The namespace for the Redis cache. Required if type is "redis".
|
||||
ttl (float, optional): The ttl for the Redis cache
|
||||
redis_flush_size (int, optional): The number of keys to flush at a time. Defaults to 1000. Only used if batch redis set caching is used.
|
||||
redis_startup_nodes (list, optional): The list of startup nodes for the Redis cache. Defaults to None.
|
||||
|
||||
# Qdrant Cache Args
|
||||
qdrant_api_base (str, optional): The url for your qdrant cluster. Required if type is "qdrant-semantic".
|
||||
qdrant_api_key (str, optional): The api_key for the local or cloud qdrant cluster.
|
||||
qdrant_collection_name (str, optional): The name for your qdrant collection. Required if type is "qdrant-semantic".
|
||||
similarity_threshold (float, optional): The similarity threshold for semantic-caching, Required if type is "redis-semantic" or "qdrant-semantic".
|
||||
|
||||
# Disk Cache Args
|
||||
disk_cache_dir (str, optional): The directory for the disk cache. Defaults to None.
|
||||
|
||||
# S3 Cache Args
|
||||
s3_bucket_name (str, optional): The bucket name for the s3 cache. Defaults to None.
|
||||
s3_region_name (str, optional): The region name for the s3 cache. Defaults to None.
|
||||
s3_api_version (str, optional): The api version for the s3 cache. Defaults to None.
|
||||
s3_use_ssl (bool, optional): The use ssl for the s3 cache. Defaults to True.
|
||||
s3_verify (bool, optional): The verify for the s3 cache. Defaults to None.
|
||||
s3_endpoint_url (str, optional): The endpoint url for the s3 cache. Defaults to None.
|
||||
s3_aws_access_key_id (str, optional): The aws access key id for the s3 cache. Defaults to None.
|
||||
s3_aws_secret_access_key (str, optional): The aws secret access key for the s3 cache. Defaults to None.
|
||||
s3_aws_session_token (str, optional): The aws session token for the s3 cache. Defaults to None.
|
||||
s3_config (dict, optional): The config for the s3 cache. Defaults to None.
|
||||
|
||||
# Common Cache Args
|
||||
supported_call_types (list, optional): List of call types to cache for. Defaults to cache == on for all call types.
|
||||
**kwargs: Additional keyword arguments for redis.Redis() cache
|
||||
|
||||
Raises:
|
||||
ValueError: If an invalid cache type is provided.
|
||||
|
||||
Returns:
|
||||
None. Cache is set as a litellm param
|
||||
"""
|
||||
if type == LiteLLMCacheType.REDIS:
|
||||
if redis_startup_nodes:
|
||||
self.cache: BaseCache = RedisClusterCache(
|
||||
host=host,
|
||||
port=port,
|
||||
password=password,
|
||||
redis_flush_size=redis_flush_size,
|
||||
startup_nodes=redis_startup_nodes,
|
||||
**kwargs,
|
||||
)
|
||||
else:
|
||||
self.cache = RedisCache(
|
||||
host=host,
|
||||
port=port,
|
||||
password=password,
|
||||
redis_flush_size=redis_flush_size,
|
||||
**kwargs,
|
||||
)
|
||||
elif type == LiteLLMCacheType.REDIS_SEMANTIC:
|
||||
self.cache = RedisSemanticCache(
|
||||
host=host,
|
||||
port=port,
|
||||
password=password,
|
||||
similarity_threshold=similarity_threshold,
|
||||
embedding_model=redis_semantic_cache_embedding_model,
|
||||
index_name=redis_semantic_cache_index_name,
|
||||
**kwargs,
|
||||
)
|
||||
elif type == LiteLLMCacheType.QDRANT_SEMANTIC:
|
||||
self.cache = QdrantSemanticCache(
|
||||
qdrant_api_base=qdrant_api_base,
|
||||
qdrant_api_key=qdrant_api_key,
|
||||
collection_name=qdrant_collection_name,
|
||||
similarity_threshold=similarity_threshold,
|
||||
quantization_config=qdrant_quantization_config,
|
||||
embedding_model=qdrant_semantic_cache_embedding_model,
|
||||
)
|
||||
elif type == LiteLLMCacheType.LOCAL:
|
||||
self.cache = InMemoryCache()
|
||||
elif type == LiteLLMCacheType.S3:
|
||||
self.cache = S3Cache(
|
||||
s3_bucket_name=s3_bucket_name,
|
||||
s3_region_name=s3_region_name,
|
||||
s3_api_version=s3_api_version,
|
||||
s3_use_ssl=s3_use_ssl,
|
||||
s3_verify=s3_verify,
|
||||
s3_endpoint_url=s3_endpoint_url,
|
||||
s3_aws_access_key_id=s3_aws_access_key_id,
|
||||
s3_aws_secret_access_key=s3_aws_secret_access_key,
|
||||
s3_aws_session_token=s3_aws_session_token,
|
||||
s3_config=s3_config,
|
||||
s3_path=s3_path,
|
||||
**kwargs,
|
||||
)
|
||||
elif type == LiteLLMCacheType.DISK:
|
||||
self.cache = DiskCache(disk_cache_dir=disk_cache_dir)
|
||||
if "cache" not in litellm.input_callback:
|
||||
litellm.input_callback.append("cache")
|
||||
if "cache" not in litellm.success_callback:
|
||||
litellm.logging_callback_manager.add_litellm_success_callback("cache")
|
||||
if "cache" not in litellm._async_success_callback:
|
||||
litellm.logging_callback_manager.add_litellm_async_success_callback("cache")
|
||||
self.supported_call_types = supported_call_types # default to ["completion", "acompletion", "embedding", "aembedding"]
|
||||
self.type = type
|
||||
self.namespace = namespace
|
||||
self.redis_flush_size = redis_flush_size
|
||||
self.ttl = ttl
|
||||
self.mode: CacheMode = mode or CacheMode.default_on
|
||||
|
||||
if self.type == LiteLLMCacheType.LOCAL and default_in_memory_ttl is not None:
|
||||
self.ttl = default_in_memory_ttl
|
||||
|
||||
if (
|
||||
self.type == LiteLLMCacheType.REDIS
|
||||
or self.type == LiteLLMCacheType.REDIS_SEMANTIC
|
||||
) and default_in_redis_ttl is not None:
|
||||
self.ttl = default_in_redis_ttl
|
||||
|
||||
if self.namespace is not None and isinstance(self.cache, RedisCache):
|
||||
self.cache.namespace = self.namespace
|
||||
|
||||
def get_cache_key(self, **kwargs) -> str:
|
||||
"""
|
||||
Get the cache key for the given arguments.
|
||||
|
||||
Args:
|
||||
**kwargs: kwargs to litellm.completion() or embedding()
|
||||
|
||||
Returns:
|
||||
str: The cache key generated from the arguments, or None if no cache key could be generated.
|
||||
"""
|
||||
cache_key = ""
|
||||
# verbose_logger.debug("\nGetting Cache key. Kwargs: %s", kwargs)
|
||||
|
||||
preset_cache_key = self._get_preset_cache_key_from_kwargs(**kwargs)
|
||||
if preset_cache_key is not None:
|
||||
verbose_logger.debug("\nReturning preset cache key: %s", preset_cache_key)
|
||||
return preset_cache_key
|
||||
|
||||
combined_kwargs = ModelParamHelper._get_all_llm_api_params()
|
||||
litellm_param_kwargs = all_litellm_params
|
||||
for param in kwargs:
|
||||
if param in combined_kwargs:
|
||||
param_value: Optional[str] = self._get_param_value(param, kwargs)
|
||||
if param_value is not None:
|
||||
cache_key += f"{str(param)}: {str(param_value)}"
|
||||
elif (
|
||||
param not in litellm_param_kwargs
|
||||
): # check if user passed in optional param - e.g. top_k
|
||||
if (
|
||||
litellm.enable_caching_on_provider_specific_optional_params is True
|
||||
): # feature flagged for now
|
||||
if kwargs[param] is None:
|
||||
continue # ignore None params
|
||||
param_value = kwargs[param]
|
||||
cache_key += f"{str(param)}: {str(param_value)}"
|
||||
|
||||
verbose_logger.debug("\nCreated cache key: %s", cache_key)
|
||||
hashed_cache_key = Cache._get_hashed_cache_key(cache_key)
|
||||
hashed_cache_key = self._add_namespace_to_cache_key(hashed_cache_key, **kwargs)
|
||||
self._set_preset_cache_key_in_kwargs(
|
||||
preset_cache_key=hashed_cache_key, **kwargs
|
||||
)
|
||||
return hashed_cache_key
|
||||
|
||||
def _get_param_value(
|
||||
self,
|
||||
param: str,
|
||||
kwargs: dict,
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
Get the value for the given param from kwargs
|
||||
"""
|
||||
if param == "model":
|
||||
return self._get_model_param_value(kwargs)
|
||||
elif param == "file":
|
||||
return self._get_file_param_value(kwargs)
|
||||
return kwargs[param]
|
||||
|
||||
def _get_model_param_value(self, kwargs: dict) -> str:
|
||||
"""
|
||||
Handles getting the value for the 'model' param from kwargs
|
||||
|
||||
1. If caching groups are set, then return the caching group as the model https://docs.litellm.ai/docs/routing#caching-across-model-groups
|
||||
2. Else if a model_group is set, then return the model_group as the model. This is used for all requests sent through the litellm.Router()
|
||||
3. Else use the `model` passed in kwargs
|
||||
"""
|
||||
metadata: Dict = kwargs.get("metadata", {}) or {}
|
||||
litellm_params: Dict = kwargs.get("litellm_params", {}) or {}
|
||||
metadata_in_litellm_params: Dict = litellm_params.get("metadata", {}) or {}
|
||||
model_group: Optional[str] = metadata.get(
|
||||
"model_group"
|
||||
) or metadata_in_litellm_params.get("model_group")
|
||||
caching_group = self._get_caching_group(metadata, model_group)
|
||||
return caching_group or model_group or kwargs["model"]
|
||||
|
||||
def _get_caching_group(
|
||||
self, metadata: dict, model_group: Optional[str]
|
||||
) -> Optional[str]:
|
||||
caching_groups: Optional[List] = metadata.get("caching_groups", [])
|
||||
if caching_groups:
|
||||
for group in caching_groups:
|
||||
if model_group in group:
|
||||
return str(group)
|
||||
return None
|
||||
|
||||
def _get_file_param_value(self, kwargs: dict) -> str:
|
||||
"""
|
||||
Handles getting the value for the 'file' param from kwargs. Used for `transcription` requests
|
||||
"""
|
||||
file = kwargs.get("file")
|
||||
metadata = kwargs.get("metadata", {})
|
||||
litellm_params = kwargs.get("litellm_params", {})
|
||||
return (
|
||||
metadata.get("file_checksum")
|
||||
or getattr(file, "name", None)
|
||||
or metadata.get("file_name")
|
||||
or litellm_params.get("file_name")
|
||||
)
|
||||
|
||||
def _get_preset_cache_key_from_kwargs(self, **kwargs) -> Optional[str]:
|
||||
"""
|
||||
Get the preset cache key from kwargs["litellm_params"]
|
||||
|
||||
We use _get_preset_cache_keys for two reasons
|
||||
|
||||
1. optional params like max_tokens, get transformed for bedrock -> max_new_tokens
|
||||
2. avoid doing duplicate / repeated work
|
||||
"""
|
||||
if kwargs:
|
||||
if "litellm_params" in kwargs:
|
||||
return kwargs["litellm_params"].get("preset_cache_key", None)
|
||||
return None
|
||||
|
||||
def _set_preset_cache_key_in_kwargs(self, preset_cache_key: str, **kwargs) -> None:
|
||||
"""
|
||||
Set the calculated cache key in kwargs
|
||||
|
||||
This is used to avoid doing duplicate / repeated work
|
||||
|
||||
Placed in kwargs["litellm_params"]
|
||||
"""
|
||||
if kwargs:
|
||||
if "litellm_params" in kwargs:
|
||||
kwargs["litellm_params"]["preset_cache_key"] = preset_cache_key
|
||||
|
||||
@staticmethod
|
||||
def _get_hashed_cache_key(cache_key: str) -> str:
|
||||
"""
|
||||
Get the hashed cache key for the given cache key.
|
||||
|
||||
Use hashlib to create a sha256 hash of the cache key
|
||||
|
||||
Args:
|
||||
cache_key (str): The cache key to hash.
|
||||
|
||||
Returns:
|
||||
str: The hashed cache key.
|
||||
"""
|
||||
hash_object = hashlib.sha256(cache_key.encode())
|
||||
# Hexadecimal representation of the hash
|
||||
hash_hex = hash_object.hexdigest()
|
||||
verbose_logger.debug("Hashed cache key (SHA-256): %s", hash_hex)
|
||||
return hash_hex
|
||||
|
||||
def _add_namespace_to_cache_key(self, hash_hex: str, **kwargs) -> str:
|
||||
"""
|
||||
If a redis namespace is provided, add it to the cache key
|
||||
|
||||
Args:
|
||||
hash_hex (str): The hashed cache key.
|
||||
**kwargs: Additional keyword arguments.
|
||||
|
||||
Returns:
|
||||
str: The final hashed cache key with the redis namespace.
|
||||
"""
|
||||
dynamic_cache_control: DynamicCacheControl = kwargs.get("cache", {})
|
||||
namespace = (
|
||||
dynamic_cache_control.get("namespace")
|
||||
or kwargs.get("metadata", {}).get("redis_namespace")
|
||||
or self.namespace
|
||||
)
|
||||
if namespace:
|
||||
hash_hex = f"{namespace}:{hash_hex}"
|
||||
verbose_logger.debug("Final hashed key: %s", hash_hex)
|
||||
return hash_hex
|
||||
|
||||
def generate_streaming_content(self, content):
|
||||
chunk_size = 5 # Adjust the chunk size as needed
|
||||
for i in range(0, len(content), chunk_size):
|
||||
yield {
|
||||
"choices": [
|
||||
{
|
||||
"delta": {
|
||||
"role": "assistant",
|
||||
"content": content[i : i + chunk_size],
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
time.sleep(CACHED_STREAMING_CHUNK_DELAY)
|
||||
|
||||
def _get_cache_logic(
|
||||
self,
|
||||
cached_result: Optional[Any],
|
||||
max_age: Optional[float],
|
||||
):
|
||||
"""
|
||||
Common get cache logic across sync + async implementations
|
||||
"""
|
||||
# Check if a timestamp was stored with the cached response
|
||||
if (
|
||||
cached_result is not None
|
||||
and isinstance(cached_result, dict)
|
||||
and "timestamp" in cached_result
|
||||
):
|
||||
timestamp = cached_result["timestamp"]
|
||||
current_time = time.time()
|
||||
|
||||
# Calculate age of the cached response
|
||||
response_age = current_time - timestamp
|
||||
|
||||
# Check if the cached response is older than the max-age
|
||||
if max_age is not None and response_age > max_age:
|
||||
return None # Cached response is too old
|
||||
|
||||
# If the response is fresh, or there's no max-age requirement, return the cached response
|
||||
# cached_response is in `b{} convert it to ModelResponse
|
||||
cached_response = cached_result.get("response")
|
||||
try:
|
||||
if isinstance(cached_response, dict):
|
||||
pass
|
||||
else:
|
||||
cached_response = json.loads(
|
||||
cached_response # type: ignore
|
||||
) # Convert string to dictionary
|
||||
except Exception:
|
||||
cached_response = ast.literal_eval(cached_response) # type: ignore
|
||||
return cached_response
|
||||
return cached_result
|
||||
|
||||
def get_cache(self, **kwargs):
|
||||
"""
|
||||
Retrieves the cached result for the given arguments.
|
||||
|
||||
Args:
|
||||
*args: args to litellm.completion() or embedding()
|
||||
**kwargs: kwargs to litellm.completion() or embedding()
|
||||
|
||||
Returns:
|
||||
The cached result if it exists, otherwise None.
|
||||
"""
|
||||
try: # never block execution
|
||||
if self.should_use_cache(**kwargs) is not True:
|
||||
return
|
||||
messages = kwargs.get("messages", [])
|
||||
if "cache_key" in kwargs:
|
||||
cache_key = kwargs["cache_key"]
|
||||
else:
|
||||
cache_key = self.get_cache_key(**kwargs)
|
||||
if cache_key is not None:
|
||||
cache_control_args: DynamicCacheControl = kwargs.get("cache", {})
|
||||
max_age = (
|
||||
cache_control_args.get("s-maxage")
|
||||
or cache_control_args.get("s-max-age")
|
||||
or float("inf")
|
||||
)
|
||||
cached_result = self.cache.get_cache(cache_key, messages=messages)
|
||||
cached_result = self.cache.get_cache(cache_key, messages=messages)
|
||||
return self._get_cache_logic(
|
||||
cached_result=cached_result, max_age=max_age
|
||||
)
|
||||
except Exception:
|
||||
print_verbose(f"An exception occurred: {traceback.format_exc()}")
|
||||
return None
|
||||
|
||||
async def async_get_cache(self, **kwargs):
|
||||
"""
|
||||
Async get cache implementation.
|
||||
|
||||
Used for embedding calls in async wrapper
|
||||
"""
|
||||
|
||||
try: # never block execution
|
||||
if self.should_use_cache(**kwargs) is not True:
|
||||
return
|
||||
|
||||
kwargs.get("messages", [])
|
||||
if "cache_key" in kwargs:
|
||||
cache_key = kwargs["cache_key"]
|
||||
else:
|
||||
cache_key = self.get_cache_key(**kwargs)
|
||||
if cache_key is not None:
|
||||
cache_control_args = kwargs.get("cache", {})
|
||||
max_age = cache_control_args.get(
|
||||
"s-max-age", cache_control_args.get("s-maxage", float("inf"))
|
||||
)
|
||||
cached_result = await self.cache.async_get_cache(cache_key, **kwargs)
|
||||
return self._get_cache_logic(
|
||||
cached_result=cached_result, max_age=max_age
|
||||
)
|
||||
except Exception:
|
||||
print_verbose(f"An exception occurred: {traceback.format_exc()}")
|
||||
return None
|
||||
|
||||
def _add_cache_logic(self, result, **kwargs):
|
||||
"""
|
||||
Common implementation across sync + async add_cache functions
|
||||
"""
|
||||
try:
|
||||
if "cache_key" in kwargs:
|
||||
cache_key = kwargs["cache_key"]
|
||||
else:
|
||||
cache_key = self.get_cache_key(**kwargs)
|
||||
if cache_key is not None:
|
||||
if isinstance(result, BaseModel):
|
||||
result = result.model_dump_json()
|
||||
|
||||
## DEFAULT TTL ##
|
||||
if self.ttl is not None:
|
||||
kwargs["ttl"] = self.ttl
|
||||
## Get Cache-Controls ##
|
||||
_cache_kwargs = kwargs.get("cache", None)
|
||||
if isinstance(_cache_kwargs, dict):
|
||||
for k, v in _cache_kwargs.items():
|
||||
if k == "ttl":
|
||||
kwargs["ttl"] = v
|
||||
|
||||
cached_data = {"timestamp": time.time(), "response": result}
|
||||
return cache_key, cached_data, kwargs
|
||||
else:
|
||||
raise Exception("cache key is None")
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
def add_cache(self, result, **kwargs):
|
||||
"""
|
||||
Adds a result to the cache.
|
||||
|
||||
Args:
|
||||
*args: args to litellm.completion() or embedding()
|
||||
**kwargs: kwargs to litellm.completion() or embedding()
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
try:
|
||||
if self.should_use_cache(**kwargs) is not True:
|
||||
return
|
||||
cache_key, cached_data, kwargs = self._add_cache_logic(
|
||||
result=result, **kwargs
|
||||
)
|
||||
self.cache.set_cache(cache_key, cached_data, **kwargs)
|
||||
except Exception as e:
|
||||
verbose_logger.exception(f"LiteLLM Cache: Excepton add_cache: {str(e)}")
|
||||
|
||||
async def async_add_cache(self, result, **kwargs):
|
||||
"""
|
||||
Async implementation of add_cache
|
||||
"""
|
||||
try:
|
||||
if self.should_use_cache(**kwargs) is not True:
|
||||
return
|
||||
if self.type == "redis" and self.redis_flush_size is not None:
|
||||
# high traffic - fill in results in memory and then flush
|
||||
await self.batch_cache_write(result, **kwargs)
|
||||
else:
|
||||
cache_key, cached_data, kwargs = self._add_cache_logic(
|
||||
result=result, **kwargs
|
||||
)
|
||||
|
||||
await self.cache.async_set_cache(cache_key, cached_data, **kwargs)
|
||||
except Exception as e:
|
||||
verbose_logger.exception(f"LiteLLM Cache: Excepton add_cache: {str(e)}")
|
||||
|
||||
async def async_add_cache_pipeline(self, result, **kwargs):
|
||||
"""
|
||||
Async implementation of add_cache for Embedding calls
|
||||
|
||||
Does a bulk write, to prevent using too many clients
|
||||
"""
|
||||
try:
|
||||
if self.should_use_cache(**kwargs) is not True:
|
||||
return
|
||||
|
||||
# set default ttl if not set
|
||||
if self.ttl is not None:
|
||||
kwargs["ttl"] = self.ttl
|
||||
|
||||
cache_list = []
|
||||
for idx, i in enumerate(kwargs["input"]):
|
||||
preset_cache_key = self.get_cache_key(**{**kwargs, "input": i})
|
||||
kwargs["cache_key"] = preset_cache_key
|
||||
embedding_response = result.data[idx]
|
||||
cache_key, cached_data, kwargs = self._add_cache_logic(
|
||||
result=embedding_response,
|
||||
**kwargs,
|
||||
)
|
||||
cache_list.append((cache_key, cached_data))
|
||||
|
||||
await self.cache.async_set_cache_pipeline(cache_list=cache_list, **kwargs)
|
||||
# if async_set_cache_pipeline:
|
||||
# await async_set_cache_pipeline(cache_list=cache_list, **kwargs)
|
||||
# else:
|
||||
# tasks = []
|
||||
# for val in cache_list:
|
||||
# tasks.append(self.cache.async_set_cache(val[0], val[1], **kwargs))
|
||||
# await asyncio.gather(*tasks)
|
||||
except Exception as e:
|
||||
verbose_logger.exception(f"LiteLLM Cache: Excepton add_cache: {str(e)}")
|
||||
|
||||
def should_use_cache(self, **kwargs):
|
||||
"""
|
||||
Returns true if we should use the cache for LLM API calls
|
||||
|
||||
If cache is default_on then this is True
|
||||
If cache is default_off then this is only true when user has opted in to use cache
|
||||
"""
|
||||
if self.mode == CacheMode.default_on:
|
||||
return True
|
||||
|
||||
# when mode == default_off -> Cache is opt in only
|
||||
_cache = kwargs.get("cache", None)
|
||||
verbose_logger.debug("should_use_cache: kwargs: %s; _cache: %s", kwargs, _cache)
|
||||
if _cache and isinstance(_cache, dict):
|
||||
if _cache.get("use-cache", False) is True:
|
||||
return True
|
||||
return False
|
||||
|
||||
async def batch_cache_write(self, result, **kwargs):
|
||||
cache_key, cached_data, kwargs = self._add_cache_logic(result=result, **kwargs)
|
||||
await self.cache.batch_cache_write(cache_key, cached_data, **kwargs)
|
||||
|
||||
async def ping(self):
|
||||
cache_ping = getattr(self.cache, "ping")
|
||||
if cache_ping:
|
||||
return await cache_ping()
|
||||
return None
|
||||
|
||||
async def delete_cache_keys(self, keys):
|
||||
cache_delete_cache_keys = getattr(self.cache, "delete_cache_keys")
|
||||
if cache_delete_cache_keys:
|
||||
return await cache_delete_cache_keys(keys)
|
||||
return None
|
||||
|
||||
async def disconnect(self):
|
||||
if hasattr(self.cache, "disconnect"):
|
||||
await self.cache.disconnect()
|
||||
|
||||
def _supports_async(self) -> bool:
|
||||
"""
|
||||
Internal method to check if the cache type supports async get/set operations
|
||||
|
||||
Only S3 Cache Does NOT support async operations
|
||||
|
||||
"""
|
||||
if self.type and self.type == LiteLLMCacheType.S3:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def enable_cache(
|
||||
type: Optional[LiteLLMCacheType] = LiteLLMCacheType.LOCAL,
|
||||
host: Optional[str] = None,
|
||||
port: Optional[str] = None,
|
||||
password: Optional[str] = None,
|
||||
supported_call_types: Optional[List[CachingSupportedCallTypes]] = [
|
||||
"completion",
|
||||
"acompletion",
|
||||
"embedding",
|
||||
"aembedding",
|
||||
"atranscription",
|
||||
"transcription",
|
||||
"atext_completion",
|
||||
"text_completion",
|
||||
"arerank",
|
||||
"rerank",
|
||||
],
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Enable cache with the specified configuration.
|
||||
|
||||
Args:
|
||||
type (Optional[Literal["local", "redis", "s3", "disk"]]): The type of cache to enable. Defaults to "local".
|
||||
host (Optional[str]): The host address of the cache server. Defaults to None.
|
||||
port (Optional[str]): The port number of the cache server. Defaults to None.
|
||||
password (Optional[str]): The password for the cache server. Defaults to None.
|
||||
supported_call_types (Optional[List[Literal["completion", "acompletion", "embedding", "aembedding"]]]):
|
||||
The supported call types for the cache. Defaults to ["completion", "acompletion", "embedding", "aembedding"].
|
||||
**kwargs: Additional keyword arguments.
|
||||
|
||||
Returns:
|
||||
None
|
||||
|
||||
Raises:
|
||||
None
|
||||
"""
|
||||
print_verbose("LiteLLM: Enabling Cache")
|
||||
if "cache" not in litellm.input_callback:
|
||||
litellm.input_callback.append("cache")
|
||||
if "cache" not in litellm.success_callback:
|
||||
litellm.logging_callback_manager.add_litellm_success_callback("cache")
|
||||
if "cache" not in litellm._async_success_callback:
|
||||
litellm.logging_callback_manager.add_litellm_async_success_callback("cache")
|
||||
|
||||
if litellm.cache is None:
|
||||
litellm.cache = Cache(
|
||||
type=type,
|
||||
host=host,
|
||||
port=port,
|
||||
password=password,
|
||||
supported_call_types=supported_call_types,
|
||||
**kwargs,
|
||||
)
|
||||
print_verbose(f"LiteLLM: Cache enabled, litellm.cache={litellm.cache}")
|
||||
print_verbose(f"LiteLLM Cache: {vars(litellm.cache)}")
|
||||
|
||||
|
||||
def update_cache(
|
||||
type: Optional[LiteLLMCacheType] = LiteLLMCacheType.LOCAL,
|
||||
host: Optional[str] = None,
|
||||
port: Optional[str] = None,
|
||||
password: Optional[str] = None,
|
||||
supported_call_types: Optional[List[CachingSupportedCallTypes]] = [
|
||||
"completion",
|
||||
"acompletion",
|
||||
"embedding",
|
||||
"aembedding",
|
||||
"atranscription",
|
||||
"transcription",
|
||||
"atext_completion",
|
||||
"text_completion",
|
||||
"arerank",
|
||||
"rerank",
|
||||
],
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Update the cache for LiteLLM.
|
||||
|
||||
Args:
|
||||
type (Optional[Literal["local", "redis", "s3", "disk"]]): The type of cache. Defaults to "local".
|
||||
host (Optional[str]): The host of the cache. Defaults to None.
|
||||
port (Optional[str]): The port of the cache. Defaults to None.
|
||||
password (Optional[str]): The password for the cache. Defaults to None.
|
||||
supported_call_types (Optional[List[Literal["completion", "acompletion", "embedding", "aembedding"]]]):
|
||||
The supported call types for the cache. Defaults to ["completion", "acompletion", "embedding", "aembedding"].
|
||||
**kwargs: Additional keyword arguments for the cache.
|
||||
|
||||
Returns:
|
||||
None
|
||||
|
||||
"""
|
||||
print_verbose("LiteLLM: Updating Cache")
|
||||
litellm.cache = Cache(
|
||||
type=type,
|
||||
host=host,
|
||||
port=port,
|
||||
password=password,
|
||||
supported_call_types=supported_call_types,
|
||||
**kwargs,
|
||||
)
|
||||
print_verbose(f"LiteLLM: Cache Updated, litellm.cache={litellm.cache}")
|
||||
print_verbose(f"LiteLLM Cache: {vars(litellm.cache)}")
|
||||
|
||||
|
||||
def disable_cache():
|
||||
"""
|
||||
Disable the cache used by LiteLLM.
|
||||
|
||||
This function disables the cache used by the LiteLLM module. It removes the cache-related callbacks from the input_callback, success_callback, and _async_success_callback lists. It also sets the litellm.cache attribute to None.
|
||||
|
||||
Parameters:
|
||||
None
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
from contextlib import suppress
|
||||
|
||||
print_verbose("LiteLLM: Disabling Cache")
|
||||
with suppress(ValueError):
|
||||
litellm.input_callback.remove("cache")
|
||||
litellm.success_callback.remove("cache")
|
||||
litellm._async_success_callback.remove("cache")
|
||||
|
||||
litellm.cache = None
|
||||
print_verbose(f"LiteLLM: Cache disabled, litellm.cache={litellm.cache}")
|
||||
Reference in New Issue
Block a user