structure saas with tools
This commit is contained in:
@@ -0,0 +1,281 @@
|
||||
# used for monitoring litellm services health on `/metrics` endpoint on LiteLLM Proxy
|
||||
#### What this does ####
|
||||
# On success + failure, log events to Prometheus for litellm / adjacent services (litellm, redis, postgres, llm api providers)
|
||||
|
||||
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
from litellm._logging import print_verbose, verbose_logger
|
||||
from litellm.types.integrations.prometheus import LATENCY_BUCKETS
|
||||
from litellm.types.services import (
|
||||
DEFAULT_SERVICE_CONFIGS,
|
||||
ServiceLoggerPayload,
|
||||
ServiceMetrics,
|
||||
ServiceTypes,
|
||||
)
|
||||
|
||||
FAILED_REQUESTS_LABELS = ["error_class", "function_name"]
|
||||
|
||||
|
||||
class PrometheusServicesLogger:
|
||||
# Class variables or attributes
|
||||
litellm_service_latency = None # Class-level attribute to store the Histogram
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
mock_testing: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
try:
|
||||
try:
|
||||
from prometheus_client import REGISTRY, Counter, Gauge, Histogram
|
||||
from prometheus_client.gc_collector import Collector
|
||||
except ImportError:
|
||||
raise Exception(
|
||||
"Missing prometheus_client. Run `pip install prometheus-client`"
|
||||
)
|
||||
|
||||
self.Histogram = Histogram
|
||||
self.Counter = Counter
|
||||
self.Gauge = Gauge
|
||||
self.REGISTRY = REGISTRY
|
||||
|
||||
verbose_logger.debug("in init prometheus services metrics")
|
||||
|
||||
self.payload_to_prometheus_map: Dict[
|
||||
str, List[Union[Histogram, Counter, Gauge, Collector]]
|
||||
] = {}
|
||||
|
||||
for service in ServiceTypes:
|
||||
service_metrics: List[Union[Histogram, Counter, Gauge, Collector]] = []
|
||||
|
||||
metrics_to_initialize = self._get_service_metrics_initialize(service)
|
||||
|
||||
# Initialize only the configured metrics for each service
|
||||
if ServiceMetrics.HISTOGRAM in metrics_to_initialize:
|
||||
histogram = self.create_histogram(
|
||||
service.value, type_of_request="latency"
|
||||
)
|
||||
if histogram:
|
||||
service_metrics.append(histogram)
|
||||
|
||||
if ServiceMetrics.COUNTER in metrics_to_initialize:
|
||||
counter_failed_request = self.create_counter(
|
||||
service.value,
|
||||
type_of_request="failed_requests",
|
||||
additional_labels=FAILED_REQUESTS_LABELS,
|
||||
)
|
||||
if counter_failed_request:
|
||||
service_metrics.append(counter_failed_request)
|
||||
counter_total_requests = self.create_counter(
|
||||
service.value, type_of_request="total_requests"
|
||||
)
|
||||
if counter_total_requests:
|
||||
service_metrics.append(counter_total_requests)
|
||||
|
||||
if ServiceMetrics.GAUGE in metrics_to_initialize:
|
||||
gauge = self.create_gauge(service.value, type_of_request="size")
|
||||
if gauge:
|
||||
service_metrics.append(gauge)
|
||||
|
||||
if service_metrics:
|
||||
self.payload_to_prometheus_map[service.value] = service_metrics
|
||||
|
||||
self.prometheus_to_amount_map: dict = {}
|
||||
### MOCK TESTING ###
|
||||
self.mock_testing = mock_testing
|
||||
self.mock_testing_success_calls = 0
|
||||
self.mock_testing_failure_calls = 0
|
||||
|
||||
except Exception as e:
|
||||
print_verbose(f"Got exception on init prometheus client {str(e)}")
|
||||
raise e
|
||||
|
||||
def _get_service_metrics_initialize(
|
||||
self, service: ServiceTypes
|
||||
) -> List[ServiceMetrics]:
|
||||
DEFAULT_METRICS = [ServiceMetrics.COUNTER, ServiceMetrics.HISTOGRAM]
|
||||
if service not in DEFAULT_SERVICE_CONFIGS:
|
||||
return DEFAULT_METRICS
|
||||
|
||||
metrics = DEFAULT_SERVICE_CONFIGS.get(service, {}).get("metrics", [])
|
||||
if not metrics:
|
||||
verbose_logger.debug(f"No metrics found for service {service}")
|
||||
return DEFAULT_METRICS
|
||||
return metrics
|
||||
|
||||
def is_metric_registered(self, metric_name) -> bool:
|
||||
for metric in self.REGISTRY.collect():
|
||||
if metric_name == metric.name:
|
||||
return True
|
||||
return False
|
||||
|
||||
def _get_metric(self, metric_name):
|
||||
"""
|
||||
Helper function to get a metric from the registry by name.
|
||||
"""
|
||||
return self.REGISTRY._names_to_collectors.get(metric_name)
|
||||
|
||||
def create_histogram(self, service: str, type_of_request: str):
|
||||
metric_name = "litellm_{}_{}".format(service, type_of_request)
|
||||
is_registered = self.is_metric_registered(metric_name)
|
||||
if is_registered:
|
||||
return self._get_metric(metric_name)
|
||||
return self.Histogram(
|
||||
metric_name,
|
||||
"Latency for {} service".format(service),
|
||||
labelnames=[service],
|
||||
buckets=LATENCY_BUCKETS,
|
||||
)
|
||||
|
||||
def create_gauge(self, service: str, type_of_request: str):
|
||||
metric_name = "litellm_{}_{}".format(service, type_of_request)
|
||||
is_registered = self.is_metric_registered(metric_name)
|
||||
if is_registered:
|
||||
return self._get_metric(metric_name)
|
||||
return self.Gauge(
|
||||
metric_name, "Gauge for {} service".format(service), labelnames=[service]
|
||||
)
|
||||
|
||||
def create_counter(
|
||||
self,
|
||||
service: str,
|
||||
type_of_request: str,
|
||||
additional_labels: Optional[List[str]] = None,
|
||||
):
|
||||
metric_name = "litellm_{}_{}".format(service, type_of_request)
|
||||
is_registered = self.is_metric_registered(metric_name)
|
||||
if is_registered:
|
||||
return self._get_metric(metric_name)
|
||||
return self.Counter(
|
||||
metric_name,
|
||||
"Total {} for {} service".format(type_of_request, service),
|
||||
labelnames=[service] + (additional_labels or []),
|
||||
)
|
||||
|
||||
def observe_histogram(
|
||||
self,
|
||||
histogram,
|
||||
labels: str,
|
||||
amount: float,
|
||||
):
|
||||
assert isinstance(histogram, self.Histogram)
|
||||
|
||||
histogram.labels(labels).observe(amount)
|
||||
|
||||
def update_gauge(
|
||||
self,
|
||||
gauge,
|
||||
labels: str,
|
||||
amount: float,
|
||||
):
|
||||
assert isinstance(gauge, self.Gauge)
|
||||
gauge.labels(labels).set(amount)
|
||||
|
||||
def increment_counter(
|
||||
self,
|
||||
counter,
|
||||
labels: str,
|
||||
amount: float,
|
||||
additional_labels: Optional[List[str]] = [],
|
||||
):
|
||||
assert isinstance(counter, self.Counter)
|
||||
|
||||
if additional_labels:
|
||||
counter.labels(labels, *additional_labels).inc(amount)
|
||||
else:
|
||||
counter.labels(labels).inc(amount)
|
||||
|
||||
def service_success_hook(self, payload: ServiceLoggerPayload):
|
||||
if self.mock_testing:
|
||||
self.mock_testing_success_calls += 1
|
||||
|
||||
if payload.service.value in self.payload_to_prometheus_map:
|
||||
prom_objects = self.payload_to_prometheus_map[payload.service.value]
|
||||
for obj in prom_objects:
|
||||
if isinstance(obj, self.Histogram):
|
||||
self.observe_histogram(
|
||||
histogram=obj,
|
||||
labels=payload.service.value,
|
||||
amount=payload.duration,
|
||||
)
|
||||
elif isinstance(obj, self.Counter) and "total_requests" in obj._name:
|
||||
self.increment_counter(
|
||||
counter=obj,
|
||||
labels=payload.service.value,
|
||||
amount=1, # LOG TOTAL REQUESTS TO PROMETHEUS
|
||||
)
|
||||
|
||||
def service_failure_hook(self, payload: ServiceLoggerPayload):
|
||||
if self.mock_testing:
|
||||
self.mock_testing_failure_calls += 1
|
||||
|
||||
if payload.service.value in self.payload_to_prometheus_map:
|
||||
prom_objects = self.payload_to_prometheus_map[payload.service.value]
|
||||
for obj in prom_objects:
|
||||
if isinstance(obj, self.Counter):
|
||||
self.increment_counter(
|
||||
counter=obj,
|
||||
labels=payload.service.value,
|
||||
amount=1, # LOG ERROR COUNT / TOTAL REQUESTS TO PROMETHEUS
|
||||
)
|
||||
|
||||
async def async_service_success_hook(self, payload: ServiceLoggerPayload):
|
||||
"""
|
||||
Log successful call to prometheus
|
||||
"""
|
||||
if self.mock_testing:
|
||||
self.mock_testing_success_calls += 1
|
||||
|
||||
if payload.service.value in self.payload_to_prometheus_map:
|
||||
prom_objects = self.payload_to_prometheus_map[payload.service.value]
|
||||
for obj in prom_objects:
|
||||
if isinstance(obj, self.Histogram):
|
||||
self.observe_histogram(
|
||||
histogram=obj,
|
||||
labels=payload.service.value,
|
||||
amount=payload.duration,
|
||||
)
|
||||
elif isinstance(obj, self.Counter) and "total_requests" in obj._name:
|
||||
self.increment_counter(
|
||||
counter=obj,
|
||||
labels=payload.service.value,
|
||||
amount=1, # LOG TOTAL REQUESTS TO PROMETHEUS
|
||||
)
|
||||
elif isinstance(obj, self.Gauge):
|
||||
if payload.event_metadata:
|
||||
self.update_gauge(
|
||||
gauge=obj,
|
||||
labels=payload.event_metadata.get("gauge_labels") or "",
|
||||
amount=payload.event_metadata.get("gauge_value") or 0,
|
||||
)
|
||||
|
||||
async def async_service_failure_hook(
|
||||
self,
|
||||
payload: ServiceLoggerPayload,
|
||||
error: Union[str, Exception],
|
||||
):
|
||||
if self.mock_testing:
|
||||
self.mock_testing_failure_calls += 1
|
||||
error_class = error.__class__.__name__
|
||||
function_name = payload.call_type
|
||||
|
||||
if payload.service.value in self.payload_to_prometheus_map:
|
||||
prom_objects = self.payload_to_prometheus_map[payload.service.value]
|
||||
for obj in prom_objects:
|
||||
# increment both failed and total requests
|
||||
if isinstance(obj, self.Counter):
|
||||
if "failed_requests" in obj._name:
|
||||
self.increment_counter(
|
||||
counter=obj,
|
||||
labels=payload.service.value,
|
||||
# log additional_labels=["error_class", "function_name"], used for debugging what's going wrong with the DB
|
||||
additional_labels=[error_class, function_name],
|
||||
amount=1, # LOG ERROR COUNT TO PROMETHEUS
|
||||
)
|
||||
else:
|
||||
self.increment_counter(
|
||||
counter=obj,
|
||||
labels=payload.service.value,
|
||||
amount=1, # LOG TOTAL REQUESTS TO PROMETHEUS
|
||||
)
|
||||
Reference in New Issue
Block a user