structure saas with tools
This commit is contained in:
Binary file not shown.
@@ -0,0 +1,214 @@
|
||||
"""
|
||||
BETA
|
||||
|
||||
This is the PubSub logger for GCS PubSub, this sends LiteLLM SpendLogs Payloads to GCS PubSub.
|
||||
|
||||
Users can use this instead of sending their SpendLogs to their Postgres database.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import traceback
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
|
||||
|
||||
from litellm.types.utils import StandardLoggingPayload
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.proxy._types import SpendLogsPayload
|
||||
else:
|
||||
SpendLogsPayload = Any
|
||||
|
||||
import litellm
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.integrations.custom_batch_logger import CustomBatchLogger
|
||||
from litellm.llms.custom_httpx.http_handler import (
|
||||
get_async_httpx_client,
|
||||
httpxSpecialProvider,
|
||||
)
|
||||
|
||||
|
||||
class GcsPubSubLogger(CustomBatchLogger):
|
||||
def __init__(
|
||||
self,
|
||||
project_id: Optional[str] = None,
|
||||
topic_id: Optional[str] = None,
|
||||
credentials_path: Optional[str] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Initialize Google Cloud Pub/Sub publisher
|
||||
|
||||
Args:
|
||||
project_id (str): Google Cloud project ID
|
||||
topic_id (str): Pub/Sub topic ID
|
||||
credentials_path (str, optional): Path to Google Cloud credentials JSON file
|
||||
"""
|
||||
from litellm.proxy.utils import _premium_user_check
|
||||
|
||||
_premium_user_check()
|
||||
|
||||
self.async_httpx_client = get_async_httpx_client(
|
||||
llm_provider=httpxSpecialProvider.LoggingCallback
|
||||
)
|
||||
|
||||
self.project_id = project_id or os.getenv("GCS_PUBSUB_PROJECT_ID")
|
||||
self.topic_id = topic_id or os.getenv("GCS_PUBSUB_TOPIC_ID")
|
||||
self.path_service_account_json = credentials_path or os.getenv(
|
||||
"GCS_PATH_SERVICE_ACCOUNT"
|
||||
)
|
||||
|
||||
if not self.project_id or not self.topic_id:
|
||||
raise ValueError("Both project_id and topic_id must be provided")
|
||||
|
||||
self.flush_lock = asyncio.Lock()
|
||||
super().__init__(**kwargs, flush_lock=self.flush_lock)
|
||||
asyncio.create_task(self.periodic_flush())
|
||||
self.log_queue: List[Union[SpendLogsPayload, StandardLoggingPayload]] = []
|
||||
|
||||
async def construct_request_headers(self) -> Dict[str, str]:
|
||||
"""Construct authorization headers using Vertex AI auth"""
|
||||
from litellm import vertex_chat_completion
|
||||
|
||||
(
|
||||
_auth_header,
|
||||
vertex_project,
|
||||
) = await vertex_chat_completion._ensure_access_token_async(
|
||||
credentials=self.path_service_account_json,
|
||||
project_id=self.project_id,
|
||||
custom_llm_provider="vertex_ai",
|
||||
)
|
||||
|
||||
auth_header, _ = vertex_chat_completion._get_token_and_url(
|
||||
model="pub-sub",
|
||||
auth_header=_auth_header,
|
||||
vertex_credentials=self.path_service_account_json,
|
||||
vertex_project=vertex_project,
|
||||
vertex_location=None,
|
||||
gemini_api_key=None,
|
||||
stream=None,
|
||||
custom_llm_provider="vertex_ai",
|
||||
api_base=None,
|
||||
)
|
||||
|
||||
headers = {
|
||||
"Authorization": f"Bearer {auth_header}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
return headers
|
||||
|
||||
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
|
||||
"""
|
||||
Async Log success events to GCS PubSub Topic
|
||||
|
||||
- Creates a SpendLogsPayload
|
||||
- Adds to batch queue
|
||||
- Flushes based on CustomBatchLogger settings
|
||||
|
||||
Raises:
|
||||
Raises a NON Blocking verbose_logger.exception if an error occurs
|
||||
"""
|
||||
from litellm.proxy.spend_tracking.spend_tracking_utils import (
|
||||
get_logging_payload,
|
||||
)
|
||||
from litellm.proxy.utils import _premium_user_check
|
||||
|
||||
_premium_user_check()
|
||||
|
||||
try:
|
||||
verbose_logger.debug(
|
||||
"PubSub: Logging - Enters logging function for model %s", kwargs
|
||||
)
|
||||
standard_logging_payload = kwargs.get("standard_logging_object", None)
|
||||
|
||||
# Backwards compatibility with old logging payload
|
||||
if litellm.gcs_pub_sub_use_v1 is True:
|
||||
spend_logs_payload = get_logging_payload(
|
||||
kwargs=kwargs,
|
||||
response_obj=response_obj,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
)
|
||||
self.log_queue.append(spend_logs_payload)
|
||||
else:
|
||||
# New logging payload, StandardLoggingPayload
|
||||
self.log_queue.append(standard_logging_payload)
|
||||
|
||||
if len(self.log_queue) >= self.batch_size:
|
||||
await self.async_send_batch()
|
||||
|
||||
except Exception as e:
|
||||
verbose_logger.exception(
|
||||
f"PubSub Layer Error - {str(e)}\n{traceback.format_exc()}"
|
||||
)
|
||||
pass
|
||||
|
||||
async def async_send_batch(self):
|
||||
"""
|
||||
Sends the batch of messages to Pub/Sub
|
||||
"""
|
||||
try:
|
||||
if not self.log_queue:
|
||||
return
|
||||
|
||||
verbose_logger.debug(
|
||||
f"PubSub - about to flush {len(self.log_queue)} events"
|
||||
)
|
||||
|
||||
for message in self.log_queue:
|
||||
await self.publish_message(message)
|
||||
|
||||
except Exception as e:
|
||||
verbose_logger.exception(
|
||||
f"PubSub Error sending batch - {str(e)}\n{traceback.format_exc()}"
|
||||
)
|
||||
finally:
|
||||
self.log_queue.clear()
|
||||
|
||||
async def publish_message(
|
||||
self, message: Union[SpendLogsPayload, StandardLoggingPayload]
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Publish message to Google Cloud Pub/Sub using REST API
|
||||
|
||||
Args:
|
||||
message: Message to publish (dict or string)
|
||||
|
||||
Returns:
|
||||
dict: Published message response
|
||||
"""
|
||||
try:
|
||||
headers = await self.construct_request_headers()
|
||||
|
||||
# Prepare message data
|
||||
if isinstance(message, str):
|
||||
message_data = message
|
||||
else:
|
||||
message_data = json.dumps(message, default=str)
|
||||
|
||||
# Base64 encode the message
|
||||
import base64
|
||||
|
||||
encoded_message = base64.b64encode(message_data.encode("utf-8")).decode(
|
||||
"utf-8"
|
||||
)
|
||||
|
||||
# Construct request body
|
||||
request_body = {"messages": [{"data": encoded_message}]}
|
||||
|
||||
url = f"https://pubsub.googleapis.com/v1/projects/{self.project_id}/topics/{self.topic_id}:publish"
|
||||
|
||||
response = await self.async_httpx_client.post(
|
||||
url=url, headers=headers, json=request_body
|
||||
)
|
||||
|
||||
if response.status_code not in [200, 202]:
|
||||
verbose_logger.error("Pub/Sub publish error: %s", str(response.text))
|
||||
raise Exception(f"Failed to publish message: {response.text}")
|
||||
|
||||
verbose_logger.debug("Pub/Sub response: %s", response.text)
|
||||
return response.json()
|
||||
|
||||
except Exception as e:
|
||||
verbose_logger.error("Pub/Sub publish error: %s", str(e))
|
||||
return None
|
||||
Reference in New Issue
Block a user