structure saas with tools
This commit is contained in:
@@ -0,0 +1,336 @@
|
||||
import asyncio
|
||||
import json
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import httpx
|
||||
|
||||
from litellm.constants import STREAM_SSE_DONE_STRING
|
||||
from litellm.litellm_core_utils.asyncify import run_async_function
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
|
||||
from litellm.litellm_core_utils.thread_pool_executor import executor
|
||||
from litellm.llms.base_llm.responses.transformation import BaseResponsesAPIConfig
|
||||
from litellm.responses.utils import ResponsesAPIRequestUtils
|
||||
from litellm.types.llms.openai import (
|
||||
OutputTextDeltaEvent,
|
||||
ResponseCompletedEvent,
|
||||
ResponsesAPIResponse,
|
||||
ResponsesAPIStreamEvents,
|
||||
ResponsesAPIStreamingResponse,
|
||||
)
|
||||
from litellm.utils import CustomStreamWrapper
|
||||
|
||||
|
||||
class BaseResponsesAPIStreamingIterator:
|
||||
"""
|
||||
Base class for streaming iterators that process responses from the Responses API.
|
||||
|
||||
This class contains shared logic for both synchronous and asynchronous iterators.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
response: httpx.Response,
|
||||
model: str,
|
||||
responses_api_provider_config: BaseResponsesAPIConfig,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
litellm_metadata: Optional[Dict[str, Any]] = None,
|
||||
custom_llm_provider: Optional[str] = None,
|
||||
):
|
||||
self.response = response
|
||||
self.model = model
|
||||
self.logging_obj = logging_obj
|
||||
self.finished = False
|
||||
self.responses_api_provider_config = responses_api_provider_config
|
||||
self.completed_response: Optional[ResponsesAPIStreamingResponse] = None
|
||||
self.start_time = datetime.now()
|
||||
|
||||
# set request kwargs
|
||||
self.litellm_metadata = litellm_metadata
|
||||
self.custom_llm_provider = custom_llm_provider
|
||||
|
||||
def _process_chunk(self, chunk) -> Optional[ResponsesAPIStreamingResponse]:
|
||||
"""Process a single chunk of data from the stream"""
|
||||
if not chunk:
|
||||
return None
|
||||
|
||||
# Handle SSE format (data: {...})
|
||||
chunk = CustomStreamWrapper._strip_sse_data_from_chunk(chunk)
|
||||
if chunk is None:
|
||||
return None
|
||||
|
||||
# Handle "[DONE]" marker
|
||||
if chunk == STREAM_SSE_DONE_STRING:
|
||||
self.finished = True
|
||||
return None
|
||||
|
||||
try:
|
||||
# Parse the JSON chunk
|
||||
parsed_chunk = json.loads(chunk)
|
||||
|
||||
# Format as ResponsesAPIStreamingResponse
|
||||
if isinstance(parsed_chunk, dict):
|
||||
openai_responses_api_chunk = (
|
||||
self.responses_api_provider_config.transform_streaming_response(
|
||||
model=self.model,
|
||||
parsed_chunk=parsed_chunk,
|
||||
logging_obj=self.logging_obj,
|
||||
)
|
||||
)
|
||||
|
||||
# if "response" in parsed_chunk, then encode litellm specific information like custom_llm_provider
|
||||
response_object = getattr(openai_responses_api_chunk, "response", None)
|
||||
if response_object:
|
||||
response = ResponsesAPIRequestUtils._update_responses_api_response_id_with_model_id(
|
||||
responses_api_response=response_object,
|
||||
litellm_metadata=self.litellm_metadata,
|
||||
custom_llm_provider=self.custom_llm_provider,
|
||||
)
|
||||
setattr(openai_responses_api_chunk, "response", response)
|
||||
|
||||
# Store the completed response
|
||||
if (
|
||||
openai_responses_api_chunk
|
||||
and openai_responses_api_chunk.type
|
||||
== ResponsesAPIStreamEvents.RESPONSE_COMPLETED
|
||||
):
|
||||
self.completed_response = openai_responses_api_chunk
|
||||
self._handle_logging_completed_response()
|
||||
|
||||
return openai_responses_api_chunk
|
||||
|
||||
return None
|
||||
except json.JSONDecodeError:
|
||||
# If we can't parse the chunk, continue
|
||||
return None
|
||||
|
||||
def _handle_logging_completed_response(self):
|
||||
"""Base implementation - should be overridden by subclasses"""
|
||||
pass
|
||||
|
||||
|
||||
class ResponsesAPIStreamingIterator(BaseResponsesAPIStreamingIterator):
|
||||
"""
|
||||
Async iterator for processing streaming responses from the Responses API.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
response: httpx.Response,
|
||||
model: str,
|
||||
responses_api_provider_config: BaseResponsesAPIConfig,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
litellm_metadata: Optional[Dict[str, Any]] = None,
|
||||
custom_llm_provider: Optional[str] = None,
|
||||
):
|
||||
super().__init__(
|
||||
response,
|
||||
model,
|
||||
responses_api_provider_config,
|
||||
logging_obj,
|
||||
litellm_metadata,
|
||||
custom_llm_provider,
|
||||
)
|
||||
self.stream_iterator = response.aiter_lines()
|
||||
|
||||
def __aiter__(self):
|
||||
return self
|
||||
|
||||
async def __anext__(self) -> ResponsesAPIStreamingResponse:
|
||||
try:
|
||||
while True:
|
||||
# Get the next chunk from the stream
|
||||
try:
|
||||
chunk = await self.stream_iterator.__anext__()
|
||||
except StopAsyncIteration:
|
||||
self.finished = True
|
||||
raise StopAsyncIteration
|
||||
|
||||
result = self._process_chunk(chunk)
|
||||
|
||||
if self.finished:
|
||||
raise StopAsyncIteration
|
||||
elif result is not None:
|
||||
return result
|
||||
# If result is None, continue the loop to get the next chunk
|
||||
|
||||
except httpx.HTTPError as e:
|
||||
# Handle HTTP errors
|
||||
self.finished = True
|
||||
raise e
|
||||
|
||||
def _handle_logging_completed_response(self):
|
||||
"""Handle logging for completed responses in async context"""
|
||||
asyncio.create_task(
|
||||
self.logging_obj.async_success_handler(
|
||||
result=self.completed_response,
|
||||
start_time=self.start_time,
|
||||
end_time=datetime.now(),
|
||||
cache_hit=None,
|
||||
)
|
||||
)
|
||||
|
||||
executor.submit(
|
||||
self.logging_obj.success_handler,
|
||||
result=self.completed_response,
|
||||
cache_hit=None,
|
||||
start_time=self.start_time,
|
||||
end_time=datetime.now(),
|
||||
)
|
||||
|
||||
|
||||
class SyncResponsesAPIStreamingIterator(BaseResponsesAPIStreamingIterator):
|
||||
"""
|
||||
Synchronous iterator for processing streaming responses from the Responses API.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
response: httpx.Response,
|
||||
model: str,
|
||||
responses_api_provider_config: BaseResponsesAPIConfig,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
litellm_metadata: Optional[Dict[str, Any]] = None,
|
||||
custom_llm_provider: Optional[str] = None,
|
||||
):
|
||||
super().__init__(
|
||||
response,
|
||||
model,
|
||||
responses_api_provider_config,
|
||||
logging_obj,
|
||||
litellm_metadata,
|
||||
custom_llm_provider,
|
||||
)
|
||||
self.stream_iterator = response.iter_lines()
|
||||
|
||||
def __iter__(self):
|
||||
return self
|
||||
|
||||
def __next__(self):
|
||||
try:
|
||||
while True:
|
||||
# Get the next chunk from the stream
|
||||
try:
|
||||
chunk = next(self.stream_iterator)
|
||||
except StopIteration:
|
||||
self.finished = True
|
||||
raise StopIteration
|
||||
|
||||
result = self._process_chunk(chunk)
|
||||
|
||||
if self.finished:
|
||||
raise StopIteration
|
||||
elif result is not None:
|
||||
return result
|
||||
# If result is None, continue the loop to get the next chunk
|
||||
|
||||
except httpx.HTTPError as e:
|
||||
# Handle HTTP errors
|
||||
self.finished = True
|
||||
raise e
|
||||
|
||||
def _handle_logging_completed_response(self):
|
||||
"""Handle logging for completed responses in sync context"""
|
||||
run_async_function(
|
||||
async_function=self.logging_obj.async_success_handler,
|
||||
result=self.completed_response,
|
||||
start_time=self.start_time,
|
||||
end_time=datetime.now(),
|
||||
cache_hit=None,
|
||||
)
|
||||
|
||||
executor.submit(
|
||||
self.logging_obj.success_handler,
|
||||
result=self.completed_response,
|
||||
cache_hit=None,
|
||||
start_time=self.start_time,
|
||||
end_time=datetime.now(),
|
||||
)
|
||||
|
||||
|
||||
class MockResponsesAPIStreamingIterator(BaseResponsesAPIStreamingIterator):
|
||||
"""
|
||||
Mock iterator—fake a stream by slicing the full response text into
|
||||
5 char deltas, then emit a completed event.
|
||||
|
||||
Models like o1-pro don't support streaming, so we fake it.
|
||||
"""
|
||||
|
||||
CHUNK_SIZE = 5
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
response: httpx.Response,
|
||||
model: str,
|
||||
responses_api_provider_config: BaseResponsesAPIConfig,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
litellm_metadata: Optional[Dict[str, Any]] = None,
|
||||
custom_llm_provider: Optional[str] = None,
|
||||
):
|
||||
super().__init__(
|
||||
response=response,
|
||||
model=model,
|
||||
responses_api_provider_config=responses_api_provider_config,
|
||||
logging_obj=logging_obj,
|
||||
litellm_metadata=litellm_metadata,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
)
|
||||
|
||||
# one-time transform
|
||||
transformed = (
|
||||
self.responses_api_provider_config.transform_response_api_response(
|
||||
model=self.model,
|
||||
raw_response=response,
|
||||
logging_obj=logging_obj,
|
||||
)
|
||||
)
|
||||
full_text = self._collect_text(transformed)
|
||||
|
||||
# build a list of 5‑char delta events
|
||||
deltas = [
|
||||
OutputTextDeltaEvent(
|
||||
type=ResponsesAPIStreamEvents.OUTPUT_TEXT_DELTA,
|
||||
delta=full_text[i : i + self.CHUNK_SIZE],
|
||||
item_id=transformed.id,
|
||||
output_index=0,
|
||||
content_index=0,
|
||||
)
|
||||
for i in range(0, len(full_text), self.CHUNK_SIZE)
|
||||
]
|
||||
|
||||
# append the completed event
|
||||
self._events = deltas + [
|
||||
ResponseCompletedEvent(
|
||||
type=ResponsesAPIStreamEvents.RESPONSE_COMPLETED,
|
||||
response=transformed,
|
||||
)
|
||||
]
|
||||
self._idx = 0
|
||||
|
||||
def __aiter__(self):
|
||||
return self
|
||||
|
||||
async def __anext__(self) -> ResponsesAPIStreamingResponse:
|
||||
if self._idx >= len(self._events):
|
||||
raise StopAsyncIteration
|
||||
evt = self._events[self._idx]
|
||||
self._idx += 1
|
||||
return evt
|
||||
|
||||
def __iter__(self):
|
||||
return self
|
||||
|
||||
def __next__(self) -> ResponsesAPIStreamingResponse:
|
||||
if self._idx >= len(self._events):
|
||||
raise StopIteration
|
||||
evt = self._events[self._idx]
|
||||
self._idx += 1
|
||||
return evt
|
||||
|
||||
def _collect_text(self, resp: ResponsesAPIResponse) -> str:
|
||||
out = ""
|
||||
for out_item in resp.output:
|
||||
if out_item.type == "message":
|
||||
for c in getattr(out_item, "content", []):
|
||||
out += c.text
|
||||
return out
|
||||
Reference in New Issue
Block a user