new mcp servers format

This commit is contained in:
Davidson Gomes
2025-04-28 12:37:58 -03:00
parent 0112573d9b
commit e98744b7a4
7182 changed files with 4839 additions and 4998 deletions

View File

@@ -89,6 +89,16 @@ class LiteLLMCompletionTransformationHandler:
responses_api_request: ResponsesAPIOptionalRequestParams,
**kwargs,
) -> Union[ResponsesAPIResponse, BaseResponsesAPIStreamingIterator]:
previous_response_id: Optional[str] = responses_api_request.get(
"previous_response_id"
)
if previous_response_id:
litellm_completion_request = await LiteLLMCompletionResponsesConfig.async_responses_api_session_handler(
previous_response_id=previous_response_id,
litellm_completion_request=litellm_completion_request,
)
litellm_completion_response: Union[
ModelResponse, litellm.CustomStreamWrapper
] = await litellm.acompletion(

View File

@@ -1,59 +0,0 @@
"""
Responses API has previous_response_id, which is the id of the previous response.
LiteLLM needs to maintain a cache of the previous response input, output, previous_response_id, and model.
This class handles that cache.
"""
from typing import List, Optional, Tuple, Union
from typing_extensions import TypedDict
from litellm.caching import InMemoryCache
from litellm.types.llms.openai import ResponseInputParam, ResponsesAPIResponse
RESPONSES_API_PREVIOUS_RESPONSES_CACHE = InMemoryCache()
MAX_PREV_SESSION_INPUTS = 50
class ResponsesAPISessionElement(TypedDict, total=False):
input: Union[str, ResponseInputParam]
output: ResponsesAPIResponse
response_id: str
previous_response_id: Optional[str]
class SessionHandler:
def add_completed_response_to_cache(
self, response_id: str, session_element: ResponsesAPISessionElement
):
RESPONSES_API_PREVIOUS_RESPONSES_CACHE.set_cache(
key=response_id, value=session_element
)
def get_chain_of_previous_input_output_pairs(
self, previous_response_id: str
) -> List[Tuple[ResponseInputParam, ResponsesAPIResponse]]:
response_api_inputs: List[Tuple[ResponseInputParam, ResponsesAPIResponse]] = []
current_previous_response_id = previous_response_id
count_session_elements = 0
while current_previous_response_id:
if count_session_elements > MAX_PREV_SESSION_INPUTS:
break
session_element = RESPONSES_API_PREVIOUS_RESPONSES_CACHE.get_cache(
key=current_previous_response_id
)
if session_element:
response_api_inputs.append(
(session_element.get("input"), session_element.get("output"))
)
current_previous_response_id = session_element.get(
"previous_response_id"
)
else:
break
count_session_elements += 1
return response_api_inputs

View File

@@ -5,13 +5,21 @@ Handles transforming from Responses API -> LiteLLM completion (Chat Completion
from typing import Any, Dict, List, Optional, Union
from openai.types.responses.tool_param import FunctionToolParam
from typing_extensions import TypedDict
HAS_ENTERPRISE_DIRECTORY = False
try:
from enterprise.enterprise_hooks.session_handler import (
_ENTERPRISE_ResponsesSessionHandler,
)
HAS_ENTERPRISE_DIRECTORY = True
except ImportError:
_ENTERPRISE_ResponsesSessionHandler = None # type: ignore
HAS_ENTERPRISE_DIRECTORY = False
from litellm.caching import InMemoryCache
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
from litellm.responses.litellm_completion_transformation.session_handler import (
ResponsesAPISessionElement,
SessionHandler,
)
from litellm.types.llms.openai import (
AllMessageValues,
ChatCompletionResponseMessage,
@@ -48,7 +56,21 @@ from litellm.types.utils import (
########### Initialize Classes used for Responses API ###########
TOOL_CALLS_CACHE = InMemoryCache()
RESPONSES_API_SESSION_HANDLER = SessionHandler()
class ChatCompletionSession(TypedDict, total=False):
messages: List[
Union[
AllMessageValues,
GenericChatCompletionMessage,
ChatCompletionMessageToolCall,
ChatCompletionResponseMessage,
Message,
]
]
litellm_session_id: Optional[str]
########### End of Initialize Classes used for Responses API ###########
@@ -90,7 +112,6 @@ class LiteLLMCompletionResponsesConfig:
"messages": LiteLLMCompletionResponsesConfig.transform_responses_api_input_to_messages(
input=input,
responses_api_request=responses_api_request,
previous_response_id=responses_api_request.get("previous_response_id"),
),
"model": model,
"tool_choice": responses_api_request.get("tool_choice"),
@@ -131,14 +152,14 @@ class LiteLLMCompletionResponsesConfig:
@staticmethod
def transform_responses_api_input_to_messages(
input: Union[str, ResponseInputParam],
responses_api_request: ResponsesAPIOptionalRequestParams,
previous_response_id: Optional[str] = None,
responses_api_request: Union[ResponsesAPIOptionalRequestParams, dict],
) -> List[
Union[
AllMessageValues,
GenericChatCompletionMessage,
ChatCompletionMessageToolCall,
ChatCompletionResponseMessage,
Message,
]
]:
"""
@@ -150,6 +171,7 @@ class LiteLLMCompletionResponsesConfig:
GenericChatCompletionMessage,
ChatCompletionMessageToolCall,
ChatCompletionResponseMessage,
Message,
]
] = []
if responses_api_request.get("instructions"):
@@ -159,24 +181,6 @@ class LiteLLMCompletionResponsesConfig:
)
)
if previous_response_id:
previous_response_pairs = (
RESPONSES_API_SESSION_HANDLER.get_chain_of_previous_input_output_pairs(
previous_response_id=previous_response_id
)
)
if previous_response_pairs:
for previous_response_pair in previous_response_pairs:
chat_completion_input_messages = LiteLLMCompletionResponsesConfig._transform_response_input_param_to_chat_completion_message(
input=previous_response_pair[0],
)
chat_completion_output_messages = LiteLLMCompletionResponsesConfig._transform_responses_api_outputs_to_chat_completion_messages(
responses_api_output=previous_response_pair[1],
)
messages.extend(chat_completion_input_messages)
messages.extend(chat_completion_output_messages)
messages.extend(
LiteLLMCompletionResponsesConfig._transform_response_input_param_to_chat_completion_message(
input=input,
@@ -185,6 +189,33 @@ class LiteLLMCompletionResponsesConfig:
return messages
@staticmethod
async def async_responses_api_session_handler(
previous_response_id: str,
litellm_completion_request: dict,
) -> dict:
"""
Async hook to get the chain of previous input and output pairs and return a list of Chat Completion messages
"""
if (
HAS_ENTERPRISE_DIRECTORY is True
and _ENTERPRISE_ResponsesSessionHandler is not None
):
chat_completion_session = ChatCompletionSession(
messages=[], litellm_session_id=None
)
if previous_response_id:
chat_completion_session = await _ENTERPRISE_ResponsesSessionHandler.get_chat_completion_message_history_for_previous_response_id(
previous_response_id=previous_response_id
)
_messages = litellm_completion_request.get("messages") or []
session_messages = chat_completion_session.get("messages") or []
litellm_completion_request["messages"] = session_messages + _messages
litellm_completion_request["litellm_trace_id"] = (
chat_completion_session.get("litellm_session_id")
)
return litellm_completion_request
@staticmethod
def _transform_response_input_param_to_chat_completion_message(
input: Union[str, ResponseInputParam],
@@ -471,11 +502,13 @@ class LiteLLMCompletionResponsesConfig:
def transform_chat_completion_response_to_responses_api_response(
request_input: Union[str, ResponseInputParam],
responses_api_request: ResponsesAPIOptionalRequestParams,
chat_completion_response: ModelResponse,
chat_completion_response: Union[ModelResponse, dict],
) -> ResponsesAPIResponse:
"""
Transform a Chat Completion response into a Responses API response
"""
if isinstance(chat_completion_response, dict):
chat_completion_response = ModelResponse(**chat_completion_response)
responses_api_response: ResponsesAPIResponse = ResponsesAPIResponse(
id=chat_completion_response.id,
created_at=chat_completion_response.created,
@@ -513,16 +546,6 @@ class LiteLLMCompletionResponsesConfig:
),
user=getattr(chat_completion_response, "user", None),
)
RESPONSES_API_SESSION_HANDLER.add_completed_response_to_cache(
response_id=responses_api_response.id,
session_element=ResponsesAPISessionElement(
input=request_input,
output=responses_api_response,
response_id=responses_api_response.id,
previous_response_id=responses_api_request.get("previous_response_id"),
),
)
return responses_api_response
@staticmethod

View File

@@ -434,3 +434,188 @@ def delete_responses(
completion_kwargs=local_vars,
extra_kwargs=kwargs,
)
@client
async def aget_responses(
response_id: str,
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
# The extra values given here take precedence over values defined on the client or passed to this method.
extra_headers: Optional[Dict[str, Any]] = None,
extra_query: Optional[Dict[str, Any]] = None,
extra_body: Optional[Dict[str, Any]] = None,
timeout: Optional[Union[float, httpx.Timeout]] = None,
# LiteLLM specific params,
custom_llm_provider: Optional[str] = None,
**kwargs,
) -> ResponsesAPIResponse:
"""
Async: Fetch a response by its ID.
GET /v1/responses/{response_id} endpoint in the responses API
Args:
response_id: The ID of the response to fetch.
custom_llm_provider: Optional provider name. If not specified, will be decoded from response_id.
Returns:
The response object with complete information about the stored response.
"""
local_vars = locals()
try:
loop = asyncio.get_event_loop()
kwargs["aget_responses"] = True
# get custom llm provider from response_id
decoded_response_id: DecodedResponseId = (
ResponsesAPIRequestUtils._decode_responses_api_response_id(
response_id=response_id,
)
)
response_id = decoded_response_id.get("response_id") or response_id
custom_llm_provider = (
decoded_response_id.get("custom_llm_provider") or custom_llm_provider
)
func = partial(
get_responses,
response_id=response_id,
custom_llm_provider=custom_llm_provider,
extra_headers=extra_headers,
extra_query=extra_query,
extra_body=extra_body,
timeout=timeout,
**kwargs,
)
ctx = contextvars.copy_context()
func_with_context = partial(ctx.run, func)
init_response = await loop.run_in_executor(None, func_with_context)
if asyncio.iscoroutine(init_response):
response = await init_response
else:
response = init_response
# Update the responses_api_response_id with the model_id
if isinstance(response, ResponsesAPIResponse):
response = ResponsesAPIRequestUtils._update_responses_api_response_id_with_model_id(
responses_api_response=response,
litellm_metadata=kwargs.get("litellm_metadata", {}),
custom_llm_provider=custom_llm_provider,
)
return response
except Exception as e:
raise litellm.exception_type(
model=None,
custom_llm_provider=custom_llm_provider,
original_exception=e,
completion_kwargs=local_vars,
extra_kwargs=kwargs,
)
@client
def get_responses(
response_id: str,
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
# The extra values given here take precedence over values defined on the client or passed to this method.
extra_headers: Optional[Dict[str, Any]] = None,
extra_query: Optional[Dict[str, Any]] = None,
extra_body: Optional[Dict[str, Any]] = None,
timeout: Optional[Union[float, httpx.Timeout]] = None,
# LiteLLM specific params,
custom_llm_provider: Optional[str] = None,
**kwargs,
) -> Union[ResponsesAPIResponse, Coroutine[Any, Any, ResponsesAPIResponse]]:
"""
Fetch a response by its ID.
GET /v1/responses/{response_id} endpoint in the responses API
Args:
response_id: The ID of the response to fetch.
custom_llm_provider: Optional provider name. If not specified, will be decoded from response_id.
Returns:
The response object with complete information about the stored response.
"""
local_vars = locals()
try:
litellm_logging_obj: LiteLLMLoggingObj = kwargs.get("litellm_logging_obj") # type: ignore
litellm_call_id: Optional[str] = kwargs.get("litellm_call_id", None)
_is_async = kwargs.pop("aget_responses", False) is True
# get llm provider logic
litellm_params = GenericLiteLLMParams(**kwargs)
# get custom llm provider from response_id
decoded_response_id: DecodedResponseId = (
ResponsesAPIRequestUtils._decode_responses_api_response_id(
response_id=response_id,
)
)
response_id = decoded_response_id.get("response_id") or response_id
custom_llm_provider = (
decoded_response_id.get("custom_llm_provider") or custom_llm_provider
)
if custom_llm_provider is None:
raise ValueError("custom_llm_provider is required but passed as None")
# get provider config
responses_api_provider_config: Optional[BaseResponsesAPIConfig] = (
ProviderConfigManager.get_provider_responses_api_config(
model=None,
provider=litellm.LlmProviders(custom_llm_provider),
)
)
if responses_api_provider_config is None:
raise ValueError(
f"GET responses is not supported for {custom_llm_provider}"
)
local_vars.update(kwargs)
# Pre Call logging
litellm_logging_obj.update_environment_variables(
model=None,
optional_params={
"response_id": response_id,
},
litellm_params={
"litellm_call_id": litellm_call_id,
},
custom_llm_provider=custom_llm_provider,
)
# Call the handler with _is_async flag instead of directly calling the async handler
response = base_llm_http_handler.get_responses(
response_id=response_id,
custom_llm_provider=custom_llm_provider,
responses_api_provider_config=responses_api_provider_config,
litellm_params=litellm_params,
logging_obj=litellm_logging_obj,
extra_headers=extra_headers,
extra_body=extra_body,
timeout=timeout or request_timeout,
_is_async=_is_async,
client=kwargs.get("client"),
)
# Update the responses_api_response_id with the model_id
if isinstance(response, ResponsesAPIResponse):
response = ResponsesAPIRequestUtils._update_responses_api_response_id_with_model_id(
responses_api_response=response,
litellm_metadata=kwargs.get("litellm_metadata", {}),
custom_llm_provider=custom_llm_provider,
)
return response
except Exception as e:
raise litellm.exception_type(
model=None,
custom_llm_provider=custom_llm_provider,
original_exception=e,
completion_kwargs=local_vars,
extra_kwargs=kwargs,
)

View File

@@ -176,6 +176,16 @@ class ResponsesAPIRequestUtils:
response_id=response_id,
)
@staticmethod
def get_model_id_from_response_id(response_id: Optional[str]) -> Optional[str]:
"""Get the model_id from the response_id"""
if response_id is None:
return None
decoded_response_id = (
ResponsesAPIRequestUtils._decode_responses_api_response_id(response_id)
)
return decoded_response_id.get("model_id") or None
class ResponseAPILoggingUtils:
@staticmethod