structure saas with tools

This commit is contained in:
Davidson Gomes
2025-04-25 15:30:54 -03:00
commit 1aef473937
16434 changed files with 6584257 additions and 0 deletions

View File

@@ -0,0 +1,16 @@
from fastapi import Request
def get_litellm_virtual_key(request: Request) -> str:
"""
Extract and format API key from request headers.
Prioritizes x-litellm-api-key over Authorization header.
Vertex JS SDK uses `Authorization` header, we use `x-litellm-api-key` to pass litellm virtual key
"""
litellm_api_key = request.headers.get("x-litellm-api-key")
if litellm_api_key:
return f"Bearer {litellm_api_key}"
return request.headers.get("Authorization", "")

View File

@@ -0,0 +1,965 @@
"""
What is this?
Provider-specific Pass-Through Endpoints
Use litellm with Anthropic SDK, Vertex AI SDK, Cohere SDK, etc.
"""
import os
from typing import Optional
import httpx
from fastapi import APIRouter, Depends, HTTPException, Request, Response
import litellm
from litellm._logging import verbose_proxy_logger
from litellm.constants import BEDROCK_AGENT_RUNTIME_PASS_THROUGH_ROUTES
from litellm.llms.vertex_ai.vertex_llm_base import VertexBase
from litellm.proxy._types import *
from litellm.proxy.auth.route_checks import RouteChecks
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
from litellm.proxy.pass_through_endpoints.common_utils import get_litellm_virtual_key
from litellm.proxy.pass_through_endpoints.pass_through_endpoints import (
create_pass_through_route,
)
from litellm.secret_managers.main import get_secret_str
from .passthrough_endpoint_router import PassthroughEndpointRouter
vertex_llm_base = VertexBase()
router = APIRouter()
default_vertex_config = None
passthrough_endpoint_router = PassthroughEndpointRouter()
def create_request_copy(request: Request):
return {
"method": request.method,
"url": str(request.url),
"headers": dict(request.headers),
"cookies": request.cookies,
"query_params": dict(request.query_params),
}
async def llm_passthrough_factory_proxy_route(
custom_llm_provider: str,
endpoint: str,
request: Request,
fastapi_response: Response,
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
):
"""
Factory function for creating pass-through endpoints for LLM providers.
"""
from litellm.types.utils import LlmProviders
from litellm.utils import ProviderConfigManager
provider_config = ProviderConfigManager.get_provider_model_info(
provider=LlmProviders(custom_llm_provider),
model=None,
)
if provider_config is None:
raise HTTPException(
status_code=404, detail=f"Provider {custom_llm_provider} not found"
)
base_target_url = provider_config.get_api_base()
if base_target_url is None:
raise HTTPException(
status_code=404, detail=f"Provider {custom_llm_provider} api base not found"
)
encoded_endpoint = httpx.URL(endpoint).path
# Ensure endpoint starts with '/' for proper URL construction
if not encoded_endpoint.startswith("/"):
encoded_endpoint = "/" + encoded_endpoint
# Construct the full target URL using httpx
base_url = httpx.URL(base_target_url)
updated_url = base_url.copy_with(path=encoded_endpoint)
# Add or update query parameters
provider_api_key = passthrough_endpoint_router.get_credentials(
custom_llm_provider=custom_llm_provider,
region_name=None,
)
auth_headers = provider_config.validate_environment(
headers={},
model="",
messages=[],
optional_params={},
litellm_params={},
api_key=provider_api_key,
api_base=base_target_url,
)
## check for streaming
is_streaming_request = False
# anthropic is streaming when 'stream' = True is in the body
if request.method == "POST":
_request_body = await request.json()
if _request_body.get("stream"):
is_streaming_request = True
## CREATE PASS-THROUGH
endpoint_func = create_pass_through_route(
endpoint=endpoint,
target=str(updated_url),
custom_headers=auth_headers,
) # dynamically construct pass-through endpoint based on incoming path
received_value = await endpoint_func(
request,
fastapi_response,
user_api_key_dict,
stream=is_streaming_request, # type: ignore
)
return received_value
@router.api_route(
"/gemini/{endpoint:path}",
methods=["GET", "POST", "PUT", "DELETE", "PATCH"],
tags=["Google AI Studio Pass-through", "pass-through"],
)
async def gemini_proxy_route(
endpoint: str,
request: Request,
fastapi_response: Response,
):
"""
[Docs](https://docs.litellm.ai/docs/pass_through/google_ai_studio)
"""
## CHECK FOR LITELLM API KEY IN THE QUERY PARAMS - ?..key=LITELLM_API_KEY
google_ai_studio_api_key = request.query_params.get("key") or request.headers.get(
"x-goog-api-key"
)
user_api_key_dict = await user_api_key_auth(
request=request, api_key=f"Bearer {google_ai_studio_api_key}"
)
base_target_url = "https://generativelanguage.googleapis.com"
encoded_endpoint = httpx.URL(endpoint).path
# Ensure endpoint starts with '/' for proper URL construction
if not encoded_endpoint.startswith("/"):
encoded_endpoint = "/" + encoded_endpoint
# Construct the full target URL using httpx
base_url = httpx.URL(base_target_url)
updated_url = base_url.copy_with(path=encoded_endpoint)
# Add or update query parameters
gemini_api_key: Optional[str] = passthrough_endpoint_router.get_credentials(
custom_llm_provider="gemini",
region_name=None,
)
if gemini_api_key is None:
raise Exception(
"Required 'GEMINI_API_KEY' in environment to make pass-through calls to Google AI Studio."
)
# Merge query parameters, giving precedence to those in updated_url
merged_params = dict(request.query_params)
merged_params.update({"key": gemini_api_key})
## check for streaming
is_streaming_request = False
if "stream" in str(updated_url):
is_streaming_request = True
## CREATE PASS-THROUGH
endpoint_func = create_pass_through_route(
endpoint=endpoint,
target=str(updated_url),
) # dynamically construct pass-through endpoint based on incoming path
received_value = await endpoint_func(
request,
fastapi_response,
user_api_key_dict,
query_params=merged_params, # type: ignore
stream=is_streaming_request, # type: ignore
)
return received_value
@router.api_route(
"/cohere/{endpoint:path}",
methods=["GET", "POST", "PUT", "DELETE", "PATCH"],
tags=["Cohere Pass-through", "pass-through"],
)
async def cohere_proxy_route(
endpoint: str,
request: Request,
fastapi_response: Response,
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
):
"""
[Docs](https://docs.litellm.ai/docs/pass_through/cohere)
"""
base_target_url = "https://api.cohere.com"
encoded_endpoint = httpx.URL(endpoint).path
# Ensure endpoint starts with '/' for proper URL construction
if not encoded_endpoint.startswith("/"):
encoded_endpoint = "/" + encoded_endpoint
# Construct the full target URL using httpx
base_url = httpx.URL(base_target_url)
updated_url = base_url.copy_with(path=encoded_endpoint)
# Add or update query parameters
cohere_api_key = passthrough_endpoint_router.get_credentials(
custom_llm_provider="cohere",
region_name=None,
)
## check for streaming
is_streaming_request = False
if "stream" in str(updated_url):
is_streaming_request = True
## CREATE PASS-THROUGH
endpoint_func = create_pass_through_route(
endpoint=endpoint,
target=str(updated_url),
custom_headers={"Authorization": "Bearer {}".format(cohere_api_key)},
) # dynamically construct pass-through endpoint based on incoming path
received_value = await endpoint_func(
request,
fastapi_response,
user_api_key_dict,
stream=is_streaming_request, # type: ignore
)
return received_value
@router.api_route(
"/vllm/{endpoint:path}",
methods=["GET", "POST", "PUT", "DELETE", "PATCH"],
tags=["VLLM Pass-through", "pass-through"],
)
async def vllm_proxy_route(
endpoint: str,
request: Request,
fastapi_response: Response,
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
):
"""
[Docs](https://docs.litellm.ai/docs/pass_through/vllm)
"""
return await llm_passthrough_factory_proxy_route(
endpoint=endpoint,
request=request,
fastapi_response=fastapi_response,
user_api_key_dict=user_api_key_dict,
custom_llm_provider="vllm",
)
@router.api_route(
"/mistral/{endpoint:path}",
methods=["GET", "POST", "PUT", "DELETE", "PATCH"],
tags=["Mistral Pass-through", "pass-through"],
)
async def mistral_proxy_route(
endpoint: str,
request: Request,
fastapi_response: Response,
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
):
"""
[Docs](https://docs.litellm.ai/docs/anthropic_completion)
"""
base_target_url = os.getenv("MISTRAL_API_BASE") or "https://api.mistral.ai"
encoded_endpoint = httpx.URL(endpoint).path
# Ensure endpoint starts with '/' for proper URL construction
if not encoded_endpoint.startswith("/"):
encoded_endpoint = "/" + encoded_endpoint
# Construct the full target URL using httpx
base_url = httpx.URL(base_target_url)
updated_url = base_url.copy_with(path=encoded_endpoint)
# Add or update query parameters
mistral_api_key = passthrough_endpoint_router.get_credentials(
custom_llm_provider="mistral",
region_name=None,
)
## check for streaming
is_streaming_request = False
# anthropic is streaming when 'stream' = True is in the body
if request.method == "POST":
_request_body = await request.json()
if _request_body.get("stream"):
is_streaming_request = True
## CREATE PASS-THROUGH
endpoint_func = create_pass_through_route(
endpoint=endpoint,
target=str(updated_url),
custom_headers={"Authorization": "Bearer {}".format(mistral_api_key)},
) # dynamically construct pass-through endpoint based on incoming path
received_value = await endpoint_func(
request,
fastapi_response,
user_api_key_dict,
stream=is_streaming_request, # type: ignore
)
return received_value
@router.api_route(
"/anthropic/{endpoint:path}",
methods=["GET", "POST", "PUT", "DELETE", "PATCH"],
tags=["Anthropic Pass-through", "pass-through"],
)
async def anthropic_proxy_route(
endpoint: str,
request: Request,
fastapi_response: Response,
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
):
"""
[Docs](https://docs.litellm.ai/docs/anthropic_completion)
"""
base_target_url = "https://api.anthropic.com"
encoded_endpoint = httpx.URL(endpoint).path
# Ensure endpoint starts with '/' for proper URL construction
if not encoded_endpoint.startswith("/"):
encoded_endpoint = "/" + encoded_endpoint
# Construct the full target URL using httpx
base_url = httpx.URL(base_target_url)
updated_url = base_url.copy_with(path=encoded_endpoint)
# Add or update query parameters
anthropic_api_key = passthrough_endpoint_router.get_credentials(
custom_llm_provider="anthropic",
region_name=None,
)
## check for streaming
is_streaming_request = False
# anthropic is streaming when 'stream' = True is in the body
if request.method == "POST":
_request_body = await request.json()
if _request_body.get("stream"):
is_streaming_request = True
## CREATE PASS-THROUGH
endpoint_func = create_pass_through_route(
endpoint=endpoint,
target=str(updated_url),
custom_headers={"x-api-key": "{}".format(anthropic_api_key)},
_forward_headers=True,
) # dynamically construct pass-through endpoint based on incoming path
received_value = await endpoint_func(
request,
fastapi_response,
user_api_key_dict,
stream=is_streaming_request, # type: ignore
)
return received_value
@router.api_route(
"/bedrock/{endpoint:path}",
methods=["GET", "POST", "PUT", "DELETE", "PATCH"],
tags=["Bedrock Pass-through", "pass-through"],
)
async def bedrock_proxy_route(
endpoint: str,
request: Request,
fastapi_response: Response,
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
):
"""
[Docs](https://docs.litellm.ai/docs/pass_through/bedrock)
"""
create_request_copy(request)
try:
from botocore.auth import SigV4Auth
from botocore.awsrequest import AWSRequest
from botocore.credentials import Credentials
except ImportError:
raise ImportError("Missing boto3 to call bedrock. Run 'pip install boto3'.")
aws_region_name = litellm.utils.get_secret(secret_name="AWS_REGION_NAME")
if _is_bedrock_agent_runtime_route(endpoint=endpoint): # handle bedrock agents
base_target_url = (
f"https://bedrock-agent-runtime.{aws_region_name}.amazonaws.com"
)
else:
base_target_url = f"https://bedrock-runtime.{aws_region_name}.amazonaws.com"
encoded_endpoint = httpx.URL(endpoint).path
# Ensure endpoint starts with '/' for proper URL construction
if not encoded_endpoint.startswith("/"):
encoded_endpoint = "/" + encoded_endpoint
# Construct the full target URL using httpx
base_url = httpx.URL(base_target_url)
updated_url = base_url.copy_with(path=encoded_endpoint)
# Add or update query parameters
from litellm.llms.bedrock.chat import BedrockConverseLLM
credentials: Credentials = BedrockConverseLLM().get_credentials()
sigv4 = SigV4Auth(credentials, "bedrock", aws_region_name)
headers = {"Content-Type": "application/json"}
# Assuming the body contains JSON data, parse it
try:
data = await request.json()
except Exception as e:
raise HTTPException(status_code=400, detail={"error": e})
_request = AWSRequest(
method="POST", url=str(updated_url), data=json.dumps(data), headers=headers
)
sigv4.add_auth(_request)
prepped = _request.prepare()
## check for streaming
is_streaming_request = False
if "stream" in str(updated_url):
is_streaming_request = True
## CREATE PASS-THROUGH
endpoint_func = create_pass_through_route(
endpoint=endpoint,
target=str(prepped.url),
custom_headers=prepped.headers, # type: ignore
) # dynamically construct pass-through endpoint based on incoming path
received_value = await endpoint_func(
request,
fastapi_response,
user_api_key_dict,
stream=is_streaming_request, # type: ignore
custom_body=data, # type: ignore
query_params={}, # type: ignore
)
return received_value
def _is_bedrock_agent_runtime_route(endpoint: str) -> bool:
"""
Return True, if the endpoint should be routed to the `bedrock-agent-runtime` endpoint.
"""
for _route in BEDROCK_AGENT_RUNTIME_PASS_THROUGH_ROUTES:
if _route in endpoint:
return True
return False
@router.api_route(
"/assemblyai/{endpoint:path}",
methods=["GET", "POST", "PUT", "DELETE", "PATCH"],
tags=["AssemblyAI Pass-through", "pass-through"],
)
@router.api_route(
"/eu.assemblyai/{endpoint:path}",
methods=["GET", "POST", "PUT", "DELETE", "PATCH"],
tags=["AssemblyAI EU Pass-through", "pass-through"],
)
async def assemblyai_proxy_route(
endpoint: str,
request: Request,
fastapi_response: Response,
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
):
from litellm.proxy.pass_through_endpoints.llm_provider_handlers.assembly_passthrough_logging_handler import (
AssemblyAIPassthroughLoggingHandler,
)
"""
[Docs](https://api.assemblyai.com)
"""
# Set base URL based on the route
assembly_region = AssemblyAIPassthroughLoggingHandler._get_assembly_region_from_url(
url=str(request.url)
)
base_target_url = (
AssemblyAIPassthroughLoggingHandler._get_assembly_base_url_from_region(
region=assembly_region
)
)
encoded_endpoint = httpx.URL(endpoint).path
# Ensure endpoint starts with '/' for proper URL construction
if not encoded_endpoint.startswith("/"):
encoded_endpoint = "/" + encoded_endpoint
# Construct the full target URL using httpx
base_url = httpx.URL(base_target_url)
updated_url = base_url.copy_with(path=encoded_endpoint)
# Add or update query parameters
assemblyai_api_key = passthrough_endpoint_router.get_credentials(
custom_llm_provider="assemblyai",
region_name=assembly_region,
)
## check for streaming
is_streaming_request = False
# assemblyai is streaming when 'stream' = True is in the body
if request.method == "POST":
_request_body = await request.json()
if _request_body.get("stream"):
is_streaming_request = True
## CREATE PASS-THROUGH
endpoint_func = create_pass_through_route(
endpoint=endpoint,
target=str(updated_url),
custom_headers={"Authorization": "{}".format(assemblyai_api_key)},
) # dynamically construct pass-through endpoint based on incoming path
received_value = await endpoint_func(
request=request,
fastapi_response=fastapi_response,
user_api_key_dict=user_api_key_dict,
stream=is_streaming_request, # type: ignore
)
return received_value
@router.api_route(
"/azure/{endpoint:path}",
methods=["GET", "POST", "PUT", "DELETE", "PATCH"],
tags=["Azure Pass-through", "pass-through"],
)
async def azure_proxy_route(
endpoint: str,
request: Request,
fastapi_response: Response,
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
):
"""
Call any azure endpoint using the proxy.
Just use `{PROXY_BASE_URL}/azure/{endpoint:path}`
"""
base_target_url = get_secret_str(secret_name="AZURE_API_BASE")
if base_target_url is None:
raise Exception(
"Required 'AZURE_API_BASE' in environment to make pass-through calls to Azure."
)
# Add or update query parameters
azure_api_key = passthrough_endpoint_router.get_credentials(
custom_llm_provider=litellm.LlmProviders.AZURE.value,
region_name=None,
)
if azure_api_key is None:
raise Exception(
"Required 'AZURE_API_KEY' in environment to make pass-through calls to Azure."
)
return await BaseOpenAIPassThroughHandler._base_openai_pass_through_handler(
endpoint=endpoint,
request=request,
fastapi_response=fastapi_response,
user_api_key_dict=user_api_key_dict,
base_target_url=base_target_url,
api_key=azure_api_key,
custom_llm_provider=litellm.LlmProviders.AZURE,
)
from abc import ABC, abstractmethod
class BaseVertexAIPassThroughHandler(ABC):
@staticmethod
@abstractmethod
def get_default_base_target_url(vertex_location: Optional[str]) -> str:
pass
@staticmethod
@abstractmethod
def update_base_target_url_with_credential_location(
base_target_url: str, vertex_location: Optional[str]
) -> str:
pass
class VertexAIDiscoveryPassThroughHandler(BaseVertexAIPassThroughHandler):
@staticmethod
def get_default_base_target_url(vertex_location: Optional[str]) -> str:
return "https://discoveryengine.googleapis.com/"
@staticmethod
def update_base_target_url_with_credential_location(
base_target_url: str, vertex_location: Optional[str]
) -> str:
return base_target_url
class VertexAIPassThroughHandler(BaseVertexAIPassThroughHandler):
@staticmethod
def get_default_base_target_url(vertex_location: Optional[str]) -> str:
return f"https://{vertex_location}-aiplatform.googleapis.com/"
@staticmethod
def update_base_target_url_with_credential_location(
base_target_url: str, vertex_location: Optional[str]
) -> str:
return f"https://{vertex_location}-aiplatform.googleapis.com/"
def get_vertex_pass_through_handler(
call_type: Literal["discovery", "aiplatform"]
) -> BaseVertexAIPassThroughHandler:
if call_type == "discovery":
return VertexAIDiscoveryPassThroughHandler()
elif call_type == "aiplatform":
return VertexAIPassThroughHandler()
else:
raise ValueError(f"Invalid call type: {call_type}")
async def _base_vertex_proxy_route(
endpoint: str,
request: Request,
fastapi_response: Response,
get_vertex_pass_through_handler: BaseVertexAIPassThroughHandler,
user_api_key_dict: Optional[UserAPIKeyAuth] = None,
):
"""
Base function for Vertex AI passthrough routes.
Handles common logic for all Vertex AI services.
Default base_target_url is `https://{vertex_location}-aiplatform.googleapis.com/`
"""
from litellm.llms.vertex_ai.common_utils import (
construct_target_url,
get_vertex_location_from_url,
get_vertex_project_id_from_url,
)
encoded_endpoint = httpx.URL(endpoint).path
verbose_proxy_logger.debug("requested endpoint %s", endpoint)
headers: dict = {}
api_key_to_use = get_litellm_virtual_key(request=request)
user_api_key_dict = await user_api_key_auth(
request=request,
api_key=api_key_to_use,
)
if user_api_key_dict is None:
api_key_to_use = get_litellm_virtual_key(request=request)
user_api_key_dict = await user_api_key_auth(
request=request,
api_key=api_key_to_use,
)
vertex_project: Optional[str] = get_vertex_project_id_from_url(endpoint)
vertex_location: Optional[str] = get_vertex_location_from_url(endpoint)
vertex_credentials = passthrough_endpoint_router.get_vertex_credentials(
project_id=vertex_project,
location=vertex_location,
)
base_target_url = get_vertex_pass_through_handler.get_default_base_target_url(
vertex_location
)
headers_passed_through = False
# Use headers from the incoming request if no vertex credentials are found
if vertex_credentials is None or vertex_credentials.vertex_project is None:
headers = dict(request.headers) or {}
headers_passed_through = True
verbose_proxy_logger.debug(
"default_vertex_config not set, incoming request headers %s", headers
)
headers.pop("content-length", None)
headers.pop("host", None)
else:
vertex_project = vertex_credentials.vertex_project
vertex_location = vertex_credentials.vertex_location
vertex_credentials_str = vertex_credentials.vertex_credentials
_auth_header, vertex_project = await vertex_llm_base._ensure_access_token_async(
credentials=vertex_credentials_str,
project_id=vertex_project,
custom_llm_provider="vertex_ai_beta",
)
auth_header, _ = vertex_llm_base._get_token_and_url(
model="",
auth_header=_auth_header,
gemini_api_key=None,
vertex_credentials=vertex_credentials_str,
vertex_project=vertex_project,
vertex_location=vertex_location,
stream=False,
custom_llm_provider="vertex_ai_beta",
api_base="",
)
headers = {
"Authorization": f"Bearer {auth_header}",
}
base_target_url = get_vertex_pass_through_handler.update_base_target_url_with_credential_location(
base_target_url, vertex_location
)
if base_target_url is None:
base_target_url = f"https://{vertex_location}-aiplatform.googleapis.com/"
request_route = encoded_endpoint
verbose_proxy_logger.debug("request_route %s", request_route)
# Ensure endpoint starts with '/' for proper URL construction
if not encoded_endpoint.startswith("/"):
encoded_endpoint = "/" + encoded_endpoint
# Construct the full target URL using httpx
updated_url = construct_target_url(
base_url=base_target_url,
requested_route=encoded_endpoint,
vertex_location=vertex_location,
vertex_project=vertex_project,
)
verbose_proxy_logger.debug("updated url %s", updated_url)
## check for streaming
target = str(updated_url)
is_streaming_request = False
if "stream" in str(updated_url):
is_streaming_request = True
target += "?alt=sse"
## CREATE PASS-THROUGH
endpoint_func = create_pass_through_route(
endpoint=endpoint,
target=target,
custom_headers=headers,
) # dynamically construct pass-through endpoint based on incoming path
try:
received_value = await endpoint_func(
request,
fastapi_response,
user_api_key_dict,
stream=is_streaming_request, # type: ignore
)
except ProxyException as e:
if headers_passed_through:
e.message = f"No credentials found on proxy for project_name={vertex_project} + location={vertex_location}, check `/model/info` for allowed project + region combinations with `use_in_pass_through: true`. Headers were passed through directly but request failed with error: {e.message}"
raise e
return received_value
@router.api_route(
"/vertex_ai/discovery/{endpoint:path}",
methods=["GET", "POST", "PUT", "DELETE", "PATCH"],
tags=["Vertex AI Pass-through", "pass-through"],
)
async def vertex_discovery_proxy_route(
endpoint: str,
request: Request,
fastapi_response: Response,
):
"""
Call any vertex discovery endpoint using the proxy.
Just use `{PROXY_BASE_URL}/vertex_ai/discovery/{endpoint:path}`
Target url: `https://discoveryengine.googleapis.com`
"""
discovery_handler = get_vertex_pass_through_handler(call_type="discovery")
return await _base_vertex_proxy_route(
endpoint=endpoint,
request=request,
fastapi_response=fastapi_response,
get_vertex_pass_through_handler=discovery_handler,
)
@router.api_route(
"/vertex-ai/{endpoint:path}",
methods=["GET", "POST", "PUT", "DELETE", "PATCH"],
tags=["Vertex AI Pass-through", "pass-through"],
include_in_schema=False,
)
@router.api_route(
"/vertex_ai/{endpoint:path}",
methods=["GET", "POST", "PUT", "DELETE", "PATCH"],
tags=["Vertex AI Pass-through", "pass-through"],
)
async def vertex_proxy_route(
endpoint: str,
request: Request,
fastapi_response: Response,
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
):
"""
Call LiteLLM proxy via Vertex AI SDK.
[Docs](https://docs.litellm.ai/docs/pass_through/vertex_ai)
"""
ai_platform_handler = get_vertex_pass_through_handler(call_type="aiplatform")
return await _base_vertex_proxy_route(
endpoint=endpoint,
request=request,
fastapi_response=fastapi_response,
get_vertex_pass_through_handler=ai_platform_handler,
user_api_key_dict=user_api_key_dict,
)
@router.api_route(
"/openai/{endpoint:path}",
methods=["GET", "POST", "PUT", "DELETE", "PATCH"],
tags=["OpenAI Pass-through", "pass-through"],
)
async def openai_proxy_route(
endpoint: str,
request: Request,
fastapi_response: Response,
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
):
"""
Simple pass-through for OpenAI. Use this if you want to directly send a request to OpenAI.
"""
base_target_url = "https://api.openai.com/"
# Add or update query parameters
openai_api_key = passthrough_endpoint_router.get_credentials(
custom_llm_provider=litellm.LlmProviders.OPENAI.value,
region_name=None,
)
if openai_api_key is None:
raise Exception(
"Required 'OPENAI_API_KEY' in environment to make pass-through calls to OpenAI."
)
return await BaseOpenAIPassThroughHandler._base_openai_pass_through_handler(
endpoint=endpoint,
request=request,
fastapi_response=fastapi_response,
user_api_key_dict=user_api_key_dict,
base_target_url=base_target_url,
api_key=openai_api_key,
custom_llm_provider=litellm.LlmProviders.OPENAI,
)
class BaseOpenAIPassThroughHandler:
@staticmethod
async def _base_openai_pass_through_handler(
endpoint: str,
request: Request,
fastapi_response: Response,
user_api_key_dict: UserAPIKeyAuth,
base_target_url: str,
api_key: str,
custom_llm_provider: litellm.LlmProviders,
):
encoded_endpoint = httpx.URL(endpoint).path
# Ensure endpoint starts with '/' for proper URL construction
if not encoded_endpoint.startswith("/"):
encoded_endpoint = "/" + encoded_endpoint
# Construct the full target URL by properly joining the base URL and endpoint path
base_url = httpx.URL(base_target_url)
updated_url = BaseOpenAIPassThroughHandler._join_url_paths(
base_url=base_url,
path=encoded_endpoint,
custom_llm_provider=custom_llm_provider,
)
## check for streaming
is_streaming_request = False
if "stream" in str(updated_url):
is_streaming_request = True
## CREATE PASS-THROUGH
endpoint_func = create_pass_through_route(
endpoint=endpoint,
target=str(updated_url),
custom_headers=BaseOpenAIPassThroughHandler._assemble_headers(
api_key=api_key, request=request
),
) # dynamically construct pass-through endpoint based on incoming path
received_value = await endpoint_func(
request,
fastapi_response,
user_api_key_dict,
stream=is_streaming_request, # type: ignore
query_params=dict(request.query_params), # type: ignore
)
return received_value
@staticmethod
def _append_openai_beta_header(headers: dict, request: Request) -> dict:
"""
Appends the OpenAI-Beta header to the headers if the request is an OpenAI Assistants API request
"""
if (
RouteChecks._is_assistants_api_request(request) is True
and "OpenAI-Beta" not in headers
):
headers["OpenAI-Beta"] = "assistants=v2"
return headers
@staticmethod
def _assemble_headers(api_key: str, request: Request) -> dict:
base_headers = {
"authorization": "Bearer {}".format(api_key),
"api-key": "{}".format(api_key),
}
return BaseOpenAIPassThroughHandler._append_openai_beta_header(
headers=base_headers,
request=request,
)
@staticmethod
def _join_url_paths(
base_url: httpx.URL, path: str, custom_llm_provider: litellm.LlmProviders
) -> str:
"""
Properly joins a base URL with a path, preserving any existing path in the base URL.
"""
# Join paths correctly by removing trailing/leading slashes as needed
if not base_url.path or base_url.path == "/":
# If base URL has no path, just use the new path
joined_path_str = str(base_url.copy_with(path=path))
else:
# Otherwise, combine the paths
base_path = base_url.path.rstrip("/")
clean_path = path.lstrip("/")
full_path = f"{base_path}/{clean_path}"
joined_path_str = str(base_url.copy_with(path=full_path))
# Apply OpenAI-specific path handling for both branches
if (
custom_llm_provider == litellm.LlmProviders.OPENAI
and "/v1/" not in joined_path_str
):
# Insert v1 after api.openai.com for OpenAI requests
joined_path_str = joined_path_str.replace(
"api.openai.com/", "api.openai.com/v1/"
)
return joined_path_str

View File

@@ -0,0 +1,221 @@
import json
from datetime import datetime
from typing import TYPE_CHECKING, Any, List, Optional, Union
import httpx
import litellm
from litellm._logging import verbose_proxy_logger
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
from litellm.llms.anthropic.chat.handler import (
ModelResponseIterator as AnthropicModelResponseIterator,
)
from litellm.llms.anthropic.chat.transformation import AnthropicConfig
from litellm.proxy._types import PassThroughEndpointLoggingTypedDict
from litellm.proxy.auth.auth_utils import get_end_user_id_from_request_body
from litellm.types.passthrough_endpoints.pass_through_endpoints import (
PassthroughStandardLoggingPayload,
)
from litellm.types.utils import ModelResponse, TextCompletionResponse
if TYPE_CHECKING:
from ..success_handler import PassThroughEndpointLogging
from ..types import EndpointType
else:
PassThroughEndpointLogging = Any
EndpointType = Any
class AnthropicPassthroughLoggingHandler:
@staticmethod
def anthropic_passthrough_handler(
httpx_response: httpx.Response,
response_body: dict,
logging_obj: LiteLLMLoggingObj,
url_route: str,
result: str,
start_time: datetime,
end_time: datetime,
cache_hit: bool,
**kwargs,
) -> PassThroughEndpointLoggingTypedDict:
"""
Transforms Anthropic response to OpenAI response, generates a standard logging object so downstream logging can be handled
"""
model = response_body.get("model", "")
litellm_model_response: ModelResponse = AnthropicConfig().transform_response(
raw_response=httpx_response,
model_response=litellm.ModelResponse(),
model=model,
messages=[],
logging_obj=logging_obj,
optional_params={},
api_key="",
request_data={},
encoding=litellm.encoding,
json_mode=False,
litellm_params={},
)
kwargs = AnthropicPassthroughLoggingHandler._create_anthropic_response_logging_payload(
litellm_model_response=litellm_model_response,
model=model,
kwargs=kwargs,
start_time=start_time,
end_time=end_time,
logging_obj=logging_obj,
)
return {
"result": litellm_model_response,
"kwargs": kwargs,
}
@staticmethod
def _get_user_from_metadata(
passthrough_logging_payload: PassthroughStandardLoggingPayload,
) -> Optional[str]:
request_body = passthrough_logging_payload.get("request_body")
if request_body:
return get_end_user_id_from_request_body(request_body)
return None
@staticmethod
def _create_anthropic_response_logging_payload(
litellm_model_response: Union[ModelResponse, TextCompletionResponse],
model: str,
kwargs: dict,
start_time: datetime,
end_time: datetime,
logging_obj: LiteLLMLoggingObj,
):
"""
Create the standard logging object for Anthropic passthrough
handles streaming and non-streaming responses
"""
try:
response_cost = litellm.completion_cost(
completion_response=litellm_model_response,
model=model,
)
kwargs["response_cost"] = response_cost
kwargs["model"] = model
passthrough_logging_payload: Optional[PassthroughStandardLoggingPayload] = ( # type: ignore
kwargs.get("passthrough_logging_payload")
)
if passthrough_logging_payload:
user = AnthropicPassthroughLoggingHandler._get_user_from_metadata(
passthrough_logging_payload=passthrough_logging_payload,
)
if user:
kwargs.setdefault("litellm_params", {})
kwargs["litellm_params"].update(
{"proxy_server_request": {"body": {"user": user}}}
)
# pretty print standard logging object
verbose_proxy_logger.debug(
"kwargs= %s",
json.dumps(kwargs, indent=4, default=str),
)
# set litellm_call_id to logging response object
litellm_model_response.id = logging_obj.litellm_call_id
litellm_model_response.model = model
logging_obj.model_call_details["model"] = model
logging_obj.model_call_details["custom_llm_provider"] = (
litellm.LlmProviders.ANTHROPIC.value
)
return kwargs
except Exception as e:
verbose_proxy_logger.exception(
"Error creating Anthropic response logging payload: %s", e
)
return kwargs
@staticmethod
def _handle_logging_anthropic_collected_chunks(
litellm_logging_obj: LiteLLMLoggingObj,
passthrough_success_handler_obj: PassThroughEndpointLogging,
url_route: str,
request_body: dict,
endpoint_type: EndpointType,
start_time: datetime,
all_chunks: List[str],
end_time: datetime,
) -> PassThroughEndpointLoggingTypedDict:
"""
Takes raw chunks from Anthropic passthrough endpoint and logs them in litellm callbacks
- Builds complete response from chunks
- Creates standard logging object
- Logs in litellm callbacks
"""
model = request_body.get("model", "")
complete_streaming_response = (
AnthropicPassthroughLoggingHandler._build_complete_streaming_response(
all_chunks=all_chunks,
litellm_logging_obj=litellm_logging_obj,
model=model,
)
)
if complete_streaming_response is None:
verbose_proxy_logger.error(
"Unable to build complete streaming response for Anthropic passthrough endpoint, not logging..."
)
return {
"result": None,
"kwargs": {},
}
kwargs = AnthropicPassthroughLoggingHandler._create_anthropic_response_logging_payload(
litellm_model_response=complete_streaming_response,
model=model,
kwargs={},
start_time=start_time,
end_time=end_time,
logging_obj=litellm_logging_obj,
)
return {
"result": complete_streaming_response,
"kwargs": kwargs,
}
@staticmethod
def _build_complete_streaming_response(
all_chunks: List[str],
litellm_logging_obj: LiteLLMLoggingObj,
model: str,
) -> Optional[Union[ModelResponse, TextCompletionResponse]]:
"""
Builds complete response from raw Anthropic chunks
- Converts str chunks to generic chunks
- Converts generic chunks to litellm chunks (OpenAI format)
- Builds complete response from litellm chunks
"""
anthropic_model_response_iterator = AnthropicModelResponseIterator(
streaming_response=None,
sync_stream=False,
)
all_openai_chunks = []
for _chunk_str in all_chunks:
try:
transformed_openai_chunk = anthropic_model_response_iterator.convert_str_chunk_to_generic_chunk(
chunk=_chunk_str
)
if transformed_openai_chunk is not None:
all_openai_chunks.append(transformed_openai_chunk)
verbose_proxy_logger.debug(
"all openai chunks= %s",
json.dumps(all_openai_chunks, indent=4, default=str),
)
except (StopIteration, StopAsyncIteration):
break
complete_streaming_response = litellm.stream_chunk_builder(
chunks=all_openai_chunks
)
return complete_streaming_response

View File

@@ -0,0 +1,332 @@
import asyncio
import json
import time
from datetime import datetime
from typing import Literal, Optional, TypedDict
from urllib.parse import urlparse
import httpx
import litellm
from litellm._logging import verbose_proxy_logger
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
from litellm.litellm_core_utils.litellm_logging import (
get_standard_logging_object_payload,
)
from litellm.litellm_core_utils.thread_pool_executor import executor
from litellm.types.passthrough_endpoints.assembly_ai import (
ASSEMBLY_AI_MAX_POLLING_ATTEMPTS,
ASSEMBLY_AI_POLLING_INTERVAL,
)
from litellm.types.passthrough_endpoints.pass_through_endpoints import (
PassthroughStandardLoggingPayload,
)
class AssemblyAITranscriptResponse(TypedDict, total=False):
id: str
speech_model: str
acoustic_model: str
language_code: str
status: str
audio_duration: float
class AssemblyAIPassthroughLoggingHandler:
def __init__(self):
self.assembly_ai_base_url = "https://api.assemblyai.com"
self.assembly_ai_eu_base_url = "https://eu.assemblyai.com"
"""
The base URL for the AssemblyAI API
"""
self.polling_interval: float = ASSEMBLY_AI_POLLING_INTERVAL
"""
The polling interval for the AssemblyAI API.
litellm needs to poll the GET /transcript/{transcript_id} endpoint to get the status of the transcript.
"""
self.max_polling_attempts = ASSEMBLY_AI_MAX_POLLING_ATTEMPTS
"""
The maximum number of polling attempts for the AssemblyAI API.
"""
def assemblyai_passthrough_logging_handler(
self,
httpx_response: httpx.Response,
response_body: dict,
logging_obj: LiteLLMLoggingObj,
url_route: str,
result: str,
start_time: datetime,
end_time: datetime,
cache_hit: bool,
**kwargs,
):
"""
Since cost tracking requires polling the AssemblyAI API, we need to handle this in a separate thread. Hence the executor.submit.
"""
executor.submit(
self._handle_assemblyai_passthrough_logging,
httpx_response,
response_body,
logging_obj,
url_route,
result,
start_time,
end_time,
cache_hit,
**kwargs,
)
def _handle_assemblyai_passthrough_logging(
self,
httpx_response: httpx.Response,
response_body: dict,
logging_obj: LiteLLMLoggingObj,
url_route: str,
result: str,
start_time: datetime,
end_time: datetime,
cache_hit: bool,
**kwargs,
):
"""
Handles logging for AssemblyAI successful passthrough requests
"""
from ..pass_through_endpoints import pass_through_endpoint_logging
model = response_body.get("speech_model", "")
verbose_proxy_logger.debug(
"response body %s", json.dumps(response_body, indent=4)
)
kwargs["model"] = model
kwargs["custom_llm_provider"] = "assemblyai"
response_cost: Optional[float] = None
transcript_id = response_body.get("id")
if transcript_id is None:
raise ValueError(
"Transcript ID is required to log the cost of the transcription"
)
transcript_response = self._poll_assembly_for_transcript_response(
transcript_id=transcript_id, url_route=url_route
)
verbose_proxy_logger.debug(
"finished polling assembly for transcript response- got transcript response %s",
json.dumps(transcript_response, indent=4),
)
if transcript_response:
cost = self.get_cost_for_assembly_transcript(
speech_model=model,
transcript_response=transcript_response,
)
response_cost = cost
# Make standard logging object for Vertex AI
standard_logging_object = get_standard_logging_object_payload(
kwargs=kwargs,
init_response_obj=transcript_response,
start_time=start_time,
end_time=end_time,
logging_obj=logging_obj,
status="success",
)
passthrough_logging_payload: Optional[PassthroughStandardLoggingPayload] = ( # type: ignore
kwargs.get("passthrough_logging_payload")
)
verbose_proxy_logger.debug(
"standard_passthrough_logging_object %s",
json.dumps(passthrough_logging_payload, indent=4),
)
# pretty print standard logging object
verbose_proxy_logger.debug(
"standard_logging_object= %s", json.dumps(standard_logging_object, indent=4)
)
logging_obj.model_call_details["model"] = model
logging_obj.model_call_details["custom_llm_provider"] = "assemblyai"
logging_obj.model_call_details["response_cost"] = response_cost
asyncio.run(
pass_through_endpoint_logging._handle_logging(
logging_obj=logging_obj,
standard_logging_response_object=self._get_response_to_log(
transcript_response
),
result=result,
start_time=start_time,
end_time=end_time,
cache_hit=cache_hit,
**kwargs,
)
)
pass
def _get_response_to_log(
self, transcript_response: Optional[AssemblyAITranscriptResponse]
) -> dict:
if transcript_response is None:
return {}
return dict(transcript_response)
def _get_assembly_transcript(
self,
transcript_id: str,
request_region: Optional[Literal["eu"]] = None,
) -> Optional[dict]:
"""
Get the transcript details from AssemblyAI API
Args:
response_body (dict): Response containing the transcript ID
Returns:
Optional[dict]: Transcript details if successful, None otherwise
"""
from litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints import (
passthrough_endpoint_router,
)
_base_url = (
self.assembly_ai_eu_base_url
if request_region == "eu"
else self.assembly_ai_base_url
)
_api_key = passthrough_endpoint_router.get_credentials(
custom_llm_provider="assemblyai",
region_name=request_region,
)
if _api_key is None:
raise ValueError("AssemblyAI API key not found")
try:
url = f"{_base_url}/v2/transcript/{transcript_id}"
headers = {
"Authorization": f"Bearer {_api_key}",
"Content-Type": "application/json",
}
response = httpx.get(url, headers=headers)
response.raise_for_status()
return response.json()
except Exception as e:
verbose_proxy_logger.exception(
f"[Non blocking logging error] Error getting AssemblyAI transcript: {str(e)}"
)
return None
def _poll_assembly_for_transcript_response(
self,
transcript_id: str,
url_route: Optional[str] = None,
) -> Optional[AssemblyAITranscriptResponse]:
"""
Poll the status of the transcript until it is completed or timeout (30 minutes)
"""
for _ in range(
self.max_polling_attempts
): # 180 attempts * 10s = 30 minutes max
transcript = self._get_assembly_transcript(
request_region=AssemblyAIPassthroughLoggingHandler._get_assembly_region_from_url(
url=url_route
),
transcript_id=transcript_id,
)
if transcript is None:
return None
if (
transcript.get("status") == "completed"
or transcript.get("status") == "error"
):
return AssemblyAITranscriptResponse(**transcript)
time.sleep(self.polling_interval)
return None
@staticmethod
def get_cost_for_assembly_transcript(
transcript_response: AssemblyAITranscriptResponse,
speech_model: str,
) -> Optional[float]:
"""
Get the cost for the assembly transcript
"""
_audio_duration = transcript_response.get("audio_duration")
if _audio_duration is None:
return None
_cost_per_second = (
AssemblyAIPassthroughLoggingHandler.get_cost_per_second_for_assembly_model(
speech_model=speech_model
)
)
if _cost_per_second is None:
return None
return _audio_duration * _cost_per_second
@staticmethod
def get_cost_per_second_for_assembly_model(speech_model: str) -> Optional[float]:
"""
Get the cost per second for the assembly model.
Falls back to assemblyai/nano if the specific speech model info cannot be found.
"""
try:
# First try with the provided speech model
try:
model_info = litellm.get_model_info(
model=speech_model,
custom_llm_provider="assemblyai",
)
if model_info and model_info.get("input_cost_per_second") is not None:
return model_info.get("input_cost_per_second")
except Exception:
pass # Continue to fallback if model not found
# Fallback to assemblyai/nano if speech model info not found
try:
model_info = litellm.get_model_info(
model="assemblyai/nano",
custom_llm_provider="assemblyai",
)
if model_info and model_info.get("input_cost_per_second") is not None:
return model_info.get("input_cost_per_second")
except Exception:
pass
return None
except Exception as e:
verbose_proxy_logger.exception(
f"[Non blocking logging error] Error getting AssemblyAI model info: {str(e)}"
)
return None
@staticmethod
def _should_log_request(request_method: str) -> bool:
"""
only POST transcription jobs are logged. litellm will POLL assembly to wait for the transcription to complete to log the complete response / cost
"""
return request_method == "POST"
@staticmethod
def _get_assembly_region_from_url(url: Optional[str]) -> Optional[Literal["eu"]]:
"""
Get the region from the URL
"""
if url is None:
return None
if urlparse(url).hostname == "eu.assemblyai.com":
return "eu"
return None
@staticmethod
def _get_assembly_base_url_from_region(region: Optional[Literal["eu"]]) -> str:
"""
Get the base URL for the AssemblyAI API
if region == "eu", return "https://api.eu.assemblyai.com"
else return "https://api.assemblyai.com"
"""
if region == "eu":
return "https://api.eu.assemblyai.com"
return "https://api.assemblyai.com"

View File

@@ -0,0 +1,221 @@
import json
from datetime import datetime
from typing import TYPE_CHECKING, Any, List, Optional, Union
import httpx
import litellm
from litellm._logging import verbose_proxy_logger
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
from litellm.litellm_core_utils.litellm_logging import (
get_standard_logging_object_payload,
)
from litellm.llms.base_llm.chat.transformation import BaseConfig
from litellm.proxy._types import PassThroughEndpointLoggingTypedDict
from litellm.proxy.auth.auth_utils import get_end_user_id_from_request_body
from litellm.types.passthrough_endpoints.pass_through_endpoints import (
PassthroughStandardLoggingPayload,
)
from litellm.types.utils import LlmProviders, ModelResponse, TextCompletionResponse
if TYPE_CHECKING:
from ..success_handler import PassThroughEndpointLogging
from ..types import EndpointType
else:
PassThroughEndpointLogging = Any
EndpointType = Any
from abc import ABC, abstractmethod
class BasePassthroughLoggingHandler(ABC):
@property
@abstractmethod
def llm_provider_name(self) -> LlmProviders:
pass
@abstractmethod
def get_provider_config(self, model: str) -> BaseConfig:
pass
def passthrough_chat_handler(
self,
httpx_response: httpx.Response,
response_body: dict,
logging_obj: LiteLLMLoggingObj,
url_route: str,
result: str,
start_time: datetime,
end_time: datetime,
cache_hit: bool,
request_body: dict,
**kwargs,
) -> PassThroughEndpointLoggingTypedDict:
"""
Transforms LLM response to OpenAI response, generates a standard logging object so downstream logging can be handled
"""
model = request_body.get("model", response_body.get("model", ""))
provider_config = self.get_provider_config(model=model)
litellm_model_response: ModelResponse = provider_config.transform_response(
raw_response=httpx_response,
model_response=litellm.ModelResponse(),
model=model,
messages=[],
logging_obj=logging_obj,
optional_params={},
api_key="",
request_data={},
encoding=litellm.encoding,
json_mode=False,
litellm_params={},
)
kwargs = self._create_response_logging_payload(
litellm_model_response=litellm_model_response,
model=model,
kwargs=kwargs,
start_time=start_time,
end_time=end_time,
logging_obj=logging_obj,
)
return {
"result": litellm_model_response,
"kwargs": kwargs,
}
def _get_user_from_metadata(
self,
passthrough_logging_payload: PassthroughStandardLoggingPayload,
) -> Optional[str]:
request_body = passthrough_logging_payload.get("request_body")
if request_body:
return get_end_user_id_from_request_body(request_body)
return None
def _create_response_logging_payload(
self,
litellm_model_response: Union[ModelResponse, TextCompletionResponse],
model: str,
kwargs: dict,
start_time: datetime,
end_time: datetime,
logging_obj: LiteLLMLoggingObj,
) -> dict:
"""
Create the standard logging object for Generic LLM passthrough
handles streaming and non-streaming responses
"""
try:
response_cost = litellm.completion_cost(
completion_response=litellm_model_response,
model=model,
)
kwargs["response_cost"] = response_cost
kwargs["model"] = model
passthrough_logging_payload: Optional[PassthroughStandardLoggingPayload] = ( # type: ignore
kwargs.get("passthrough_logging_payload")
)
if passthrough_logging_payload:
user = self._get_user_from_metadata(
passthrough_logging_payload=passthrough_logging_payload,
)
if user:
kwargs.setdefault("litellm_params", {})
kwargs["litellm_params"].update(
{"proxy_server_request": {"body": {"user": user}}}
)
# Make standard logging object for Anthropic
standard_logging_object = get_standard_logging_object_payload(
kwargs=kwargs,
init_response_obj=litellm_model_response,
start_time=start_time,
end_time=end_time,
logging_obj=logging_obj,
status="success",
)
# pretty print standard logging object
verbose_proxy_logger.debug(
"standard_logging_object= %s",
json.dumps(standard_logging_object, indent=4),
)
kwargs["standard_logging_object"] = standard_logging_object
# set litellm_call_id to logging response object
litellm_model_response.id = logging_obj.litellm_call_id
litellm_model_response.model = model
logging_obj.model_call_details["model"] = model
return kwargs
except Exception as e:
verbose_proxy_logger.exception(
"Error creating LLM passthrough response logging payload: %s", e
)
return kwargs
@abstractmethod
def _build_complete_streaming_response(
self,
all_chunks: List[str],
litellm_logging_obj: LiteLLMLoggingObj,
model: str,
) -> Optional[Union[ModelResponse, TextCompletionResponse]]:
"""
Builds complete response from raw chunks
- Converts str chunks to generic chunks
- Converts generic chunks to litellm chunks (OpenAI format)
- Builds complete response from litellm chunks
"""
pass
def _handle_logging_llm_collected_chunks(
self,
litellm_logging_obj: LiteLLMLoggingObj,
passthrough_success_handler_obj: PassThroughEndpointLogging,
url_route: str,
request_body: dict,
endpoint_type: EndpointType,
start_time: datetime,
all_chunks: List[str],
end_time: datetime,
) -> PassThroughEndpointLoggingTypedDict:
"""
Takes raw chunks from Anthropic passthrough endpoint and logs them in litellm callbacks
- Builds complete response from chunks
- Creates standard logging object
- Logs in litellm callbacks
"""
model = request_body.get("model", "")
complete_streaming_response = self._build_complete_streaming_response(
all_chunks=all_chunks,
litellm_logging_obj=litellm_logging_obj,
model=model,
)
if complete_streaming_response is None:
verbose_proxy_logger.error(
"Unable to build complete streaming response for Anthropic passthrough endpoint, not logging..."
)
return {
"result": None,
"kwargs": {},
}
kwargs = self._create_response_logging_payload(
litellm_model_response=complete_streaming_response,
model=model,
kwargs={},
start_time=start_time,
end_time=end_time,
logging_obj=litellm_logging_obj,
)
return {
"result": complete_streaming_response,
"kwargs": kwargs,
}

View File

@@ -0,0 +1,56 @@
from typing import List, Optional, Union
from litellm import stream_chunk_builder
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
from litellm.litellm_core_utils.streaming_handler import CustomStreamWrapper
from litellm.llms.base_llm.chat.transformation import BaseConfig
from litellm.llms.cohere.chat.v2_transformation import CohereV2ChatConfig
from litellm.llms.cohere.common_utils import (
ModelResponseIterator as CohereModelResponseIterator,
)
from litellm.types.utils import LlmProviders, ModelResponse, TextCompletionResponse
from .base_passthrough_logging_handler import BasePassthroughLoggingHandler
class CoherePassthroughLoggingHandler(BasePassthroughLoggingHandler):
@property
def llm_provider_name(self) -> LlmProviders:
return LlmProviders.COHERE
def get_provider_config(self, model: str) -> BaseConfig:
return CohereV2ChatConfig()
def _build_complete_streaming_response(
self,
all_chunks: List[str],
litellm_logging_obj: LiteLLMLoggingObj,
model: str,
) -> Optional[Union[ModelResponse, TextCompletionResponse]]:
cohere_model_response_iterator = CohereModelResponseIterator(
streaming_response=None,
sync_stream=False,
)
litellm_custom_stream_wrapper = CustomStreamWrapper(
completion_stream=cohere_model_response_iterator,
model=model,
logging_obj=litellm_logging_obj,
custom_llm_provider="cohere",
)
all_openai_chunks = []
for _chunk_str in all_chunks:
try:
generic_chunk = (
cohere_model_response_iterator.convert_str_chunk_to_generic_chunk(
chunk=_chunk_str
)
)
litellm_chunk = litellm_custom_stream_wrapper.chunk_creator(
chunk=generic_chunk
)
if litellm_chunk is not None:
all_openai_chunks.append(litellm_chunk)
except (StopIteration, StopAsyncIteration):
break
complete_streaming_response = stream_chunk_builder(chunks=all_openai_chunks)
return complete_streaming_response

View File

@@ -0,0 +1,261 @@
import json
import re
from datetime import datetime
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
from urllib.parse import urlparse
import httpx
import litellm
from litellm._logging import verbose_proxy_logger
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
from litellm.llms.vertex_ai.gemini.vertex_and_google_ai_studio_gemini import (
ModelResponseIterator as VertexModelResponseIterator,
)
from litellm.proxy._types import PassThroughEndpointLoggingTypedDict
from litellm.types.utils import (
EmbeddingResponse,
ImageResponse,
ModelResponse,
TextCompletionResponse,
)
if TYPE_CHECKING:
from ..success_handler import PassThroughEndpointLogging
from ..types import EndpointType
else:
PassThroughEndpointLogging = Any
EndpointType = Any
class VertexPassthroughLoggingHandler:
@staticmethod
def vertex_passthrough_handler(
httpx_response: httpx.Response,
logging_obj: LiteLLMLoggingObj,
url_route: str,
result: str,
start_time: datetime,
end_time: datetime,
cache_hit: bool,
**kwargs,
) -> PassThroughEndpointLoggingTypedDict:
if "generateContent" in url_route:
model = VertexPassthroughLoggingHandler.extract_model_from_url(url_route)
instance_of_vertex_llm = litellm.VertexGeminiConfig()
litellm_model_response: ModelResponse = (
instance_of_vertex_llm.transform_response(
model=model,
messages=[
{"role": "user", "content": "no-message-pass-through-endpoint"}
],
raw_response=httpx_response,
model_response=litellm.ModelResponse(),
logging_obj=logging_obj,
optional_params={},
litellm_params={},
api_key="",
request_data={},
encoding=litellm.encoding,
)
)
kwargs = VertexPassthroughLoggingHandler._create_vertex_response_logging_payload_for_generate_content(
litellm_model_response=litellm_model_response,
model=model,
kwargs=kwargs,
start_time=start_time,
end_time=end_time,
logging_obj=logging_obj,
custom_llm_provider=VertexPassthroughLoggingHandler._get_custom_llm_provider_from_url(
url_route
),
)
return {
"result": litellm_model_response,
"kwargs": kwargs,
}
elif "predict" in url_route:
from litellm.llms.vertex_ai.image_generation.image_generation_handler import (
VertexImageGeneration,
)
from litellm.types.utils import PassthroughCallTypes
vertex_image_generation_class = VertexImageGeneration()
model = VertexPassthroughLoggingHandler.extract_model_from_url(url_route)
_json_response = httpx_response.json()
litellm_prediction_response: Union[
ModelResponse, EmbeddingResponse, ImageResponse
] = ModelResponse()
if vertex_image_generation_class.is_image_generation_response(
_json_response
):
litellm_prediction_response = (
vertex_image_generation_class.process_image_generation_response(
_json_response,
model_response=litellm.ImageResponse(),
model=model,
)
)
logging_obj.call_type = (
PassthroughCallTypes.passthrough_image_generation.value
)
else:
litellm_prediction_response = litellm.vertexAITextEmbeddingConfig.transform_vertex_response_to_openai(
response=_json_response,
model=model,
model_response=litellm.EmbeddingResponse(),
)
if isinstance(litellm_prediction_response, litellm.EmbeddingResponse):
litellm_prediction_response.model = model
logging_obj.model = model
logging_obj.model_call_details["model"] = logging_obj.model
return {
"result": litellm_prediction_response,
"kwargs": kwargs,
}
else:
return {
"result": None,
"kwargs": kwargs,
}
@staticmethod
def _handle_logging_vertex_collected_chunks(
litellm_logging_obj: LiteLLMLoggingObj,
passthrough_success_handler_obj: PassThroughEndpointLogging,
url_route: str,
request_body: dict,
endpoint_type: EndpointType,
start_time: datetime,
all_chunks: List[str],
end_time: datetime,
) -> PassThroughEndpointLoggingTypedDict:
"""
Takes raw chunks from Vertex passthrough endpoint and logs them in litellm callbacks
- Builds complete response from chunks
- Creates standard logging object
- Logs in litellm callbacks
"""
kwargs: Dict[str, Any] = {}
model = VertexPassthroughLoggingHandler.extract_model_from_url(url_route)
complete_streaming_response = (
VertexPassthroughLoggingHandler._build_complete_streaming_response(
all_chunks=all_chunks,
litellm_logging_obj=litellm_logging_obj,
model=model,
)
)
if complete_streaming_response is None:
verbose_proxy_logger.error(
"Unable to build complete streaming response for Vertex passthrough endpoint, not logging..."
)
return {
"result": None,
"kwargs": kwargs,
}
kwargs = VertexPassthroughLoggingHandler._create_vertex_response_logging_payload_for_generate_content(
litellm_model_response=complete_streaming_response,
model=model,
kwargs=kwargs,
start_time=start_time,
end_time=end_time,
logging_obj=litellm_logging_obj,
custom_llm_provider=VertexPassthroughLoggingHandler._get_custom_llm_provider_from_url(
url_route
),
)
return {
"result": complete_streaming_response,
"kwargs": kwargs,
}
@staticmethod
def _build_complete_streaming_response(
all_chunks: List[str],
litellm_logging_obj: LiteLLMLoggingObj,
model: str,
) -> Optional[Union[ModelResponse, TextCompletionResponse]]:
vertex_iterator = VertexModelResponseIterator(
streaming_response=None,
sync_stream=False,
)
litellm_custom_stream_wrapper = litellm.CustomStreamWrapper(
completion_stream=vertex_iterator,
model=model,
logging_obj=litellm_logging_obj,
custom_llm_provider="vertex_ai",
)
all_openai_chunks = []
for chunk in all_chunks:
generic_chunk = vertex_iterator._common_chunk_parsing_logic(chunk)
litellm_chunk = litellm_custom_stream_wrapper.chunk_creator(
chunk=generic_chunk
)
if litellm_chunk is not None:
all_openai_chunks.append(litellm_chunk)
complete_streaming_response = litellm.stream_chunk_builder(
chunks=all_openai_chunks
)
return complete_streaming_response
@staticmethod
def extract_model_from_url(url: str) -> str:
pattern = r"/models/([^:]+)"
match = re.search(pattern, url)
if match:
return match.group(1)
return "unknown"
@staticmethod
def _get_custom_llm_provider_from_url(url: str) -> str:
parsed_url = urlparse(url)
if parsed_url.hostname and parsed_url.hostname.endswith(
"generativelanguage.googleapis.com"
):
return litellm.LlmProviders.GEMINI.value
return litellm.LlmProviders.VERTEX_AI.value
@staticmethod
def _create_vertex_response_logging_payload_for_generate_content(
litellm_model_response: Union[ModelResponse, TextCompletionResponse],
model: str,
kwargs: dict,
start_time: datetime,
end_time: datetime,
logging_obj: LiteLLMLoggingObj,
custom_llm_provider: str,
):
"""
Create the standard logging object for Vertex passthrough generateContent (streaming and non-streaming)
"""
response_cost = litellm.completion_cost(
completion_response=litellm_model_response,
model=model,
)
kwargs["response_cost"] = response_cost
kwargs["model"] = model
# pretty print standard logging object
verbose_proxy_logger.debug("kwargs= %s", json.dumps(kwargs, indent=4))
# set litellm_call_id to logging response object
litellm_model_response.id = logging_obj.litellm_call_id
logging_obj.model = litellm_model_response.model or model
logging_obj.model_call_details["model"] = logging_obj.model
logging_obj.model_call_details["custom_llm_provider"] = custom_llm_provider
return kwargs

View File

@@ -0,0 +1,193 @@
from typing import Dict, Optional
from litellm._logging import verbose_router_logger
from litellm.secret_managers.main import get_secret_str
from litellm.types.llms.vertex_ai import VERTEX_CREDENTIALS_TYPES
from litellm.types.passthrough_endpoints.vertex_ai import VertexPassThroughCredentials
class PassthroughEndpointRouter:
"""
Use this class to Set/Get credentials for pass-through endpoints
"""
def __init__(self):
self.credentials: Dict[str, str] = {}
self.deployment_key_to_vertex_credentials: Dict[
str, VertexPassThroughCredentials
] = {}
self.default_vertex_config: Optional[VertexPassThroughCredentials] = None
def set_pass_through_credentials(
self,
custom_llm_provider: str,
api_base: Optional[str],
api_key: Optional[str],
):
"""
Set credentials for a pass-through endpoint. Used when a user adds a pass-through LLM endpoint on the UI.
Args:
custom_llm_provider: The provider of the pass-through endpoint
api_base: The base URL of the pass-through endpoint
api_key: The API key for the pass-through endpoint
"""
credential_name = self._get_credential_name_for_provider(
custom_llm_provider=custom_llm_provider,
region_name=self._get_region_name_from_api_base(
api_base=api_base, custom_llm_provider=custom_llm_provider
),
)
if api_key is None:
raise ValueError("api_key is required for setting pass-through credentials")
self.credentials[credential_name] = api_key
def get_credentials(
self,
custom_llm_provider: str,
region_name: Optional[str],
) -> Optional[str]:
credential_name = self._get_credential_name_for_provider(
custom_llm_provider=custom_llm_provider,
region_name=region_name,
)
verbose_router_logger.debug(
f"Pass-through llm endpoints router, looking for credentials for {credential_name}"
)
if credential_name in self.credentials:
verbose_router_logger.debug(f"Found credentials for {credential_name}")
return self.credentials[credential_name]
else:
verbose_router_logger.debug(
f"No credentials found for {credential_name}, looking for env variable"
)
_env_variable_name = (
self._get_default_env_variable_name_passthrough_endpoint(
custom_llm_provider=custom_llm_provider,
)
)
return get_secret_str(_env_variable_name)
def _get_vertex_env_vars(self) -> VertexPassThroughCredentials:
"""
Helper to get vertex pass through config from environment variables
The following environment variables are used:
- DEFAULT_VERTEXAI_PROJECT (project id)
- DEFAULT_VERTEXAI_LOCATION (location)
- DEFAULT_GOOGLE_APPLICATION_CREDENTIALS (path to credentials file)
"""
return VertexPassThroughCredentials(
vertex_project=get_secret_str("DEFAULT_VERTEXAI_PROJECT"),
vertex_location=get_secret_str("DEFAULT_VERTEXAI_LOCATION"),
vertex_credentials=get_secret_str("DEFAULT_GOOGLE_APPLICATION_CREDENTIALS"),
)
def set_default_vertex_config(self, config: Optional[dict] = None):
"""Sets vertex configuration from provided config and/or environment variables
Args:
config (Optional[dict]): Configuration dictionary
Example: {
"vertex_project": "my-project-123",
"vertex_location": "us-central1",
"vertex_credentials": "os.environ/GOOGLE_CREDS"
}
"""
# Initialize config dictionary if None
if config is None:
self.default_vertex_config = self._get_vertex_env_vars()
return
if isinstance(config, dict):
for key, value in config.items():
if isinstance(value, str) and value.startswith("os.environ/"):
config[key] = get_secret_str(value)
self.default_vertex_config = VertexPassThroughCredentials(**config)
def add_vertex_credentials(
self,
project_id: str,
location: str,
vertex_credentials: VERTEX_CREDENTIALS_TYPES,
):
"""
Add the vertex credentials for the given project-id, location
"""
deployment_key = self._get_deployment_key(
project_id=project_id,
location=location,
)
if deployment_key is None:
verbose_router_logger.debug(
"No deployment key found for project-id, location"
)
return
vertex_pass_through_credentials = VertexPassThroughCredentials(
vertex_project=project_id,
vertex_location=location,
vertex_credentials=vertex_credentials,
)
self.deployment_key_to_vertex_credentials[
deployment_key
] = vertex_pass_through_credentials
def _get_deployment_key(
self, project_id: Optional[str], location: Optional[str]
) -> Optional[str]:
"""
Get the deployment key for the given project-id, location
"""
if project_id is None or location is None:
return None
return f"{project_id}-{location}"
def get_vertex_credentials(
self, project_id: Optional[str], location: Optional[str]
) -> Optional[VertexPassThroughCredentials]:
"""
Get the vertex credentials for the given project-id, location
"""
deployment_key = self._get_deployment_key(
project_id=project_id,
location=location,
)
if deployment_key is None:
return self.default_vertex_config
if deployment_key in self.deployment_key_to_vertex_credentials:
return self.deployment_key_to_vertex_credentials[deployment_key]
else:
return self.default_vertex_config
def _get_credential_name_for_provider(
self,
custom_llm_provider: str,
region_name: Optional[str],
) -> str:
if region_name is None:
return f"{custom_llm_provider.upper()}_API_KEY"
return f"{custom_llm_provider.upper()}_{region_name.upper()}_API_KEY"
def _get_region_name_from_api_base(
self,
custom_llm_provider: str,
api_base: Optional[str],
) -> Optional[str]:
"""
Get the region name from the API base.
Each provider might have a different way of specifying the region in the API base - this is where you can use conditional logic to handle that.
"""
if custom_llm_provider == "assemblyai":
if api_base and "eu" in api_base:
return "eu"
return None
@staticmethod
def _get_default_env_variable_name_passthrough_endpoint(
custom_llm_provider: str,
) -> str:
return f"{custom_llm_provider.upper()}_API_KEY"

View File

@@ -0,0 +1,159 @@
import asyncio
import threading
from datetime import datetime
from typing import List, Optional
import httpx
from litellm._logging import verbose_proxy_logger
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
from litellm.proxy._types import PassThroughEndpointLoggingResultValues
from litellm.types.passthrough_endpoints.pass_through_endpoints import EndpointType
from litellm.types.utils import StandardPassThroughResponseObject
from .llm_provider_handlers.anthropic_passthrough_logging_handler import (
AnthropicPassthroughLoggingHandler,
)
from .llm_provider_handlers.vertex_passthrough_logging_handler import (
VertexPassthroughLoggingHandler,
)
from .success_handler import PassThroughEndpointLogging
class PassThroughStreamingHandler:
@staticmethod
async def chunk_processor(
response: httpx.Response,
request_body: Optional[dict],
litellm_logging_obj: LiteLLMLoggingObj,
endpoint_type: EndpointType,
start_time: datetime,
passthrough_success_handler_obj: PassThroughEndpointLogging,
url_route: str,
):
"""
- Yields chunks from the response
- Collect non-empty chunks for post-processing (logging)
"""
try:
raw_bytes: List[bytes] = []
async for chunk in response.aiter_bytes():
raw_bytes.append(chunk)
yield chunk
# After all chunks are processed, handle post-processing
end_time = datetime.now()
asyncio.create_task(
PassThroughStreamingHandler._route_streaming_logging_to_handler(
litellm_logging_obj=litellm_logging_obj,
passthrough_success_handler_obj=passthrough_success_handler_obj,
url_route=url_route,
request_body=request_body or {},
endpoint_type=endpoint_type,
start_time=start_time,
raw_bytes=raw_bytes,
end_time=end_time,
)
)
except Exception as e:
verbose_proxy_logger.error(f"Error in chunk_processor: {str(e)}")
raise
@staticmethod
async def _route_streaming_logging_to_handler(
litellm_logging_obj: LiteLLMLoggingObj,
passthrough_success_handler_obj: PassThroughEndpointLogging,
url_route: str,
request_body: dict,
endpoint_type: EndpointType,
start_time: datetime,
raw_bytes: List[bytes],
end_time: datetime,
):
"""
Route the logging for the collected chunks to the appropriate handler
Supported endpoint types:
- Anthropic
- Vertex AI
"""
all_chunks = PassThroughStreamingHandler._convert_raw_bytes_to_str_lines(
raw_bytes
)
standard_logging_response_object: Optional[
PassThroughEndpointLoggingResultValues
] = None
kwargs: dict = {}
if endpoint_type == EndpointType.ANTHROPIC:
anthropic_passthrough_logging_handler_result = AnthropicPassthroughLoggingHandler._handle_logging_anthropic_collected_chunks(
litellm_logging_obj=litellm_logging_obj,
passthrough_success_handler_obj=passthrough_success_handler_obj,
url_route=url_route,
request_body=request_body,
endpoint_type=endpoint_type,
start_time=start_time,
all_chunks=all_chunks,
end_time=end_time,
)
standard_logging_response_object = (
anthropic_passthrough_logging_handler_result["result"]
)
kwargs = anthropic_passthrough_logging_handler_result["kwargs"]
elif endpoint_type == EndpointType.VERTEX_AI:
vertex_passthrough_logging_handler_result = (
VertexPassthroughLoggingHandler._handle_logging_vertex_collected_chunks(
litellm_logging_obj=litellm_logging_obj,
passthrough_success_handler_obj=passthrough_success_handler_obj,
url_route=url_route,
request_body=request_body,
endpoint_type=endpoint_type,
start_time=start_time,
all_chunks=all_chunks,
end_time=end_time,
)
)
standard_logging_response_object = (
vertex_passthrough_logging_handler_result["result"]
)
kwargs = vertex_passthrough_logging_handler_result["kwargs"]
if standard_logging_response_object is None:
standard_logging_response_object = StandardPassThroughResponseObject(
response=f"cannot parse chunks to standard response object. Chunks={all_chunks}"
)
threading.Thread(
target=litellm_logging_obj.success_handler,
args=(
standard_logging_response_object,
start_time,
end_time,
False,
),
).start()
await litellm_logging_obj.async_success_handler(
result=standard_logging_response_object,
start_time=start_time,
end_time=end_time,
cache_hit=False,
**kwargs,
)
@staticmethod
def _convert_raw_bytes_to_str_lines(raw_bytes: List[bytes]) -> List[str]:
"""
Converts a list of raw bytes into a list of string lines, similar to aiter_lines()
Args:
raw_bytes: List of bytes chunks from aiter.bytes()
Returns:
List of string lines, with each line being a complete data: {} chunk
"""
# Combine all bytes and decode to string
combined_str = b"".join(raw_bytes).decode("utf-8")
# Split by newlines and filter out empty lines
lines = [line.strip() for line in combined_str.split("\n") if line.strip()]
return lines

View File

@@ -0,0 +1,221 @@
import json
from datetime import datetime
from typing import Optional, Union
from urllib.parse import urlparse
import httpx
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
from litellm.proxy._types import PassThroughEndpointLoggingResultValues
from litellm.types.passthrough_endpoints.pass_through_endpoints import (
PassthroughStandardLoggingPayload,
)
from litellm.types.utils import StandardPassThroughResponseObject
from litellm.utils import executor as thread_pool_executor
from .llm_provider_handlers.anthropic_passthrough_logging_handler import (
AnthropicPassthroughLoggingHandler,
)
from .llm_provider_handlers.assembly_passthrough_logging_handler import (
AssemblyAIPassthroughLoggingHandler,
)
from .llm_provider_handlers.cohere_passthrough_logging_handler import (
CoherePassthroughLoggingHandler,
)
from .llm_provider_handlers.vertex_passthrough_logging_handler import (
VertexPassthroughLoggingHandler,
)
cohere_passthrough_logging_handler = CoherePassthroughLoggingHandler()
class PassThroughEndpointLogging:
def __init__(self):
self.TRACKED_VERTEX_ROUTES = [
"generateContent",
"streamGenerateContent",
"predict",
]
# Anthropic
self.TRACKED_ANTHROPIC_ROUTES = ["/messages"]
# Cohere
self.TRACKED_COHERE_ROUTES = ["/v2/chat"]
self.assemblyai_passthrough_logging_handler = (
AssemblyAIPassthroughLoggingHandler()
)
async def _handle_logging(
self,
logging_obj: LiteLLMLoggingObj,
standard_logging_response_object: Union[
StandardPassThroughResponseObject,
PassThroughEndpointLoggingResultValues,
dict,
],
result: str,
start_time: datetime,
end_time: datetime,
cache_hit: bool,
**kwargs,
):
"""Helper function to handle both sync and async logging operations"""
# Submit to thread pool for sync logging
thread_pool_executor.submit(
logging_obj.success_handler,
standard_logging_response_object,
start_time,
end_time,
cache_hit,
**kwargs,
)
# Handle async logging
await logging_obj.async_success_handler(
result=(
json.dumps(result)
if isinstance(result, dict)
else standard_logging_response_object
),
start_time=start_time,
end_time=end_time,
cache_hit=False,
**kwargs,
)
async def pass_through_async_success_handler(
self,
httpx_response: httpx.Response,
response_body: Optional[dict],
logging_obj: LiteLLMLoggingObj,
url_route: str,
result: str,
start_time: datetime,
end_time: datetime,
cache_hit: bool,
request_body: dict,
passthrough_logging_payload: PassthroughStandardLoggingPayload,
**kwargs,
):
standard_logging_response_object: Optional[
PassThroughEndpointLoggingResultValues
] = None
logging_obj.model_call_details["passthrough_logging_payload"] = (
passthrough_logging_payload
)
if self.is_vertex_route(url_route):
vertex_passthrough_logging_handler_result = (
VertexPassthroughLoggingHandler.vertex_passthrough_handler(
httpx_response=httpx_response,
logging_obj=logging_obj,
url_route=url_route,
result=result,
start_time=start_time,
end_time=end_time,
cache_hit=cache_hit,
**kwargs,
)
)
standard_logging_response_object = (
vertex_passthrough_logging_handler_result["result"]
)
kwargs = vertex_passthrough_logging_handler_result["kwargs"]
elif self.is_anthropic_route(url_route):
anthropic_passthrough_logging_handler_result = (
AnthropicPassthroughLoggingHandler.anthropic_passthrough_handler(
httpx_response=httpx_response,
response_body=response_body or {},
logging_obj=logging_obj,
url_route=url_route,
result=result,
start_time=start_time,
end_time=end_time,
cache_hit=cache_hit,
**kwargs,
)
)
standard_logging_response_object = (
anthropic_passthrough_logging_handler_result["result"]
)
kwargs = anthropic_passthrough_logging_handler_result["kwargs"]
elif self.is_cohere_route(url_route):
cohere_passthrough_logging_handler_result = (
cohere_passthrough_logging_handler.passthrough_chat_handler(
httpx_response=httpx_response,
response_body=response_body or {},
logging_obj=logging_obj,
url_route=url_route,
result=result,
start_time=start_time,
end_time=end_time,
cache_hit=cache_hit,
request_body=request_body,
**kwargs,
)
)
standard_logging_response_object = (
cohere_passthrough_logging_handler_result["result"]
)
kwargs = cohere_passthrough_logging_handler_result["kwargs"]
elif self.is_assemblyai_route(url_route):
if (
AssemblyAIPassthroughLoggingHandler._should_log_request(
httpx_response.request.method
)
is not True
):
return
self.assemblyai_passthrough_logging_handler.assemblyai_passthrough_logging_handler(
httpx_response=httpx_response,
response_body=response_body or {},
logging_obj=logging_obj,
url_route=url_route,
result=result,
start_time=start_time,
end_time=end_time,
cache_hit=cache_hit,
**kwargs,
)
return
if standard_logging_response_object is None:
standard_logging_response_object = StandardPassThroughResponseObject(
response=httpx_response.text
)
await self._handle_logging(
logging_obj=logging_obj,
standard_logging_response_object=standard_logging_response_object,
result=result,
start_time=start_time,
end_time=end_time,
cache_hit=cache_hit,
**kwargs,
)
def is_vertex_route(self, url_route: str):
for route in self.TRACKED_VERTEX_ROUTES:
if route in url_route:
return True
return False
def is_anthropic_route(self, url_route: str):
for route in self.TRACKED_ANTHROPIC_ROUTES:
if route in url_route:
return True
return False
def is_cohere_route(self, url_route: str):
for route in self.TRACKED_COHERE_ROUTES:
if route in url_route:
return True
def is_assemblyai_route(self, url_route: str):
parsed_url = urlparse(url_route)
if parsed_url.hostname == "api.assemblyai.com":
return True
elif "/transcript" in parsed_url.path:
return True
return False