new mcp servers format
This commit is contained in:
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -89,6 +89,16 @@ class LiteLLMCompletionTransformationHandler:
|
||||
responses_api_request: ResponsesAPIOptionalRequestParams,
|
||||
**kwargs,
|
||||
) -> Union[ResponsesAPIResponse, BaseResponsesAPIStreamingIterator]:
|
||||
|
||||
previous_response_id: Optional[str] = responses_api_request.get(
|
||||
"previous_response_id"
|
||||
)
|
||||
if previous_response_id:
|
||||
litellm_completion_request = await LiteLLMCompletionResponsesConfig.async_responses_api_session_handler(
|
||||
previous_response_id=previous_response_id,
|
||||
litellm_completion_request=litellm_completion_request,
|
||||
)
|
||||
|
||||
litellm_completion_response: Union[
|
||||
ModelResponse, litellm.CustomStreamWrapper
|
||||
] = await litellm.acompletion(
|
||||
|
||||
@@ -1,59 +0,0 @@
|
||||
"""
|
||||
Responses API has previous_response_id, which is the id of the previous response.
|
||||
|
||||
LiteLLM needs to maintain a cache of the previous response input, output, previous_response_id, and model.
|
||||
|
||||
This class handles that cache.
|
||||
"""
|
||||
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from litellm.caching import InMemoryCache
|
||||
from litellm.types.llms.openai import ResponseInputParam, ResponsesAPIResponse
|
||||
|
||||
RESPONSES_API_PREVIOUS_RESPONSES_CACHE = InMemoryCache()
|
||||
MAX_PREV_SESSION_INPUTS = 50
|
||||
|
||||
|
||||
class ResponsesAPISessionElement(TypedDict, total=False):
|
||||
input: Union[str, ResponseInputParam]
|
||||
output: ResponsesAPIResponse
|
||||
response_id: str
|
||||
previous_response_id: Optional[str]
|
||||
|
||||
|
||||
class SessionHandler:
|
||||
|
||||
def add_completed_response_to_cache(
|
||||
self, response_id: str, session_element: ResponsesAPISessionElement
|
||||
):
|
||||
RESPONSES_API_PREVIOUS_RESPONSES_CACHE.set_cache(
|
||||
key=response_id, value=session_element
|
||||
)
|
||||
|
||||
def get_chain_of_previous_input_output_pairs(
|
||||
self, previous_response_id: str
|
||||
) -> List[Tuple[ResponseInputParam, ResponsesAPIResponse]]:
|
||||
response_api_inputs: List[Tuple[ResponseInputParam, ResponsesAPIResponse]] = []
|
||||
current_previous_response_id = previous_response_id
|
||||
|
||||
count_session_elements = 0
|
||||
while current_previous_response_id:
|
||||
if count_session_elements > MAX_PREV_SESSION_INPUTS:
|
||||
break
|
||||
session_element = RESPONSES_API_PREVIOUS_RESPONSES_CACHE.get_cache(
|
||||
key=current_previous_response_id
|
||||
)
|
||||
if session_element:
|
||||
response_api_inputs.append(
|
||||
(session_element.get("input"), session_element.get("output"))
|
||||
)
|
||||
current_previous_response_id = session_element.get(
|
||||
"previous_response_id"
|
||||
)
|
||||
else:
|
||||
break
|
||||
count_session_elements += 1
|
||||
return response_api_inputs
|
||||
@@ -5,13 +5,21 @@ Handles transforming from Responses API -> LiteLLM completion (Chat Completion
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from openai.types.responses.tool_param import FunctionToolParam
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
HAS_ENTERPRISE_DIRECTORY = False
|
||||
try:
|
||||
from enterprise.enterprise_hooks.session_handler import (
|
||||
_ENTERPRISE_ResponsesSessionHandler,
|
||||
)
|
||||
|
||||
HAS_ENTERPRISE_DIRECTORY = True
|
||||
except ImportError:
|
||||
_ENTERPRISE_ResponsesSessionHandler = None # type: ignore
|
||||
HAS_ENTERPRISE_DIRECTORY = False
|
||||
|
||||
from litellm.caching import InMemoryCache
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
|
||||
from litellm.responses.litellm_completion_transformation.session_handler import (
|
||||
ResponsesAPISessionElement,
|
||||
SessionHandler,
|
||||
)
|
||||
from litellm.types.llms.openai import (
|
||||
AllMessageValues,
|
||||
ChatCompletionResponseMessage,
|
||||
@@ -48,7 +56,21 @@ from litellm.types.utils import (
|
||||
|
||||
########### Initialize Classes used for Responses API ###########
|
||||
TOOL_CALLS_CACHE = InMemoryCache()
|
||||
RESPONSES_API_SESSION_HANDLER = SessionHandler()
|
||||
|
||||
|
||||
class ChatCompletionSession(TypedDict, total=False):
|
||||
messages: List[
|
||||
Union[
|
||||
AllMessageValues,
|
||||
GenericChatCompletionMessage,
|
||||
ChatCompletionMessageToolCall,
|
||||
ChatCompletionResponseMessage,
|
||||
Message,
|
||||
]
|
||||
]
|
||||
litellm_session_id: Optional[str]
|
||||
|
||||
|
||||
########### End of Initialize Classes used for Responses API ###########
|
||||
|
||||
|
||||
@@ -90,7 +112,6 @@ class LiteLLMCompletionResponsesConfig:
|
||||
"messages": LiteLLMCompletionResponsesConfig.transform_responses_api_input_to_messages(
|
||||
input=input,
|
||||
responses_api_request=responses_api_request,
|
||||
previous_response_id=responses_api_request.get("previous_response_id"),
|
||||
),
|
||||
"model": model,
|
||||
"tool_choice": responses_api_request.get("tool_choice"),
|
||||
@@ -131,14 +152,14 @@ class LiteLLMCompletionResponsesConfig:
|
||||
@staticmethod
|
||||
def transform_responses_api_input_to_messages(
|
||||
input: Union[str, ResponseInputParam],
|
||||
responses_api_request: ResponsesAPIOptionalRequestParams,
|
||||
previous_response_id: Optional[str] = None,
|
||||
responses_api_request: Union[ResponsesAPIOptionalRequestParams, dict],
|
||||
) -> List[
|
||||
Union[
|
||||
AllMessageValues,
|
||||
GenericChatCompletionMessage,
|
||||
ChatCompletionMessageToolCall,
|
||||
ChatCompletionResponseMessage,
|
||||
Message,
|
||||
]
|
||||
]:
|
||||
"""
|
||||
@@ -150,6 +171,7 @@ class LiteLLMCompletionResponsesConfig:
|
||||
GenericChatCompletionMessage,
|
||||
ChatCompletionMessageToolCall,
|
||||
ChatCompletionResponseMessage,
|
||||
Message,
|
||||
]
|
||||
] = []
|
||||
if responses_api_request.get("instructions"):
|
||||
@@ -159,24 +181,6 @@ class LiteLLMCompletionResponsesConfig:
|
||||
)
|
||||
)
|
||||
|
||||
if previous_response_id:
|
||||
previous_response_pairs = (
|
||||
RESPONSES_API_SESSION_HANDLER.get_chain_of_previous_input_output_pairs(
|
||||
previous_response_id=previous_response_id
|
||||
)
|
||||
)
|
||||
if previous_response_pairs:
|
||||
for previous_response_pair in previous_response_pairs:
|
||||
chat_completion_input_messages = LiteLLMCompletionResponsesConfig._transform_response_input_param_to_chat_completion_message(
|
||||
input=previous_response_pair[0],
|
||||
)
|
||||
chat_completion_output_messages = LiteLLMCompletionResponsesConfig._transform_responses_api_outputs_to_chat_completion_messages(
|
||||
responses_api_output=previous_response_pair[1],
|
||||
)
|
||||
|
||||
messages.extend(chat_completion_input_messages)
|
||||
messages.extend(chat_completion_output_messages)
|
||||
|
||||
messages.extend(
|
||||
LiteLLMCompletionResponsesConfig._transform_response_input_param_to_chat_completion_message(
|
||||
input=input,
|
||||
@@ -185,6 +189,33 @@ class LiteLLMCompletionResponsesConfig:
|
||||
|
||||
return messages
|
||||
|
||||
@staticmethod
|
||||
async def async_responses_api_session_handler(
|
||||
previous_response_id: str,
|
||||
litellm_completion_request: dict,
|
||||
) -> dict:
|
||||
"""
|
||||
Async hook to get the chain of previous input and output pairs and return a list of Chat Completion messages
|
||||
"""
|
||||
if (
|
||||
HAS_ENTERPRISE_DIRECTORY is True
|
||||
and _ENTERPRISE_ResponsesSessionHandler is not None
|
||||
):
|
||||
chat_completion_session = ChatCompletionSession(
|
||||
messages=[], litellm_session_id=None
|
||||
)
|
||||
if previous_response_id:
|
||||
chat_completion_session = await _ENTERPRISE_ResponsesSessionHandler.get_chat_completion_message_history_for_previous_response_id(
|
||||
previous_response_id=previous_response_id
|
||||
)
|
||||
_messages = litellm_completion_request.get("messages") or []
|
||||
session_messages = chat_completion_session.get("messages") or []
|
||||
litellm_completion_request["messages"] = session_messages + _messages
|
||||
litellm_completion_request["litellm_trace_id"] = (
|
||||
chat_completion_session.get("litellm_session_id")
|
||||
)
|
||||
return litellm_completion_request
|
||||
|
||||
@staticmethod
|
||||
def _transform_response_input_param_to_chat_completion_message(
|
||||
input: Union[str, ResponseInputParam],
|
||||
@@ -471,11 +502,13 @@ class LiteLLMCompletionResponsesConfig:
|
||||
def transform_chat_completion_response_to_responses_api_response(
|
||||
request_input: Union[str, ResponseInputParam],
|
||||
responses_api_request: ResponsesAPIOptionalRequestParams,
|
||||
chat_completion_response: ModelResponse,
|
||||
chat_completion_response: Union[ModelResponse, dict],
|
||||
) -> ResponsesAPIResponse:
|
||||
"""
|
||||
Transform a Chat Completion response into a Responses API response
|
||||
"""
|
||||
if isinstance(chat_completion_response, dict):
|
||||
chat_completion_response = ModelResponse(**chat_completion_response)
|
||||
responses_api_response: ResponsesAPIResponse = ResponsesAPIResponse(
|
||||
id=chat_completion_response.id,
|
||||
created_at=chat_completion_response.created,
|
||||
@@ -513,16 +546,6 @@ class LiteLLMCompletionResponsesConfig:
|
||||
),
|
||||
user=getattr(chat_completion_response, "user", None),
|
||||
)
|
||||
|
||||
RESPONSES_API_SESSION_HANDLER.add_completed_response_to_cache(
|
||||
response_id=responses_api_response.id,
|
||||
session_element=ResponsesAPISessionElement(
|
||||
input=request_input,
|
||||
output=responses_api_response,
|
||||
response_id=responses_api_response.id,
|
||||
previous_response_id=responses_api_request.get("previous_response_id"),
|
||||
),
|
||||
)
|
||||
return responses_api_response
|
||||
|
||||
@staticmethod
|
||||
|
||||
@@ -434,3 +434,188 @@ def delete_responses(
|
||||
completion_kwargs=local_vars,
|
||||
extra_kwargs=kwargs,
|
||||
)
|
||||
|
||||
@client
|
||||
async def aget_responses(
|
||||
response_id: str,
|
||||
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
|
||||
# The extra values given here take precedence over values defined on the client or passed to this method.
|
||||
extra_headers: Optional[Dict[str, Any]] = None,
|
||||
extra_query: Optional[Dict[str, Any]] = None,
|
||||
extra_body: Optional[Dict[str, Any]] = None,
|
||||
timeout: Optional[Union[float, httpx.Timeout]] = None,
|
||||
# LiteLLM specific params,
|
||||
custom_llm_provider: Optional[str] = None,
|
||||
**kwargs,
|
||||
) -> ResponsesAPIResponse:
|
||||
"""
|
||||
Async: Fetch a response by its ID.
|
||||
|
||||
GET /v1/responses/{response_id} endpoint in the responses API
|
||||
|
||||
Args:
|
||||
response_id: The ID of the response to fetch.
|
||||
custom_llm_provider: Optional provider name. If not specified, will be decoded from response_id.
|
||||
|
||||
Returns:
|
||||
The response object with complete information about the stored response.
|
||||
"""
|
||||
local_vars = locals()
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
kwargs["aget_responses"] = True
|
||||
|
||||
# get custom llm provider from response_id
|
||||
decoded_response_id: DecodedResponseId = (
|
||||
ResponsesAPIRequestUtils._decode_responses_api_response_id(
|
||||
response_id=response_id,
|
||||
)
|
||||
)
|
||||
response_id = decoded_response_id.get("response_id") or response_id
|
||||
custom_llm_provider = (
|
||||
decoded_response_id.get("custom_llm_provider") or custom_llm_provider
|
||||
)
|
||||
|
||||
func = partial(
|
||||
get_responses,
|
||||
response_id=response_id,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
extra_headers=extra_headers,
|
||||
extra_query=extra_query,
|
||||
extra_body=extra_body,
|
||||
timeout=timeout,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
ctx = contextvars.copy_context()
|
||||
func_with_context = partial(ctx.run, func)
|
||||
init_response = await loop.run_in_executor(None, func_with_context)
|
||||
|
||||
if asyncio.iscoroutine(init_response):
|
||||
response = await init_response
|
||||
else:
|
||||
response = init_response
|
||||
|
||||
# Update the responses_api_response_id with the model_id
|
||||
if isinstance(response, ResponsesAPIResponse):
|
||||
response = ResponsesAPIRequestUtils._update_responses_api_response_id_with_model_id(
|
||||
responses_api_response=response,
|
||||
litellm_metadata=kwargs.get("litellm_metadata", {}),
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
)
|
||||
return response
|
||||
except Exception as e:
|
||||
raise litellm.exception_type(
|
||||
model=None,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
original_exception=e,
|
||||
completion_kwargs=local_vars,
|
||||
extra_kwargs=kwargs,
|
||||
)
|
||||
|
||||
@client
|
||||
def get_responses(
|
||||
response_id: str,
|
||||
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
|
||||
# The extra values given here take precedence over values defined on the client or passed to this method.
|
||||
extra_headers: Optional[Dict[str, Any]] = None,
|
||||
extra_query: Optional[Dict[str, Any]] = None,
|
||||
extra_body: Optional[Dict[str, Any]] = None,
|
||||
timeout: Optional[Union[float, httpx.Timeout]] = None,
|
||||
# LiteLLM specific params,
|
||||
custom_llm_provider: Optional[str] = None,
|
||||
**kwargs,
|
||||
) -> Union[ResponsesAPIResponse, Coroutine[Any, Any, ResponsesAPIResponse]]:
|
||||
"""
|
||||
Fetch a response by its ID.
|
||||
|
||||
GET /v1/responses/{response_id} endpoint in the responses API
|
||||
|
||||
Args:
|
||||
response_id: The ID of the response to fetch.
|
||||
custom_llm_provider: Optional provider name. If not specified, will be decoded from response_id.
|
||||
|
||||
Returns:
|
||||
The response object with complete information about the stored response.
|
||||
"""
|
||||
local_vars = locals()
|
||||
try:
|
||||
litellm_logging_obj: LiteLLMLoggingObj = kwargs.get("litellm_logging_obj") # type: ignore
|
||||
litellm_call_id: Optional[str] = kwargs.get("litellm_call_id", None)
|
||||
_is_async = kwargs.pop("aget_responses", False) is True
|
||||
|
||||
# get llm provider logic
|
||||
litellm_params = GenericLiteLLMParams(**kwargs)
|
||||
|
||||
# get custom llm provider from response_id
|
||||
decoded_response_id: DecodedResponseId = (
|
||||
ResponsesAPIRequestUtils._decode_responses_api_response_id(
|
||||
response_id=response_id,
|
||||
)
|
||||
)
|
||||
response_id = decoded_response_id.get("response_id") or response_id
|
||||
custom_llm_provider = (
|
||||
decoded_response_id.get("custom_llm_provider") or custom_llm_provider
|
||||
)
|
||||
|
||||
if custom_llm_provider is None:
|
||||
raise ValueError("custom_llm_provider is required but passed as None")
|
||||
|
||||
# get provider config
|
||||
responses_api_provider_config: Optional[BaseResponsesAPIConfig] = (
|
||||
ProviderConfigManager.get_provider_responses_api_config(
|
||||
model=None,
|
||||
provider=litellm.LlmProviders(custom_llm_provider),
|
||||
)
|
||||
)
|
||||
|
||||
if responses_api_provider_config is None:
|
||||
raise ValueError(
|
||||
f"GET responses is not supported for {custom_llm_provider}"
|
||||
)
|
||||
|
||||
local_vars.update(kwargs)
|
||||
|
||||
# Pre Call logging
|
||||
litellm_logging_obj.update_environment_variables(
|
||||
model=None,
|
||||
optional_params={
|
||||
"response_id": response_id,
|
||||
},
|
||||
litellm_params={
|
||||
"litellm_call_id": litellm_call_id,
|
||||
},
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
)
|
||||
|
||||
# Call the handler with _is_async flag instead of directly calling the async handler
|
||||
response = base_llm_http_handler.get_responses(
|
||||
response_id=response_id,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
responses_api_provider_config=responses_api_provider_config,
|
||||
litellm_params=litellm_params,
|
||||
logging_obj=litellm_logging_obj,
|
||||
extra_headers=extra_headers,
|
||||
extra_body=extra_body,
|
||||
timeout=timeout or request_timeout,
|
||||
_is_async=_is_async,
|
||||
client=kwargs.get("client"),
|
||||
)
|
||||
|
||||
# Update the responses_api_response_id with the model_id
|
||||
if isinstance(response, ResponsesAPIResponse):
|
||||
response = ResponsesAPIRequestUtils._update_responses_api_response_id_with_model_id(
|
||||
responses_api_response=response,
|
||||
litellm_metadata=kwargs.get("litellm_metadata", {}),
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
)
|
||||
|
||||
return response
|
||||
except Exception as e:
|
||||
raise litellm.exception_type(
|
||||
model=None,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
original_exception=e,
|
||||
completion_kwargs=local_vars,
|
||||
extra_kwargs=kwargs,
|
||||
)
|
||||
@@ -176,6 +176,16 @@ class ResponsesAPIRequestUtils:
|
||||
response_id=response_id,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_model_id_from_response_id(response_id: Optional[str]) -> Optional[str]:
|
||||
"""Get the model_id from the response_id"""
|
||||
if response_id is None:
|
||||
return None
|
||||
decoded_response_id = (
|
||||
ResponsesAPIRequestUtils._decode_responses_api_response_id(response_id)
|
||||
)
|
||||
return decoded_response_id.get("model_id") or None
|
||||
|
||||
|
||||
class ResponseAPILoggingUtils:
|
||||
@staticmethod
|
||||
|
||||
Reference in New Issue
Block a user