structure saas with tools
This commit is contained in:
@@ -0,0 +1,906 @@
|
||||
"""
|
||||
This contains LLMCachingHandler
|
||||
|
||||
This exposes two methods:
|
||||
- async_get_cache
|
||||
- async_set_cache
|
||||
|
||||
This file is a wrapper around caching.py
|
||||
|
||||
This class is used to handle caching logic specific for LLM API requests (completion / embedding / text_completion / transcription etc)
|
||||
|
||||
It utilizes the (RedisCache, s3Cache, RedisSemanticCache, QdrantSemanticCache, InMemoryCache, DiskCache) based on what the user has setup
|
||||
|
||||
In each method it will call the appropriate method from caching.py
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import datetime
|
||||
import inspect
|
||||
import threading
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
AsyncGenerator,
|
||||
Callable,
|
||||
Dict,
|
||||
Generator,
|
||||
List,
|
||||
Optional,
|
||||
Tuple,
|
||||
Union,
|
||||
)
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
import litellm
|
||||
from litellm._logging import print_verbose, verbose_logger
|
||||
from litellm.caching.caching import S3Cache
|
||||
from litellm.litellm_core_utils.logging_utils import (
|
||||
_assemble_complete_response_from_streaming_chunks,
|
||||
)
|
||||
from litellm.types.rerank import RerankResponse
|
||||
from litellm.types.utils import (
|
||||
CallTypes,
|
||||
Embedding,
|
||||
EmbeddingResponse,
|
||||
ModelResponse,
|
||||
TextCompletionResponse,
|
||||
TranscriptionResponse,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
|
||||
from litellm.utils import CustomStreamWrapper
|
||||
else:
|
||||
LiteLLMLoggingObj = Any
|
||||
CustomStreamWrapper = Any
|
||||
|
||||
|
||||
class CachingHandlerResponse(BaseModel):
|
||||
"""
|
||||
This is the response object for the caching handler. We need to separate embedding cached responses and (completion / text_completion / transcription) cached responses
|
||||
|
||||
For embeddings there can be a cache hit for some of the inputs in the list and a cache miss for others
|
||||
"""
|
||||
|
||||
cached_result: Optional[Any] = None
|
||||
final_embedding_cached_response: Optional[EmbeddingResponse] = None
|
||||
embedding_all_elements_cache_hit: bool = False # this is set to True when all elements in the list have a cache hit in the embedding cache, if true return the final_embedding_cached_response no need to make an API call
|
||||
|
||||
|
||||
class LLMCachingHandler:
|
||||
def __init__(
|
||||
self,
|
||||
original_function: Callable,
|
||||
request_kwargs: Dict[str, Any],
|
||||
start_time: datetime.datetime,
|
||||
):
|
||||
self.async_streaming_chunks: List[ModelResponse] = []
|
||||
self.sync_streaming_chunks: List[ModelResponse] = []
|
||||
self.request_kwargs = request_kwargs
|
||||
self.original_function = original_function
|
||||
self.start_time = start_time
|
||||
pass
|
||||
|
||||
async def _async_get_cache(
|
||||
self,
|
||||
model: str,
|
||||
original_function: Callable,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
start_time: datetime.datetime,
|
||||
call_type: str,
|
||||
kwargs: Dict[str, Any],
|
||||
args: Optional[Tuple[Any, ...]] = None,
|
||||
) -> CachingHandlerResponse:
|
||||
"""
|
||||
Internal method to get from the cache.
|
||||
Handles different call types (embeddings, chat/completions, text_completion, transcription)
|
||||
and accordingly returns the cached response
|
||||
|
||||
Args:
|
||||
model: str:
|
||||
original_function: Callable:
|
||||
logging_obj: LiteLLMLoggingObj:
|
||||
start_time: datetime.datetime:
|
||||
call_type: str:
|
||||
kwargs: Dict[str, Any]:
|
||||
args: Optional[Tuple[Any, ...]] = None:
|
||||
|
||||
|
||||
Returns:
|
||||
CachingHandlerResponse:
|
||||
Raises:
|
||||
None
|
||||
"""
|
||||
from litellm.utils import CustomStreamWrapper
|
||||
|
||||
args = args or ()
|
||||
|
||||
final_embedding_cached_response: Optional[EmbeddingResponse] = None
|
||||
embedding_all_elements_cache_hit: bool = False
|
||||
cached_result: Optional[Any] = None
|
||||
if (
|
||||
(kwargs.get("caching", None) is None and litellm.cache is not None)
|
||||
or kwargs.get("caching", False) is True
|
||||
) and (
|
||||
kwargs.get("cache", {}).get("no-cache", False) is not True
|
||||
): # allow users to control returning cached responses from the completion function
|
||||
if litellm.cache is not None and self._is_call_type_supported_by_cache(
|
||||
original_function=original_function
|
||||
):
|
||||
verbose_logger.debug("Checking Cache")
|
||||
cached_result = await self._retrieve_from_cache(
|
||||
call_type=call_type,
|
||||
kwargs=kwargs,
|
||||
args=args,
|
||||
)
|
||||
|
||||
if cached_result is not None and not isinstance(cached_result, list):
|
||||
verbose_logger.debug("Cache Hit!")
|
||||
cache_hit = True
|
||||
end_time = datetime.datetime.now()
|
||||
model, _, _, _ = litellm.get_llm_provider(
|
||||
model=model,
|
||||
custom_llm_provider=kwargs.get("custom_llm_provider", None),
|
||||
api_base=kwargs.get("api_base", None),
|
||||
api_key=kwargs.get("api_key", None),
|
||||
)
|
||||
self._update_litellm_logging_obj_environment(
|
||||
logging_obj=logging_obj,
|
||||
model=model,
|
||||
kwargs=kwargs,
|
||||
cached_result=cached_result,
|
||||
is_async=True,
|
||||
)
|
||||
|
||||
call_type = original_function.__name__
|
||||
|
||||
cached_result = self._convert_cached_result_to_model_response(
|
||||
cached_result=cached_result,
|
||||
call_type=call_type,
|
||||
kwargs=kwargs,
|
||||
logging_obj=logging_obj,
|
||||
model=model,
|
||||
custom_llm_provider=kwargs.get("custom_llm_provider", None),
|
||||
args=args,
|
||||
)
|
||||
if kwargs.get("stream", False) is False:
|
||||
# LOG SUCCESS
|
||||
self._async_log_cache_hit_on_callbacks(
|
||||
logging_obj=logging_obj,
|
||||
cached_result=cached_result,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
cache_hit=cache_hit,
|
||||
)
|
||||
cache_key = litellm.cache._get_preset_cache_key_from_kwargs(
|
||||
**kwargs
|
||||
)
|
||||
if (
|
||||
isinstance(cached_result, BaseModel)
|
||||
or isinstance(cached_result, CustomStreamWrapper)
|
||||
) and hasattr(cached_result, "_hidden_params"):
|
||||
cached_result._hidden_params["cache_key"] = cache_key # type: ignore
|
||||
return CachingHandlerResponse(cached_result=cached_result)
|
||||
elif (
|
||||
call_type == CallTypes.aembedding.value
|
||||
and cached_result is not None
|
||||
and isinstance(cached_result, list)
|
||||
and litellm.cache is not None
|
||||
and not isinstance(
|
||||
litellm.cache.cache, S3Cache
|
||||
) # s3 doesn't support bulk writing. Exclude.
|
||||
):
|
||||
(
|
||||
final_embedding_cached_response,
|
||||
embedding_all_elements_cache_hit,
|
||||
) = self._process_async_embedding_cached_response(
|
||||
final_embedding_cached_response=final_embedding_cached_response,
|
||||
cached_result=cached_result,
|
||||
kwargs=kwargs,
|
||||
logging_obj=logging_obj,
|
||||
start_time=start_time,
|
||||
model=model,
|
||||
)
|
||||
return CachingHandlerResponse(
|
||||
final_embedding_cached_response=final_embedding_cached_response,
|
||||
embedding_all_elements_cache_hit=embedding_all_elements_cache_hit,
|
||||
)
|
||||
verbose_logger.debug(f"CACHE RESULT: {cached_result}")
|
||||
return CachingHandlerResponse(
|
||||
cached_result=cached_result,
|
||||
final_embedding_cached_response=final_embedding_cached_response,
|
||||
)
|
||||
|
||||
def _sync_get_cache(
|
||||
self,
|
||||
model: str,
|
||||
original_function: Callable,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
start_time: datetime.datetime,
|
||||
call_type: str,
|
||||
kwargs: Dict[str, Any],
|
||||
args: Optional[Tuple[Any, ...]] = None,
|
||||
) -> CachingHandlerResponse:
|
||||
from litellm.utils import CustomStreamWrapper
|
||||
|
||||
args = args or ()
|
||||
new_kwargs = kwargs.copy()
|
||||
new_kwargs.update(
|
||||
convert_args_to_kwargs(
|
||||
self.original_function,
|
||||
args,
|
||||
)
|
||||
)
|
||||
cached_result: Optional[Any] = None
|
||||
if litellm.cache is not None and self._is_call_type_supported_by_cache(
|
||||
original_function=original_function
|
||||
):
|
||||
print_verbose("Checking Cache")
|
||||
cached_result = litellm.cache.get_cache(**new_kwargs)
|
||||
if cached_result is not None:
|
||||
if "detail" in cached_result:
|
||||
# implies an error occurred
|
||||
pass
|
||||
else:
|
||||
call_type = original_function.__name__
|
||||
cached_result = self._convert_cached_result_to_model_response(
|
||||
cached_result=cached_result,
|
||||
call_type=call_type,
|
||||
kwargs=kwargs,
|
||||
logging_obj=logging_obj,
|
||||
model=model,
|
||||
custom_llm_provider=kwargs.get("custom_llm_provider", None),
|
||||
args=args,
|
||||
)
|
||||
|
||||
# LOG SUCCESS
|
||||
cache_hit = True
|
||||
end_time = datetime.datetime.now()
|
||||
(
|
||||
model,
|
||||
custom_llm_provider,
|
||||
dynamic_api_key,
|
||||
api_base,
|
||||
) = litellm.get_llm_provider(
|
||||
model=model or "",
|
||||
custom_llm_provider=kwargs.get("custom_llm_provider", None),
|
||||
api_base=kwargs.get("api_base", None),
|
||||
api_key=kwargs.get("api_key", None),
|
||||
)
|
||||
self._update_litellm_logging_obj_environment(
|
||||
logging_obj=logging_obj,
|
||||
model=model,
|
||||
kwargs=kwargs,
|
||||
cached_result=cached_result,
|
||||
is_async=False,
|
||||
)
|
||||
|
||||
threading.Thread(
|
||||
target=logging_obj.success_handler,
|
||||
args=(cached_result, start_time, end_time, cache_hit),
|
||||
).start()
|
||||
cache_key = litellm.cache._get_preset_cache_key_from_kwargs(
|
||||
**kwargs
|
||||
)
|
||||
if (
|
||||
isinstance(cached_result, BaseModel)
|
||||
or isinstance(cached_result, CustomStreamWrapper)
|
||||
) and hasattr(cached_result, "_hidden_params"):
|
||||
cached_result._hidden_params["cache_key"] = cache_key # type: ignore
|
||||
return CachingHandlerResponse(cached_result=cached_result)
|
||||
return CachingHandlerResponse(cached_result=cached_result)
|
||||
|
||||
def _process_async_embedding_cached_response(
|
||||
self,
|
||||
final_embedding_cached_response: Optional[EmbeddingResponse],
|
||||
cached_result: List[Optional[Dict[str, Any]]],
|
||||
kwargs: Dict[str, Any],
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
start_time: datetime.datetime,
|
||||
model: str,
|
||||
) -> Tuple[Optional[EmbeddingResponse], bool]:
|
||||
"""
|
||||
Returns the final embedding cached response and a boolean indicating if all elements in the list have a cache hit
|
||||
|
||||
For embedding responses, there can be a cache hit for some of the inputs in the list and a cache miss for others
|
||||
This function processes the cached embedding responses and returns the final embedding cached response and a boolean indicating if all elements in the list have a cache hit
|
||||
|
||||
Args:
|
||||
final_embedding_cached_response: Optional[EmbeddingResponse]:
|
||||
cached_result: List[Optional[Dict[str, Any]]]:
|
||||
kwargs: Dict[str, Any]:
|
||||
logging_obj: LiteLLMLoggingObj:
|
||||
start_time: datetime.datetime:
|
||||
model: str:
|
||||
|
||||
Returns:
|
||||
Tuple[Optional[EmbeddingResponse], bool]:
|
||||
Returns the final embedding cached response and a boolean indicating if all elements in the list have a cache hit
|
||||
|
||||
|
||||
"""
|
||||
embedding_all_elements_cache_hit: bool = False
|
||||
remaining_list = []
|
||||
non_null_list = []
|
||||
for idx, cr in enumerate(cached_result):
|
||||
if cr is None:
|
||||
remaining_list.append(kwargs["input"][idx])
|
||||
else:
|
||||
non_null_list.append((idx, cr))
|
||||
original_kwargs_input = kwargs["input"]
|
||||
kwargs["input"] = remaining_list
|
||||
if len(non_null_list) > 0:
|
||||
print_verbose(f"EMBEDDING CACHE HIT! - {len(non_null_list)}")
|
||||
final_embedding_cached_response = EmbeddingResponse(
|
||||
model=kwargs.get("model"),
|
||||
data=[None] * len(original_kwargs_input),
|
||||
)
|
||||
final_embedding_cached_response._hidden_params["cache_hit"] = True
|
||||
|
||||
for val in non_null_list:
|
||||
idx, cr = val # (idx, cr) tuple
|
||||
if cr is not None:
|
||||
final_embedding_cached_response.data[idx] = Embedding(
|
||||
embedding=cr["embedding"],
|
||||
index=idx,
|
||||
object="embedding",
|
||||
)
|
||||
if len(remaining_list) == 0:
|
||||
# LOG SUCCESS
|
||||
cache_hit = True
|
||||
embedding_all_elements_cache_hit = True
|
||||
end_time = datetime.datetime.now()
|
||||
(
|
||||
model,
|
||||
custom_llm_provider,
|
||||
dynamic_api_key,
|
||||
api_base,
|
||||
) = litellm.get_llm_provider(
|
||||
model=model,
|
||||
custom_llm_provider=kwargs.get("custom_llm_provider", None),
|
||||
api_base=kwargs.get("api_base", None),
|
||||
api_key=kwargs.get("api_key", None),
|
||||
)
|
||||
|
||||
self._update_litellm_logging_obj_environment(
|
||||
logging_obj=logging_obj,
|
||||
model=model,
|
||||
kwargs=kwargs,
|
||||
cached_result=final_embedding_cached_response,
|
||||
is_async=True,
|
||||
is_embedding=True,
|
||||
)
|
||||
self._async_log_cache_hit_on_callbacks(
|
||||
logging_obj=logging_obj,
|
||||
cached_result=final_embedding_cached_response,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
cache_hit=cache_hit,
|
||||
)
|
||||
return final_embedding_cached_response, embedding_all_elements_cache_hit
|
||||
return final_embedding_cached_response, embedding_all_elements_cache_hit
|
||||
|
||||
def _combine_cached_embedding_response_with_api_result(
|
||||
self,
|
||||
_caching_handler_response: CachingHandlerResponse,
|
||||
embedding_response: EmbeddingResponse,
|
||||
start_time: datetime.datetime,
|
||||
end_time: datetime.datetime,
|
||||
) -> EmbeddingResponse:
|
||||
"""
|
||||
Combines the cached embedding response with the API EmbeddingResponse
|
||||
|
||||
For caching there can be a cache hit for some of the inputs in the list and a cache miss for others
|
||||
This function combines the cached embedding response with the API EmbeddingResponse
|
||||
|
||||
Args:
|
||||
caching_handler_response: CachingHandlerResponse:
|
||||
embedding_response: EmbeddingResponse:
|
||||
|
||||
Returns:
|
||||
EmbeddingResponse:
|
||||
"""
|
||||
if _caching_handler_response.final_embedding_cached_response is None:
|
||||
return embedding_response
|
||||
|
||||
idx = 0
|
||||
final_data_list = []
|
||||
for item in _caching_handler_response.final_embedding_cached_response.data:
|
||||
if item is None and embedding_response.data is not None:
|
||||
final_data_list.append(embedding_response.data[idx])
|
||||
idx += 1
|
||||
else:
|
||||
final_data_list.append(item)
|
||||
|
||||
_caching_handler_response.final_embedding_cached_response.data = final_data_list
|
||||
_caching_handler_response.final_embedding_cached_response._hidden_params[
|
||||
"cache_hit"
|
||||
] = True
|
||||
_caching_handler_response.final_embedding_cached_response._response_ms = (
|
||||
end_time - start_time
|
||||
).total_seconds() * 1000
|
||||
return _caching_handler_response.final_embedding_cached_response
|
||||
|
||||
def _async_log_cache_hit_on_callbacks(
|
||||
self,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
cached_result: Any,
|
||||
start_time: datetime.datetime,
|
||||
end_time: datetime.datetime,
|
||||
cache_hit: bool,
|
||||
):
|
||||
"""
|
||||
Helper function to log the success of a cached result on callbacks
|
||||
|
||||
Args:
|
||||
logging_obj (LiteLLMLoggingObj): The logging object.
|
||||
cached_result: The cached result.
|
||||
start_time (datetime): The start time of the operation.
|
||||
end_time (datetime): The end time of the operation.
|
||||
cache_hit (bool): Whether it was a cache hit.
|
||||
"""
|
||||
asyncio.create_task(
|
||||
logging_obj.async_success_handler(
|
||||
cached_result, start_time, end_time, cache_hit
|
||||
)
|
||||
)
|
||||
threading.Thread(
|
||||
target=logging_obj.success_handler,
|
||||
args=(cached_result, start_time, end_time, cache_hit),
|
||||
).start()
|
||||
|
||||
async def _retrieve_from_cache(
|
||||
self, call_type: str, kwargs: Dict[str, Any], args: Tuple[Any, ...]
|
||||
) -> Optional[Any]:
|
||||
"""
|
||||
Internal method to
|
||||
- get cache key
|
||||
- check what type of cache is used - Redis, RedisSemantic, Qdrant, S3
|
||||
- async get cache value
|
||||
- return the cached value
|
||||
|
||||
Args:
|
||||
call_type: str:
|
||||
kwargs: Dict[str, Any]:
|
||||
args: Optional[Tuple[Any, ...]] = None:
|
||||
|
||||
Returns:
|
||||
Optional[Any]:
|
||||
Raises:
|
||||
None
|
||||
"""
|
||||
if litellm.cache is None:
|
||||
return None
|
||||
|
||||
new_kwargs = kwargs.copy()
|
||||
new_kwargs.update(
|
||||
convert_args_to_kwargs(
|
||||
self.original_function,
|
||||
args,
|
||||
)
|
||||
)
|
||||
cached_result: Optional[Any] = None
|
||||
if call_type == CallTypes.aembedding.value and isinstance(
|
||||
new_kwargs["input"], list
|
||||
):
|
||||
tasks = []
|
||||
for idx, i in enumerate(new_kwargs["input"]):
|
||||
preset_cache_key = litellm.cache.get_cache_key(
|
||||
**{**new_kwargs, "input": i}
|
||||
)
|
||||
tasks.append(litellm.cache.async_get_cache(cache_key=preset_cache_key))
|
||||
cached_result = await asyncio.gather(*tasks)
|
||||
## check if cached result is None ##
|
||||
if cached_result is not None and isinstance(cached_result, list):
|
||||
# set cached_result to None if all elements are None
|
||||
if all(result is None for result in cached_result):
|
||||
cached_result = None
|
||||
else:
|
||||
if litellm.cache._supports_async() is True:
|
||||
cached_result = await litellm.cache.async_get_cache(**new_kwargs)
|
||||
else: # for s3 caching. [NOT RECOMMENDED IN PROD - this will slow down responses since boto3 is sync]
|
||||
cached_result = litellm.cache.get_cache(**new_kwargs)
|
||||
return cached_result
|
||||
|
||||
def _convert_cached_result_to_model_response(
|
||||
self,
|
||||
cached_result: Any,
|
||||
call_type: str,
|
||||
kwargs: Dict[str, Any],
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
model: str,
|
||||
args: Tuple[Any, ...],
|
||||
custom_llm_provider: Optional[str] = None,
|
||||
) -> Optional[
|
||||
Union[
|
||||
ModelResponse,
|
||||
TextCompletionResponse,
|
||||
EmbeddingResponse,
|
||||
RerankResponse,
|
||||
TranscriptionResponse,
|
||||
CustomStreamWrapper,
|
||||
]
|
||||
]:
|
||||
"""
|
||||
Internal method to process the cached result
|
||||
|
||||
Checks the call type and converts the cached result to the appropriate model response object
|
||||
example if call type is text_completion -> returns TextCompletionResponse object
|
||||
|
||||
Args:
|
||||
cached_result: Any:
|
||||
call_type: str:
|
||||
kwargs: Dict[str, Any]:
|
||||
logging_obj: LiteLLMLoggingObj:
|
||||
model: str:
|
||||
custom_llm_provider: Optional[str] = None:
|
||||
args: Optional[Tuple[Any, ...]] = None:
|
||||
|
||||
Returns:
|
||||
Optional[Any]:
|
||||
"""
|
||||
from litellm.utils import convert_to_model_response_object
|
||||
|
||||
if (
|
||||
call_type == CallTypes.acompletion.value
|
||||
or call_type == CallTypes.completion.value
|
||||
) and isinstance(cached_result, dict):
|
||||
if kwargs.get("stream", False) is True:
|
||||
cached_result = self._convert_cached_stream_response(
|
||||
cached_result=cached_result,
|
||||
call_type=call_type,
|
||||
logging_obj=logging_obj,
|
||||
model=model,
|
||||
)
|
||||
else:
|
||||
cached_result = convert_to_model_response_object(
|
||||
response_object=cached_result,
|
||||
model_response_object=ModelResponse(),
|
||||
)
|
||||
if (
|
||||
call_type == CallTypes.atext_completion.value
|
||||
or call_type == CallTypes.text_completion.value
|
||||
) and isinstance(cached_result, dict):
|
||||
if kwargs.get("stream", False) is True:
|
||||
cached_result = self._convert_cached_stream_response(
|
||||
cached_result=cached_result,
|
||||
call_type=call_type,
|
||||
logging_obj=logging_obj,
|
||||
model=model,
|
||||
)
|
||||
else:
|
||||
cached_result = TextCompletionResponse(**cached_result)
|
||||
elif (
|
||||
call_type == CallTypes.aembedding.value
|
||||
or call_type == CallTypes.embedding.value
|
||||
) and isinstance(cached_result, dict):
|
||||
cached_result = convert_to_model_response_object(
|
||||
response_object=cached_result,
|
||||
model_response_object=EmbeddingResponse(),
|
||||
response_type="embedding",
|
||||
)
|
||||
|
||||
elif (
|
||||
call_type == CallTypes.arerank.value or call_type == CallTypes.rerank.value
|
||||
) and isinstance(cached_result, dict):
|
||||
cached_result = convert_to_model_response_object(
|
||||
response_object=cached_result,
|
||||
model_response_object=None,
|
||||
response_type="rerank",
|
||||
)
|
||||
elif (
|
||||
call_type == CallTypes.atranscription.value
|
||||
or call_type == CallTypes.transcription.value
|
||||
) and isinstance(cached_result, dict):
|
||||
hidden_params = {
|
||||
"model": "whisper-1",
|
||||
"custom_llm_provider": custom_llm_provider,
|
||||
"cache_hit": True,
|
||||
}
|
||||
cached_result = convert_to_model_response_object(
|
||||
response_object=cached_result,
|
||||
model_response_object=TranscriptionResponse(),
|
||||
response_type="audio_transcription",
|
||||
hidden_params=hidden_params,
|
||||
)
|
||||
|
||||
if (
|
||||
hasattr(cached_result, "_hidden_params")
|
||||
and cached_result._hidden_params is not None
|
||||
and isinstance(cached_result._hidden_params, dict)
|
||||
):
|
||||
cached_result._hidden_params["cache_hit"] = True
|
||||
return cached_result
|
||||
|
||||
def _convert_cached_stream_response(
|
||||
self,
|
||||
cached_result: Any,
|
||||
call_type: str,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
model: str,
|
||||
) -> CustomStreamWrapper:
|
||||
from litellm.utils import (
|
||||
CustomStreamWrapper,
|
||||
convert_to_streaming_response,
|
||||
convert_to_streaming_response_async,
|
||||
)
|
||||
|
||||
_stream_cached_result: Union[AsyncGenerator, Generator]
|
||||
if (
|
||||
call_type == CallTypes.acompletion.value
|
||||
or call_type == CallTypes.atext_completion.value
|
||||
):
|
||||
_stream_cached_result = convert_to_streaming_response_async(
|
||||
response_object=cached_result,
|
||||
)
|
||||
else:
|
||||
_stream_cached_result = convert_to_streaming_response(
|
||||
response_object=cached_result,
|
||||
)
|
||||
return CustomStreamWrapper(
|
||||
completion_stream=_stream_cached_result,
|
||||
model=model,
|
||||
custom_llm_provider="cached_response",
|
||||
logging_obj=logging_obj,
|
||||
)
|
||||
|
||||
async def async_set_cache(
|
||||
self,
|
||||
result: Any,
|
||||
original_function: Callable,
|
||||
kwargs: Dict[str, Any],
|
||||
args: Optional[Tuple[Any, ...]] = None,
|
||||
):
|
||||
"""
|
||||
Internal method to check the type of the result & cache used and adds the result to the cache accordingly
|
||||
|
||||
Args:
|
||||
result: Any:
|
||||
original_function: Callable:
|
||||
kwargs: Dict[str, Any]:
|
||||
args: Optional[Tuple[Any, ...]] = None:
|
||||
|
||||
Returns:
|
||||
None
|
||||
Raises:
|
||||
None
|
||||
"""
|
||||
if litellm.cache is None:
|
||||
return
|
||||
|
||||
new_kwargs = kwargs.copy()
|
||||
new_kwargs.update(
|
||||
convert_args_to_kwargs(
|
||||
original_function,
|
||||
args,
|
||||
)
|
||||
)
|
||||
# [OPTIONAL] ADD TO CACHE
|
||||
if self._should_store_result_in_cache(
|
||||
original_function=original_function, kwargs=new_kwargs
|
||||
):
|
||||
if (
|
||||
isinstance(result, litellm.ModelResponse)
|
||||
or isinstance(result, litellm.EmbeddingResponse)
|
||||
or isinstance(result, TranscriptionResponse)
|
||||
or isinstance(result, RerankResponse)
|
||||
):
|
||||
if (
|
||||
isinstance(result, EmbeddingResponse)
|
||||
and isinstance(new_kwargs["input"], list)
|
||||
and litellm.cache is not None
|
||||
and not isinstance(
|
||||
litellm.cache.cache, S3Cache
|
||||
) # s3 doesn't support bulk writing. Exclude.
|
||||
):
|
||||
asyncio.create_task(
|
||||
litellm.cache.async_add_cache_pipeline(result, **new_kwargs)
|
||||
)
|
||||
elif isinstance(litellm.cache.cache, S3Cache):
|
||||
threading.Thread(
|
||||
target=litellm.cache.add_cache,
|
||||
args=(result,),
|
||||
kwargs=new_kwargs,
|
||||
).start()
|
||||
else:
|
||||
asyncio.create_task(
|
||||
litellm.cache.async_add_cache(
|
||||
result.model_dump_json(), **new_kwargs
|
||||
)
|
||||
)
|
||||
else:
|
||||
asyncio.create_task(litellm.cache.async_add_cache(result, **new_kwargs))
|
||||
|
||||
def sync_set_cache(
|
||||
self,
|
||||
result: Any,
|
||||
kwargs: Dict[str, Any],
|
||||
args: Optional[Tuple[Any, ...]] = None,
|
||||
):
|
||||
"""
|
||||
Sync internal method to add the result to the cache
|
||||
"""
|
||||
|
||||
new_kwargs = kwargs.copy()
|
||||
new_kwargs.update(
|
||||
convert_args_to_kwargs(
|
||||
self.original_function,
|
||||
args,
|
||||
)
|
||||
)
|
||||
if litellm.cache is None:
|
||||
return
|
||||
|
||||
if self._should_store_result_in_cache(
|
||||
original_function=self.original_function, kwargs=new_kwargs
|
||||
):
|
||||
litellm.cache.add_cache(result, **new_kwargs)
|
||||
|
||||
return
|
||||
|
||||
def _should_store_result_in_cache(
|
||||
self, original_function: Callable, kwargs: Dict[str, Any]
|
||||
) -> bool:
|
||||
"""
|
||||
Helper function to determine if the result should be stored in the cache.
|
||||
|
||||
Returns:
|
||||
bool: True if the result should be stored in the cache, False otherwise.
|
||||
"""
|
||||
return (
|
||||
(litellm.cache is not None)
|
||||
and litellm.cache.supported_call_types is not None
|
||||
and (str(original_function.__name__) in litellm.cache.supported_call_types)
|
||||
and (kwargs.get("cache", {}).get("no-store", False) is not True)
|
||||
)
|
||||
|
||||
def _is_call_type_supported_by_cache(
|
||||
self,
|
||||
original_function: Callable,
|
||||
) -> bool:
|
||||
"""
|
||||
Helper function to determine if the call type is supported by the cache.
|
||||
|
||||
call types are acompletion, aembedding, atext_completion, atranscription, arerank
|
||||
|
||||
Defined on `litellm.types.utils.CallTypes`
|
||||
|
||||
Returns:
|
||||
bool: True if the call type is supported by the cache, False otherwise.
|
||||
"""
|
||||
if (
|
||||
litellm.cache is not None
|
||||
and litellm.cache.supported_call_types is not None
|
||||
and str(original_function.__name__) in litellm.cache.supported_call_types
|
||||
):
|
||||
return True
|
||||
return False
|
||||
|
||||
async def _add_streaming_response_to_cache(self, processed_chunk: ModelResponse):
|
||||
"""
|
||||
Internal method to add the streaming response to the cache
|
||||
|
||||
|
||||
- If 'streaming_chunk' has a 'finish_reason' then assemble a litellm.ModelResponse object
|
||||
- Else append the chunk to self.async_streaming_chunks
|
||||
|
||||
"""
|
||||
|
||||
complete_streaming_response: Optional[
|
||||
Union[ModelResponse, TextCompletionResponse]
|
||||
] = _assemble_complete_response_from_streaming_chunks(
|
||||
result=processed_chunk,
|
||||
start_time=self.start_time,
|
||||
end_time=datetime.datetime.now(),
|
||||
request_kwargs=self.request_kwargs,
|
||||
streaming_chunks=self.async_streaming_chunks,
|
||||
is_async=True,
|
||||
)
|
||||
# if a complete_streaming_response is assembled, add it to the cache
|
||||
if complete_streaming_response is not None:
|
||||
await self.async_set_cache(
|
||||
result=complete_streaming_response,
|
||||
original_function=self.original_function,
|
||||
kwargs=self.request_kwargs,
|
||||
)
|
||||
|
||||
def _sync_add_streaming_response_to_cache(self, processed_chunk: ModelResponse):
|
||||
"""
|
||||
Sync internal method to add the streaming response to the cache
|
||||
"""
|
||||
complete_streaming_response: Optional[
|
||||
Union[ModelResponse, TextCompletionResponse]
|
||||
] = _assemble_complete_response_from_streaming_chunks(
|
||||
result=processed_chunk,
|
||||
start_time=self.start_time,
|
||||
end_time=datetime.datetime.now(),
|
||||
request_kwargs=self.request_kwargs,
|
||||
streaming_chunks=self.sync_streaming_chunks,
|
||||
is_async=False,
|
||||
)
|
||||
|
||||
# if a complete_streaming_response is assembled, add it to the cache
|
||||
if complete_streaming_response is not None:
|
||||
self.sync_set_cache(
|
||||
result=complete_streaming_response,
|
||||
kwargs=self.request_kwargs,
|
||||
)
|
||||
|
||||
def _update_litellm_logging_obj_environment(
|
||||
self,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
model: str,
|
||||
kwargs: Dict[str, Any],
|
||||
cached_result: Any,
|
||||
is_async: bool,
|
||||
is_embedding: bool = False,
|
||||
):
|
||||
"""
|
||||
Helper function to update the LiteLLMLoggingObj environment variables.
|
||||
|
||||
Args:
|
||||
logging_obj (LiteLLMLoggingObj): The logging object to update.
|
||||
model (str): The model being used.
|
||||
kwargs (Dict[str, Any]): The keyword arguments from the original function call.
|
||||
cached_result (Any): The cached result to log.
|
||||
is_async (bool): Whether the call is asynchronous or not.
|
||||
is_embedding (bool): Whether the call is for embeddings or not.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
litellm_params = {
|
||||
"logger_fn": kwargs.get("logger_fn", None),
|
||||
"acompletion": is_async,
|
||||
"api_base": kwargs.get("api_base", ""),
|
||||
"metadata": kwargs.get("metadata", {}),
|
||||
"model_info": kwargs.get("model_info", {}),
|
||||
"proxy_server_request": kwargs.get("proxy_server_request", None),
|
||||
"stream_response": kwargs.get("stream_response", {}),
|
||||
}
|
||||
|
||||
if litellm.cache is not None:
|
||||
litellm_params[
|
||||
"preset_cache_key"
|
||||
] = litellm.cache._get_preset_cache_key_from_kwargs(**kwargs)
|
||||
else:
|
||||
litellm_params["preset_cache_key"] = None
|
||||
|
||||
logging_obj.update_environment_variables(
|
||||
model=model,
|
||||
user=kwargs.get("user", None),
|
||||
optional_params={},
|
||||
litellm_params=litellm_params,
|
||||
input=(
|
||||
kwargs.get("messages", "")
|
||||
if not is_embedding
|
||||
else kwargs.get("input", "")
|
||||
),
|
||||
api_key=kwargs.get("api_key", None),
|
||||
original_response=str(cached_result),
|
||||
additional_args=None,
|
||||
stream=kwargs.get("stream", False),
|
||||
)
|
||||
|
||||
|
||||
def convert_args_to_kwargs(
|
||||
original_function: Callable,
|
||||
args: Optional[Tuple[Any, ...]] = None,
|
||||
) -> Dict[str, Any]:
|
||||
# Get the signature of the original function
|
||||
signature = inspect.signature(original_function)
|
||||
|
||||
# Get parameter names in the order they appear in the original function
|
||||
param_names = list(signature.parameters.keys())
|
||||
|
||||
# Create a mapping of positional arguments to parameter names
|
||||
args_to_kwargs = {}
|
||||
if args:
|
||||
for index, arg in enumerate(args):
|
||||
if index < len(param_names):
|
||||
param_name = param_names[index]
|
||||
args_to_kwargs[param_name] = arg
|
||||
|
||||
return args_to_kwargs
|
||||
Reference in New Issue
Block a user