structure saas with tools
This commit is contained in:
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,16 @@
|
||||
from fastapi import Request
|
||||
|
||||
|
||||
def get_litellm_virtual_key(request: Request) -> str:
|
||||
"""
|
||||
Extract and format API key from request headers.
|
||||
Prioritizes x-litellm-api-key over Authorization header.
|
||||
|
||||
|
||||
Vertex JS SDK uses `Authorization` header, we use `x-litellm-api-key` to pass litellm virtual key
|
||||
|
||||
"""
|
||||
litellm_api_key = request.headers.get("x-litellm-api-key")
|
||||
if litellm_api_key:
|
||||
return f"Bearer {litellm_api_key}"
|
||||
return request.headers.get("Authorization", "")
|
||||
@@ -0,0 +1,965 @@
|
||||
"""
|
||||
What is this?
|
||||
|
||||
Provider-specific Pass-Through Endpoints
|
||||
|
||||
Use litellm with Anthropic SDK, Vertex AI SDK, Cohere SDK, etc.
|
||||
"""
|
||||
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
import httpx
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request, Response
|
||||
|
||||
import litellm
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.constants import BEDROCK_AGENT_RUNTIME_PASS_THROUGH_ROUTES
|
||||
from litellm.llms.vertex_ai.vertex_llm_base import VertexBase
|
||||
from litellm.proxy._types import *
|
||||
from litellm.proxy.auth.route_checks import RouteChecks
|
||||
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
|
||||
from litellm.proxy.pass_through_endpoints.common_utils import get_litellm_virtual_key
|
||||
from litellm.proxy.pass_through_endpoints.pass_through_endpoints import (
|
||||
create_pass_through_route,
|
||||
)
|
||||
from litellm.secret_managers.main import get_secret_str
|
||||
|
||||
from .passthrough_endpoint_router import PassthroughEndpointRouter
|
||||
|
||||
vertex_llm_base = VertexBase()
|
||||
router = APIRouter()
|
||||
default_vertex_config = None
|
||||
|
||||
passthrough_endpoint_router = PassthroughEndpointRouter()
|
||||
|
||||
|
||||
def create_request_copy(request: Request):
|
||||
return {
|
||||
"method": request.method,
|
||||
"url": str(request.url),
|
||||
"headers": dict(request.headers),
|
||||
"cookies": request.cookies,
|
||||
"query_params": dict(request.query_params),
|
||||
}
|
||||
|
||||
|
||||
async def llm_passthrough_factory_proxy_route(
|
||||
custom_llm_provider: str,
|
||||
endpoint: str,
|
||||
request: Request,
|
||||
fastapi_response: Response,
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
):
|
||||
"""
|
||||
Factory function for creating pass-through endpoints for LLM providers.
|
||||
"""
|
||||
from litellm.types.utils import LlmProviders
|
||||
from litellm.utils import ProviderConfigManager
|
||||
|
||||
provider_config = ProviderConfigManager.get_provider_model_info(
|
||||
provider=LlmProviders(custom_llm_provider),
|
||||
model=None,
|
||||
)
|
||||
if provider_config is None:
|
||||
raise HTTPException(
|
||||
status_code=404, detail=f"Provider {custom_llm_provider} not found"
|
||||
)
|
||||
base_target_url = provider_config.get_api_base()
|
||||
|
||||
if base_target_url is None:
|
||||
raise HTTPException(
|
||||
status_code=404, detail=f"Provider {custom_llm_provider} api base not found"
|
||||
)
|
||||
|
||||
encoded_endpoint = httpx.URL(endpoint).path
|
||||
|
||||
# Ensure endpoint starts with '/' for proper URL construction
|
||||
if not encoded_endpoint.startswith("/"):
|
||||
encoded_endpoint = "/" + encoded_endpoint
|
||||
|
||||
# Construct the full target URL using httpx
|
||||
base_url = httpx.URL(base_target_url)
|
||||
updated_url = base_url.copy_with(path=encoded_endpoint)
|
||||
|
||||
# Add or update query parameters
|
||||
provider_api_key = passthrough_endpoint_router.get_credentials(
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
region_name=None,
|
||||
)
|
||||
|
||||
auth_headers = provider_config.validate_environment(
|
||||
headers={},
|
||||
model="",
|
||||
messages=[],
|
||||
optional_params={},
|
||||
litellm_params={},
|
||||
api_key=provider_api_key,
|
||||
api_base=base_target_url,
|
||||
)
|
||||
|
||||
## check for streaming
|
||||
is_streaming_request = False
|
||||
# anthropic is streaming when 'stream' = True is in the body
|
||||
if request.method == "POST":
|
||||
_request_body = await request.json()
|
||||
if _request_body.get("stream"):
|
||||
is_streaming_request = True
|
||||
|
||||
## CREATE PASS-THROUGH
|
||||
endpoint_func = create_pass_through_route(
|
||||
endpoint=endpoint,
|
||||
target=str(updated_url),
|
||||
custom_headers=auth_headers,
|
||||
) # dynamically construct pass-through endpoint based on incoming path
|
||||
received_value = await endpoint_func(
|
||||
request,
|
||||
fastapi_response,
|
||||
user_api_key_dict,
|
||||
stream=is_streaming_request, # type: ignore
|
||||
)
|
||||
|
||||
return received_value
|
||||
|
||||
|
||||
@router.api_route(
|
||||
"/gemini/{endpoint:path}",
|
||||
methods=["GET", "POST", "PUT", "DELETE", "PATCH"],
|
||||
tags=["Google AI Studio Pass-through", "pass-through"],
|
||||
)
|
||||
async def gemini_proxy_route(
|
||||
endpoint: str,
|
||||
request: Request,
|
||||
fastapi_response: Response,
|
||||
):
|
||||
"""
|
||||
[Docs](https://docs.litellm.ai/docs/pass_through/google_ai_studio)
|
||||
"""
|
||||
## CHECK FOR LITELLM API KEY IN THE QUERY PARAMS - ?..key=LITELLM_API_KEY
|
||||
google_ai_studio_api_key = request.query_params.get("key") or request.headers.get(
|
||||
"x-goog-api-key"
|
||||
)
|
||||
|
||||
user_api_key_dict = await user_api_key_auth(
|
||||
request=request, api_key=f"Bearer {google_ai_studio_api_key}"
|
||||
)
|
||||
|
||||
base_target_url = "https://generativelanguage.googleapis.com"
|
||||
encoded_endpoint = httpx.URL(endpoint).path
|
||||
|
||||
# Ensure endpoint starts with '/' for proper URL construction
|
||||
if not encoded_endpoint.startswith("/"):
|
||||
encoded_endpoint = "/" + encoded_endpoint
|
||||
|
||||
# Construct the full target URL using httpx
|
||||
base_url = httpx.URL(base_target_url)
|
||||
updated_url = base_url.copy_with(path=encoded_endpoint)
|
||||
|
||||
# Add or update query parameters
|
||||
gemini_api_key: Optional[str] = passthrough_endpoint_router.get_credentials(
|
||||
custom_llm_provider="gemini",
|
||||
region_name=None,
|
||||
)
|
||||
if gemini_api_key is None:
|
||||
raise Exception(
|
||||
"Required 'GEMINI_API_KEY' in environment to make pass-through calls to Google AI Studio."
|
||||
)
|
||||
# Merge query parameters, giving precedence to those in updated_url
|
||||
merged_params = dict(request.query_params)
|
||||
merged_params.update({"key": gemini_api_key})
|
||||
|
||||
## check for streaming
|
||||
is_streaming_request = False
|
||||
if "stream" in str(updated_url):
|
||||
is_streaming_request = True
|
||||
|
||||
## CREATE PASS-THROUGH
|
||||
endpoint_func = create_pass_through_route(
|
||||
endpoint=endpoint,
|
||||
target=str(updated_url),
|
||||
) # dynamically construct pass-through endpoint based on incoming path
|
||||
received_value = await endpoint_func(
|
||||
request,
|
||||
fastapi_response,
|
||||
user_api_key_dict,
|
||||
query_params=merged_params, # type: ignore
|
||||
stream=is_streaming_request, # type: ignore
|
||||
)
|
||||
|
||||
return received_value
|
||||
|
||||
|
||||
@router.api_route(
|
||||
"/cohere/{endpoint:path}",
|
||||
methods=["GET", "POST", "PUT", "DELETE", "PATCH"],
|
||||
tags=["Cohere Pass-through", "pass-through"],
|
||||
)
|
||||
async def cohere_proxy_route(
|
||||
endpoint: str,
|
||||
request: Request,
|
||||
fastapi_response: Response,
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
):
|
||||
"""
|
||||
[Docs](https://docs.litellm.ai/docs/pass_through/cohere)
|
||||
"""
|
||||
base_target_url = "https://api.cohere.com"
|
||||
encoded_endpoint = httpx.URL(endpoint).path
|
||||
|
||||
# Ensure endpoint starts with '/' for proper URL construction
|
||||
if not encoded_endpoint.startswith("/"):
|
||||
encoded_endpoint = "/" + encoded_endpoint
|
||||
|
||||
# Construct the full target URL using httpx
|
||||
base_url = httpx.URL(base_target_url)
|
||||
updated_url = base_url.copy_with(path=encoded_endpoint)
|
||||
|
||||
# Add or update query parameters
|
||||
cohere_api_key = passthrough_endpoint_router.get_credentials(
|
||||
custom_llm_provider="cohere",
|
||||
region_name=None,
|
||||
)
|
||||
|
||||
## check for streaming
|
||||
is_streaming_request = False
|
||||
if "stream" in str(updated_url):
|
||||
is_streaming_request = True
|
||||
|
||||
## CREATE PASS-THROUGH
|
||||
endpoint_func = create_pass_through_route(
|
||||
endpoint=endpoint,
|
||||
target=str(updated_url),
|
||||
custom_headers={"Authorization": "Bearer {}".format(cohere_api_key)},
|
||||
) # dynamically construct pass-through endpoint based on incoming path
|
||||
received_value = await endpoint_func(
|
||||
request,
|
||||
fastapi_response,
|
||||
user_api_key_dict,
|
||||
stream=is_streaming_request, # type: ignore
|
||||
)
|
||||
|
||||
return received_value
|
||||
|
||||
|
||||
@router.api_route(
|
||||
"/vllm/{endpoint:path}",
|
||||
methods=["GET", "POST", "PUT", "DELETE", "PATCH"],
|
||||
tags=["VLLM Pass-through", "pass-through"],
|
||||
)
|
||||
async def vllm_proxy_route(
|
||||
endpoint: str,
|
||||
request: Request,
|
||||
fastapi_response: Response,
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
):
|
||||
"""
|
||||
[Docs](https://docs.litellm.ai/docs/pass_through/vllm)
|
||||
"""
|
||||
return await llm_passthrough_factory_proxy_route(
|
||||
endpoint=endpoint,
|
||||
request=request,
|
||||
fastapi_response=fastapi_response,
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
custom_llm_provider="vllm",
|
||||
)
|
||||
|
||||
|
||||
@router.api_route(
|
||||
"/mistral/{endpoint:path}",
|
||||
methods=["GET", "POST", "PUT", "DELETE", "PATCH"],
|
||||
tags=["Mistral Pass-through", "pass-through"],
|
||||
)
|
||||
async def mistral_proxy_route(
|
||||
endpoint: str,
|
||||
request: Request,
|
||||
fastapi_response: Response,
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
):
|
||||
"""
|
||||
[Docs](https://docs.litellm.ai/docs/anthropic_completion)
|
||||
"""
|
||||
base_target_url = os.getenv("MISTRAL_API_BASE") or "https://api.mistral.ai"
|
||||
encoded_endpoint = httpx.URL(endpoint).path
|
||||
|
||||
# Ensure endpoint starts with '/' for proper URL construction
|
||||
if not encoded_endpoint.startswith("/"):
|
||||
encoded_endpoint = "/" + encoded_endpoint
|
||||
|
||||
# Construct the full target URL using httpx
|
||||
base_url = httpx.URL(base_target_url)
|
||||
updated_url = base_url.copy_with(path=encoded_endpoint)
|
||||
|
||||
# Add or update query parameters
|
||||
mistral_api_key = passthrough_endpoint_router.get_credentials(
|
||||
custom_llm_provider="mistral",
|
||||
region_name=None,
|
||||
)
|
||||
|
||||
## check for streaming
|
||||
is_streaming_request = False
|
||||
# anthropic is streaming when 'stream' = True is in the body
|
||||
if request.method == "POST":
|
||||
_request_body = await request.json()
|
||||
if _request_body.get("stream"):
|
||||
is_streaming_request = True
|
||||
|
||||
## CREATE PASS-THROUGH
|
||||
endpoint_func = create_pass_through_route(
|
||||
endpoint=endpoint,
|
||||
target=str(updated_url),
|
||||
custom_headers={"Authorization": "Bearer {}".format(mistral_api_key)},
|
||||
) # dynamically construct pass-through endpoint based on incoming path
|
||||
received_value = await endpoint_func(
|
||||
request,
|
||||
fastapi_response,
|
||||
user_api_key_dict,
|
||||
stream=is_streaming_request, # type: ignore
|
||||
)
|
||||
|
||||
return received_value
|
||||
|
||||
|
||||
@router.api_route(
|
||||
"/anthropic/{endpoint:path}",
|
||||
methods=["GET", "POST", "PUT", "DELETE", "PATCH"],
|
||||
tags=["Anthropic Pass-through", "pass-through"],
|
||||
)
|
||||
async def anthropic_proxy_route(
|
||||
endpoint: str,
|
||||
request: Request,
|
||||
fastapi_response: Response,
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
):
|
||||
"""
|
||||
[Docs](https://docs.litellm.ai/docs/anthropic_completion)
|
||||
"""
|
||||
base_target_url = "https://api.anthropic.com"
|
||||
encoded_endpoint = httpx.URL(endpoint).path
|
||||
|
||||
# Ensure endpoint starts with '/' for proper URL construction
|
||||
if not encoded_endpoint.startswith("/"):
|
||||
encoded_endpoint = "/" + encoded_endpoint
|
||||
|
||||
# Construct the full target URL using httpx
|
||||
base_url = httpx.URL(base_target_url)
|
||||
updated_url = base_url.copy_with(path=encoded_endpoint)
|
||||
|
||||
# Add or update query parameters
|
||||
anthropic_api_key = passthrough_endpoint_router.get_credentials(
|
||||
custom_llm_provider="anthropic",
|
||||
region_name=None,
|
||||
)
|
||||
|
||||
## check for streaming
|
||||
is_streaming_request = False
|
||||
# anthropic is streaming when 'stream' = True is in the body
|
||||
if request.method == "POST":
|
||||
_request_body = await request.json()
|
||||
if _request_body.get("stream"):
|
||||
is_streaming_request = True
|
||||
|
||||
## CREATE PASS-THROUGH
|
||||
endpoint_func = create_pass_through_route(
|
||||
endpoint=endpoint,
|
||||
target=str(updated_url),
|
||||
custom_headers={"x-api-key": "{}".format(anthropic_api_key)},
|
||||
_forward_headers=True,
|
||||
) # dynamically construct pass-through endpoint based on incoming path
|
||||
received_value = await endpoint_func(
|
||||
request,
|
||||
fastapi_response,
|
||||
user_api_key_dict,
|
||||
stream=is_streaming_request, # type: ignore
|
||||
)
|
||||
|
||||
return received_value
|
||||
|
||||
|
||||
@router.api_route(
|
||||
"/bedrock/{endpoint:path}",
|
||||
methods=["GET", "POST", "PUT", "DELETE", "PATCH"],
|
||||
tags=["Bedrock Pass-through", "pass-through"],
|
||||
)
|
||||
async def bedrock_proxy_route(
|
||||
endpoint: str,
|
||||
request: Request,
|
||||
fastapi_response: Response,
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
):
|
||||
"""
|
||||
[Docs](https://docs.litellm.ai/docs/pass_through/bedrock)
|
||||
"""
|
||||
create_request_copy(request)
|
||||
|
||||
try:
|
||||
from botocore.auth import SigV4Auth
|
||||
from botocore.awsrequest import AWSRequest
|
||||
from botocore.credentials import Credentials
|
||||
except ImportError:
|
||||
raise ImportError("Missing boto3 to call bedrock. Run 'pip install boto3'.")
|
||||
|
||||
aws_region_name = litellm.utils.get_secret(secret_name="AWS_REGION_NAME")
|
||||
if _is_bedrock_agent_runtime_route(endpoint=endpoint): # handle bedrock agents
|
||||
base_target_url = (
|
||||
f"https://bedrock-agent-runtime.{aws_region_name}.amazonaws.com"
|
||||
)
|
||||
else:
|
||||
base_target_url = f"https://bedrock-runtime.{aws_region_name}.amazonaws.com"
|
||||
encoded_endpoint = httpx.URL(endpoint).path
|
||||
|
||||
# Ensure endpoint starts with '/' for proper URL construction
|
||||
if not encoded_endpoint.startswith("/"):
|
||||
encoded_endpoint = "/" + encoded_endpoint
|
||||
|
||||
# Construct the full target URL using httpx
|
||||
base_url = httpx.URL(base_target_url)
|
||||
updated_url = base_url.copy_with(path=encoded_endpoint)
|
||||
|
||||
# Add or update query parameters
|
||||
from litellm.llms.bedrock.chat import BedrockConverseLLM
|
||||
|
||||
credentials: Credentials = BedrockConverseLLM().get_credentials()
|
||||
sigv4 = SigV4Auth(credentials, "bedrock", aws_region_name)
|
||||
headers = {"Content-Type": "application/json"}
|
||||
# Assuming the body contains JSON data, parse it
|
||||
try:
|
||||
data = await request.json()
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=400, detail={"error": e})
|
||||
_request = AWSRequest(
|
||||
method="POST", url=str(updated_url), data=json.dumps(data), headers=headers
|
||||
)
|
||||
sigv4.add_auth(_request)
|
||||
prepped = _request.prepare()
|
||||
|
||||
## check for streaming
|
||||
is_streaming_request = False
|
||||
if "stream" in str(updated_url):
|
||||
is_streaming_request = True
|
||||
|
||||
## CREATE PASS-THROUGH
|
||||
endpoint_func = create_pass_through_route(
|
||||
endpoint=endpoint,
|
||||
target=str(prepped.url),
|
||||
custom_headers=prepped.headers, # type: ignore
|
||||
) # dynamically construct pass-through endpoint based on incoming path
|
||||
received_value = await endpoint_func(
|
||||
request,
|
||||
fastapi_response,
|
||||
user_api_key_dict,
|
||||
stream=is_streaming_request, # type: ignore
|
||||
custom_body=data, # type: ignore
|
||||
query_params={}, # type: ignore
|
||||
)
|
||||
|
||||
return received_value
|
||||
|
||||
|
||||
def _is_bedrock_agent_runtime_route(endpoint: str) -> bool:
|
||||
"""
|
||||
Return True, if the endpoint should be routed to the `bedrock-agent-runtime` endpoint.
|
||||
"""
|
||||
for _route in BEDROCK_AGENT_RUNTIME_PASS_THROUGH_ROUTES:
|
||||
if _route in endpoint:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
@router.api_route(
|
||||
"/assemblyai/{endpoint:path}",
|
||||
methods=["GET", "POST", "PUT", "DELETE", "PATCH"],
|
||||
tags=["AssemblyAI Pass-through", "pass-through"],
|
||||
)
|
||||
@router.api_route(
|
||||
"/eu.assemblyai/{endpoint:path}",
|
||||
methods=["GET", "POST", "PUT", "DELETE", "PATCH"],
|
||||
tags=["AssemblyAI EU Pass-through", "pass-through"],
|
||||
)
|
||||
async def assemblyai_proxy_route(
|
||||
endpoint: str,
|
||||
request: Request,
|
||||
fastapi_response: Response,
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
):
|
||||
from litellm.proxy.pass_through_endpoints.llm_provider_handlers.assembly_passthrough_logging_handler import (
|
||||
AssemblyAIPassthroughLoggingHandler,
|
||||
)
|
||||
|
||||
"""
|
||||
[Docs](https://api.assemblyai.com)
|
||||
"""
|
||||
# Set base URL based on the route
|
||||
assembly_region = AssemblyAIPassthroughLoggingHandler._get_assembly_region_from_url(
|
||||
url=str(request.url)
|
||||
)
|
||||
base_target_url = (
|
||||
AssemblyAIPassthroughLoggingHandler._get_assembly_base_url_from_region(
|
||||
region=assembly_region
|
||||
)
|
||||
)
|
||||
encoded_endpoint = httpx.URL(endpoint).path
|
||||
# Ensure endpoint starts with '/' for proper URL construction
|
||||
if not encoded_endpoint.startswith("/"):
|
||||
encoded_endpoint = "/" + encoded_endpoint
|
||||
|
||||
# Construct the full target URL using httpx
|
||||
base_url = httpx.URL(base_target_url)
|
||||
updated_url = base_url.copy_with(path=encoded_endpoint)
|
||||
|
||||
# Add or update query parameters
|
||||
assemblyai_api_key = passthrough_endpoint_router.get_credentials(
|
||||
custom_llm_provider="assemblyai",
|
||||
region_name=assembly_region,
|
||||
)
|
||||
|
||||
## check for streaming
|
||||
is_streaming_request = False
|
||||
# assemblyai is streaming when 'stream' = True is in the body
|
||||
if request.method == "POST":
|
||||
_request_body = await request.json()
|
||||
if _request_body.get("stream"):
|
||||
is_streaming_request = True
|
||||
|
||||
## CREATE PASS-THROUGH
|
||||
endpoint_func = create_pass_through_route(
|
||||
endpoint=endpoint,
|
||||
target=str(updated_url),
|
||||
custom_headers={"Authorization": "{}".format(assemblyai_api_key)},
|
||||
) # dynamically construct pass-through endpoint based on incoming path
|
||||
received_value = await endpoint_func(
|
||||
request=request,
|
||||
fastapi_response=fastapi_response,
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
stream=is_streaming_request, # type: ignore
|
||||
)
|
||||
|
||||
return received_value
|
||||
|
||||
|
||||
@router.api_route(
|
||||
"/azure/{endpoint:path}",
|
||||
methods=["GET", "POST", "PUT", "DELETE", "PATCH"],
|
||||
tags=["Azure Pass-through", "pass-through"],
|
||||
)
|
||||
async def azure_proxy_route(
|
||||
endpoint: str,
|
||||
request: Request,
|
||||
fastapi_response: Response,
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
):
|
||||
"""
|
||||
Call any azure endpoint using the proxy.
|
||||
|
||||
Just use `{PROXY_BASE_URL}/azure/{endpoint:path}`
|
||||
"""
|
||||
base_target_url = get_secret_str(secret_name="AZURE_API_BASE")
|
||||
if base_target_url is None:
|
||||
raise Exception(
|
||||
"Required 'AZURE_API_BASE' in environment to make pass-through calls to Azure."
|
||||
)
|
||||
# Add or update query parameters
|
||||
azure_api_key = passthrough_endpoint_router.get_credentials(
|
||||
custom_llm_provider=litellm.LlmProviders.AZURE.value,
|
||||
region_name=None,
|
||||
)
|
||||
if azure_api_key is None:
|
||||
raise Exception(
|
||||
"Required 'AZURE_API_KEY' in environment to make pass-through calls to Azure."
|
||||
)
|
||||
|
||||
return await BaseOpenAIPassThroughHandler._base_openai_pass_through_handler(
|
||||
endpoint=endpoint,
|
||||
request=request,
|
||||
fastapi_response=fastapi_response,
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
base_target_url=base_target_url,
|
||||
api_key=azure_api_key,
|
||||
custom_llm_provider=litellm.LlmProviders.AZURE,
|
||||
)
|
||||
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
|
||||
class BaseVertexAIPassThroughHandler(ABC):
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
def get_default_base_target_url(vertex_location: Optional[str]) -> str:
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
def update_base_target_url_with_credential_location(
|
||||
base_target_url: str, vertex_location: Optional[str]
|
||||
) -> str:
|
||||
pass
|
||||
|
||||
|
||||
class VertexAIDiscoveryPassThroughHandler(BaseVertexAIPassThroughHandler):
|
||||
@staticmethod
|
||||
def get_default_base_target_url(vertex_location: Optional[str]) -> str:
|
||||
return "https://discoveryengine.googleapis.com/"
|
||||
|
||||
@staticmethod
|
||||
def update_base_target_url_with_credential_location(
|
||||
base_target_url: str, vertex_location: Optional[str]
|
||||
) -> str:
|
||||
return base_target_url
|
||||
|
||||
|
||||
class VertexAIPassThroughHandler(BaseVertexAIPassThroughHandler):
|
||||
@staticmethod
|
||||
def get_default_base_target_url(vertex_location: Optional[str]) -> str:
|
||||
return f"https://{vertex_location}-aiplatform.googleapis.com/"
|
||||
|
||||
@staticmethod
|
||||
def update_base_target_url_with_credential_location(
|
||||
base_target_url: str, vertex_location: Optional[str]
|
||||
) -> str:
|
||||
return f"https://{vertex_location}-aiplatform.googleapis.com/"
|
||||
|
||||
|
||||
def get_vertex_pass_through_handler(
|
||||
call_type: Literal["discovery", "aiplatform"]
|
||||
) -> BaseVertexAIPassThroughHandler:
|
||||
if call_type == "discovery":
|
||||
return VertexAIDiscoveryPassThroughHandler()
|
||||
elif call_type == "aiplatform":
|
||||
return VertexAIPassThroughHandler()
|
||||
else:
|
||||
raise ValueError(f"Invalid call type: {call_type}")
|
||||
|
||||
|
||||
async def _base_vertex_proxy_route(
|
||||
endpoint: str,
|
||||
request: Request,
|
||||
fastapi_response: Response,
|
||||
get_vertex_pass_through_handler: BaseVertexAIPassThroughHandler,
|
||||
user_api_key_dict: Optional[UserAPIKeyAuth] = None,
|
||||
):
|
||||
"""
|
||||
Base function for Vertex AI passthrough routes.
|
||||
Handles common logic for all Vertex AI services.
|
||||
|
||||
Default base_target_url is `https://{vertex_location}-aiplatform.googleapis.com/`
|
||||
"""
|
||||
from litellm.llms.vertex_ai.common_utils import (
|
||||
construct_target_url,
|
||||
get_vertex_location_from_url,
|
||||
get_vertex_project_id_from_url,
|
||||
)
|
||||
|
||||
encoded_endpoint = httpx.URL(endpoint).path
|
||||
verbose_proxy_logger.debug("requested endpoint %s", endpoint)
|
||||
headers: dict = {}
|
||||
api_key_to_use = get_litellm_virtual_key(request=request)
|
||||
user_api_key_dict = await user_api_key_auth(
|
||||
request=request,
|
||||
api_key=api_key_to_use,
|
||||
)
|
||||
|
||||
if user_api_key_dict is None:
|
||||
api_key_to_use = get_litellm_virtual_key(request=request)
|
||||
user_api_key_dict = await user_api_key_auth(
|
||||
request=request,
|
||||
api_key=api_key_to_use,
|
||||
)
|
||||
|
||||
vertex_project: Optional[str] = get_vertex_project_id_from_url(endpoint)
|
||||
vertex_location: Optional[str] = get_vertex_location_from_url(endpoint)
|
||||
vertex_credentials = passthrough_endpoint_router.get_vertex_credentials(
|
||||
project_id=vertex_project,
|
||||
location=vertex_location,
|
||||
)
|
||||
|
||||
base_target_url = get_vertex_pass_through_handler.get_default_base_target_url(
|
||||
vertex_location
|
||||
)
|
||||
|
||||
headers_passed_through = False
|
||||
# Use headers from the incoming request if no vertex credentials are found
|
||||
if vertex_credentials is None or vertex_credentials.vertex_project is None:
|
||||
headers = dict(request.headers) or {}
|
||||
headers_passed_through = True
|
||||
verbose_proxy_logger.debug(
|
||||
"default_vertex_config not set, incoming request headers %s", headers
|
||||
)
|
||||
headers.pop("content-length", None)
|
||||
headers.pop("host", None)
|
||||
else:
|
||||
vertex_project = vertex_credentials.vertex_project
|
||||
vertex_location = vertex_credentials.vertex_location
|
||||
vertex_credentials_str = vertex_credentials.vertex_credentials
|
||||
|
||||
_auth_header, vertex_project = await vertex_llm_base._ensure_access_token_async(
|
||||
credentials=vertex_credentials_str,
|
||||
project_id=vertex_project,
|
||||
custom_llm_provider="vertex_ai_beta",
|
||||
)
|
||||
|
||||
auth_header, _ = vertex_llm_base._get_token_and_url(
|
||||
model="",
|
||||
auth_header=_auth_header,
|
||||
gemini_api_key=None,
|
||||
vertex_credentials=vertex_credentials_str,
|
||||
vertex_project=vertex_project,
|
||||
vertex_location=vertex_location,
|
||||
stream=False,
|
||||
custom_llm_provider="vertex_ai_beta",
|
||||
api_base="",
|
||||
)
|
||||
|
||||
headers = {
|
||||
"Authorization": f"Bearer {auth_header}",
|
||||
}
|
||||
|
||||
base_target_url = get_vertex_pass_through_handler.update_base_target_url_with_credential_location(
|
||||
base_target_url, vertex_location
|
||||
)
|
||||
|
||||
if base_target_url is None:
|
||||
base_target_url = f"https://{vertex_location}-aiplatform.googleapis.com/"
|
||||
|
||||
request_route = encoded_endpoint
|
||||
verbose_proxy_logger.debug("request_route %s", request_route)
|
||||
|
||||
# Ensure endpoint starts with '/' for proper URL construction
|
||||
if not encoded_endpoint.startswith("/"):
|
||||
encoded_endpoint = "/" + encoded_endpoint
|
||||
|
||||
# Construct the full target URL using httpx
|
||||
updated_url = construct_target_url(
|
||||
base_url=base_target_url,
|
||||
requested_route=encoded_endpoint,
|
||||
vertex_location=vertex_location,
|
||||
vertex_project=vertex_project,
|
||||
)
|
||||
|
||||
verbose_proxy_logger.debug("updated url %s", updated_url)
|
||||
|
||||
## check for streaming
|
||||
target = str(updated_url)
|
||||
is_streaming_request = False
|
||||
if "stream" in str(updated_url):
|
||||
is_streaming_request = True
|
||||
target += "?alt=sse"
|
||||
|
||||
## CREATE PASS-THROUGH
|
||||
endpoint_func = create_pass_through_route(
|
||||
endpoint=endpoint,
|
||||
target=target,
|
||||
custom_headers=headers,
|
||||
) # dynamically construct pass-through endpoint based on incoming path
|
||||
|
||||
try:
|
||||
received_value = await endpoint_func(
|
||||
request,
|
||||
fastapi_response,
|
||||
user_api_key_dict,
|
||||
stream=is_streaming_request, # type: ignore
|
||||
)
|
||||
except ProxyException as e:
|
||||
if headers_passed_through:
|
||||
e.message = f"No credentials found on proxy for project_name={vertex_project} + location={vertex_location}, check `/model/info` for allowed project + region combinations with `use_in_pass_through: true`. Headers were passed through directly but request failed with error: {e.message}"
|
||||
raise e
|
||||
|
||||
return received_value
|
||||
|
||||
|
||||
@router.api_route(
|
||||
"/vertex_ai/discovery/{endpoint:path}",
|
||||
methods=["GET", "POST", "PUT", "DELETE", "PATCH"],
|
||||
tags=["Vertex AI Pass-through", "pass-through"],
|
||||
)
|
||||
async def vertex_discovery_proxy_route(
|
||||
endpoint: str,
|
||||
request: Request,
|
||||
fastapi_response: Response,
|
||||
):
|
||||
"""
|
||||
Call any vertex discovery endpoint using the proxy.
|
||||
|
||||
Just use `{PROXY_BASE_URL}/vertex_ai/discovery/{endpoint:path}`
|
||||
|
||||
Target url: `https://discoveryengine.googleapis.com`
|
||||
"""
|
||||
|
||||
discovery_handler = get_vertex_pass_through_handler(call_type="discovery")
|
||||
return await _base_vertex_proxy_route(
|
||||
endpoint=endpoint,
|
||||
request=request,
|
||||
fastapi_response=fastapi_response,
|
||||
get_vertex_pass_through_handler=discovery_handler,
|
||||
)
|
||||
|
||||
|
||||
@router.api_route(
|
||||
"/vertex-ai/{endpoint:path}",
|
||||
methods=["GET", "POST", "PUT", "DELETE", "PATCH"],
|
||||
tags=["Vertex AI Pass-through", "pass-through"],
|
||||
include_in_schema=False,
|
||||
)
|
||||
@router.api_route(
|
||||
"/vertex_ai/{endpoint:path}",
|
||||
methods=["GET", "POST", "PUT", "DELETE", "PATCH"],
|
||||
tags=["Vertex AI Pass-through", "pass-through"],
|
||||
)
|
||||
async def vertex_proxy_route(
|
||||
endpoint: str,
|
||||
request: Request,
|
||||
fastapi_response: Response,
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
):
|
||||
"""
|
||||
Call LiteLLM proxy via Vertex AI SDK.
|
||||
|
||||
[Docs](https://docs.litellm.ai/docs/pass_through/vertex_ai)
|
||||
"""
|
||||
ai_platform_handler = get_vertex_pass_through_handler(call_type="aiplatform")
|
||||
|
||||
return await _base_vertex_proxy_route(
|
||||
endpoint=endpoint,
|
||||
request=request,
|
||||
fastapi_response=fastapi_response,
|
||||
get_vertex_pass_through_handler=ai_platform_handler,
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
)
|
||||
|
||||
|
||||
@router.api_route(
|
||||
"/openai/{endpoint:path}",
|
||||
methods=["GET", "POST", "PUT", "DELETE", "PATCH"],
|
||||
tags=["OpenAI Pass-through", "pass-through"],
|
||||
)
|
||||
async def openai_proxy_route(
|
||||
endpoint: str,
|
||||
request: Request,
|
||||
fastapi_response: Response,
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
):
|
||||
"""
|
||||
Simple pass-through for OpenAI. Use this if you want to directly send a request to OpenAI.
|
||||
|
||||
|
||||
"""
|
||||
base_target_url = "https://api.openai.com/"
|
||||
# Add or update query parameters
|
||||
openai_api_key = passthrough_endpoint_router.get_credentials(
|
||||
custom_llm_provider=litellm.LlmProviders.OPENAI.value,
|
||||
region_name=None,
|
||||
)
|
||||
if openai_api_key is None:
|
||||
raise Exception(
|
||||
"Required 'OPENAI_API_KEY' in environment to make pass-through calls to OpenAI."
|
||||
)
|
||||
|
||||
return await BaseOpenAIPassThroughHandler._base_openai_pass_through_handler(
|
||||
endpoint=endpoint,
|
||||
request=request,
|
||||
fastapi_response=fastapi_response,
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
base_target_url=base_target_url,
|
||||
api_key=openai_api_key,
|
||||
custom_llm_provider=litellm.LlmProviders.OPENAI,
|
||||
)
|
||||
|
||||
|
||||
class BaseOpenAIPassThroughHandler:
|
||||
@staticmethod
|
||||
async def _base_openai_pass_through_handler(
|
||||
endpoint: str,
|
||||
request: Request,
|
||||
fastapi_response: Response,
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
base_target_url: str,
|
||||
api_key: str,
|
||||
custom_llm_provider: litellm.LlmProviders,
|
||||
):
|
||||
encoded_endpoint = httpx.URL(endpoint).path
|
||||
# Ensure endpoint starts with '/' for proper URL construction
|
||||
if not encoded_endpoint.startswith("/"):
|
||||
encoded_endpoint = "/" + encoded_endpoint
|
||||
|
||||
# Construct the full target URL by properly joining the base URL and endpoint path
|
||||
base_url = httpx.URL(base_target_url)
|
||||
updated_url = BaseOpenAIPassThroughHandler._join_url_paths(
|
||||
base_url=base_url,
|
||||
path=encoded_endpoint,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
)
|
||||
|
||||
## check for streaming
|
||||
is_streaming_request = False
|
||||
if "stream" in str(updated_url):
|
||||
is_streaming_request = True
|
||||
|
||||
## CREATE PASS-THROUGH
|
||||
endpoint_func = create_pass_through_route(
|
||||
endpoint=endpoint,
|
||||
target=str(updated_url),
|
||||
custom_headers=BaseOpenAIPassThroughHandler._assemble_headers(
|
||||
api_key=api_key, request=request
|
||||
),
|
||||
) # dynamically construct pass-through endpoint based on incoming path
|
||||
received_value = await endpoint_func(
|
||||
request,
|
||||
fastapi_response,
|
||||
user_api_key_dict,
|
||||
stream=is_streaming_request, # type: ignore
|
||||
query_params=dict(request.query_params), # type: ignore
|
||||
)
|
||||
|
||||
return received_value
|
||||
|
||||
@staticmethod
|
||||
def _append_openai_beta_header(headers: dict, request: Request) -> dict:
|
||||
"""
|
||||
Appends the OpenAI-Beta header to the headers if the request is an OpenAI Assistants API request
|
||||
"""
|
||||
if (
|
||||
RouteChecks._is_assistants_api_request(request) is True
|
||||
and "OpenAI-Beta" not in headers
|
||||
):
|
||||
headers["OpenAI-Beta"] = "assistants=v2"
|
||||
return headers
|
||||
|
||||
@staticmethod
|
||||
def _assemble_headers(api_key: str, request: Request) -> dict:
|
||||
base_headers = {
|
||||
"authorization": "Bearer {}".format(api_key),
|
||||
"api-key": "{}".format(api_key),
|
||||
}
|
||||
return BaseOpenAIPassThroughHandler._append_openai_beta_header(
|
||||
headers=base_headers,
|
||||
request=request,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _join_url_paths(
|
||||
base_url: httpx.URL, path: str, custom_llm_provider: litellm.LlmProviders
|
||||
) -> str:
|
||||
"""
|
||||
Properly joins a base URL with a path, preserving any existing path in the base URL.
|
||||
"""
|
||||
# Join paths correctly by removing trailing/leading slashes as needed
|
||||
if not base_url.path or base_url.path == "/":
|
||||
# If base URL has no path, just use the new path
|
||||
joined_path_str = str(base_url.copy_with(path=path))
|
||||
else:
|
||||
# Otherwise, combine the paths
|
||||
base_path = base_url.path.rstrip("/")
|
||||
clean_path = path.lstrip("/")
|
||||
full_path = f"{base_path}/{clean_path}"
|
||||
joined_path_str = str(base_url.copy_with(path=full_path))
|
||||
|
||||
# Apply OpenAI-specific path handling for both branches
|
||||
if (
|
||||
custom_llm_provider == litellm.LlmProviders.OPENAI
|
||||
and "/v1/" not in joined_path_str
|
||||
):
|
||||
# Insert v1 after api.openai.com for OpenAI requests
|
||||
joined_path_str = joined_path_str.replace(
|
||||
"api.openai.com/", "api.openai.com/v1/"
|
||||
)
|
||||
|
||||
return joined_path_str
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,221 @@
|
||||
import json
|
||||
from datetime import datetime
|
||||
from typing import TYPE_CHECKING, Any, List, Optional, Union
|
||||
|
||||
import httpx
|
||||
|
||||
import litellm
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
|
||||
from litellm.llms.anthropic.chat.handler import (
|
||||
ModelResponseIterator as AnthropicModelResponseIterator,
|
||||
)
|
||||
from litellm.llms.anthropic.chat.transformation import AnthropicConfig
|
||||
from litellm.proxy._types import PassThroughEndpointLoggingTypedDict
|
||||
from litellm.proxy.auth.auth_utils import get_end_user_id_from_request_body
|
||||
from litellm.types.passthrough_endpoints.pass_through_endpoints import (
|
||||
PassthroughStandardLoggingPayload,
|
||||
)
|
||||
from litellm.types.utils import ModelResponse, TextCompletionResponse
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..success_handler import PassThroughEndpointLogging
|
||||
from ..types import EndpointType
|
||||
else:
|
||||
PassThroughEndpointLogging = Any
|
||||
EndpointType = Any
|
||||
|
||||
|
||||
class AnthropicPassthroughLoggingHandler:
|
||||
@staticmethod
|
||||
def anthropic_passthrough_handler(
|
||||
httpx_response: httpx.Response,
|
||||
response_body: dict,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
url_route: str,
|
||||
result: str,
|
||||
start_time: datetime,
|
||||
end_time: datetime,
|
||||
cache_hit: bool,
|
||||
**kwargs,
|
||||
) -> PassThroughEndpointLoggingTypedDict:
|
||||
"""
|
||||
Transforms Anthropic response to OpenAI response, generates a standard logging object so downstream logging can be handled
|
||||
"""
|
||||
model = response_body.get("model", "")
|
||||
litellm_model_response: ModelResponse = AnthropicConfig().transform_response(
|
||||
raw_response=httpx_response,
|
||||
model_response=litellm.ModelResponse(),
|
||||
model=model,
|
||||
messages=[],
|
||||
logging_obj=logging_obj,
|
||||
optional_params={},
|
||||
api_key="",
|
||||
request_data={},
|
||||
encoding=litellm.encoding,
|
||||
json_mode=False,
|
||||
litellm_params={},
|
||||
)
|
||||
|
||||
kwargs = AnthropicPassthroughLoggingHandler._create_anthropic_response_logging_payload(
|
||||
litellm_model_response=litellm_model_response,
|
||||
model=model,
|
||||
kwargs=kwargs,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
logging_obj=logging_obj,
|
||||
)
|
||||
|
||||
return {
|
||||
"result": litellm_model_response,
|
||||
"kwargs": kwargs,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _get_user_from_metadata(
|
||||
passthrough_logging_payload: PassthroughStandardLoggingPayload,
|
||||
) -> Optional[str]:
|
||||
request_body = passthrough_logging_payload.get("request_body")
|
||||
if request_body:
|
||||
return get_end_user_id_from_request_body(request_body)
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _create_anthropic_response_logging_payload(
|
||||
litellm_model_response: Union[ModelResponse, TextCompletionResponse],
|
||||
model: str,
|
||||
kwargs: dict,
|
||||
start_time: datetime,
|
||||
end_time: datetime,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
):
|
||||
"""
|
||||
Create the standard logging object for Anthropic passthrough
|
||||
|
||||
handles streaming and non-streaming responses
|
||||
"""
|
||||
try:
|
||||
response_cost = litellm.completion_cost(
|
||||
completion_response=litellm_model_response,
|
||||
model=model,
|
||||
)
|
||||
kwargs["response_cost"] = response_cost
|
||||
kwargs["model"] = model
|
||||
passthrough_logging_payload: Optional[PassthroughStandardLoggingPayload] = ( # type: ignore
|
||||
kwargs.get("passthrough_logging_payload")
|
||||
)
|
||||
if passthrough_logging_payload:
|
||||
user = AnthropicPassthroughLoggingHandler._get_user_from_metadata(
|
||||
passthrough_logging_payload=passthrough_logging_payload,
|
||||
)
|
||||
if user:
|
||||
kwargs.setdefault("litellm_params", {})
|
||||
kwargs["litellm_params"].update(
|
||||
{"proxy_server_request": {"body": {"user": user}}}
|
||||
)
|
||||
|
||||
# pretty print standard logging object
|
||||
verbose_proxy_logger.debug(
|
||||
"kwargs= %s",
|
||||
json.dumps(kwargs, indent=4, default=str),
|
||||
)
|
||||
|
||||
# set litellm_call_id to logging response object
|
||||
litellm_model_response.id = logging_obj.litellm_call_id
|
||||
litellm_model_response.model = model
|
||||
logging_obj.model_call_details["model"] = model
|
||||
logging_obj.model_call_details["custom_llm_provider"] = (
|
||||
litellm.LlmProviders.ANTHROPIC.value
|
||||
)
|
||||
return kwargs
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.exception(
|
||||
"Error creating Anthropic response logging payload: %s", e
|
||||
)
|
||||
return kwargs
|
||||
|
||||
@staticmethod
|
||||
def _handle_logging_anthropic_collected_chunks(
|
||||
litellm_logging_obj: LiteLLMLoggingObj,
|
||||
passthrough_success_handler_obj: PassThroughEndpointLogging,
|
||||
url_route: str,
|
||||
request_body: dict,
|
||||
endpoint_type: EndpointType,
|
||||
start_time: datetime,
|
||||
all_chunks: List[str],
|
||||
end_time: datetime,
|
||||
) -> PassThroughEndpointLoggingTypedDict:
|
||||
"""
|
||||
Takes raw chunks from Anthropic passthrough endpoint and logs them in litellm callbacks
|
||||
|
||||
- Builds complete response from chunks
|
||||
- Creates standard logging object
|
||||
- Logs in litellm callbacks
|
||||
"""
|
||||
|
||||
model = request_body.get("model", "")
|
||||
complete_streaming_response = (
|
||||
AnthropicPassthroughLoggingHandler._build_complete_streaming_response(
|
||||
all_chunks=all_chunks,
|
||||
litellm_logging_obj=litellm_logging_obj,
|
||||
model=model,
|
||||
)
|
||||
)
|
||||
if complete_streaming_response is None:
|
||||
verbose_proxy_logger.error(
|
||||
"Unable to build complete streaming response for Anthropic passthrough endpoint, not logging..."
|
||||
)
|
||||
return {
|
||||
"result": None,
|
||||
"kwargs": {},
|
||||
}
|
||||
kwargs = AnthropicPassthroughLoggingHandler._create_anthropic_response_logging_payload(
|
||||
litellm_model_response=complete_streaming_response,
|
||||
model=model,
|
||||
kwargs={},
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
logging_obj=litellm_logging_obj,
|
||||
)
|
||||
|
||||
return {
|
||||
"result": complete_streaming_response,
|
||||
"kwargs": kwargs,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _build_complete_streaming_response(
|
||||
all_chunks: List[str],
|
||||
litellm_logging_obj: LiteLLMLoggingObj,
|
||||
model: str,
|
||||
) -> Optional[Union[ModelResponse, TextCompletionResponse]]:
|
||||
"""
|
||||
Builds complete response from raw Anthropic chunks
|
||||
|
||||
- Converts str chunks to generic chunks
|
||||
- Converts generic chunks to litellm chunks (OpenAI format)
|
||||
- Builds complete response from litellm chunks
|
||||
"""
|
||||
anthropic_model_response_iterator = AnthropicModelResponseIterator(
|
||||
streaming_response=None,
|
||||
sync_stream=False,
|
||||
)
|
||||
all_openai_chunks = []
|
||||
for _chunk_str in all_chunks:
|
||||
try:
|
||||
transformed_openai_chunk = anthropic_model_response_iterator.convert_str_chunk_to_generic_chunk(
|
||||
chunk=_chunk_str
|
||||
)
|
||||
if transformed_openai_chunk is not None:
|
||||
all_openai_chunks.append(transformed_openai_chunk)
|
||||
|
||||
verbose_proxy_logger.debug(
|
||||
"all openai chunks= %s",
|
||||
json.dumps(all_openai_chunks, indent=4, default=str),
|
||||
)
|
||||
except (StopIteration, StopAsyncIteration):
|
||||
break
|
||||
complete_streaming_response = litellm.stream_chunk_builder(
|
||||
chunks=all_openai_chunks
|
||||
)
|
||||
return complete_streaming_response
|
||||
@@ -0,0 +1,332 @@
|
||||
import asyncio
|
||||
import json
|
||||
import time
|
||||
from datetime import datetime
|
||||
from typing import Literal, Optional, TypedDict
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import httpx
|
||||
|
||||
import litellm
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
|
||||
from litellm.litellm_core_utils.litellm_logging import (
|
||||
get_standard_logging_object_payload,
|
||||
)
|
||||
from litellm.litellm_core_utils.thread_pool_executor import executor
|
||||
from litellm.types.passthrough_endpoints.assembly_ai import (
|
||||
ASSEMBLY_AI_MAX_POLLING_ATTEMPTS,
|
||||
ASSEMBLY_AI_POLLING_INTERVAL,
|
||||
)
|
||||
from litellm.types.passthrough_endpoints.pass_through_endpoints import (
|
||||
PassthroughStandardLoggingPayload,
|
||||
)
|
||||
|
||||
|
||||
class AssemblyAITranscriptResponse(TypedDict, total=False):
|
||||
id: str
|
||||
speech_model: str
|
||||
acoustic_model: str
|
||||
language_code: str
|
||||
status: str
|
||||
audio_duration: float
|
||||
|
||||
|
||||
class AssemblyAIPassthroughLoggingHandler:
|
||||
def __init__(self):
|
||||
self.assembly_ai_base_url = "https://api.assemblyai.com"
|
||||
self.assembly_ai_eu_base_url = "https://eu.assemblyai.com"
|
||||
"""
|
||||
The base URL for the AssemblyAI API
|
||||
"""
|
||||
|
||||
self.polling_interval: float = ASSEMBLY_AI_POLLING_INTERVAL
|
||||
"""
|
||||
The polling interval for the AssemblyAI API.
|
||||
litellm needs to poll the GET /transcript/{transcript_id} endpoint to get the status of the transcript.
|
||||
"""
|
||||
|
||||
self.max_polling_attempts = ASSEMBLY_AI_MAX_POLLING_ATTEMPTS
|
||||
"""
|
||||
The maximum number of polling attempts for the AssemblyAI API.
|
||||
"""
|
||||
|
||||
def assemblyai_passthrough_logging_handler(
|
||||
self,
|
||||
httpx_response: httpx.Response,
|
||||
response_body: dict,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
url_route: str,
|
||||
result: str,
|
||||
start_time: datetime,
|
||||
end_time: datetime,
|
||||
cache_hit: bool,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Since cost tracking requires polling the AssemblyAI API, we need to handle this in a separate thread. Hence the executor.submit.
|
||||
"""
|
||||
executor.submit(
|
||||
self._handle_assemblyai_passthrough_logging,
|
||||
httpx_response,
|
||||
response_body,
|
||||
logging_obj,
|
||||
url_route,
|
||||
result,
|
||||
start_time,
|
||||
end_time,
|
||||
cache_hit,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def _handle_assemblyai_passthrough_logging(
|
||||
self,
|
||||
httpx_response: httpx.Response,
|
||||
response_body: dict,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
url_route: str,
|
||||
result: str,
|
||||
start_time: datetime,
|
||||
end_time: datetime,
|
||||
cache_hit: bool,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Handles logging for AssemblyAI successful passthrough requests
|
||||
"""
|
||||
from ..pass_through_endpoints import pass_through_endpoint_logging
|
||||
|
||||
model = response_body.get("speech_model", "")
|
||||
verbose_proxy_logger.debug(
|
||||
"response body %s", json.dumps(response_body, indent=4)
|
||||
)
|
||||
kwargs["model"] = model
|
||||
kwargs["custom_llm_provider"] = "assemblyai"
|
||||
response_cost: Optional[float] = None
|
||||
|
||||
transcript_id = response_body.get("id")
|
||||
if transcript_id is None:
|
||||
raise ValueError(
|
||||
"Transcript ID is required to log the cost of the transcription"
|
||||
)
|
||||
transcript_response = self._poll_assembly_for_transcript_response(
|
||||
transcript_id=transcript_id, url_route=url_route
|
||||
)
|
||||
verbose_proxy_logger.debug(
|
||||
"finished polling assembly for transcript response- got transcript response %s",
|
||||
json.dumps(transcript_response, indent=4),
|
||||
)
|
||||
if transcript_response:
|
||||
cost = self.get_cost_for_assembly_transcript(
|
||||
speech_model=model,
|
||||
transcript_response=transcript_response,
|
||||
)
|
||||
response_cost = cost
|
||||
|
||||
# Make standard logging object for Vertex AI
|
||||
standard_logging_object = get_standard_logging_object_payload(
|
||||
kwargs=kwargs,
|
||||
init_response_obj=transcript_response,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
logging_obj=logging_obj,
|
||||
status="success",
|
||||
)
|
||||
|
||||
passthrough_logging_payload: Optional[PassthroughStandardLoggingPayload] = ( # type: ignore
|
||||
kwargs.get("passthrough_logging_payload")
|
||||
)
|
||||
|
||||
verbose_proxy_logger.debug(
|
||||
"standard_passthrough_logging_object %s",
|
||||
json.dumps(passthrough_logging_payload, indent=4),
|
||||
)
|
||||
|
||||
# pretty print standard logging object
|
||||
verbose_proxy_logger.debug(
|
||||
"standard_logging_object= %s", json.dumps(standard_logging_object, indent=4)
|
||||
)
|
||||
logging_obj.model_call_details["model"] = model
|
||||
logging_obj.model_call_details["custom_llm_provider"] = "assemblyai"
|
||||
logging_obj.model_call_details["response_cost"] = response_cost
|
||||
|
||||
asyncio.run(
|
||||
pass_through_endpoint_logging._handle_logging(
|
||||
logging_obj=logging_obj,
|
||||
standard_logging_response_object=self._get_response_to_log(
|
||||
transcript_response
|
||||
),
|
||||
result=result,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
cache_hit=cache_hit,
|
||||
**kwargs,
|
||||
)
|
||||
)
|
||||
|
||||
pass
|
||||
|
||||
def _get_response_to_log(
|
||||
self, transcript_response: Optional[AssemblyAITranscriptResponse]
|
||||
) -> dict:
|
||||
if transcript_response is None:
|
||||
return {}
|
||||
return dict(transcript_response)
|
||||
|
||||
def _get_assembly_transcript(
|
||||
self,
|
||||
transcript_id: str,
|
||||
request_region: Optional[Literal["eu"]] = None,
|
||||
) -> Optional[dict]:
|
||||
"""
|
||||
Get the transcript details from AssemblyAI API
|
||||
|
||||
Args:
|
||||
response_body (dict): Response containing the transcript ID
|
||||
|
||||
Returns:
|
||||
Optional[dict]: Transcript details if successful, None otherwise
|
||||
"""
|
||||
from litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints import (
|
||||
passthrough_endpoint_router,
|
||||
)
|
||||
|
||||
_base_url = (
|
||||
self.assembly_ai_eu_base_url
|
||||
if request_region == "eu"
|
||||
else self.assembly_ai_base_url
|
||||
)
|
||||
_api_key = passthrough_endpoint_router.get_credentials(
|
||||
custom_llm_provider="assemblyai",
|
||||
region_name=request_region,
|
||||
)
|
||||
if _api_key is None:
|
||||
raise ValueError("AssemblyAI API key not found")
|
||||
try:
|
||||
url = f"{_base_url}/v2/transcript/{transcript_id}"
|
||||
headers = {
|
||||
"Authorization": f"Bearer {_api_key}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
response = httpx.get(url, headers=headers)
|
||||
response.raise_for_status()
|
||||
|
||||
return response.json()
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.exception(
|
||||
f"[Non blocking logging error] Error getting AssemblyAI transcript: {str(e)}"
|
||||
)
|
||||
return None
|
||||
|
||||
def _poll_assembly_for_transcript_response(
|
||||
self,
|
||||
transcript_id: str,
|
||||
url_route: Optional[str] = None,
|
||||
) -> Optional[AssemblyAITranscriptResponse]:
|
||||
"""
|
||||
Poll the status of the transcript until it is completed or timeout (30 minutes)
|
||||
"""
|
||||
for _ in range(
|
||||
self.max_polling_attempts
|
||||
): # 180 attempts * 10s = 30 minutes max
|
||||
transcript = self._get_assembly_transcript(
|
||||
request_region=AssemblyAIPassthroughLoggingHandler._get_assembly_region_from_url(
|
||||
url=url_route
|
||||
),
|
||||
transcript_id=transcript_id,
|
||||
)
|
||||
if transcript is None:
|
||||
return None
|
||||
if (
|
||||
transcript.get("status") == "completed"
|
||||
or transcript.get("status") == "error"
|
||||
):
|
||||
return AssemblyAITranscriptResponse(**transcript)
|
||||
time.sleep(self.polling_interval)
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def get_cost_for_assembly_transcript(
|
||||
transcript_response: AssemblyAITranscriptResponse,
|
||||
speech_model: str,
|
||||
) -> Optional[float]:
|
||||
"""
|
||||
Get the cost for the assembly transcript
|
||||
"""
|
||||
_audio_duration = transcript_response.get("audio_duration")
|
||||
if _audio_duration is None:
|
||||
return None
|
||||
_cost_per_second = (
|
||||
AssemblyAIPassthroughLoggingHandler.get_cost_per_second_for_assembly_model(
|
||||
speech_model=speech_model
|
||||
)
|
||||
)
|
||||
if _cost_per_second is None:
|
||||
return None
|
||||
return _audio_duration * _cost_per_second
|
||||
|
||||
@staticmethod
|
||||
def get_cost_per_second_for_assembly_model(speech_model: str) -> Optional[float]:
|
||||
"""
|
||||
Get the cost per second for the assembly model.
|
||||
Falls back to assemblyai/nano if the specific speech model info cannot be found.
|
||||
"""
|
||||
try:
|
||||
# First try with the provided speech model
|
||||
try:
|
||||
model_info = litellm.get_model_info(
|
||||
model=speech_model,
|
||||
custom_llm_provider="assemblyai",
|
||||
)
|
||||
if model_info and model_info.get("input_cost_per_second") is not None:
|
||||
return model_info.get("input_cost_per_second")
|
||||
except Exception:
|
||||
pass # Continue to fallback if model not found
|
||||
|
||||
# Fallback to assemblyai/nano if speech model info not found
|
||||
try:
|
||||
model_info = litellm.get_model_info(
|
||||
model="assemblyai/nano",
|
||||
custom_llm_provider="assemblyai",
|
||||
)
|
||||
if model_info and model_info.get("input_cost_per_second") is not None:
|
||||
return model_info.get("input_cost_per_second")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return None
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.exception(
|
||||
f"[Non blocking logging error] Error getting AssemblyAI model info: {str(e)}"
|
||||
)
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _should_log_request(request_method: str) -> bool:
|
||||
"""
|
||||
only POST transcription jobs are logged. litellm will POLL assembly to wait for the transcription to complete to log the complete response / cost
|
||||
"""
|
||||
return request_method == "POST"
|
||||
|
||||
@staticmethod
|
||||
def _get_assembly_region_from_url(url: Optional[str]) -> Optional[Literal["eu"]]:
|
||||
"""
|
||||
Get the region from the URL
|
||||
"""
|
||||
if url is None:
|
||||
return None
|
||||
if urlparse(url).hostname == "eu.assemblyai.com":
|
||||
return "eu"
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _get_assembly_base_url_from_region(region: Optional[Literal["eu"]]) -> str:
|
||||
"""
|
||||
Get the base URL for the AssemblyAI API
|
||||
if region == "eu", return "https://api.eu.assemblyai.com"
|
||||
else return "https://api.assemblyai.com"
|
||||
"""
|
||||
if region == "eu":
|
||||
return "https://api.eu.assemblyai.com"
|
||||
return "https://api.assemblyai.com"
|
||||
@@ -0,0 +1,221 @@
|
||||
import json
|
||||
from datetime import datetime
|
||||
from typing import TYPE_CHECKING, Any, List, Optional, Union
|
||||
|
||||
import httpx
|
||||
|
||||
import litellm
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
|
||||
from litellm.litellm_core_utils.litellm_logging import (
|
||||
get_standard_logging_object_payload,
|
||||
)
|
||||
from litellm.llms.base_llm.chat.transformation import BaseConfig
|
||||
from litellm.proxy._types import PassThroughEndpointLoggingTypedDict
|
||||
from litellm.proxy.auth.auth_utils import get_end_user_id_from_request_body
|
||||
from litellm.types.passthrough_endpoints.pass_through_endpoints import (
|
||||
PassthroughStandardLoggingPayload,
|
||||
)
|
||||
from litellm.types.utils import LlmProviders, ModelResponse, TextCompletionResponse
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..success_handler import PassThroughEndpointLogging
|
||||
from ..types import EndpointType
|
||||
else:
|
||||
PassThroughEndpointLogging = Any
|
||||
EndpointType = Any
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
|
||||
class BasePassthroughLoggingHandler(ABC):
|
||||
@property
|
||||
@abstractmethod
|
||||
def llm_provider_name(self) -> LlmProviders:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_provider_config(self, model: str) -> BaseConfig:
|
||||
pass
|
||||
|
||||
def passthrough_chat_handler(
|
||||
self,
|
||||
httpx_response: httpx.Response,
|
||||
response_body: dict,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
url_route: str,
|
||||
result: str,
|
||||
start_time: datetime,
|
||||
end_time: datetime,
|
||||
cache_hit: bool,
|
||||
request_body: dict,
|
||||
**kwargs,
|
||||
) -> PassThroughEndpointLoggingTypedDict:
|
||||
"""
|
||||
Transforms LLM response to OpenAI response, generates a standard logging object so downstream logging can be handled
|
||||
"""
|
||||
model = request_body.get("model", response_body.get("model", ""))
|
||||
provider_config = self.get_provider_config(model=model)
|
||||
litellm_model_response: ModelResponse = provider_config.transform_response(
|
||||
raw_response=httpx_response,
|
||||
model_response=litellm.ModelResponse(),
|
||||
model=model,
|
||||
messages=[],
|
||||
logging_obj=logging_obj,
|
||||
optional_params={},
|
||||
api_key="",
|
||||
request_data={},
|
||||
encoding=litellm.encoding,
|
||||
json_mode=False,
|
||||
litellm_params={},
|
||||
)
|
||||
|
||||
kwargs = self._create_response_logging_payload(
|
||||
litellm_model_response=litellm_model_response,
|
||||
model=model,
|
||||
kwargs=kwargs,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
logging_obj=logging_obj,
|
||||
)
|
||||
|
||||
return {
|
||||
"result": litellm_model_response,
|
||||
"kwargs": kwargs,
|
||||
}
|
||||
|
||||
def _get_user_from_metadata(
|
||||
self,
|
||||
passthrough_logging_payload: PassthroughStandardLoggingPayload,
|
||||
) -> Optional[str]:
|
||||
request_body = passthrough_logging_payload.get("request_body")
|
||||
if request_body:
|
||||
return get_end_user_id_from_request_body(request_body)
|
||||
return None
|
||||
|
||||
def _create_response_logging_payload(
|
||||
self,
|
||||
litellm_model_response: Union[ModelResponse, TextCompletionResponse],
|
||||
model: str,
|
||||
kwargs: dict,
|
||||
start_time: datetime,
|
||||
end_time: datetime,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
) -> dict:
|
||||
"""
|
||||
Create the standard logging object for Generic LLM passthrough
|
||||
|
||||
handles streaming and non-streaming responses
|
||||
"""
|
||||
|
||||
try:
|
||||
response_cost = litellm.completion_cost(
|
||||
completion_response=litellm_model_response,
|
||||
model=model,
|
||||
)
|
||||
|
||||
kwargs["response_cost"] = response_cost
|
||||
kwargs["model"] = model
|
||||
passthrough_logging_payload: Optional[PassthroughStandardLoggingPayload] = ( # type: ignore
|
||||
kwargs.get("passthrough_logging_payload")
|
||||
)
|
||||
if passthrough_logging_payload:
|
||||
user = self._get_user_from_metadata(
|
||||
passthrough_logging_payload=passthrough_logging_payload,
|
||||
)
|
||||
if user:
|
||||
kwargs.setdefault("litellm_params", {})
|
||||
kwargs["litellm_params"].update(
|
||||
{"proxy_server_request": {"body": {"user": user}}}
|
||||
)
|
||||
|
||||
# Make standard logging object for Anthropic
|
||||
standard_logging_object = get_standard_logging_object_payload(
|
||||
kwargs=kwargs,
|
||||
init_response_obj=litellm_model_response,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
logging_obj=logging_obj,
|
||||
status="success",
|
||||
)
|
||||
|
||||
# pretty print standard logging object
|
||||
verbose_proxy_logger.debug(
|
||||
"standard_logging_object= %s",
|
||||
json.dumps(standard_logging_object, indent=4),
|
||||
)
|
||||
kwargs["standard_logging_object"] = standard_logging_object
|
||||
|
||||
# set litellm_call_id to logging response object
|
||||
litellm_model_response.id = logging_obj.litellm_call_id
|
||||
litellm_model_response.model = model
|
||||
logging_obj.model_call_details["model"] = model
|
||||
return kwargs
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.exception(
|
||||
"Error creating LLM passthrough response logging payload: %s", e
|
||||
)
|
||||
return kwargs
|
||||
|
||||
@abstractmethod
|
||||
def _build_complete_streaming_response(
|
||||
self,
|
||||
all_chunks: List[str],
|
||||
litellm_logging_obj: LiteLLMLoggingObj,
|
||||
model: str,
|
||||
) -> Optional[Union[ModelResponse, TextCompletionResponse]]:
|
||||
"""
|
||||
Builds complete response from raw chunks
|
||||
|
||||
- Converts str chunks to generic chunks
|
||||
- Converts generic chunks to litellm chunks (OpenAI format)
|
||||
- Builds complete response from litellm chunks
|
||||
"""
|
||||
pass
|
||||
|
||||
def _handle_logging_llm_collected_chunks(
|
||||
self,
|
||||
litellm_logging_obj: LiteLLMLoggingObj,
|
||||
passthrough_success_handler_obj: PassThroughEndpointLogging,
|
||||
url_route: str,
|
||||
request_body: dict,
|
||||
endpoint_type: EndpointType,
|
||||
start_time: datetime,
|
||||
all_chunks: List[str],
|
||||
end_time: datetime,
|
||||
) -> PassThroughEndpointLoggingTypedDict:
|
||||
"""
|
||||
Takes raw chunks from Anthropic passthrough endpoint and logs them in litellm callbacks
|
||||
|
||||
- Builds complete response from chunks
|
||||
- Creates standard logging object
|
||||
- Logs in litellm callbacks
|
||||
"""
|
||||
|
||||
model = request_body.get("model", "")
|
||||
complete_streaming_response = self._build_complete_streaming_response(
|
||||
all_chunks=all_chunks,
|
||||
litellm_logging_obj=litellm_logging_obj,
|
||||
model=model,
|
||||
)
|
||||
if complete_streaming_response is None:
|
||||
verbose_proxy_logger.error(
|
||||
"Unable to build complete streaming response for Anthropic passthrough endpoint, not logging..."
|
||||
)
|
||||
return {
|
||||
"result": None,
|
||||
"kwargs": {},
|
||||
}
|
||||
kwargs = self._create_response_logging_payload(
|
||||
litellm_model_response=complete_streaming_response,
|
||||
model=model,
|
||||
kwargs={},
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
logging_obj=litellm_logging_obj,
|
||||
)
|
||||
|
||||
return {
|
||||
"result": complete_streaming_response,
|
||||
"kwargs": kwargs,
|
||||
}
|
||||
@@ -0,0 +1,56 @@
|
||||
from typing import List, Optional, Union
|
||||
|
||||
from litellm import stream_chunk_builder
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
|
||||
from litellm.litellm_core_utils.streaming_handler import CustomStreamWrapper
|
||||
from litellm.llms.base_llm.chat.transformation import BaseConfig
|
||||
from litellm.llms.cohere.chat.v2_transformation import CohereV2ChatConfig
|
||||
from litellm.llms.cohere.common_utils import (
|
||||
ModelResponseIterator as CohereModelResponseIterator,
|
||||
)
|
||||
from litellm.types.utils import LlmProviders, ModelResponse, TextCompletionResponse
|
||||
|
||||
from .base_passthrough_logging_handler import BasePassthroughLoggingHandler
|
||||
|
||||
|
||||
class CoherePassthroughLoggingHandler(BasePassthroughLoggingHandler):
|
||||
@property
|
||||
def llm_provider_name(self) -> LlmProviders:
|
||||
return LlmProviders.COHERE
|
||||
|
||||
def get_provider_config(self, model: str) -> BaseConfig:
|
||||
return CohereV2ChatConfig()
|
||||
|
||||
def _build_complete_streaming_response(
|
||||
self,
|
||||
all_chunks: List[str],
|
||||
litellm_logging_obj: LiteLLMLoggingObj,
|
||||
model: str,
|
||||
) -> Optional[Union[ModelResponse, TextCompletionResponse]]:
|
||||
cohere_model_response_iterator = CohereModelResponseIterator(
|
||||
streaming_response=None,
|
||||
sync_stream=False,
|
||||
)
|
||||
litellm_custom_stream_wrapper = CustomStreamWrapper(
|
||||
completion_stream=cohere_model_response_iterator,
|
||||
model=model,
|
||||
logging_obj=litellm_logging_obj,
|
||||
custom_llm_provider="cohere",
|
||||
)
|
||||
all_openai_chunks = []
|
||||
for _chunk_str in all_chunks:
|
||||
try:
|
||||
generic_chunk = (
|
||||
cohere_model_response_iterator.convert_str_chunk_to_generic_chunk(
|
||||
chunk=_chunk_str
|
||||
)
|
||||
)
|
||||
litellm_chunk = litellm_custom_stream_wrapper.chunk_creator(
|
||||
chunk=generic_chunk
|
||||
)
|
||||
if litellm_chunk is not None:
|
||||
all_openai_chunks.append(litellm_chunk)
|
||||
except (StopIteration, StopAsyncIteration):
|
||||
break
|
||||
complete_streaming_response = stream_chunk_builder(chunks=all_openai_chunks)
|
||||
return complete_streaming_response
|
||||
@@ -0,0 +1,261 @@
|
||||
import json
|
||||
import re
|
||||
from datetime import datetime
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import httpx
|
||||
|
||||
import litellm
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
|
||||
from litellm.llms.vertex_ai.gemini.vertex_and_google_ai_studio_gemini import (
|
||||
ModelResponseIterator as VertexModelResponseIterator,
|
||||
)
|
||||
from litellm.proxy._types import PassThroughEndpointLoggingTypedDict
|
||||
from litellm.types.utils import (
|
||||
EmbeddingResponse,
|
||||
ImageResponse,
|
||||
ModelResponse,
|
||||
TextCompletionResponse,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..success_handler import PassThroughEndpointLogging
|
||||
from ..types import EndpointType
|
||||
else:
|
||||
PassThroughEndpointLogging = Any
|
||||
EndpointType = Any
|
||||
|
||||
|
||||
class VertexPassthroughLoggingHandler:
|
||||
@staticmethod
|
||||
def vertex_passthrough_handler(
|
||||
httpx_response: httpx.Response,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
url_route: str,
|
||||
result: str,
|
||||
start_time: datetime,
|
||||
end_time: datetime,
|
||||
cache_hit: bool,
|
||||
**kwargs,
|
||||
) -> PassThroughEndpointLoggingTypedDict:
|
||||
if "generateContent" in url_route:
|
||||
model = VertexPassthroughLoggingHandler.extract_model_from_url(url_route)
|
||||
|
||||
instance_of_vertex_llm = litellm.VertexGeminiConfig()
|
||||
litellm_model_response: ModelResponse = (
|
||||
instance_of_vertex_llm.transform_response(
|
||||
model=model,
|
||||
messages=[
|
||||
{"role": "user", "content": "no-message-pass-through-endpoint"}
|
||||
],
|
||||
raw_response=httpx_response,
|
||||
model_response=litellm.ModelResponse(),
|
||||
logging_obj=logging_obj,
|
||||
optional_params={},
|
||||
litellm_params={},
|
||||
api_key="",
|
||||
request_data={},
|
||||
encoding=litellm.encoding,
|
||||
)
|
||||
)
|
||||
kwargs = VertexPassthroughLoggingHandler._create_vertex_response_logging_payload_for_generate_content(
|
||||
litellm_model_response=litellm_model_response,
|
||||
model=model,
|
||||
kwargs=kwargs,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
logging_obj=logging_obj,
|
||||
custom_llm_provider=VertexPassthroughLoggingHandler._get_custom_llm_provider_from_url(
|
||||
url_route
|
||||
),
|
||||
)
|
||||
|
||||
return {
|
||||
"result": litellm_model_response,
|
||||
"kwargs": kwargs,
|
||||
}
|
||||
|
||||
elif "predict" in url_route:
|
||||
from litellm.llms.vertex_ai.image_generation.image_generation_handler import (
|
||||
VertexImageGeneration,
|
||||
)
|
||||
from litellm.types.utils import PassthroughCallTypes
|
||||
|
||||
vertex_image_generation_class = VertexImageGeneration()
|
||||
|
||||
model = VertexPassthroughLoggingHandler.extract_model_from_url(url_route)
|
||||
_json_response = httpx_response.json()
|
||||
|
||||
litellm_prediction_response: Union[
|
||||
ModelResponse, EmbeddingResponse, ImageResponse
|
||||
] = ModelResponse()
|
||||
if vertex_image_generation_class.is_image_generation_response(
|
||||
_json_response
|
||||
):
|
||||
litellm_prediction_response = (
|
||||
vertex_image_generation_class.process_image_generation_response(
|
||||
_json_response,
|
||||
model_response=litellm.ImageResponse(),
|
||||
model=model,
|
||||
)
|
||||
)
|
||||
|
||||
logging_obj.call_type = (
|
||||
PassthroughCallTypes.passthrough_image_generation.value
|
||||
)
|
||||
else:
|
||||
litellm_prediction_response = litellm.vertexAITextEmbeddingConfig.transform_vertex_response_to_openai(
|
||||
response=_json_response,
|
||||
model=model,
|
||||
model_response=litellm.EmbeddingResponse(),
|
||||
)
|
||||
if isinstance(litellm_prediction_response, litellm.EmbeddingResponse):
|
||||
litellm_prediction_response.model = model
|
||||
|
||||
logging_obj.model = model
|
||||
logging_obj.model_call_details["model"] = logging_obj.model
|
||||
|
||||
return {
|
||||
"result": litellm_prediction_response,
|
||||
"kwargs": kwargs,
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"result": None,
|
||||
"kwargs": kwargs,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _handle_logging_vertex_collected_chunks(
|
||||
litellm_logging_obj: LiteLLMLoggingObj,
|
||||
passthrough_success_handler_obj: PassThroughEndpointLogging,
|
||||
url_route: str,
|
||||
request_body: dict,
|
||||
endpoint_type: EndpointType,
|
||||
start_time: datetime,
|
||||
all_chunks: List[str],
|
||||
end_time: datetime,
|
||||
) -> PassThroughEndpointLoggingTypedDict:
|
||||
"""
|
||||
Takes raw chunks from Vertex passthrough endpoint and logs them in litellm callbacks
|
||||
|
||||
- Builds complete response from chunks
|
||||
- Creates standard logging object
|
||||
- Logs in litellm callbacks
|
||||
"""
|
||||
kwargs: Dict[str, Any] = {}
|
||||
model = VertexPassthroughLoggingHandler.extract_model_from_url(url_route)
|
||||
complete_streaming_response = (
|
||||
VertexPassthroughLoggingHandler._build_complete_streaming_response(
|
||||
all_chunks=all_chunks,
|
||||
litellm_logging_obj=litellm_logging_obj,
|
||||
model=model,
|
||||
)
|
||||
)
|
||||
|
||||
if complete_streaming_response is None:
|
||||
verbose_proxy_logger.error(
|
||||
"Unable to build complete streaming response for Vertex passthrough endpoint, not logging..."
|
||||
)
|
||||
return {
|
||||
"result": None,
|
||||
"kwargs": kwargs,
|
||||
}
|
||||
|
||||
kwargs = VertexPassthroughLoggingHandler._create_vertex_response_logging_payload_for_generate_content(
|
||||
litellm_model_response=complete_streaming_response,
|
||||
model=model,
|
||||
kwargs=kwargs,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
logging_obj=litellm_logging_obj,
|
||||
custom_llm_provider=VertexPassthroughLoggingHandler._get_custom_llm_provider_from_url(
|
||||
url_route
|
||||
),
|
||||
)
|
||||
|
||||
return {
|
||||
"result": complete_streaming_response,
|
||||
"kwargs": kwargs,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _build_complete_streaming_response(
|
||||
all_chunks: List[str],
|
||||
litellm_logging_obj: LiteLLMLoggingObj,
|
||||
model: str,
|
||||
) -> Optional[Union[ModelResponse, TextCompletionResponse]]:
|
||||
vertex_iterator = VertexModelResponseIterator(
|
||||
streaming_response=None,
|
||||
sync_stream=False,
|
||||
)
|
||||
litellm_custom_stream_wrapper = litellm.CustomStreamWrapper(
|
||||
completion_stream=vertex_iterator,
|
||||
model=model,
|
||||
logging_obj=litellm_logging_obj,
|
||||
custom_llm_provider="vertex_ai",
|
||||
)
|
||||
all_openai_chunks = []
|
||||
for chunk in all_chunks:
|
||||
generic_chunk = vertex_iterator._common_chunk_parsing_logic(chunk)
|
||||
litellm_chunk = litellm_custom_stream_wrapper.chunk_creator(
|
||||
chunk=generic_chunk
|
||||
)
|
||||
if litellm_chunk is not None:
|
||||
all_openai_chunks.append(litellm_chunk)
|
||||
|
||||
complete_streaming_response = litellm.stream_chunk_builder(
|
||||
chunks=all_openai_chunks
|
||||
)
|
||||
|
||||
return complete_streaming_response
|
||||
|
||||
@staticmethod
|
||||
def extract_model_from_url(url: str) -> str:
|
||||
pattern = r"/models/([^:]+)"
|
||||
match = re.search(pattern, url)
|
||||
if match:
|
||||
return match.group(1)
|
||||
return "unknown"
|
||||
|
||||
@staticmethod
|
||||
def _get_custom_llm_provider_from_url(url: str) -> str:
|
||||
parsed_url = urlparse(url)
|
||||
if parsed_url.hostname and parsed_url.hostname.endswith(
|
||||
"generativelanguage.googleapis.com"
|
||||
):
|
||||
return litellm.LlmProviders.GEMINI.value
|
||||
return litellm.LlmProviders.VERTEX_AI.value
|
||||
|
||||
@staticmethod
|
||||
def _create_vertex_response_logging_payload_for_generate_content(
|
||||
litellm_model_response: Union[ModelResponse, TextCompletionResponse],
|
||||
model: str,
|
||||
kwargs: dict,
|
||||
start_time: datetime,
|
||||
end_time: datetime,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
custom_llm_provider: str,
|
||||
):
|
||||
"""
|
||||
Create the standard logging object for Vertex passthrough generateContent (streaming and non-streaming)
|
||||
|
||||
"""
|
||||
response_cost = litellm.completion_cost(
|
||||
completion_response=litellm_model_response,
|
||||
model=model,
|
||||
)
|
||||
kwargs["response_cost"] = response_cost
|
||||
kwargs["model"] = model
|
||||
|
||||
# pretty print standard logging object
|
||||
verbose_proxy_logger.debug("kwargs= %s", json.dumps(kwargs, indent=4))
|
||||
|
||||
# set litellm_call_id to logging response object
|
||||
litellm_model_response.id = logging_obj.litellm_call_id
|
||||
logging_obj.model = litellm_model_response.model or model
|
||||
logging_obj.model_call_details["model"] = logging_obj.model
|
||||
logging_obj.model_call_details["custom_llm_provider"] = custom_llm_provider
|
||||
return kwargs
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,193 @@
|
||||
from typing import Dict, Optional
|
||||
|
||||
from litellm._logging import verbose_router_logger
|
||||
from litellm.secret_managers.main import get_secret_str
|
||||
from litellm.types.llms.vertex_ai import VERTEX_CREDENTIALS_TYPES
|
||||
from litellm.types.passthrough_endpoints.vertex_ai import VertexPassThroughCredentials
|
||||
|
||||
|
||||
class PassthroughEndpointRouter:
|
||||
"""
|
||||
Use this class to Set/Get credentials for pass-through endpoints
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.credentials: Dict[str, str] = {}
|
||||
self.deployment_key_to_vertex_credentials: Dict[
|
||||
str, VertexPassThroughCredentials
|
||||
] = {}
|
||||
self.default_vertex_config: Optional[VertexPassThroughCredentials] = None
|
||||
|
||||
def set_pass_through_credentials(
|
||||
self,
|
||||
custom_llm_provider: str,
|
||||
api_base: Optional[str],
|
||||
api_key: Optional[str],
|
||||
):
|
||||
"""
|
||||
Set credentials for a pass-through endpoint. Used when a user adds a pass-through LLM endpoint on the UI.
|
||||
|
||||
Args:
|
||||
custom_llm_provider: The provider of the pass-through endpoint
|
||||
api_base: The base URL of the pass-through endpoint
|
||||
api_key: The API key for the pass-through endpoint
|
||||
"""
|
||||
credential_name = self._get_credential_name_for_provider(
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
region_name=self._get_region_name_from_api_base(
|
||||
api_base=api_base, custom_llm_provider=custom_llm_provider
|
||||
),
|
||||
)
|
||||
if api_key is None:
|
||||
raise ValueError("api_key is required for setting pass-through credentials")
|
||||
self.credentials[credential_name] = api_key
|
||||
|
||||
def get_credentials(
|
||||
self,
|
||||
custom_llm_provider: str,
|
||||
region_name: Optional[str],
|
||||
) -> Optional[str]:
|
||||
credential_name = self._get_credential_name_for_provider(
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
region_name=region_name,
|
||||
)
|
||||
verbose_router_logger.debug(
|
||||
f"Pass-through llm endpoints router, looking for credentials for {credential_name}"
|
||||
)
|
||||
if credential_name in self.credentials:
|
||||
verbose_router_logger.debug(f"Found credentials for {credential_name}")
|
||||
return self.credentials[credential_name]
|
||||
else:
|
||||
verbose_router_logger.debug(
|
||||
f"No credentials found for {credential_name}, looking for env variable"
|
||||
)
|
||||
_env_variable_name = (
|
||||
self._get_default_env_variable_name_passthrough_endpoint(
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
)
|
||||
)
|
||||
return get_secret_str(_env_variable_name)
|
||||
|
||||
def _get_vertex_env_vars(self) -> VertexPassThroughCredentials:
|
||||
"""
|
||||
Helper to get vertex pass through config from environment variables
|
||||
|
||||
The following environment variables are used:
|
||||
- DEFAULT_VERTEXAI_PROJECT (project id)
|
||||
- DEFAULT_VERTEXAI_LOCATION (location)
|
||||
- DEFAULT_GOOGLE_APPLICATION_CREDENTIALS (path to credentials file)
|
||||
"""
|
||||
return VertexPassThroughCredentials(
|
||||
vertex_project=get_secret_str("DEFAULT_VERTEXAI_PROJECT"),
|
||||
vertex_location=get_secret_str("DEFAULT_VERTEXAI_LOCATION"),
|
||||
vertex_credentials=get_secret_str("DEFAULT_GOOGLE_APPLICATION_CREDENTIALS"),
|
||||
)
|
||||
|
||||
def set_default_vertex_config(self, config: Optional[dict] = None):
|
||||
"""Sets vertex configuration from provided config and/or environment variables
|
||||
|
||||
Args:
|
||||
config (Optional[dict]): Configuration dictionary
|
||||
Example: {
|
||||
"vertex_project": "my-project-123",
|
||||
"vertex_location": "us-central1",
|
||||
"vertex_credentials": "os.environ/GOOGLE_CREDS"
|
||||
}
|
||||
"""
|
||||
# Initialize config dictionary if None
|
||||
if config is None:
|
||||
self.default_vertex_config = self._get_vertex_env_vars()
|
||||
return
|
||||
|
||||
if isinstance(config, dict):
|
||||
for key, value in config.items():
|
||||
if isinstance(value, str) and value.startswith("os.environ/"):
|
||||
config[key] = get_secret_str(value)
|
||||
|
||||
self.default_vertex_config = VertexPassThroughCredentials(**config)
|
||||
|
||||
def add_vertex_credentials(
|
||||
self,
|
||||
project_id: str,
|
||||
location: str,
|
||||
vertex_credentials: VERTEX_CREDENTIALS_TYPES,
|
||||
):
|
||||
"""
|
||||
Add the vertex credentials for the given project-id, location
|
||||
"""
|
||||
|
||||
deployment_key = self._get_deployment_key(
|
||||
project_id=project_id,
|
||||
location=location,
|
||||
)
|
||||
if deployment_key is None:
|
||||
verbose_router_logger.debug(
|
||||
"No deployment key found for project-id, location"
|
||||
)
|
||||
return
|
||||
vertex_pass_through_credentials = VertexPassThroughCredentials(
|
||||
vertex_project=project_id,
|
||||
vertex_location=location,
|
||||
vertex_credentials=vertex_credentials,
|
||||
)
|
||||
self.deployment_key_to_vertex_credentials[
|
||||
deployment_key
|
||||
] = vertex_pass_through_credentials
|
||||
|
||||
def _get_deployment_key(
|
||||
self, project_id: Optional[str], location: Optional[str]
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
Get the deployment key for the given project-id, location
|
||||
"""
|
||||
if project_id is None or location is None:
|
||||
return None
|
||||
return f"{project_id}-{location}"
|
||||
|
||||
def get_vertex_credentials(
|
||||
self, project_id: Optional[str], location: Optional[str]
|
||||
) -> Optional[VertexPassThroughCredentials]:
|
||||
"""
|
||||
Get the vertex credentials for the given project-id, location
|
||||
"""
|
||||
deployment_key = self._get_deployment_key(
|
||||
project_id=project_id,
|
||||
location=location,
|
||||
)
|
||||
|
||||
if deployment_key is None:
|
||||
return self.default_vertex_config
|
||||
if deployment_key in self.deployment_key_to_vertex_credentials:
|
||||
return self.deployment_key_to_vertex_credentials[deployment_key]
|
||||
else:
|
||||
return self.default_vertex_config
|
||||
|
||||
def _get_credential_name_for_provider(
|
||||
self,
|
||||
custom_llm_provider: str,
|
||||
region_name: Optional[str],
|
||||
) -> str:
|
||||
if region_name is None:
|
||||
return f"{custom_llm_provider.upper()}_API_KEY"
|
||||
return f"{custom_llm_provider.upper()}_{region_name.upper()}_API_KEY"
|
||||
|
||||
def _get_region_name_from_api_base(
|
||||
self,
|
||||
custom_llm_provider: str,
|
||||
api_base: Optional[str],
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
Get the region name from the API base.
|
||||
|
||||
Each provider might have a different way of specifying the region in the API base - this is where you can use conditional logic to handle that.
|
||||
"""
|
||||
if custom_llm_provider == "assemblyai":
|
||||
if api_base and "eu" in api_base:
|
||||
return "eu"
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _get_default_env_variable_name_passthrough_endpoint(
|
||||
custom_llm_provider: str,
|
||||
) -> str:
|
||||
return f"{custom_llm_provider.upper()}_API_KEY"
|
||||
@@ -0,0 +1,159 @@
|
||||
import asyncio
|
||||
import threading
|
||||
from datetime import datetime
|
||||
from typing import List, Optional
|
||||
|
||||
import httpx
|
||||
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
|
||||
from litellm.proxy._types import PassThroughEndpointLoggingResultValues
|
||||
from litellm.types.passthrough_endpoints.pass_through_endpoints import EndpointType
|
||||
from litellm.types.utils import StandardPassThroughResponseObject
|
||||
|
||||
from .llm_provider_handlers.anthropic_passthrough_logging_handler import (
|
||||
AnthropicPassthroughLoggingHandler,
|
||||
)
|
||||
from .llm_provider_handlers.vertex_passthrough_logging_handler import (
|
||||
VertexPassthroughLoggingHandler,
|
||||
)
|
||||
from .success_handler import PassThroughEndpointLogging
|
||||
|
||||
|
||||
class PassThroughStreamingHandler:
|
||||
@staticmethod
|
||||
async def chunk_processor(
|
||||
response: httpx.Response,
|
||||
request_body: Optional[dict],
|
||||
litellm_logging_obj: LiteLLMLoggingObj,
|
||||
endpoint_type: EndpointType,
|
||||
start_time: datetime,
|
||||
passthrough_success_handler_obj: PassThroughEndpointLogging,
|
||||
url_route: str,
|
||||
):
|
||||
"""
|
||||
- Yields chunks from the response
|
||||
- Collect non-empty chunks for post-processing (logging)
|
||||
"""
|
||||
try:
|
||||
raw_bytes: List[bytes] = []
|
||||
async for chunk in response.aiter_bytes():
|
||||
raw_bytes.append(chunk)
|
||||
yield chunk
|
||||
|
||||
# After all chunks are processed, handle post-processing
|
||||
end_time = datetime.now()
|
||||
|
||||
asyncio.create_task(
|
||||
PassThroughStreamingHandler._route_streaming_logging_to_handler(
|
||||
litellm_logging_obj=litellm_logging_obj,
|
||||
passthrough_success_handler_obj=passthrough_success_handler_obj,
|
||||
url_route=url_route,
|
||||
request_body=request_body or {},
|
||||
endpoint_type=endpoint_type,
|
||||
start_time=start_time,
|
||||
raw_bytes=raw_bytes,
|
||||
end_time=end_time,
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.error(f"Error in chunk_processor: {str(e)}")
|
||||
raise
|
||||
|
||||
@staticmethod
|
||||
async def _route_streaming_logging_to_handler(
|
||||
litellm_logging_obj: LiteLLMLoggingObj,
|
||||
passthrough_success_handler_obj: PassThroughEndpointLogging,
|
||||
url_route: str,
|
||||
request_body: dict,
|
||||
endpoint_type: EndpointType,
|
||||
start_time: datetime,
|
||||
raw_bytes: List[bytes],
|
||||
end_time: datetime,
|
||||
):
|
||||
"""
|
||||
Route the logging for the collected chunks to the appropriate handler
|
||||
|
||||
Supported endpoint types:
|
||||
- Anthropic
|
||||
- Vertex AI
|
||||
"""
|
||||
all_chunks = PassThroughStreamingHandler._convert_raw_bytes_to_str_lines(
|
||||
raw_bytes
|
||||
)
|
||||
standard_logging_response_object: Optional[
|
||||
PassThroughEndpointLoggingResultValues
|
||||
] = None
|
||||
kwargs: dict = {}
|
||||
if endpoint_type == EndpointType.ANTHROPIC:
|
||||
anthropic_passthrough_logging_handler_result = AnthropicPassthroughLoggingHandler._handle_logging_anthropic_collected_chunks(
|
||||
litellm_logging_obj=litellm_logging_obj,
|
||||
passthrough_success_handler_obj=passthrough_success_handler_obj,
|
||||
url_route=url_route,
|
||||
request_body=request_body,
|
||||
endpoint_type=endpoint_type,
|
||||
start_time=start_time,
|
||||
all_chunks=all_chunks,
|
||||
end_time=end_time,
|
||||
)
|
||||
standard_logging_response_object = (
|
||||
anthropic_passthrough_logging_handler_result["result"]
|
||||
)
|
||||
kwargs = anthropic_passthrough_logging_handler_result["kwargs"]
|
||||
elif endpoint_type == EndpointType.VERTEX_AI:
|
||||
vertex_passthrough_logging_handler_result = (
|
||||
VertexPassthroughLoggingHandler._handle_logging_vertex_collected_chunks(
|
||||
litellm_logging_obj=litellm_logging_obj,
|
||||
passthrough_success_handler_obj=passthrough_success_handler_obj,
|
||||
url_route=url_route,
|
||||
request_body=request_body,
|
||||
endpoint_type=endpoint_type,
|
||||
start_time=start_time,
|
||||
all_chunks=all_chunks,
|
||||
end_time=end_time,
|
||||
)
|
||||
)
|
||||
standard_logging_response_object = (
|
||||
vertex_passthrough_logging_handler_result["result"]
|
||||
)
|
||||
kwargs = vertex_passthrough_logging_handler_result["kwargs"]
|
||||
|
||||
if standard_logging_response_object is None:
|
||||
standard_logging_response_object = StandardPassThroughResponseObject(
|
||||
response=f"cannot parse chunks to standard response object. Chunks={all_chunks}"
|
||||
)
|
||||
threading.Thread(
|
||||
target=litellm_logging_obj.success_handler,
|
||||
args=(
|
||||
standard_logging_response_object,
|
||||
start_time,
|
||||
end_time,
|
||||
False,
|
||||
),
|
||||
).start()
|
||||
await litellm_logging_obj.async_success_handler(
|
||||
result=standard_logging_response_object,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
cache_hit=False,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _convert_raw_bytes_to_str_lines(raw_bytes: List[bytes]) -> List[str]:
|
||||
"""
|
||||
Converts a list of raw bytes into a list of string lines, similar to aiter_lines()
|
||||
|
||||
Args:
|
||||
raw_bytes: List of bytes chunks from aiter.bytes()
|
||||
|
||||
Returns:
|
||||
List of string lines, with each line being a complete data: {} chunk
|
||||
"""
|
||||
# Combine all bytes and decode to string
|
||||
combined_str = b"".join(raw_bytes).decode("utf-8")
|
||||
|
||||
# Split by newlines and filter out empty lines
|
||||
lines = [line.strip() for line in combined_str.split("\n") if line.strip()]
|
||||
|
||||
return lines
|
||||
@@ -0,0 +1,221 @@
|
||||
import json
|
||||
from datetime import datetime
|
||||
from typing import Optional, Union
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import httpx
|
||||
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
|
||||
from litellm.proxy._types import PassThroughEndpointLoggingResultValues
|
||||
from litellm.types.passthrough_endpoints.pass_through_endpoints import (
|
||||
PassthroughStandardLoggingPayload,
|
||||
)
|
||||
from litellm.types.utils import StandardPassThroughResponseObject
|
||||
from litellm.utils import executor as thread_pool_executor
|
||||
|
||||
from .llm_provider_handlers.anthropic_passthrough_logging_handler import (
|
||||
AnthropicPassthroughLoggingHandler,
|
||||
)
|
||||
from .llm_provider_handlers.assembly_passthrough_logging_handler import (
|
||||
AssemblyAIPassthroughLoggingHandler,
|
||||
)
|
||||
from .llm_provider_handlers.cohere_passthrough_logging_handler import (
|
||||
CoherePassthroughLoggingHandler,
|
||||
)
|
||||
from .llm_provider_handlers.vertex_passthrough_logging_handler import (
|
||||
VertexPassthroughLoggingHandler,
|
||||
)
|
||||
|
||||
cohere_passthrough_logging_handler = CoherePassthroughLoggingHandler()
|
||||
|
||||
|
||||
class PassThroughEndpointLogging:
|
||||
def __init__(self):
|
||||
self.TRACKED_VERTEX_ROUTES = [
|
||||
"generateContent",
|
||||
"streamGenerateContent",
|
||||
"predict",
|
||||
]
|
||||
|
||||
# Anthropic
|
||||
self.TRACKED_ANTHROPIC_ROUTES = ["/messages"]
|
||||
|
||||
# Cohere
|
||||
self.TRACKED_COHERE_ROUTES = ["/v2/chat"]
|
||||
self.assemblyai_passthrough_logging_handler = (
|
||||
AssemblyAIPassthroughLoggingHandler()
|
||||
)
|
||||
|
||||
async def _handle_logging(
|
||||
self,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
standard_logging_response_object: Union[
|
||||
StandardPassThroughResponseObject,
|
||||
PassThroughEndpointLoggingResultValues,
|
||||
dict,
|
||||
],
|
||||
result: str,
|
||||
start_time: datetime,
|
||||
end_time: datetime,
|
||||
cache_hit: bool,
|
||||
**kwargs,
|
||||
):
|
||||
"""Helper function to handle both sync and async logging operations"""
|
||||
# Submit to thread pool for sync logging
|
||||
thread_pool_executor.submit(
|
||||
logging_obj.success_handler,
|
||||
standard_logging_response_object,
|
||||
start_time,
|
||||
end_time,
|
||||
cache_hit,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# Handle async logging
|
||||
await logging_obj.async_success_handler(
|
||||
result=(
|
||||
json.dumps(result)
|
||||
if isinstance(result, dict)
|
||||
else standard_logging_response_object
|
||||
),
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
cache_hit=False,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
async def pass_through_async_success_handler(
|
||||
self,
|
||||
httpx_response: httpx.Response,
|
||||
response_body: Optional[dict],
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
url_route: str,
|
||||
result: str,
|
||||
start_time: datetime,
|
||||
end_time: datetime,
|
||||
cache_hit: bool,
|
||||
request_body: dict,
|
||||
passthrough_logging_payload: PassthroughStandardLoggingPayload,
|
||||
**kwargs,
|
||||
):
|
||||
standard_logging_response_object: Optional[
|
||||
PassThroughEndpointLoggingResultValues
|
||||
] = None
|
||||
logging_obj.model_call_details["passthrough_logging_payload"] = (
|
||||
passthrough_logging_payload
|
||||
)
|
||||
if self.is_vertex_route(url_route):
|
||||
vertex_passthrough_logging_handler_result = (
|
||||
VertexPassthroughLoggingHandler.vertex_passthrough_handler(
|
||||
httpx_response=httpx_response,
|
||||
logging_obj=logging_obj,
|
||||
url_route=url_route,
|
||||
result=result,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
cache_hit=cache_hit,
|
||||
**kwargs,
|
||||
)
|
||||
)
|
||||
standard_logging_response_object = (
|
||||
vertex_passthrough_logging_handler_result["result"]
|
||||
)
|
||||
kwargs = vertex_passthrough_logging_handler_result["kwargs"]
|
||||
elif self.is_anthropic_route(url_route):
|
||||
anthropic_passthrough_logging_handler_result = (
|
||||
AnthropicPassthroughLoggingHandler.anthropic_passthrough_handler(
|
||||
httpx_response=httpx_response,
|
||||
response_body=response_body or {},
|
||||
logging_obj=logging_obj,
|
||||
url_route=url_route,
|
||||
result=result,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
cache_hit=cache_hit,
|
||||
**kwargs,
|
||||
)
|
||||
)
|
||||
|
||||
standard_logging_response_object = (
|
||||
anthropic_passthrough_logging_handler_result["result"]
|
||||
)
|
||||
kwargs = anthropic_passthrough_logging_handler_result["kwargs"]
|
||||
elif self.is_cohere_route(url_route):
|
||||
cohere_passthrough_logging_handler_result = (
|
||||
cohere_passthrough_logging_handler.passthrough_chat_handler(
|
||||
httpx_response=httpx_response,
|
||||
response_body=response_body or {},
|
||||
logging_obj=logging_obj,
|
||||
url_route=url_route,
|
||||
result=result,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
cache_hit=cache_hit,
|
||||
request_body=request_body,
|
||||
**kwargs,
|
||||
)
|
||||
)
|
||||
standard_logging_response_object = (
|
||||
cohere_passthrough_logging_handler_result["result"]
|
||||
)
|
||||
kwargs = cohere_passthrough_logging_handler_result["kwargs"]
|
||||
elif self.is_assemblyai_route(url_route):
|
||||
if (
|
||||
AssemblyAIPassthroughLoggingHandler._should_log_request(
|
||||
httpx_response.request.method
|
||||
)
|
||||
is not True
|
||||
):
|
||||
return
|
||||
self.assemblyai_passthrough_logging_handler.assemblyai_passthrough_logging_handler(
|
||||
httpx_response=httpx_response,
|
||||
response_body=response_body or {},
|
||||
logging_obj=logging_obj,
|
||||
url_route=url_route,
|
||||
result=result,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
cache_hit=cache_hit,
|
||||
**kwargs,
|
||||
)
|
||||
return
|
||||
|
||||
if standard_logging_response_object is None:
|
||||
standard_logging_response_object = StandardPassThroughResponseObject(
|
||||
response=httpx_response.text
|
||||
)
|
||||
|
||||
await self._handle_logging(
|
||||
logging_obj=logging_obj,
|
||||
standard_logging_response_object=standard_logging_response_object,
|
||||
result=result,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
cache_hit=cache_hit,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def is_vertex_route(self, url_route: str):
|
||||
for route in self.TRACKED_VERTEX_ROUTES:
|
||||
if route in url_route:
|
||||
return True
|
||||
return False
|
||||
|
||||
def is_anthropic_route(self, url_route: str):
|
||||
for route in self.TRACKED_ANTHROPIC_ROUTES:
|
||||
if route in url_route:
|
||||
return True
|
||||
return False
|
||||
|
||||
def is_cohere_route(self, url_route: str):
|
||||
for route in self.TRACKED_COHERE_ROUTES:
|
||||
if route in url_route:
|
||||
return True
|
||||
|
||||
def is_assemblyai_route(self, url_route: str):
|
||||
parsed_url = urlparse(url_route)
|
||||
if parsed_url.hostname == "api.assemblyai.com":
|
||||
return True
|
||||
elif "/transcript" in parsed_url.path:
|
||||
return True
|
||||
return False
|
||||
Reference in New Issue
Block a user